Regridding using xESMF and an existing weights file#

A fairly common request is to use an existing ESMF weights file to regrid a Xarray Dataset (1, 2). Applying weights in general should be easy: read weights then apply them using dot or tensordot on the input dataset.

In the Xarray/Dask/Pangeo ecosystem, xESMF provides an interface to ESMF for convenient regridding, includiing parallelization with Dask. Here we demonstrate how to use an existing ESMF weights file with xESMF specifically for CAM-SE.

CAM-SE is the Community Atmosphere Model with the spectral element dynamical core (Dennis et al, 2011)

The spectral element dynamical core, CAM-SE, is an unstructured grid that supports uniform resolutions based on the equiangular gnomonic cubed-sphere grid as well as a mesh refinement capability with local increases in resolution through conformal mesh refinement. CAM-SE is the default resolution for CESM2 high resolution capabilities … (Lauritzen et al. 2018).

The main challenge is the input dataset has one spatial dimension (ncol), while xESMF is hardcoded to expect two spatial dimensions (lat, lon). We solve that by adding a dummy dimension. At the end, we’ll make this plot of the vertically integrated water vapour transport (IVT)

from IPython.display import Image

Image("../images/cam-se-ivt.png", width=600)
../../../_images/fd44e9f7bac7bfb9e60950a50b7f51ec81040235662db40028a587eea1fe403e.png
%load_ext watermark

import hvplot.xarray
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
import xesmf

%watermark -iv
sys       : 3.10.8 | packaged by conda-forge | (main, Nov 22 2022, 08:26:04) [GCC 10.4.0]
numpy     : 1.23.5
matplotlib: 3.6.2
xesmf     : 0.6.3
json      : 2.0.9
xarray    : 2022.12.0
hvplot    : 0.8.2

Read data#

First some file paths

data_file = "/glade/campaign/collections/cmip/CMIP6/iHESP/BHIST/HR/b.e13.BHISTC5.ne120_t12.cesm-ihesp-hires1.0.30-1920-2100.002/atm/proc/tseries/hour_6I/b.e13.BHISTC5.ne120_t12.cesm-ihesp-hires1.0.30-1920-2100.002.cam.h2.IVT.192001-192912.nc"
weight_file = "/glade/work/shields/SE_grid/map_ne120_to_0.23x0.31_bilinear.nc"

We read in the input data

data_in = xr.open_dataset(data_file, chunks={"time": 50})
data_in
<xarray.Dataset>
Dimensions:        (lev: 30, ilev: 31, cosp_prs: 7, nbnd: 2, cosp_tau: 7,
                    cosp_scol: 50, cosp_ht: 40, cosp_sr: 15, cosp_sza: 5,
                    ncol: 777602, time: 14600)
Coordinates:
  * lev            (lev) float64 3.643 7.595 14.36 24.61 ... 957.5 976.3 992.6
  * ilev           (ilev) float64 2.255 5.032 10.16 18.56 ... 967.5 985.1 1e+03
  * cosp_prs       (cosp_prs) float64 900.0 740.0 620.0 500.0 375.0 245.0 90.0
  * cosp_tau       (cosp_tau) float64 0.15 0.8 2.45 6.5 16.2 41.5 219.5
  * cosp_scol      (cosp_scol) int32 1 2 3 4 5 6 7 8 ... 43 44 45 46 47 48 49 50
  * cosp_ht        (cosp_ht) float64 240.0 720.0 1.2e+03 ... 1.848e+04 1.896e+04
  * cosp_sr        (cosp_sr) float64 0.605 2.1 4.0 6.0 ... 70.0 539.5 1.004e+03
  * cosp_sza       (cosp_sza) float64 0.0 15.0 30.0 45.0 60.0
  * time           (time) object 1920-01-01 00:00:00 ... 1929-12-31 18:00:00
