Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Apply MDM and QDM to CMIP6 daily temperatures

from matplotlib import pyplot as plt
import xarray as xr
import numpy as np
import dask_jobqueue
from dask.diagnostics import progress
from tqdm.autonotebook import tqdm
import cartopy.io.shapereader as shpreader
import cartopy.feature as cfeature
import intake
import fsspec
import xarray_regrid 
#import seaborn as sns
import s3fs
import cftime
import pandas as pd
/glade/derecho/scratch/harshah/tmp/ipykernel_66585/417568950.py:6: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)
  from tqdm.autonotebook import tqdm
import dask 
from dask_jobqueue import PBSCluster
from dask.distributed import Client
from dask.distributed import performance_report
# Decide whether to re-calculate everything
RECALC = True
#
pi_year  = 1865
eoc_year = 2085
#
chic_lat  = 41.8781
chic_lon  = (360-87.6298)%360
ben_lat   = 12.9716
ben_lon   = 77.5946
#
lustre_scratch   = "/lustre/desc1/scratch/harshah"
gdex_url     =  'https://data.gdex.ucar.edu/'
catalog_url = gdex_url +  'd850001/catalogs/osdf/cmip6-aws/cmip6-osdf-zarr.json'
# catalog_url = 'https://cmip6-pds.s3.amazonaws.com/pangeo-cmip6.json'
gdex_data    = '/gdex/data/special_projects/harshah/osdf_data/'
#
tmean_path  = gdex_data + 'tmean/'
tmax_path   = gdex_data + 'tmax/'
tmin_path   = gdex_data + 'tmin/'
# Create a PBS cluster object
cluster = PBSCluster(
    job_name = 'dask-wk25-mdm',
    cores = 1,
    memory = '8GiB',
    processes = 1,
    local_directory = lustre_scratch+'/dask/spill',
    log_directory = lustre_scratch + '/dask/logs/',
    resource_spec = 'select=1:ncpus=1:mem=8GB',
    queue = 'casper',
    walltime = '5:00:00',
    interface = 'ext'
)
/glade/u/home/harshah/venvs/osdf/lib/python3.10/site-packages/distributed/node.py:187: UserWarning: Port 8787 is already in use.
Perhaps you already have a cluster running?
Hosting the HTTP server on port 44233 instead
  warnings.warn(
# Create the client to load the Dashboard
client = Client(cluster)
n_workers =8
cluster.scale(n_workers)
client.wait_for_workers(n_workers = n_workers)
cluster
Loading...
# calculate global means
def get_lat_name(ds):
    for lat_name in ['lat', 'latitude']:
        if lat_name in ds.coords:
            return lat_name
    raise RuntimeError("Couldn't find a latitude coordinate")

def global_mean(ds):
    lat = ds[get_lat_name(ds)]
    weight = np.cos(np.deg2rad(lat))
    weight /= weight.mean()
    other_dims = set(ds.dims) - {'quantile'}
    return (ds * weight).mean(other_dims)

def detrend_data(ds, central_year):
    # Assumes that the ds has coordinates day, year and member.
    
    #Fit a linear fuction and extract slope
    pcoeffs = ds.polyfit(dim='year',deg=1)
    slope   = pcoeffs.polyfit_coefficients.sel(degree=1)
    
    #Calculate trend
    ds_trend   = slope*(ds['year']- central_year)
    
    #Detrend by subtracting the trend from the data
    ds_detrended = ds  - ds_trend
    
    return ds_detrended

Section 2: Load Data

col = intake.open_esm_datastore(catalog_url)
col
Loading...
var_name    = 'tas'
folder_path = tmean_path
variable    = ['tas'] #Other variables of interest: 'tasmax', 'tasmin'
# 2. Search for daily temperature 
expts = ['ssp370','historical']

query = dict(
    experiment_id=expts,
    table_id='day',
    variable_id= variable,
    member_id = 'r1i1p1f1',
    #activity_id = 'CMIP',
    
)

col_subset = col.search(require_all_on=["source_id"], **query)

col_subset.df.groupby("source_id")[["experiment_id", "variable_id", "table_id","member_id"]].nunique()
Loading...
df = col_subset.df
# model_counts = df.groupby('source_id').size()
# print(model_counts)
df.head()
Loading...
df['activity_id'].unique()
array(['CMIP', 'ScenarioMIP', 'AerChemMIP'], dtype=object)
# Keep only rows with CMIP (historical) or ScenarioMIP (ssp370) for consistency. 
df_filtered = col_subset.df[col_subset.df['activity_id'].isin(['CMIP', 'ScenarioMIP'])]

print("Filtered DataFrame shape:", df_filtered.shape)
# print("Filtered activity_id values:", df_filtered['activity_id'])
Filtered DataFrame shape: (53, 11)
df_filtered.groupby("source_id")[["experiment_id", "variable_id", "table_id","activity_id"]].nunique()
Loading...
%%time
# dsets = col_subset.to_dataset_dict(storage_options={'anon': 'True'})
dsets = col_subset.to_dataset_dict()
print(f"\nDataset dictionary keys:\n {dsets.keys()}")

--> The keys in the returned dictionary of datasets are constructed as follows:
	'activity_id.institution_id.source_id.experiment_id.table_id.grid_label'
Loading...
Loading...
def drop_all_bounds(ds):
    drop_vars = [vname for vname in ds.coords
                 if (('_bounds') in vname ) or ('_bnds') in vname]
    return ds.drop_vars(drop_vars)

def open_dset(df):
    assert len(df) == 1
    
    # Force anonymous access for public datasets
    storage_options = {'anon': True}
    # ds = xr.open_zarr(fsspec.get_mapper(df.zstore.values[0], **storage_options),consolidated=True) #For s3fs protocol
    ds = xr.open_zarr(fsspec.get_mapper(df.zstore.values[0]), consolidated=True) #Use for PelicanFS
    return drop_all_bounds(ds)

def open_delayed(df):
    return dask.delayed(open_dset)(df)

from collections import defaultdict
dsets = defaultdict(dict)

for group, df in col_subset.df.groupby(by=['source_id', 'experiment_id']):
    dsets[group[0]][group[1]] = open_delayed(df)
%%time
# Trigger computation
dsets_ = dask.compute(dict(dsets))[0]
#Define coarse grid to regrid on 1 *1 degree card
ds_out = xr.Dataset({'lat': (['lat'], np.arange(-90, 91, 1.5)),
                    'lon': (['lon'], np.arange(0, 361, 1.5))})
ds_out
def drop_feb29(ds):
    # Check if the dataset's calendar is not '360_day'
    calendar = ds.time.encoding.get('calendar', None)
    print(ds.attrs['source_id'],calendar)
    if calendar != '360_day':
        ds = ds.convert_calendar('365_day')
    return ds


def to_daily(ds):
    # Check and deal with different datetime types
    if isinstance(ds['time'].values[0], np.datetime64):
        pass
    elif isinstance(ds['time'].values[0], cftime.datetime):
        pass
    else:
        # convert time coordinate to datetime64 objects
        ds['time'] = ds['time'].astype('datetime64[ns]')
    year      = ds.time.dt.year
    dayofyear = ds.time.dt.dayofyear

    # assign new coords
    ds = ds.assign_coords(year=("time", year.data), dayofyear=("time", dayofyear.data))

    # reshape the array to (..., "day", "year")
    return ds.set_index(time=("year", "dayofyear")).unstack("time")  


def extract_data(ds):
    """
    Extract data from the dataset 'ds' for specific time and spatial range.

    Parameters:
    - ds (xarray.Dataset): Input dataset

    Returns:
    - xarray.Dataset: Dataset subsetted for required years and the specified space and time range.
    """    

    subset1 = ds.sel(year=slice(1850, 1879))
    subset2 = ds.sel(year=slice(2071, 2100))
    
    subset = xr.concat([subset1, subset2], dim='year')  

    return subset

def is_leap(year):
    """Check if a year is a leap year."""
    return (year % 4 == 0) and ((year % 100 != 0) or (year % 400 == 0))
quants = np.linspace(0,1.0,30)

def compute_quantiles(ds, quantiles=quants):
    return ds.chunk(dict(year=-1)).quantile(quantiles, dim='year',skipna=False)

def regrid(ds, ds_out):
    experiment_id = ds.attrs['experiment_id']
    source_id     = ds.attrs['source_id']
    ds_new   = ds.regrid.nearest(ds_out)
    
#     #Assign back attributes as regirdder would have deleted attributes 
    ds_new.attrs['experiment_id'] = experiment_id
    ds_new.attrs['source_id'] = source_id
    
    #print(ds_new.attrs['experiment_id'],ds_new.attrs['source_id'])
    #print(ds_new)
    return ds_new

def process_data(ds, quantiles=quants):
    ds = ds.pipe(drop_feb29).pipe(to_daily).pipe(extract_data)
    
    if len(ds['year']) == 0:
        print("The dataset is empty. Skipping...")
        return None
    
    if len(ds['dayofyear'])<365:
        print('The dataset has less than 365 days. Skipping ..')
        return None
    
    # # Remove 'time' coordinate
    # ds = ds.set_index(time=("year", "dayofyear")).unstack("time")  
    return (ds.pipe(regrid, ds_out=ds_out))

Section 3: Computations.

  • Evalulate these

%%time
if RECALC:
    with progress.ProgressBar():
    
        expt_da = xr.DataArray(expts, dims='experiment_id', name='experiment_id',
                               coords={'experiment_id': expts})
    
        # Initialize an Empty Dictionary for Aligned Datasets:
        dsets_aligned = {}
    
        # Iterate Over dsets_ Dictionary:
    
        for k, v in tqdm(dsets_.items()):
            # Initialize a dictionary for this source_id
            dsets_aligned[k] = {}
            
            skip_source_id = False
    
            for expt in expts:
                ds = v[expt].pipe(process_data)
    
                # Check if the dataset is empty and skip this source_id if so
                if ds is None:
                    print(f"Skipping {expt} for {k} because the dataset is empty")
                    skip_source_id = True
                    break
                
                # Store the dataset in the dictionary
                # dsets_aligned[k][expt] = ds
                # Compute the dataset and store it in the dictionary
                dsets_aligned[k][expt] = ds.compute()
                print(dsets_aligned[k][expt])
    
            if skip_source_id:
                del dsets_aligned[k]
                continue
# dsets_aligned.keys()
%%time
if RECALC:
    source_ids = list(dsets_aligned.keys())
    source_da = xr.DataArray(source_ids, dims='source_id', name='source_id',
                         coords={'source_id': source_ids})
    final_ds_pi = xr.concat([ds['historical'].reset_coords(drop=True)
                                     for ds in dsets_aligned.values()],
                                    dim=source_da)
    
    final_ds_eoc = xr.concat([ds['ssp370'].reset_coords(drop=True)
                                 for ds in dsets_aligned.values()],
                                dim=source_da)
    final_ds_eoc
%%time
final_ds_pi.to_zarr(folder_path  +'cmip6_pi_daily.zarr',mode='w')
final_ds_eoc.to_zarr(folder_path +'cmip6_eoc_daily.zarr',mode='w')
final_ds_pi  = xr.open_zarr(folder_path+'cmip6_pi_daily.zarr')
final_ds_eoc = xr.open_zarr(folder_path+'cmip6_eoc_daily.zarr')
final_ds_pi  = final_ds_pi[var_name]
final_ds_eoc = final_ds_eoc[var_name]
final_ds_eoc

Detrend data and save

%%time
ds_pi_det  = detrend_data(final_ds_pi,pi_year)
ds_eoc_det = detrend_data(final_ds_eoc,eoc_year)
ds_eoc_det = ds_eoc_det.chunk({'year':30,'source_id':1})
ds_pi_det = ds_pi_det.chunk({'year':30,'source_id':1})
ds_eoc_det
# %%time
# ds_pi_det.rename(var_name).to_dataset().to_zarr(folder_path  +'cmip6_pi_ann_detrended.zarr',mode='w')
# %%time
# ds_eoc_det.rename(var_name).to_dataset().to_zarr(folder_path +'cmip6_eoc_ann_detrended.zarr',mode='w')

Check if detrending worked

ds_pi_det  = xr.open_zarr(folder_path  +'cmip6_pi_ann_detrended.zarr')
ds_eoc_det = xr.open_zarr(folder_path +'cmip6_eoc_ann_detrended.zarr')
#
ds_pi_det  = ds_pi_det[var_name]
ds_eoc_det = ds_eoc_det[var_name]
%%time
ds_pi_det.sel(lat=chic_lat,lon=chic_lon,method='nearest').sel(dayofyear=2).mean('source_id').plot()
final_ds_pi.sel(lat=chic_lat,lon=chic_lon,method='nearest').sel(dayofyear=2).mean('source_id').plot()
%%time
ds_eoc_det.sel(lat=chic_lat,lon=chic_lon,method='nearest').sel(dayofyear=2).mean('source_id').plot()
final_ds_eoc.sel(lat=chic_lat,lon=chic_lon,method='nearest').sel(dayofyear=2).mean('source_id').plot()