Skip to article frontmatterSkip to article content

Apply MDM and QDM to CMIP6 daily temperatures

  • Daily temperature data is loaded using the OSDF protocol and discovered using an intake-ESM catalog
  • We apply three different bias-correction techniques to pairs of CMIP6 models, treating one of the models in the pair as pseudo-observations
  • The bias-correction techniques are Moment Delta Mapping, Quantile Delta Mapping+ sort and a shift in DMT
from matplotlib import pyplot as plt
import xarray as xr
import numpy as np
import dask_jobqueue
from dask.diagnostics import progress
from tqdm.autonotebook import tqdm
import cartopy.io.shapereader as shpreader
import cartopy.feature as cfeature
import intake
import fsspec
import xarray_regrid 
#import seaborn as sns
import s3fs
import cftime
import pandas as pd
/glade/derecho/scratch/harshah/tmp/ipykernel_66585/417568950.py:6: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)
  from tqdm.autonotebook import tqdm
import dask 
from dask_jobqueue import PBSCluster
from dask.distributed import Client
from dask.distributed import performance_report
# Decide whether to re-calculate everything
RECALC = True
#
pi_year  = 1865
eoc_year = 2085
#
chic_lat  = 41.8781
chic_lon  = (360-87.6298)%360
ben_lat   = 12.9716
ben_lon   = 77.5946
#
lustre_scratch   = "/lustre/desc1/scratch/harshah"
gdex_url     =  'https://data.gdex.ucar.edu/'
catalog_url = gdex_url +  'd850001/catalogs/osdf/cmip6-aws/cmip6-osdf-zarr.json'
# catalog_url = 'https://cmip6-pds.s3.amazonaws.com/pangeo-cmip6.json'
gdex_data    = '/gdex/data/special_projects/harshah/osdf_data/'
#
tmean_path  = gdex_data + 'tmean/'
tmax_path   = gdex_data + 'tmax/'
tmin_path   = gdex_data + 'tmin/'
# Create a PBS cluster object
cluster = PBSCluster(
    job_name = 'dask-wk25-mdm',
    cores = 1,
    memory = '8GiB',
    processes = 1,
    local_directory = lustre_scratch+'/dask/spill',
    log_directory = lustre_scratch + '/dask/logs/',
    resource_spec = 'select=1:ncpus=1:mem=8GB',
    queue = 'casper',
    walltime = '5:00:00',
    interface = 'ext'
)
/glade/u/home/harshah/venvs/osdf/lib/python3.10/site-packages/distributed/node.py:187: UserWarning: Port 8787 is already in use.
Perhaps you already have a cluster running?
Hosting the HTTP server on port 44233 instead
  warnings.warn(
# Create the client to load the Dashboard
client = Client(cluster)
n_workers =8
cluster.scale(n_workers)
client.wait_for_workers(n_workers = n_workers)
cluster
Loading...
# calculate global means
def get_lat_name(ds):
    for lat_name in ['lat', 'latitude']:
        if lat_name in ds.coords:
            return lat_name
    raise RuntimeError("Couldn't find a latitude coordinate")

def global_mean(ds):
    lat = ds[get_lat_name(ds)]
    weight = np.cos(np.deg2rad(lat))
    weight /= weight.mean()
    other_dims = set(ds.dims) - {'quantile'}
    return (ds * weight).mean(other_dims)

def detrend_data(ds, central_year):
    # Assumes that the ds has coordinates day, year and member.
    
    #Fit a linear fuction and extract slope
    pcoeffs = ds.polyfit(dim='year',deg=1)
    slope   = pcoeffs.polyfit_coefficients.sel(degree=1)
    
    #Calculate trend
    ds_trend   = slope*(ds['year']- central_year)
    
    #Detrend by subtracting the trend from the data
    ds_detrended = ds  - ds_trend
    
    return ds_detrended

Section 2: Load Data

col = intake.open_esm_datastore(catalog_url)
col
Loading...
var_name    = 'tas'
folder_path = tmean_path
variable    = ['tas'] #Other variables of interest: 'tasmax', 'tasmin'
# 2. Search for daily temperature 
expts = ['ssp370','historical']

query = dict(
    experiment_id=expts,
    table_id='day',
    variable_id= variable,
    member_id = 'r1i1p1f1',
    #activity_id = 'CMIP',
    
)

col_subset = col.search(require_all_on=["source_id"], **query)

col_subset.df.groupby("source_id")[["experiment_id", "variable_id", "table_id","member_id"]].nunique()
Loading...
df = col_subset.df
# model_counts = df.groupby('source_id').size()
# print(model_counts)
df.head()
Loading...
df['activity_id'].unique()
array(['CMIP', 'ScenarioMIP', 'AerChemMIP'], dtype=object)
# Keep only rows with CMIP (historical) or ScenarioMIP (ssp370) for consistency. 
df_filtered = col_subset.df[col_subset.df['activity_id'].isin(['CMIP', 'ScenarioMIP'])]

print("Filtered DataFrame shape:", df_filtered.shape)
# print("Filtered activity_id values:", df_filtered['activity_id'])
Filtered DataFrame shape: (53, 11)
df_filtered.groupby("source_id")[["experiment_id", "variable_id", "table_id","activity_id"]].nunique()
Loading...
%%time
# dsets = col_subset.to_dataset_dict(storage_options={'anon': 'True'})
dsets = col_subset.to_dataset_dict()
print(f"\nDataset dictionary keys:\n {dsets.keys()}")
Loading...
def drop_all_bounds(ds):
    drop_vars = [vname for vname in ds.coords
                 if (('_bounds') in vname ) or ('_bnds') in vname]
    return ds.drop_vars(drop_vars)

def open_dset(df):
    assert len(df) == 1
    
    # Force anonymous access for public datasets
    storage_options = {'anon': True}
    # ds = xr.open_zarr(fsspec.get_mapper(df.zstore.values[0], **storage_options),consolidated=True) #For s3fs protocol
    ds = xr.open_zarr(fsspec.get_mapper(df.zstore.values[0]), consolidated=True) #Use for PelicanFS
    return drop_all_bounds(ds)

def open_delayed(df):
    return dask.delayed(open_dset)(df)

from collections import defaultdict
dsets = defaultdict(dict)

for group, df in col_subset.df.groupby(by=['source_id', 'experiment_id']):
    dsets[group[0]][group[1]] = open_delayed(df)
%%time
# Trigger computation
dsets_ = dask.compute(dict(dsets))[0]
#Define coarse grid to regrid on 1 *1 degree card
ds_out = xr.Dataset({'lat': (['lat'], np.arange(-90, 91, 1.5)),
                    'lon': (['lon'], np.arange(0, 361, 1.5))})
ds_out
def drop_feb29(ds):
    # Check if the dataset's calendar is not '360_day'
    calendar = ds.time.encoding.get('calendar', None)
    print(ds.attrs['source_id'],calendar)
    if calendar != '360_day':
        ds = ds.convert_calendar('365_day')
    return ds


def to_daily(ds):
    # Check and deal with different datetime types
    if isinstance(ds['time'].values[0], np.datetime64):
        pass
    elif isinstance(ds['time'].values[0], cftime.datetime):
        pass
    else:
        # convert time coordinate to datetime64 objects
        ds['time'] = ds['time'].astype('datetime64[ns]')
    year      = ds.time.dt.year
    dayofyear = ds.time.dt.dayofyear

    # assign new coords
    ds = ds.assign_coords(year=("time", year.data), dayofyear=("time", dayofyear.data))

    # reshape the array to (..., "day", "year")
    return ds.set_index(time=("year", "dayofyear")).unstack("time")  


def extract_data(ds):
    """
    Extract data from the dataset 'ds' for specific time and spatial range.

    Parameters:
    - ds (xarray.Dataset): Input dataset

    Returns:
    - xarray.Dataset: Dataset subsetted for required years and the specified space and time range.
    """    

    subset1 = ds.sel(year=slice(1850, 1879))
    subset2 = ds.sel(year=slice(2071, 2100))
    
    subset = xr.concat([subset1, subset2], dim='year')  

    return subset

def is_leap(year):
    """Check if a year is a leap year."""
    return (year % 4 == 0) and ((year % 100 != 0) or (year % 400 == 0))
quants = np.linspace(0,1.0,30)

def compute_quantiles(ds, quantiles=quants):
    return ds.chunk(dict(year=-1)).quantile(quantiles, dim='year',skipna=False)

def regrid(ds, ds_out):
    experiment_id = ds.attrs['experiment_id']
    source_id     = ds.attrs['source_id']
    ds_new   = ds.regrid.nearest(ds_out)
    
#     #Assign back attributes as regirdder would have deleted attributes 
    ds_new.attrs['experiment_id'] = experiment_id
    ds_new.attrs['source_id'] = source_id
    
    #print(ds_new.attrs['experiment_id'],ds_new.attrs['source_id'])
    #print(ds_new)
    return ds_new

def process_data(ds, quantiles=quants):
    ds = ds.pipe(drop_feb29).pipe(to_daily).pipe(extract_data)
    
    if len(ds['year']) == 0:
        print("The dataset is empty. Skipping...")
        return None
    
    if len(ds['dayofyear'])<365:
        print('The dataset has less than 365 days. Skipping ..')
        return None
    
    # # Remove 'time' coordinate
    # ds = ds.set_index(time=("year", "dayofyear")).unstack("time")  
    return (ds.pipe(regrid, ds_out=ds_out))

Section 3: Computations.

  • Evalulate these
%%time
if RECALC:
    with progress.ProgressBar():
    
        expt_da = xr.DataArray(expts, dims='experiment_id', name='experiment_id',
                               coords={'experiment_id': expts})
    
        # Initialize an Empty Dictionary for Aligned Datasets:
        dsets_aligned = {}
    
        # Iterate Over dsets_ Dictionary:
    
        for k, v in tqdm(dsets_.items()):
            # Initialize a dictionary for this source_id
            dsets_aligned[k] = {}
            
            skip_source_id = False
    
            for expt in expts:
                ds = v[expt].pipe(process_data)
    
                # Check if the dataset is empty and skip this source_id if so
                if ds is None:
                    print(f"Skipping {expt} for {k} because the dataset is empty")
                    skip_source_id = True
                    break
                
                # Store the dataset in the dictionary
                # dsets_aligned[k][expt] = ds
                # Compute the dataset and store it in the dictionary
                dsets_aligned[k][expt] = ds.compute()
                print(dsets_aligned[k][expt])
    
            if skip_source_id:
                del dsets_aligned[k]
                continue
# dsets_aligned.keys()
%%time
if RECALC:
    source_ids = list(dsets_aligned.keys())
    source_da = xr.DataArray(source_ids, dims='source_id', name='source_id',
                         coords={'source_id': source_ids})
    final_ds_pi = xr.concat([ds['historical'].reset_coords(drop=True)
                                     for ds in dsets_aligned.values()],
                                    dim=source_da)
    
    final_ds_eoc = xr.concat([ds['ssp370'].reset_coords(drop=True)
                                 for ds in dsets_aligned.values()],
                                dim=source_da)
    final_ds_eoc
%%time
final_ds_pi.to_zarr(folder_path  +'cmip6_pi_daily.zarr',mode='w')
final_ds_eoc.to_zarr(folder_path +'cmip6_eoc_daily.zarr',mode='w')
final_ds_pi  = xr.open_zarr(folder_path+'cmip6_pi_daily.zarr')
final_ds_eoc = xr.open_zarr(folder_path+'cmip6_eoc_daily.zarr')
final_ds_pi  = final_ds_pi[var_name]
final_ds_eoc = final_ds_eoc[var_name]
final_ds_eoc

Detrend data and save

%%time
ds_pi_det  = detrend_data(final_ds_pi,pi_year)
ds_eoc_det = detrend_data(final_ds_eoc,eoc_year)
ds_eoc_det = ds_eoc_det.chunk({'year':30,'source_id':1})
ds_pi_det = ds_pi_det.chunk({'year':30,'source_id':1})
ds_eoc_det
# %%time
# ds_pi_det.rename(var_name).to_dataset().to_zarr(folder_path  +'cmip6_pi_ann_detrended.zarr',mode='w')
# %%time
# ds_eoc_det.rename(var_name).to_dataset().to_zarr(folder_path +'cmip6_eoc_ann_detrended.zarr',mode='w')

Check if detrending worked

ds_pi_det  = xr.open_zarr(folder_path  +'cmip6_pi_ann_detrended.zarr')
ds_eoc_det = xr.open_zarr(folder_path +'cmip6_eoc_ann_detrended.zarr')
#
ds_pi_det  = ds_pi_det[var_name]
ds_eoc_det = ds_eoc_det[var_name]
%%time
ds_pi_det.sel(lat=chic_lat,lon=chic_lon,method='nearest').sel(dayofyear=2).mean('source_id').plot()
final_ds_pi.sel(lat=chic_lat,lon=chic_lon,method='nearest').sel(dayofyear=2).mean('source_id').plot()
%%time
ds_eoc_det.sel(lat=chic_lat,lon=chic_lon,method='nearest').sel(dayofyear=2).mean('source_id').plot()
final_ds_eoc.sel(lat=chic_lat,lon=chic_lon,method='nearest').sel(dayofyear=2).mean('source_id').plot()