Dimensions without coordinates: nbnd, ncol
Data variables: (12/35)
    hyam           (lev) float64 dask.array<chunksize=(30,), meta=np.ndarray>
    hybm           (lev) float64 dask.array<chunksize=(30,), meta=np.ndarray>
    P0             float64 ...
    hyai           (ilev) float64 dask.array<chunksize=(31,), meta=np.ndarray>
    hybi           (ilev) float64 dask.array<chunksize=(31,), meta=np.ndarray>
    cosp_prs_bnds  (cosp_prs, nbnd) float64 dask.array<chunksize=(7, 2), meta=np.ndarray>
    ...             ...
    n2ovmr         (time) float64 dask.array<chunksize=(50,), meta=np.ndarray>
    f11vmr         (time) float64 dask.array<chunksize=(50,), meta=np.ndarray>
    f12vmr         (time) float64 dask.array<chunksize=(50,), meta=np.ndarray>
    sol_tsi        (time) float64 dask.array<chunksize=(50,), meta=np.ndarray>
    nsteph         (time) int32 dask.array<chunksize=(50,), meta=np.ndarray>
    IVT            (time, ncol) float32 dask.array<chunksize=(50, 777602), meta=np.ndarray>
Attributes:
    np:               4
    ne:               120
    Conventions:      CF-1.0
    source:           CAM
    case:             b.e13.BHISTC5.ne120_t12.cesm-ihesp-hires1.0.30-1920-210...
    title:            UNSET
    logname:          nanr
    host:             login1.frontera.
    Version:          $Name$
    revision_Id:      $Id$
    initial_file:     B.E.13.BHISTC5.ne120_t12.sehires38.003.sunway.cam.i.192...
    topography_file:  /work/02503/edwardsj/CESM/inputdata//atm/cam/topo/USGS-...

Here’s the primary data variable IVT for integrated vapour transport with one spatial dimension: ncol

data_in.IVT
<xarray.DataArray 'IVT' (time: 14600, ncol: 777602)>
dask.array<open_dataset-124517cf72a30883d5a3c70220985aeeIVT, shape=(14600, 777602), dtype=float32, chunksize=(50, 777602), chunktype=numpy.ndarray>
Coordinates:
  * time     (time) object 1920-01-01 00:00:00 ... 1929-12-31 18:00:00
Dimensions without coordinates: ncol
Attributes:
    units:      kg/m/s
    long_name:  Total (vertically integrated) vapor transport

And here’s what an ESMF weights file looks like:

weights = xr.open_dataset(weight_file)
weights
<xarray.Dataset>
Dimensions:        (src_grid_rank: 1, dst_grid_rank: 2, n_a: 777602,
                    n_b: 884736, nv_a: 3, nv_b: 4, n_s: 2654208)
Dimensions without coordinates: src_grid_rank, dst_grid_rank, n_a, n_b, nv_a,
                                nv_b, n_s
Data variables: (12/19)
    src_grid_dims  (src_grid_rank) int32 ...
    dst_grid_dims  (dst_grid_rank) int32 ...
    yc_a           (n_a) float64 ...
    yc_b           (n_b) float64 ...
    xc_a           (n_a) float64 ...
    xc_b           (n_b) float64 ...
    ...             ...
    area_b         (n_b) float64 ...
    frac_a         (n_a) float64 ...
    frac_b         (n_b) float64 ...
    col            (n_s) int32 ...
    row            (n_s) int32 ...
    S              (n_s) float64 ...
Attributes:
    title:               ESMF Offline Regridding Weight Generator
    normalization:       destarea
    map_method:          Bilinear remapping
    ESMF_regrid_method:  Bilinear
    conventions:         NCAR-CSM
    domain_a:            /glade/scratch/shields/regridded/mapfiles/ne120.nc
    domain_b:            /glade/scratch/shields/regridded/mapfiles/0.23x0.31.nc
    grid_file_src:       /glade/scratch/shields/regridded/mapfiles/ne120.nc
    grid_file_dst:       /glade/scratch/shields/regridded/mapfiles/0.23x0.31.nc
    CVS_revision:        7.1.0r

Regridding with xESMF#

Th primary xESMF interface is the Regridder class.

regridder = xesmf.Regridder(
    ds_in, 
    ds_out,
    weights=weight_file,
    method="bilinear",
    reuse_weights=True,
    periodic=True,
)

It requires input grid information in ds_in and output grid information in ds_out. Both files need to have variables lat, lon or variables that can be identified as “latitude” and “longitude” using CF metadata.

We can construct this information from the weights file:

Read the weights file#

weights = xr.open_dataset(weight_file)

# input variable shape
in_shape = weights.src_grid_dims.load().data

# Since xESMF expects 2D vars, we'll insert a dummy dimension of size-1
if len(in_shape) == 1:
    in_shape = [1, in_shape.item()]

# output variable shape
out_shape = weights.dst_grid_dims.load().data.tolist()[::-1]

in_shape, out_shape
([1, 777602], [768, 1152])

