Stream: xarray

Topic: How to speed up unstack(sparse=True)


view this post on Zulip Brian Bonnlander (Jul 01 2021 at 21:27):

I have a large pandas MultiIndex (800,000 rows) that is taking a long time to unstack to sparse form.

### Preprocessing Steps for each input dataset before merge
def preprocess(ds):
    """Pare down each input dataset to a single variable.
       The subsequent merge will eliminate unused coordinates automatically.

        This function does not allow additional arguments, so the target
        output variable needs to be defined globally in TARGET_VAR.
    """
    drop_vars = [var for var in ds.data_vars
                 if var != TARGET_VAR]

    ds_fixed = ds.drop_vars(drop_vars)

    lats = list(ds.pfts1d_lat.astype('float32').data)
    lons = list(ds.pfts1d_lon.astype('float32').data)
    vegtype = list(ds.pfts1d_itype_veg.data)
    coltype = list(ds.pfts1d_itype_col.data)
    lunittype = list(ds.pfts1d_itype_lunit.data)
    active = list(ds.pfts1d_active.data)

    # Redefine the 'pft' dimension as a multi-index, which will increase the number of dimensions.
    index = pd.MultiIndex.from_arrays([lats, lons, vegtype, coltype, lunittype, active],
                                  names=('pftlat', 'pftlon', 'vegtype', 'coltype', 'lunittype', 'active'))
    ds_fixed['pft'] = index


    # Keep the data sparse if possible to avoid memory shortages.
    ds_fixed = ds_fixed.unstack(sparse=True)
    return ds_fixed

Does anyone know if there is a way to speed up unstack() with Dask?

view this post on Zulip Brian Bonnlander (Jul 01 2021 at 21:37):

I should add that the source data comes from a Zarr store where I tried to chunk the 'pft' dimension (the one being unstacked) into chunks of 10,000 elements each. I don't know if this helps or hurts.

view this post on Zulip Deepak Cherian (Jul 01 2021 at 21:39):

No. the fast way requires "advanced indexing" over multiple dimensions which is not supported by both dask and sparse. See discussion here: https://github.com/pydata/xarray/blob/c472f8a4c79f872edb9dcd7825f786ecb9aff5c0/xarray/core/dataset.py#L4129-L4152

view this post on Zulip Brian Bonnlander (Jul 01 2021 at 21:40):

Thanks for that info! Much appreciated.

view this post on Zulip Brian Bonnlander (Jul 01 2021 at 21:48):

One more related question: is it possible that chunking the "pft" dimension ahead of time just creates extra work for the unstack() operation?

view this post on Zulip Deepak Cherian (Jul 01 2021 at 21:52):

yes it's forced to take the slow path. You'll want to call unstack on a numpy-backed Xarray thing for max speed. You could use xarray.map_blocks for this as long as you're not chunked along pft You'll also want to pass only variables required by preprocess to the map_blocks call.

view this post on Zulip Brian Bonnlander (Jul 02 2021 at 13:09):

Hi Deepak,

I just want to make sure I followed your advice properly, but it looks like the slow behavior is unchanged with the following code. Did I do anything wrong here?

If there is any other way to produce sparse results that might be faster, I would be interested. Otherwise, we may decide to wait on producing these variables for the cloud.

def compute_index(ds):
    """Compute the transform from 1D to sparse 6D
    """
    lats = list(ds.pfts1d_lat.astype('float32').data)
    lons = list(ds.pfts1d_lon.astype('float32').data)
    vegtype = list(ds.pfts1d_itype_veg.data)
    coltype = list(ds.pfts1d_itype_col.data)
    lunittype = list(ds.pfts1d_itype_lunit.data)
    active = list(ds.pfts1d_active.data)

    # Redefine the 'pft' dimension as a multi-index, which will increase the number of dimensions.
    index = pd.MultiIndex.from_arrays([lats, lons, vegtype, coltype, lunittype, active],
                                  names=('pftlat', 'pftlon', 'vegtype', 'coltype', 'lunittype', 'active'))

    return index

