Source code for herbie.accessors

## Brian Blaylock
## April 23, 2021

"""
==================================
Herbie Extension: xarray accessors
==================================

Extend the xarray capabilities with a custom accessor.
http://xarray.pydata.org/en/stable/internals.html#extending-xarray

To use the herbie xarray accessor, do this...

.. code-block:: python

    H = Herbie('2021-01-01', model='hrrr')
    ds = H.xarray('TMP:2 m')
    ds.herbie.crs
    ds.herbie.plot()

"""
import functools
import cartopy.crs as ccrs
import metpy  # * Needed for metpy accessor
import numpy as np
import pandas as pd
import xarray as xr
from shapely.geometry import Polygon


_level_units = dict(
    adiabaticCondensation="adiabatic condensation",
    atmosphere="atmosphere",
    atmosphereSingleLayer="atmosphere single layer",
    boundaryLayerCloudLayer="boundary layer cloud layer",
    cloudBase="cloud base",
    cloudCeiling="cloud ceiling",
    cloudTop="cloud top",
    depthBelowLand="m",
    equilibrium="equilibrium",
    heightAboveGround="m",
    heightAboveGroundLayer="m",
    highCloudLayer="high cloud layer",
    highestTroposphericFreezing="highest tropospheric freezing",
    isobaricInhPa="hPa",
    isobaricLayer="hPa",
    isothermZero="0 C",
    isothermal="K",
    level="m",
    lowCloudLayer="low cloud layer",
    meanSea="MSLP",
    middleCloudLayer="middle cloud layer",
    nominalTop="nominal top",
    pressureFromGroundLayer="Pa",
    sigma="sigma",
    sigmaLayer="sigmaLayer",
    surface="surface",
)