Construct a Regridder.#

First we make dummy input dataset with lat of size 1 and lon of same size as ncol, and an output dataset with the right lat, lon. For the latter, we use xc_b and yc_b as locations of the centers.

Note

We are assuming a rectilinear output grid, but this could be modified for a curvilinear grid. (Question: Is there a way to identify this from the weights file?)

As a reminder, we use lat, lon dimension names because this is hardcoded in to xESMF.

dummy_in = xr.Dataset(
    {
        "lat": ("lat", np.empty((in_shape[0],))),
        "lon": ("lon", np.empty((in_shape[1],))),
    }
)
dummy_out = xr.Dataset(
    {
        "lat": ("lat", weights.yc_b.data.reshape(out_shape)[:, 0]),
        "lon": ("lon", weights.xc_b.data.reshape(out_shape)[0, :]),
    }
)

regridder = xesmf.Regridder(
    dummy_in,
    dummy_out,
    weights=weight_file,
    method="bilinear",
    reuse_weights=True,
    periodic=True,
)
regridder
xESMF Regridder 
Regridding algorithm:       bilinear 
Weight filename:            bilinear_1x777602_768x1152_peri.nc 
Reuse pre-computed weights? True 
Input grid shape:           (1, 777602) 
Output grid shape:          (768, 1152) 
Periodic in longitude?      True

This works but note that a lot of metadata here is wrong like “Weight filename”

Apply the Regridder#

Next to apply the regridder, we’ll use DataArray.expand_dims insert a new "dummy" dimension before the ncol dimension using axis=-2. To be safe, we force ncol to be the last dimension using transpose. We do this only for variables that already have the ncol dimension.

Here I’m using the name dummy so that we are clear that it is fake.

vars_with_ncol = [name for name in data_in.variables if "ncol" in data_in[name].dims]
updated = data_in.copy().update(
    data_in[vars_with_ncol].transpose(..., "ncol").expand_dims("dummy", axis=-2)
)
updated
<xarray.Dataset>
Dimensions:        (lev: 30, ilev: 31, cosp_prs: 7, nbnd: 2, cosp_tau: 7,
                    cosp_scol: 50, cosp_ht: 40, cosp_sr: 15, cosp_sza: 5,
                    dummy: 1, ncol: 777602, time: 14600)
Coordinates:
  * lev            (lev) float64 3.643 7.595 14.36 24.61 ... 957.5 976.3 992.6
  * ilev           (ilev) float64 2.255 5.032 10.16 18.56 ... 967.5 985.1 1e+03
  * cosp_prs       (cosp_prs) float64 900.0 740.0 620.0 500.0 375.0 245.0 90.0
  * cosp_tau       (cosp_tau) float64 0.15 0.8 2.45 6.5 16.2 41.5 219.5
  * cosp_scol      (cosp_scol) int32 1 2 3 4 5 6 7 8 ... 43 44 45 46 47 48 49 50
  * cosp_ht        (cosp_ht) float64 240.0 720.0 1.2e+03 ... 1.848e+04 1.896e+04
  * cosp_sr        (cosp_sr) float64 0.605 2.1 4.0 6.0 ... 70.0 539.5 1.004e+03
  * cosp_sza       (cosp_sza) float64 0.0 15.0 30.0 45.0 60.0
  * time           (time) object 1920-01-01 00:00:00 ... 1929-12-31 18:00:00
Dimensions without coordinates: nbnd, dummy, ncol
Data variables: (12/35)
    hyam           (lev) float64 dask.array<chunksize=(30,), meta=np.ndarray>
    hybm           (lev) float64 dask.array<chunksize=(30,), meta=np.ndarray>
    P0             float64 ...
    hyai           (ilev) float64 dask.array<chunksize=(31,), meta=np.ndarray>
    hybi           (ilev) float64 dask.array<chunksize=(31,), meta=np.ndarray>
    cosp_prs_bnds  (cosp_prs, nbnd) float64 dask.array<chunksize=(7, 2), meta=np.ndarray>
    ...             ...
    n2ovmr         (time) float64 dask.array<chunksize=(50,), meta=np.ndarray>
    f11vmr         (time) float64 dask.array<chunksize=(50,), meta=np.ndarray>
    f12vmr         (time) float64 dask.array<chunksize=(50,), meta=np.ndarray>
    sol_tsi        (time) float64 dask.array<chunksize=(50,), meta=np.ndarray>
    nsteph         (time) int32 dask.array<chunksize=(50,), meta=np.ndarray>
    IVT            (time, dummy, ncol) float32 dask.array<chunksize=(50, 1, 777602), meta=np.ndarray>
