I have a large pandas MultiIndex (800,000 rows) that is taking a long time to unstack to sparse form.
### Preprocessing Steps for each input dataset before merge def preprocess(ds): """Pare down each input dataset to a single variable. The subsequent merge will eliminate unused coordinates automatically. This function does not allow additional arguments, so the target output variable needs to be defined globally in TARGET_VAR. """ drop_vars = [var for var in ds.data_vars if var != TARGET_VAR] ds_fixed = ds.drop_vars(drop_vars) lats = list(ds.pfts1d_lat.astype('float32').data) lons = list(ds.pfts1d_lon.astype('float32').data) vegtype = list(ds.pfts1d_itype_veg.data) coltype = list(ds.pfts1d_itype_col.data) lunittype = list(ds.pfts1d_itype_lunit.data) active = list(ds.pfts1d_active.data) # Redefine the 'pft' dimension as a multi-index, which will increase the number of dimensions. index = pd.MultiIndex.from_arrays([lats, lons, vegtype, coltype, lunittype, active], names=('pftlat', 'pftlon', 'vegtype', 'coltype', 'lunittype', 'active')) ds_fixed['pft'] = index # Keep the data sparse if possible to avoid memory shortages. ds_fixed = ds_fixed.unstack(sparse=True) return ds_fixed
Does anyone know if there is a way to speed up unstack() with Dask?
I should add that the source data comes from a Zarr store where I tried to chunk the 'pft' dimension (the one being unstacked) into chunks of 10,000 elements each. I don't know if this helps or hurts.
No. the fast way requires "advanced indexing" over multiple dimensions which is not supported by both dask and sparse. See discussion here: https://github.com/pydata/xarray/blob/c472f8a4c79f872edb9dcd7825f786ecb9aff5c0/xarray/core/dataset.py#L4129-L4152
Thanks for that info! Much appreciated.
One more related question: is it possible that chunking the "pft" dimension ahead of time just creates extra work for the unstack() operation?
yes it's forced to take the slow path. You'll want to call unstack
on a numpy-backed Xarray thing for max speed. You could use xarray.map_blocks
for this as long as you're not chunked along pft
You'll also want to pass only variables required by preprocess
to the map_blocks
call.
Hi Deepak,
I just want to make sure I followed your advice properly, but it looks like the slow behavior is unchanged with the following code. Did I do anything wrong here?
If there is any other way to produce sparse results that might be faster, I would be interested. Otherwise, we may decide to wait on producing these variables for the cloud.
def compute_index(ds): """Compute the transform from 1D to sparse 6D """ lats = list(ds.pfts1d_lat.astype('float32').data) lons = list(ds.pfts1d_lon.astype('float32').data) vegtype = list(ds.pfts1d_itype_veg.data) coltype = list(ds.pfts1d_itype_col.data) lunittype = list(ds.pfts1d_itype_lunit.data) active = list(ds.pfts1d_active.data) # Redefine the 'pft' dimension as a multi-index, which will increase the number of dimensions. index = pd.MultiIndex.from_arrays([lats, lons, vegtype, coltype, lunittype, active], names=('pftlat', 'pftlon', 'vegtype', 'coltype', 'lunittype', 'active')) return index def sparsify(chunk): chunk = chunk.unstack(sparse=True) return chunk def preprocess(ds): index = compute_index(ds) ds['pft'] = index # Drop all but one PFT-related variable to make conversion faster drop_vars = [var for var in ds.data_vars if var not in PFT_VARS] ds = ds.drop_vars(drop_vars) ds = ds.chunk(chunks={"time": 100}) # I also tried loading the data to see if it could reduce what looks like pulling small chunks from disk. # ds = ds.load() ds_fixed = xr.map_blocks(sparsify, ds) return ds_fixed
I'm surprised you could do map_blocks
without passing template
. Does ds_fixed
(before computing) look like what you expect? Is compute_index
slow? Does it help to move compute_index
in to the sparsify
call?
compute_index()
takes 3-4 minutes. It appears that the call to unstack()
is taking a long time (more than 3 hours for a single variable). The dask task graphs look like they are processing one timestep at a time, despite my efforts to chunk the dataset along the time dimension. When 50 workers are present, each worker cpu is at around 10%, which suggests to me that 90% of the time is spent waiting for values to come from disk. Of course, I'm not fully confident that I'm seeing things correctly.
I'm also not aware of how to check ds_fixed
. In the past when I tried putting in print statements in preprocess()
, I was not getting anything printed. Perhaps I should try again somehow?
If you want to put print
statements, use .compute(scheduler="single-threaded")
this will run it in serial. Otherwise, I think we should push this to the next office hours.
Another possible explanation for the 10% cpu per worker is that the unstack()
operation is being done serially, over a huge MultiIndex, and even with chunking in time, it is always going to be slow going through the MultiIndex in a serial way.
Perhaps unstack()
would be much faster if we can pull the data into memory all at once? I thought that ds = ds.load()
just before the unstack would force this, but it doesn't seem to change the speed for unstack()
. Is this a possible issue to raise, or am I misunderstanding something?
When you wrap a function in map_blocks
, at compute-time that function will receive an xarray object with data loaded into memory. so in theory it's using the fast-path for numpy objects.
You can test out the speed of unstacking by loading one timestep in to memory and profiling the unstack
call.
After reducing to one timestep, I'm still getting the same dask pattern and speed. Profile says that 96% of the time is spent in compute_index()
. Of that, 46% of the time is for slice_array()
and the rest is for factorize_for_iterables()
. The code now looks like this, I hope I have done things correctly and the intent is obvious:
def preprocess(ds): TARGET_CHUNKS = {'time': 10} PFT_VARS = ['TLAI'] index = compute_index(ds) ds['pft'] = index # Drop unneeded variables as soon as possible. drop_vars = [var for var in ds.data_vars if var not in PFT_VARS] ds = ds.drop_vars(drop_vars) ds = ds.load() ds = ds.chunk(chunks=TARGET_CHUNKS) for var in PFT_VARS: # Try limiting to one timestep. ds[var] = ds[var].isel(time=0) ds[var] = xr.map_blocks(sparsify, ds[var]) return ds
I think the profile results suggest that unstack() is looping over the MultiIndex and processing one element (of 800,000) at a time in order.
Last updated: May 16 2025 at 17:14 UTC