[docs]@xr.register_dataset_accessor("herbie") class HerbieAccessor: """Accessor for xarray Datasets opened with Herbie"""
[docs] def __init__(self, xarray_obj): self._obj = xarray_obj self._center = None
@property def center(self): """Return the geographic center point of this dataset.""" if self._center is None: # we can use a cache on our accessor objects, because accessors # themselves are cached on instances that access them. lon = self._obj.latitude lat = self._obj.longitude self._center = (float(lon.mean()), float(lat.mean())) return self._center @functools.cached_property def crs(self): """ Cartopy coordinate reference system (crs) from a cfgrib Dataset. Projection information is from the grib2 message for each variable. Parameters ---------- ds : xarray.Dataset An xarray.Dataset from a GRIB2 file opened by the cfgrib engine. """ ds = self._obj # Get variables that have dimensions # (this filters out the gribfile_projection variable) variables = [i for i in list(ds) if len(ds[i].dims) > 0] ds = ds.metpy.parse_cf(varname=variables) crs = ds.metpy_crs.item().to_cartopy() return crs @functools.cached_property def polygon(self): """ Get a polygon of the domain boundary. """ ds = self._obj LON = ds.longitude.data LAT = ds.latitude.data # Path of array outside border starting from the lower left corner # and going around the array counter-clockwise. outside = ( list(zip(LON[0, :], LAT[0, :])) + list(zip(LON[:, -1], LAT[:, -1])) + list(zip(LON[-1, ::-1], LAT[-1, ::-1])) + list(zip(LON[::-1, 0], LAT[::-1, 0])) ) outside = np.array(outside) ############################### # Polygon in Lat/Lon coordinates x = outside[:, 0] y = outside[:, 1] domain_polygon_latlon = Polygon(zip(x, y)) ################################### # Polygon in projection coordinates transform = self.crs.transform_points(ccrs.PlateCarree(), x, y) # Remove any points that run off the projection map (i.e., point's value is `inf`). transform = transform[~np.isinf(transform).any(axis=1)] x = transform[:, 0] y = transform[:, 1] domain_polygon = Polygon(zip(x, y)) return domain_polygon, domain_polygon_latlon
[docs] def nearest_points(self, points, names=None, verbose=True): """ Get the nearest latitude/longitude points from a xarray Dataset. Info ---- - Stack Overflow: https://stackoverflow.com/questions/58758480/xarray-select-nearest-lat-lon-with-multi-dimension-coordinates - MetPy Details: https://unidata.github.io/MetPy/latest/tutorials/xarray_tutorial.html?highlight=assign_y_x Parameters ---------- ds : xr.Dataset A Herbie-friendly xarray Dataset points : tuple (lon, lat) or list of tuples The longitude and latitude (lon, lat) coordinate pair (as a tuple) for the points you want to pluck from the gridded Dataset. A list of tuples may be given to return the values from multiple points. names : list A list of names for each point location (i.e., station name). None will not append any names. names should be the same length as points. Benchmark --------- This is **much** faster than my old "pluck_points" method. For matchign 1,948 points: - `nearest_points` completed in 7.5 seconds. - `pluck_points` completed in 2 minutes. """ ds = self._obj # Check if MetPy has already parsed the CF metadata grid projection. # Do that if it hasn't been done yet. if "metpy_crs" not in ds: ds = ds.metpy.parse_cf() # Apply the MetPy method `assign_y_x` to the dataset # https://unidata.github.io/MetPy/latest/api/generated/metpy.xarray.html?highlight=assign_y_x#metpy.xarray.MetPyDataArrayAccessor.assign_y_x ds = ds.metpy.assign_y_x() # Convert the requested [(lon,lat), (lon,lat)] points to map projection. # Accept a list of point tuples, or Shapely Points object. # We want to index the dataset at a single point. # We can do this by transforming a lat/lon point to the grid location crs = ds.metpy_crs.item().to_cartopy() # lat/lon input must be a numpy array, not a list or polygon if isinstance(points, tuple): # If a tuple is give, turn into a one-item list. points = np.array([points]) if not isinstance(points, np.ndarray): # Points must be a 2D numpy array points = np.array(points) lons = points[:, 0] lats = points[:, 1] transformed_data = crs.transform_points(ccrs.PlateCarree(), lons, lats) xs = transformed_data[:, 0] ys = transformed_data[:, 1] # Select the nearest points from the projection coordinates. # TODO: Is there a better way? # There doesn't seem to be a way to get just the points like this # ds = ds.sel(x=xs, y=ys, method='nearest') # because it gives a 2D array, and not a point-by-point index. # Instead, I have too loop the ds.sel method new_ds = xr.concat( [ds.sel(x=xi, y=yi, method="nearest") for xi, yi in zip(xs, ys)], dim="point", ) # Add list of names as a coordinate if names is not None: # Assign the point dimension as the names. assert len(points) == len( names ), "`points` and `names` must be same length." new_ds["point"] = names return new_ds
[docs] def plot(self, ax=None, common_features_kw={}, vars=None, **kwargs): """Plot data on a map. Parameters ---------- vars : list List of variables to plot. Default None will plot all variables in the DataSet. """ # From Carpenter_Workshop: # https://github.com/blaylockbk/Carpenter_Workshop import matplotlib.pyplot as plt try: from toolbox.cartopy_tools import common_features, pc from paint.radar import cm_reflectivity from paint.radar2 import cm_reflectivity from paint.terrain2 import cm_terrain from paint.standard2 import cm_dpt, cm_pcp, cm_rh, cm_tmp, cm_wind except: print("The plotting accessor requires my Carpenter Workshop. Try:") print( "`pip install git+https://github.com/blaylockbk/Carpenter_Workshop.git`" ) ds = self._obj if isinstance(vars, str): vars = [vars] if vars is None: vars = ds.data_vars for var in vars: if "longitude" not in ds[var].coords: # This is the case for the gribfile_projection variable continue print("cfgrib variable:", var) print("GRIB_cfName", ds[var].attrs.get("GRIB_cfName")) print("GRIB_cfVarName", ds[var].attrs.get("GRIB_cfVarName")) print("GRIB_name", ds[var].attrs.get("GRIB_name")) print("GRIB_units", ds[var].attrs.get("GRIB_units")) print("GRIB_typeOfLevel", ds[var].attrs.get("GRIB_typeOfLevel")) print() ds[var].attrs["units"] = ( ds[var] .attrs["units"] .replace("**-1", "$^{-1}$") .replace("**-2", "$^{-2}$") ) defaults = dict( scale="50m", dpi=150, figsize=(10, 5), crs=ds.herbie.crs, ax=ax, ) common_features_kw = {**defaults, **common_features_kw} ax = common_features(**common_features_kw).STATES().ax title = "" kwargs.setdefault("shading", "auto") cbar_kwargs = dict(pad=0.01) if ds[var].GRIB_cfVarName in ["d2m", "dpt"]: ds[var].attrs["GRIB_cfName"] = "dew_point_temperature" ## Wind wind_pair = {"u10": "v10", "u80": "v80", "u": "v"} if ds[var].GRIB_cfName == "air_temperature": kwargs = {**cm_tmp().cmap_kwargs, **kwargs} cbar_kwargs = {**cm_tmp().cbar_kwargs, **cbar_kwargs} if ds[var].GRIB_units == "K": ds[var] -= 273.15 ds[var].attrs["GRIB_units"] = "C" ds[var].attrs["units"] = "C" elif ds[var].GRIB_cfName == "dew_point_temperature": kwargs = {**cm_dpt().cmap_kwargs, **kwargs} cbar_kwargs = {**cm_dpt().cbar_kwargs, **cbar_kwargs} if ds[var].GRIB_units == "K": ds[var] -= 273.15 ds[var].attrs["GRIB_units"] = "C" ds[var].attrs["units"] = "C" elif ds[var].GRIB_name == "Total Precipitation": title = "-".join( [f"F{int(i):02d}" for i in ds[var].GRIB_stepRange.split("-")] ) ds[var] = ds[var].where(ds[var] != 0) kwargs = {**cm_pcp().cmap_kwargs, **kwargs} cbar_kwargs = {**cm_pcp().cbar_kwargs, **cbar_kwargs} elif ds[var].GRIB_name == "Maximum/Composite radar reflectivity": ds[var] = ds[var].where(ds[var] >= 0) cbar_kwargs = {**cm_reflectivity().cbar_kwargs, **cbar_kwargs} kwargs = {**cm_reflectivity().cmap_kwargs, **kwargs} elif ds[var].GRIB_cfName == "relative_humidity": cbar_kwargs = {**cm_rh().cbar_kwargs, **cbar_kwargs} kwargs = {**cm_rh().cmap_kwargs, **kwargs} elif ds[var].GRIB_name == "Orography": if "lsm" in ds: ds["orog"] = ds.orog.where(ds.lsm == 1, -100) cbar_kwargs = {**cm_terrain().cbar_kwargs, **cbar_kwargs} kwargs = {**cm_terrain().cmap_kwargs, **kwargs} elif "wind" in ds[var].GRIB_cfName or "wind" in ds[var].GRIB_name: cbar_kwargs = {**cm_wind().cbar_kwargs, **cbar_kwargs} kwargs = {**cm_wind().cmap_kwargs, **kwargs} if ds[var].GRIB_cfName == "eastward_wind": cbar_kwargs["label"] = "U " + cbar_kwargs["label"] elif ds[var].GRIB_cfName == "northward_wind": cbar_kwargs["label"] = "V " + cbar_kwargs["label"] else: cbar_kwargs = { **dict( label=f"{ds[var].GRIB_parameterName.strip().title()} ({ds[var].units})" ), **cbar_kwargs, } p = ax.pcolormesh( ds.longitude, ds.latitude, ds[var], transform=pc, **kwargs ) plt.colorbar(p, ax=ax, **cbar_kwargs) VALID = pd.to_datetime(ds.valid_time.data).strftime("%H:%M UTC %d %b %Y") RUN = pd.to_datetime(ds.time.data).strftime("%H:%M UTC %d %b %Y") FXX = f"F{pd.to_timedelta(ds.step.data).total_seconds()/3600:02.0f}" level_type = ds[var].GRIB_typeOfLevel if level_type in _level_units: level_units = _level_units[level_type] else: level_units = "unknown" if level_units.lower() in ["surface", "atmosphere"]: level = f"{title} {level_units}" else: level = f"{ds[var][level_type].data:g} {level_units}" ax.set_title( f"Run: {RUN} {FXX}", loc="left", fontfamily="monospace", fontsize="x-small", ) ax.set_title( f"{ds.model.upper()} {level}\n", loc="center", fontweight="semibold" ) ax.set_title( f"Valid: {VALID}", loc="right", fontfamily="monospace", fontsize="x-small", ) # Set extent so no whitespace shows around pcolormesh area # TODO: Any better way to do this? With metpy.assign_y_x # !!!!: The `metpy.assign_y_x` method could be used for pluck_point :) try: if "x" in ds.dims: ds = ds.metpy.parse_cf() ds = ds.metpy.assign_y_x() ax.set_extent( [ ds.x.min().item(), ds.x.max().item(), ds.y.min().item(), ds.y.max().item(), ], crs=ds.herbie.crs, ) except: pass return ax