Attributes:
    np:               4
    ne:               120
    Conventions:      CF-1.0
    source:           CAM
    case:             b.e13.BHISTC5.ne120_t12.cesm-ihesp-hires1.0.30-1920-210...
    title:            UNSET
    logname:          nanr
    host:             login1.frontera.
    Version:          $Name$
    revision_Id:      $Id$
    initial_file:     B.E.13.BHISTC5.ne120_t12.sehires38.003.sunway.cam.i.192...
    topography_file:  /work/02503/edwardsj/CESM/inputdata//atm/cam/topo/USGS-...

Now to apply the regridder on updated we rename dummy to lat (both are size-1 in updated and dummy_in), and ncol to lon (both are the same size in updated and dummy_in)

regridded = regridder(updated.rename({"dummy": "lat", "ncol": "lon"}))
regridded
<xarray.Dataset>
Dimensions:    (lat: 768, lon: 1152, time: 14600, lev: 30, ilev: 31,
                cosp_prs: 7, cosp_tau: 7, cosp_scol: 50, cosp_ht: 40,
                cosp_sr: 15, cosp_sza: 5)
Coordinates:
  * lat        (lat) float64 -90.0 -89.77 -89.53 -89.3 ... 89.3 89.53 89.77 90.0
  * lon        (lon) float64 0.0 0.3125 0.625 0.9375 ... 358.8 359.1 359.4 359.7
  * lev        (lev) float64 3.643 7.595 14.36 24.61 ... 936.2 957.5 976.3 992.6
  * ilev       (ilev) float64 2.255 5.032 10.16 18.56 ... 967.5 985.1 1e+03
  * cosp_prs   (cosp_prs) float64 900.0 740.0 620.0 500.0 375.0 245.0 90.0
  * cosp_tau   (cosp_tau) float64 0.15 0.8 2.45 6.5 16.2 41.5 219.5
  * cosp_scol  (cosp_scol) int32 1 2 3 4 5 6 7 8 9 ... 43 44 45 46 47 48 49 50
  * cosp_ht    (cosp_ht) float64 240.0 720.0 1.2e+03 ... 1.848e+04 1.896e+04
  * cosp_sr    (cosp_sr) float64 0.605 2.1 4.0 6.0 ... 55.0 70.0 539.5 1.004e+03
  * cosp_sza   (cosp_sza) float64 0.0 15.0 30.0 45.0 60.0
  * time       (time) object 1920-01-01 00:00:00 ... 1929-12-31 18:00:00
Data variables:
    area       (lat, lon) float64 dask.array<chunksize=(768, 1152), meta=np.ndarray>
    IVT        (time, lat, lon) float32 dask.array<chunksize=(50, 768, 1152), meta=np.ndarray>
Attributes:
    regrid_method:  bilinear

Visualize#

Here we’ll visualize a single timestep but note that regridded.IVT.mean("time") will nicely parallelize with dask.

figure = regridded.IVT.isel(time=100).hvplot(
    cmap="twilight", clim=(0, 1000), frame_width=500, geo=True, coastline=True
)
figure

Wrap it up#

def regrid_cam_se(dataset, weight_file):
    """
    Regrid CAM-SE output using an existing ESMF weights file.

    Parameters
    ----------
    dataset: xarray.Dataset
        Input dataset to be regridded. Must have the `ncol` dimension.
    weight_file: str or Path
        Path to existing ESMF weights file

    Returns
    -------
    regridded
        xarray.Dataset after regridding.
    """
    import numpy as np
    import xarray as xr

    assert isinstance(dataset, xr.Dataset)
    weights = xr.open_dataset(weight_file)

    # input variable shape
    in_shape = weights.src_grid_dims.load().data

    # Since xESMF expects 2D vars, we'll insert a dummy dimension of size-1
    if len(in_shape) == 1:
        in_shape = [1, in_shape.item()]

    # output variable shape
    out_shape = weights.dst_grid_dims.load().data.tolist()[::-1]

    print(f"Regridding from {in_shape} to {out_shape}")

    # Insert dummy dimension
    vars_with_ncol = [name for name in data_in.variables if "ncol" in data_in[name].dims]
    updated = data_in.copy().update(
        data_in[vars_with_ncol].transpose(..., "ncol").expand_dims("dummy", axis=-2)
    )

    # Actually regrid, after renaming
    regridded = regridder(updated.rename({"dummy": "lat", "ncol": "lon"}))

    # merge back any variables that didn't have the ncol dimension
    # And so were not regridded
    return xr.merge([data_in.drop_vars(regridded.variables), regridded])


