- Daily temperature data is loaded using the OSDF protocol and discovered using an intake-ESM catalog
- We apply three different bias-correction techniques to pairs of CMIP6 models, treating one of the models in the pair as pseudo-observations
- The bias-correction techniques are Moment Delta Mapping, Quantile Delta Mapping+ sort and a shift in DMT
from matplotlib import pyplot as plt
import xarray as xr
import numpy as np
import dask_jobqueue
from dask.diagnostics import progress
from tqdm.autonotebook import tqdm
import cartopy.io.shapereader as shpreader
import cartopy.feature as cfeature
import intake
import fsspec
import xarray_regrid
#import seaborn as sns
import s3fs
import cftime
import pandas as pd/glade/derecho/scratch/harshah/tmp/ipykernel_66585/417568950.py:6: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)
from tqdm.autonotebook import tqdm
import dask
from dask_jobqueue import PBSCluster
from dask.distributed import Client
from dask.distributed import performance_report# Decide whether to re-calculate everything
RECALC = True
#
pi_year = 1865
eoc_year = 2085
#
chic_lat = 41.8781
chic_lon = (360-87.6298)%360
ben_lat = 12.9716
ben_lon = 77.5946
#
lustre_scratch = "/lustre/desc1/scratch/harshah"
gdex_url = 'https://data.gdex.ucar.edu/'
catalog_url = gdex_url + 'd850001/catalogs/osdf/cmip6-aws/cmip6-osdf-zarr.json'
# catalog_url = 'https://cmip6-pds.s3.amazonaws.com/pangeo-cmip6.json'
gdex_data = '/gdex/data/special_projects/harshah/osdf_data/'
#
tmean_path = gdex_data + 'tmean/'
tmax_path = gdex_data + 'tmax/'
tmin_path = gdex_data + 'tmin/'# Create a PBS cluster object
cluster = PBSCluster(
job_name = 'dask-wk25-mdm',
cores = 1,
memory = '8GiB',
processes = 1,
local_directory = lustre_scratch+'/dask/spill',
log_directory = lustre_scratch + '/dask/logs/',
resource_spec = 'select=1:ncpus=1:mem=8GB',
queue = 'casper',
walltime = '5:00:00',
interface = 'ext'
)/glade/u/home/harshah/venvs/osdf/lib/python3.10/site-packages/distributed/node.py:187: UserWarning: Port 8787 is already in use.
Perhaps you already have a cluster running?
Hosting the HTTP server on port 44233 instead
warnings.warn(
# Create the client to load the Dashboard
client = Client(cluster)n_workers =8
cluster.scale(n_workers)
client.wait_for_workers(n_workers = n_workers)
clusterLoading...
# calculate global means
def get_lat_name(ds):
for lat_name in ['lat', 'latitude']:
if lat_name in ds.coords:
return lat_name
raise RuntimeError("Couldn't find a latitude coordinate")
def global_mean(ds):
lat = ds[get_lat_name(ds)]
weight = np.cos(np.deg2rad(lat))
weight /= weight.mean()
other_dims = set(ds.dims) - {'quantile'}
return (ds * weight).mean(other_dims)
def detrend_data(ds, central_year):
# Assumes that the ds has coordinates day, year and member.
#Fit a linear fuction and extract slope
pcoeffs = ds.polyfit(dim='year',deg=1)
slope = pcoeffs.polyfit_coefficients.sel(degree=1)
#Calculate trend
ds_trend = slope*(ds['year']- central_year)
#Detrend by subtracting the trend from the data
ds_detrended = ds - ds_trend
return ds_detrendedSection 2: Load Data¶
col = intake.open_esm_datastore(catalog_url)
colLoading...
var_name = 'tas'
folder_path = tmean_path
variable = ['tas'] #Other variables of interest: 'tasmax', 'tasmin'# 2. Search for daily temperature
expts = ['ssp370','historical']
query = dict(
experiment_id=expts,
table_id='day',
variable_id= variable,
member_id = 'r1i1p1f1',
#activity_id = 'CMIP',
)
col_subset = col.search(require_all_on=["source_id"], **query)
col_subset.df.groupby("source_id")[["experiment_id", "variable_id", "table_id","member_id"]].nunique()Loading...
df = col_subset.df
# model_counts = df.groupby('source_id').size()
# print(model_counts)
df.head()Loading...
df['activity_id'].unique()array(['CMIP', 'ScenarioMIP', 'AerChemMIP'], dtype=object)# Keep only rows with CMIP (historical) or ScenarioMIP (ssp370) for consistency.
df_filtered = col_subset.df[col_subset.df['activity_id'].isin(['CMIP', 'ScenarioMIP'])]
print("Filtered DataFrame shape:", df_filtered.shape)
# print("Filtered activity_id values:", df_filtered['activity_id'])Filtered DataFrame shape: (53, 11)
df_filtered.groupby("source_id")[["experiment_id", "variable_id", "table_id","activity_id"]].nunique()Loading...
%%time
# dsets = col_subset.to_dataset_dict(storage_options={'anon': 'True'})
dsets = col_subset.to_dataset_dict()
print(f"\nDataset dictionary keys:\n {dsets.keys()}")Loading...
def drop_all_bounds(ds):
drop_vars = [vname for vname in ds.coords
if (('_bounds') in vname ) or ('_bnds') in vname]
return ds.drop_vars(drop_vars)
def open_dset(df):
assert len(df) == 1
# Force anonymous access for public datasets
storage_options = {'anon': True}
# ds = xr.open_zarr(fsspec.get_mapper(df.zstore.values[0], **storage_options),consolidated=True) #For s3fs protocol
ds = xr.open_zarr(fsspec.get_mapper(df.zstore.values[0]), consolidated=True) #Use for PelicanFS
return drop_all_bounds(ds)
def open_delayed(df):
return dask.delayed(open_dset)(df)
from collections import defaultdict
dsets = defaultdict(dict)
for group, df in col_subset.df.groupby(by=['source_id', 'experiment_id']):
dsets[group[0]][group[1]] = open_delayed(df)%%time
# Trigger computation
dsets_ = dask.compute(dict(dsets))[0]#Define coarse grid to regrid on 1 *1 degree card
ds_out = xr.Dataset({'lat': (['lat'], np.arange(-90, 91, 1.5)),
'lon': (['lon'], np.arange(0, 361, 1.5))})
ds_outdef drop_feb29(ds):
# Check if the dataset's calendar is not '360_day'
calendar = ds.time.encoding.get('calendar', None)
print(ds.attrs['source_id'],calendar)
if calendar != '360_day':
ds = ds.convert_calendar('365_day')
return ds
def to_daily(ds):
# Check and deal with different datetime types
if isinstance(ds['time'].values[0], np.datetime64):
pass
elif isinstance(ds['time'].values[0], cftime.datetime):
pass
else:
# convert time coordinate to datetime64 objects
ds['time'] = ds['time'].astype('datetime64[ns]')
year = ds.time.dt.year
dayofyear = ds.time.dt.dayofyear
# assign new coords
ds = ds.assign_coords(year=("time", year.data), dayofyear=("time", dayofyear.data))
# reshape the array to (..., "day", "year")
return ds.set_index(time=("year", "dayofyear")).unstack("time")
def extract_data(ds):
"""
Extract data from the dataset 'ds' for specific time and spatial range.
Parameters:
- ds (xarray.Dataset): Input dataset
Returns:
- xarray.Dataset: Dataset subsetted for required years and the specified space and time range.
"""
subset1 = ds.sel(year=slice(1850, 1879))
subset2 = ds.sel(year=slice(2071, 2100))
subset = xr.concat([subset1, subset2], dim='year')
return subset
def is_leap(year):
"""Check if a year is a leap year."""
return (year % 4 == 0) and ((year % 100 != 0) or (year % 400 == 0))
quants = np.linspace(0,1.0,30)
def compute_quantiles(ds, quantiles=quants):
return ds.chunk(dict(year=-1)).quantile(quantiles, dim='year',skipna=False)
def regrid(ds, ds_out):
experiment_id = ds.attrs['experiment_id']
source_id = ds.attrs['source_id']
ds_new = ds.regrid.nearest(ds_out)
# #Assign back attributes as regirdder would have deleted attributes
ds_new.attrs['experiment_id'] = experiment_id
ds_new.attrs['source_id'] = source_id
#print(ds_new.attrs['experiment_id'],ds_new.attrs['source_id'])
#print(ds_new)
return ds_new
def process_data(ds, quantiles=quants):
ds = ds.pipe(drop_feb29).pipe(to_daily).pipe(extract_data)
if len(ds['year']) == 0:
print("The dataset is empty. Skipping...")
return None
if len(ds['dayofyear'])<365:
print('The dataset has less than 365 days. Skipping ..')
return None
# # Remove 'time' coordinate
# ds = ds.set_index(time=("year", "dayofyear")).unstack("time")
return (ds.pipe(regrid, ds_out=ds_out))Section 3: Computations.¶
- Evalulate these
%%time
if RECALC:
with progress.ProgressBar():
expt_da = xr.DataArray(expts, dims='experiment_id', name='experiment_id',
coords={'experiment_id': expts})
# Initialize an Empty Dictionary for Aligned Datasets:
dsets_aligned = {}
# Iterate Over dsets_ Dictionary:
for k, v in tqdm(dsets_.items()):
# Initialize a dictionary for this source_id
dsets_aligned[k] = {}
skip_source_id = False
for expt in expts:
ds = v[expt].pipe(process_data)
# Check if the dataset is empty and skip this source_id if so
if ds is None:
print(f"Skipping {expt} for {k} because the dataset is empty")
skip_source_id = True
break
# Store the dataset in the dictionary
# dsets_aligned[k][expt] = ds
# Compute the dataset and store it in the dictionary
dsets_aligned[k][expt] = ds.compute()
print(dsets_aligned[k][expt])
if skip_source_id:
del dsets_aligned[k]
continue# dsets_aligned.keys()%%time
if RECALC:
source_ids = list(dsets_aligned.keys())
source_da = xr.DataArray(source_ids, dims='source_id', name='source_id',
coords={'source_id': source_ids})
final_ds_pi = xr.concat([ds['historical'].reset_coords(drop=True)
for ds in dsets_aligned.values()],
dim=source_da)
final_ds_eoc = xr.concat([ds['ssp370'].reset_coords(drop=True)
for ds in dsets_aligned.values()],
dim=source_da)
final_ds_eoc%%time
final_ds_pi.to_zarr(folder_path +'cmip6_pi_daily.zarr',mode='w')
final_ds_eoc.to_zarr(folder_path +'cmip6_eoc_daily.zarr',mode='w')final_ds_pi = xr.open_zarr(folder_path+'cmip6_pi_daily.zarr')
final_ds_eoc = xr.open_zarr(folder_path+'cmip6_eoc_daily.zarr')
final_ds_pi = final_ds_pi[var_name]
final_ds_eoc = final_ds_eoc[var_name]
final_ds_eocDetrend data and save¶
%%time
ds_pi_det = detrend_data(final_ds_pi,pi_year)
ds_eoc_det = detrend_data(final_ds_eoc,eoc_year)
ds_eoc_det = ds_eoc_det.chunk({'year':30,'source_id':1})
ds_pi_det = ds_pi_det.chunk({'year':30,'source_id':1})
ds_eoc_det# %%time
# ds_pi_det.rename(var_name).to_dataset().to_zarr(folder_path +'cmip6_pi_ann_detrended.zarr',mode='w')# %%time
# ds_eoc_det.rename(var_name).to_dataset().to_zarr(folder_path +'cmip6_eoc_ann_detrended.zarr',mode='w')Check if detrending worked¶
ds_pi_det = xr.open_zarr(folder_path +'cmip6_pi_ann_detrended.zarr')
ds_eoc_det = xr.open_zarr(folder_path +'cmip6_eoc_ann_detrended.zarr')
#
ds_pi_det = ds_pi_det[var_name]
ds_eoc_det = ds_eoc_det[var_name]%%time
ds_pi_det.sel(lat=chic_lat,lon=chic_lon,method='nearest').sel(dayofyear=2).mean('source_id').plot()
final_ds_pi.sel(lat=chic_lat,lon=chic_lon,method='nearest').sel(dayofyear=2).mean('source_id').plot()%%time
ds_eoc_det.sel(lat=chic_lat,lon=chic_lon,method='nearest').sel(dayofyear=2).mean('source_id').plot()
final_ds_eoc.sel(lat=chic_lat,lon=chic_lon,method='nearest').sel(dayofyear=2).mean('source_id').plot()