def sparsify(chunk):
    chunk = chunk.unstack(sparse=True)
    return chunk

def preprocess(ds):
    index = compute_index(ds)
    ds['pft'] = index

    # Drop all but one PFT-related variable to make conversion faster
    drop_vars = [var for var in ds.data_vars
                 if var not in PFT_VARS]
    ds = ds.drop_vars(drop_vars)

    ds = ds.chunk(chunks={"time": 100})

   # I also tried loading the data to see if it could reduce what looks like pulling small chunks from disk.
   #  ds = ds.load()
    ds_fixed = xr.map_blocks(sparsify, ds)

    return ds_fixed

view this post on Zulip Deepak Cherian (Jul 02 2021 at 15:51):

I'm surprised you could do map_blocks without passing template. Does ds_fixed (before computing) look like what you expect? Is compute_index slow? Does it help to move compute_index in to the sparsify call?

view this post on Zulip Brian Bonnlander (Jul 02 2021 at 16:16):

compute_index() takes 3-4 minutes. It appears that the call to unstack() is taking a long time (more than 3 hours for a single variable). The dask task graphs look like they are processing one timestep at a time, despite my efforts to chunk the dataset along the time dimension. When 50 workers are present, each worker cpu is at around 10%, which suggests to me that 90% of the time is spent waiting for values to come from disk. Of course, I'm not fully confident that I'm seeing things correctly.

view this post on Zulip Brian Bonnlander (Jul 02 2021 at 16:18):

I'm also not aware of how to check ds_fixed. In the past when I tried putting in print statements in preprocess(), I was not getting anything printed. Perhaps I should try again somehow?

view this post on Zulip Deepak Cherian (Jul 02 2021 at 16:19):

If you want to put print statements, use .compute(scheduler="single-threaded") this will run it in serial. Otherwise, I think we should push this to the next office hours.

view this post on Zulip Brian Bonnlander (Jul 02 2021 at 18:48):

Another possible explanation for the 10% cpu per worker is that the unstack() operation is being done serially, over a huge MultiIndex, and even with chunking in time, it is always going to be slow going through the MultiIndex in a serial way.

Perhaps unstack() would be much faster if we can pull the data into memory all at once? I thought that ds = ds.load() just before the unstack would force this, but it doesn't seem to change the speed for unstack(). Is this a possible issue to raise, or am I misunderstanding something?

view this post on Zulip Deepak Cherian (Jul 02 2021 at 19:02):

When you wrap a function in map_blocks , at compute-time that function will receive an xarray object with data loaded into memory. so in theory it's using the fast-path for numpy objects.

You can test out the speed of unstacking by loading one timestep in to memory and profiling the unstack call.

view this post on Zulip Brian Bonnlander (Jul 02 2021 at 20:23):

After reducing to one timestep, I'm still getting the same dask pattern and speed. Profile says that 96% of the time is spent in compute_index(). Of that, 46% of the time is for slice_array() and the rest is for factorize_for_iterables(). The code now looks like this, I hope I have done things correctly and the intent is obvious:

def preprocess(ds):
    TARGET_CHUNKS = {'time': 10}
    PFT_VARS = ['TLAI']

    index = compute_index(ds)
    ds['pft'] = index

    # Drop unneeded variables as soon as possible.
    drop_vars = [var for var in ds.data_vars
                 if var not in PFT_VARS]
    ds = ds.drop_vars(drop_vars)

    ds = ds.load()

    ds = ds.chunk(chunks=TARGET_CHUNKS)

    for var in PFT_VARS:
        # Try limiting to one timestep.
        ds[var] = ds[var].isel(time=0)
        ds[var] = xr.map_blocks(sparsify, ds[var])

    return ds

view this post on Zulip Brian Bonnlander (Jul 02 2021 at 20:28):

I think the profile results suggest that unstack() is looping over the MultiIndex and processing one element (of 800,000) at a time in order.


Last updated: May 16 2025 at 17:14 UTC