regrid_cam_se(data_in, weight_file)
Regridding from [1, 777602] to [768, 1152]
<xarray.Dataset>
Dimensions:        (lev: 30, ilev: 31, cosp_prs: 7, nbnd: 2, cosp_tau: 7,
                    cosp_ht: 40, cosp_sr: 15, time: 14600, lat: 768, lon: 1152,
                    cosp_scol: 50, cosp_sza: 5)
Coordinates:
  * lev            (lev) float64 3.643 7.595 14.36 24.61 ... 957.5 976.3 992.6
  * ilev           (ilev) float64 2.255 5.032 10.16 18.56 ... 967.5 985.1 1e+03
  * cosp_prs       (cosp_prs) float64 900.0 740.0 620.0 500.0 375.0 245.0 90.0
  * cosp_tau       (cosp_tau) float64 0.15 0.8 2.45 6.5 16.2 41.5 219.5
  * cosp_ht        (cosp_ht) float64 240.0 720.0 1.2e+03 ... 1.848e+04 1.896e+04
  * cosp_sr        (cosp_sr) float64 0.605 2.1 4.0 6.0 ... 70.0 539.5 1.004e+03
  * time           (time) object 1920-01-01 00:00:00 ... 1929-12-31 18:00:00
  * lat            (lat) float64 -90.0 -89.77 -89.53 -89.3 ... 89.53 89.77 90.0
  * lon            (lon) float64 0.0 0.3125 0.625 0.9375 ... 359.1 359.4 359.7
  * cosp_scol      (cosp_scol) int32 1 2 3 4 5 6 7 8 ... 43 44 45 46 47 48 49 50
  * cosp_sza       (cosp_sza) float64 0.0 15.0 30.0 45.0 60.0
Dimensions without coordinates: nbnd
Data variables: (12/33)
    hyam           (lev) float64 dask.array<chunksize=(30,), meta=np.ndarray>
    hybm           (lev) float64 dask.array<chunksize=(30,), meta=np.ndarray>
    P0             float64 ...
    hyai           (ilev) float64 dask.array<chunksize=(31,), meta=np.ndarray>
    hybi           (ilev) float64 dask.array<chunksize=(31,), meta=np.ndarray>
    cosp_prs_bnds  (cosp_prs, nbnd) float64 dask.array<chunksize=(7, 2), meta=np.ndarray>
    ...             ...
    f11vmr         (time) float64 dask.array<chunksize=(50,), meta=np.ndarray>
    f12vmr         (time) float64 dask.array<chunksize=(50,), meta=np.ndarray>
    sol_tsi        (time) float64 dask.array<chunksize=(50,), meta=np.ndarray>
    nsteph         (time) int32 dask.array<chunksize=(50,), meta=np.ndarray>
    area           (lat, lon) float64 dask.array<chunksize=(768, 1152), meta=np.ndarray>
    IVT            (time, lat, lon) float32 dask.array<chunksize=(50, 768, 1152), meta=np.ndarray>
Attributes:
    np:               4
    ne:               120
    Conventions:      CF-1.0
    source:           CAM
    case:             b.e13.BHISTC5.ne120_t12.cesm-ihesp-hires1.0.30-1920-210...
    title:            UNSET
    logname:          nanr
    host:             login1.frontera.
    Version:          $Name$
    revision_Id:      $Id$
    initial_file:     B.E.13.BHISTC5.ne120_t12.sehires38.003.sunway.cam.i.192...
    topography_file:  /work/02503/edwardsj/CESM/inputdata//atm/cam/topo/USGS-...

Possible improvements#

It should be possible for xESMF to do all this internally and allow the user to create a Regridder as

xesmf.Regridder.from_weights_file(
    # path to weights file
    path=...,
    # input dimension names that will be removed while regridding
    # These are "core dimensions"
    dims_in=("ncol",),
    # Output dataset with lat, lon information
    ds_out=...,
)

The repr will need to be updated to show the proper weights path, and input/output core dimensions and variables.