Debugging dask workflows: Detrending#

Detrending - subtracting a trend, commonly a linear fit, from the data - along the time dimension is a common workflow in the climate sciences.

Here’s an example

Hide code cell source
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr

time = np.arange(100)
data = time + 4 * np.sin(2 * np.pi / 10 * time)
da = xr.DataArray(data, dims="time", coords={"time": time})

# fit a straight line along time
p = da.polyfit(dim="time", deg=1, skipna=False)
# evaluate the linear trend as a timeseries
fit = xr.polyval(da["time"], p.polyfit_coefficients)


da.plot(label="data")
fit.plot(label="linear trend")
(da - fit).plot(label="detrended")
plt.legend()
<matplotlib.legend.Legend at 0x7fd46e54acd0>
../../../_images/e25cf220f15e78e7aa8d54f8e9d18a52e56c57a24b0f30bbd7bc7410543329f9.png

Detrending with dask is a consistent source of headaches:

  1. Polyfit performance on large datasets - Suboptimal dask task graph

  2. Unexpected chunking of 3d DataArray in polyfit()

  3. Memory errors with detrending + rolling operations

Challenges of detrending#

Detrending has 3 steps:

  1. Fit a polynomial using polyfit

  2. Evaluate the trend using polyval and the fit coefficients

  3. Subtract the trend from the original data

def detrend_dim(da, dim, deg=1):
    # detrend along a single dimension
    p = da.polyfit(dim=dim, deg=deg, skipna=False)
    fit = xr.polyval(da[dim], p.polyfit_coefficients)
    return da - fit

We usually want to subtract a linear trend in time. The main challenge is that model output is usually chunked with small chunksizes in time, and big chunksizes in space. Naively calling polyfit along time using xarray.apply_ufunc(numpy.polyfit, ..., dask="parallelized") would a require rechunking the dataset. This could be a good way to proceed if there are other time-series analyses that will be performed. In that case use rechunker.

DataArray.polyfit does actually support fitting along chunked dimensions. It uses dask.array.linalg.lstsq for an out-of-core parallel fitting algorithm.

However the limitation of lstsq is that it requires 2D arrays as input, but our model output is commonly 3D or 4D (time, depth, latitude, longitude). polyfit uses DataArray.stack to reshape all arrays to 2D, does the fit, and reshapes back. Easy, no?

Turns out… not easy. Reshaping in parallel can be really expensive! The dask docs illustrate this well. We have a possible solution but this hasn’t been implemented yet. It also turned out that this wasn’t the issue that affected Memory errors with detrending + rolling operations.

In this notebook, I explore the detrending problem a little bit to find what other challenges affect detrending workflows.

Summary#

  1. We use simple techniques like tracking chunk sizes and number of tasks to isolate a few problems, particularly in polyval.

  2. Fixing polyval doesn’t improve the calculation, so we then examine the dask graph for polyfit in a smaller artificial problem that replicates the main issues.

  3. This lets us trace the issue back to a serious regression in dask’s tensordot function.

  4. That issue has since been fixed (available in dask >= 2022.03.0).

  5. Now things work well!

Note

The first half of this notebook was built with dask==2021.12.0.

Setup#

import dask.array
import distributed
import numpy as np
import xarray as xr
from IPython.display import Image
client = distributed.Client()
client

Client

Client-da6eb586-ab59-11ec-8493-3af9d394f1c6

Connection method: Cluster object Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:8787/status

Cluster Info

Example detrending operation#

I’ll use the example dataset in Memory errors with detrending + rolling operations but make it smaller in time.

nt = 8772 // 4
ny = 489
nx = 655
# chunks like the data is stored on disk
# small in time, big in space
# because the chunk sizes are -1 along lat, lon;
# reshaping this array to (time, lat, lon) prior to fitting is pretty cheap
chunks = (8, -1, -1)

da = xr.DataArray(
    dask.array.random.random((nt, ny, nx), chunks=chunks),
    dims=("ocean_time", "eta_rho", "xi_rho"),
)
da
<xarray.DataArray 'random_sample-ffbb27b2e849e89719a8455875172787' (
                                                                    ocean_time: 2193,
                                                                    eta_rho: 489,
                                                                    xi_rho: 655)>
dask.array<random_sample, shape=(2193, 489, 655), dtype=float64, chunksize=(8, 489, 655), chunktype=numpy.ndarray>
Dimensions without coordinates: ocean_time, eta_rho, xi_rho
# Function to detrend
# Source: https://gist.github.com/rabernat/1ea82bb067c3273a6166d1b1f77d490f


def detrend_dim(da, dim, deg=1):
    """detrend along a single dimension."""
    # calculate polynomial coefficients
    p = da.polyfit(dim=dim, deg=deg, skipna=False)
    # evaluate trend
    fit = xr.polyval(da[dim], p.polyfit_coefficients)
    # remove the trend
    return da - fit


detrended = detrend_dim(da, dim="ocean_time")
detrended
<xarray.DataArray (ocean_time: 2193, eta_rho: 489, xi_rho: 655)>
dask.array<sub, shape=(2193, 489, 655), dtype=float64, chunksize=(8, 489, 655), chunktype=numpy.ndarray>
Coordinates:
  * ocean_time  (ocean_time) int64 0 1 2 3 4 5 ... 2187 2188 2189 2190 2191 2192
  * eta_rho     (eta_rho) int64 0 1 2 3 4 5 6 7 ... 482 483 484 485 486 487 488
  * xi_rho      (xi_rho) int64 0 1 2 3 4 5 6 7 ... 648 649 650 651 652 653 654

Warning

The dask dashboard screenshot for the following cell was captured with dask 2021.12.0.

detrended.compute()

Calling compute doesn’t work so well and that single tensordot task looks funny. We have many input chunks and many output chunks. It’s not good if there’s a single task that might be a bottleneck by accumulating many input chunks.

Hide code cell source
Image("../images/detrend-dask-dashboard-1.png")
../../../_images/e8748b7355eb2bda1a38dc0cd79eac7368bf41ddaebbda7b422bc6ccfe7b7a93.png

What is going wrong?

Lets look at detrend#

Detrend has 3 steps:

  1. polyfit

  2. polyval

  3. subtract

def detrend_dim(da, dim, deg=1):
    # detrend along a single dimension
    p = da.polyfit(dim=dim, deg=deg, skipna=False)
    fit = xr.polyval(da[dim], p.polyfit_coefficients)
    return da - fit

First lets try just polyfit.

# these are arguments provided to detrend
dim = "ocean_time"
deg = 1

p = da.polyfit(dim=dim, deg=deg, skipna=False)
p
<xarray.Dataset>
Dimensions:               (eta_rho: 489, xi_rho: 655, degree: 2)
Coordinates:
  * eta_rho               (eta_rho) int64 0 1 2 3 4 5 ... 484 485 486 487 488
  * xi_rho                (xi_rho) int64 0 1 2 3 4 5 ... 649 650 651 652 653 654
  * degree                (degree) int64 1 0
Data variables:
    polyfit_coefficients  (degree, eta_rho, xi_rho) float64 dask.array<chunksize=(2, 489, 655), meta=np.ndarray>

Seems OK, by which I mean input chunks == output chunks.

How about polyval?

fit = xr.polyval(da[dim], p.polyfit_coefficients)
fit
<xarray.DataArray (ocean_time: 2193, eta_rho: 489, xi_rho: 655)>
dask.array<sum-aggregate, shape=(2193, 489, 655), dtype=float64, chunksize=(2193, 489, 655), chunktype=numpy.ndarray>
Coordinates:
  * ocean_time  (ocean_time) int64 0 1 2 3 4 5 ... 2187 2188 2189 2190 2191 2192
  * eta_rho     (eta_rho) int64 0 1 2 3 4 5 6 7 ... 482 483 484 485 486 487 488
  * xi_rho      (xi_rho) int64 0 1 2 3 4 5 6 7 ... 648 649 650 651 652 653 654

Aha! Now suddenly we have one giant chunk for the entire array!

Hide code cell source
Image("../images/dask-repr.png")
../../../_images/044484cfb7e1241863c106d22916670cc88306c5547c684b1b1de8549a411b87.png

Tip

The first step to checking your dask workflow is to look for sudden and large changes in number of tasks or chunk size. Pay attention to the dask repr.

Building a better polyval#

To do this, we need to look closely at the code of polyval.

def polyval(coord, coeffs, degree_dim="degree"):
    from .dataarray import DataArray
    from .missing import get_clean_interp_index


    x = get_clean_interp_index(coord, coord.name, strict=False)


    deg_coord = coeffs[degree_dim]


    lhs = DataArray(
        np.vander(x, int(deg_coord.max()) + 1),
        dims=(coord.name, degree_dim),
        coords={coord.name: coord, degree_dim: np.arange(deg_coord.max() + 1)[::-1]},
    )
    return (lhs * coeffs).sum(degree_dim)

We call this using

    fit = xr.polyval(da[dim], p.polyfit_coefficients)

The core issue is that the da[dim] coord is unchunked. So lhs is a numpy array, which when combined with the coeffs (dask array) gives us a dask array with a single 5GB chunk.

da[dim]
<xarray.DataArray 'ocean_time' (ocean_time: 2193)>
array([   0,    1,    2, ..., 2190, 2191, 2192])
Dimensions without coordinates: ocean_time

Let’s instead create a chunked “ocean_time” coordinate. This is a little roundabout because Xarray doesn’t let you chunk dimension coordinates using .chunk.

# first create a chunked version of the "ocean_time" dimension
chunked_dim = xr.DataArray(
    dask.array.from_array(da[dim].data, chunks=da.chunksizes[dim]), dims=dim, name=dim
)
chunked_dim
<xarray.DataArray 'ocean_time' (ocean_time: 2193)>
dask.array<array, shape=(2193,), dtype=int64, chunksize=(8,), chunktype=numpy.ndarray>
Dimensions without coordinates: ocean_time

This little bit is a rewritten version of polyval that can handle a chunked coord as input.

Note

This function will not work properly for an actual time coordinate. Here ocean_time is just ints. If it were a datetime vector, we would need to convert that to a vector of floats or ints.

def polyval(coord, coeffs, degree_dim="degree"):
    x = coord.data

    deg_coord = coeffs[degree_dim]
    N = int(deg_coord.max()) + 1

    lhs = xr.DataArray(
        np.stack([x ** (N - 1 - i) for i in range(N)], axis=1),
        dims=(coord.name, degree_dim),
        coords={coord.name: coord, degree_dim: np.arange(deg_coord.max() + 1)[::-1]},
    )
    return (lhs * coeffs).sum(degree_dim)


fit = polyval(chunked_dim, p.polyfit_coefficients)
fit
<xarray.DataArray (ocean_time: 2193, eta_rho: 489, xi_rho: 655)>
dask.array<sum-aggregate, shape=(2193, 489, 655), dtype=float64, chunksize=(8, 489, 655), chunktype=numpy.ndarray>
Coordinates:
  * ocean_time  (ocean_time) int64 0 1 2 3 4 5 ... 2187 2188 2189 2190 2191 2192
  * eta_rho     (eta_rho) int64 0 1 2 3 4 5 6 7 ... 482 483 484 485 486 487 488
  * xi_rho      (xi_rho) int64 0 1 2 3 4 5 6 7 ... 648 649 650 651 652 653 654

An older dask version#

Warning

The following cells describe an issue present for dask>2.30.0 and dask < 2022.03.0.

This is a much better fit but we still have issues computing it.

fit.compute()
Hide code cell source
Image("../images/detrend-dask-dashboard-2.png")
../../../_images/21afdc28f8f0ed1c68855f5f61763253c4d2880a65302f67ed5008677bf854ea.png

Hmm…. that single tensordot task still looks problematic. It is probably why we see very skewed memory consumption (mostly one worker).

Sadly and surprisingly, that wasn’t fixed by our new polyval!

Let’s look at polyfit again#

A dask bug#

But now for a much smaller problem.

Tip

Visualizing the dask graph can sometimes be instructive but really only works for small problems. Usually we want to see long parallel chains.

nt = 24
ny = 489
nx = 655
chunks = (8, -1, -1)

smol_da = xr.DataArray(
    dask.array.random.random((nt, ny, nx), chunks=chunks),
    dims=("ocean_time", "eta_rho", "xi_rho"),
)
smol_da
<xarray.DataArray 'random_sample-267f8112b7b54af2dbd898356bfcb6f7' (
                                                                    ocean_time: 24,
                                                                    eta_rho: 489,
                                                                    xi_rho: 655)>
dask.array<random_sample, shape=(24, 489, 655), dtype=float64, chunksize=(8, 489, 655), chunktype=numpy.ndarray>
Dimensions without coordinates: ocean_time, eta_rho, xi_rho
p = smol_da.polyfit(dim="ocean_time", deg=1, skipna=False)
dask.visualize(p, rankdir="LR")
../../../_images/9821c8fb7ac49fcecef2f21d473c4fbfeae407bb25a82942646892d0d1abcdd8.png

Ouch. we see all the tasks go through a single tensordot task.

This cannot work well. It effectively means that one worker is going to receive all time steps as a single chunk, which implies a lot of memory use and network transfer.

Searching the dask issue tracker leads to this issue so we were stuck.

Is it fixed?#

Thanks to some upstream work in dask, we expect the tensordot problem is fixed. So let’s try it again.

Upgrading dask to dask==2022.03.0 shows a much better graph.

dask.__version__
'2022.03.0'
p = smol_da.polyfit(dim="ocean_time", deg=1, skipna=False)
dask.visualize(p, rankdir="LR")
../../../_images/4a992587ceb07bab24ce07dfb12566eb83f0025c06b071437bf777e9ddff1284.png

Yes! we now see 3 tensordot tasks for this smaller problem instead of 1 in the previous buggy version.

Attempt 2#

# these are arguments provided to detrend
dim = "ocean_time"
deg = 1

p = da.polyfit(dim=dim, deg=deg, skipna=False)
p
<xarray.Dataset>
Dimensions:               (eta_rho: 489, xi_rho: 655, degree: 2)
Coordinates:
  * eta_rho               (eta_rho) int64 0 1 2 3 4 5 ... 484 485 486 487 488
  * xi_rho                (xi_rho) int64 0 1 2 3 4 5 ... 649 650 651 652 653 654
  * degree                (degree) int64 1 0
Data variables:
    polyfit_coefficients  (degree, eta_rho, xi_rho) float64 dask.array<chunksize=(2, 489, 655), meta=np.ndarray>

Computing works very well.

p.compute()
<xarray.Dataset>
Dimensions:               (eta_rho: 489, xi_rho: 655, degree: 2)
Coordinates:
  * eta_rho               (eta_rho) int64 0 1 2 3 4 5 ... 484 485 486 487 488
  * xi_rho                (xi_rho) int64 0 1 2 3 4 5 ... 649 650 651 652 653 654
  * degree                (degree) int64 1 0
Data variables:
    polyfit_coefficients  (degree, eta_rho, xi_rho) float64 1.756e-06 ... 0.5212
fit = polyval(chunked_dim, p.polyfit_coefficients)
fit
<xarray.DataArray (ocean_time: 2193, eta_rho: 489, xi_rho: 655)>
dask.array<sum-aggregate, shape=(2193, 489, 655), dtype=float64, chunksize=(8, 489, 655), chunktype=numpy.ndarray>
Coordinates:
  * ocean_time  (ocean_time) int64 0 1 2 3 4 5 ... 2187 2188 2189 2190 2191 2192
  * eta_rho     (eta_rho) int64 0 1 2 3 4 5 6 7 ... 482 483 484 485 486 487 488
  * xi_rho      (xi_rho) int64 0 1 2 3 4 5 6 7 ... 648 649 650 651 652 653 654
fit.compute()
<xarray.DataArray (ocean_time: 2193, eta_rho: 489, xi_rho: 655)>
array([[[0.5057565 , 0.50509786, 0.5102827 , ..., 0.49117019,
         0.51031117, 0.50260723],
        [0.50364757, 0.4991787 , 0.49981748, ..., 0.50275095,
         0.49462038, 0.51076364],
        [0.48684994, 0.49536617, 0.48945836, ..., 0.49770676,
         0.49409571, 0.51575231],
        ...,
        [0.50350304, 0.49995467, 0.48793556, ..., 0.49612795,
         0.50501344, 0.49929657],
        [0.52682289, 0.50002593, 0.5005749 , ..., 0.49589506,
         0.49912113, 0.49788316],
        [0.50034415, 0.50321058, 0.49636848, ..., 0.49642915,
         0.4738819 , 0.5212416 ]],

       [[0.50575825, 0.50509024, 0.51028026, ..., 0.49118142,
         0.51029671, 0.50260979],
        [0.50365382, 0.49917254, 0.49981401, ..., 0.50274792,
         0.49462967, 0.51075928],
        [0.48685308, 0.49536343, 0.48946273, ..., 0.49770946,
         0.49410785, 0.51574038],
...
        [0.47291   , 0.49901151, 0.49777905, ..., 0.4923025 ,
         0.49520834, 0.50307457],
        [0.48737022, 0.48375397, 0.51454096, ..., 0.49594455,
         0.5089312 , 0.48435021],
        [0.51752643, 0.50832082, 0.48726435, ..., 0.48644143,
         0.50965033, 0.49448089]],

       [[0.50960524, 0.48840649, 0.50493417, ..., 0.51578852,
         0.47861806, 0.5082273 ],
        [0.51734334, 0.48567938, 0.49220683, ..., 0.49611448,
         0.51498191, 0.50122301],
        [0.49373412, 0.48936224, 0.49903297, ..., 0.50361574,
         0.52071202, 0.48960257],
        ...,
        [0.47289604, 0.49901108, 0.49778354, ..., 0.49230075,
         0.49520386, 0.50307629],
        [0.48735222, 0.48374655, 0.51454733, ..., 0.49594458,
         0.50893568, 0.48434403],
        [0.51753427, 0.50832316, 0.48726019, ..., 0.48643687,
         0.50966666, 0.49446867]]])
Coordinates:
  * ocean_time  (ocean_time) int64 0 1 2 3 4 5 ... 2187 2188 2189 2190 2191 2192
  * eta_rho     (eta_rho) int64 0 1 2 3 4 5 6 7 ... 482 483 484 485 486 487 488
  * xi_rho      (xi_rho) int64 0 1 2 3 4 5 6 7 ... 648 649 650 651 652 653 654

This ends up working well!

Hide code cell source
Image("../images/detrend-dask-dashboard-3.png")
../../../_images/5f45f027691ebe1eaeed2f5e90b4194b2c1a957403983b2eadf7e28ffd0275f1.png

Finally#

Here’s everything put together:

def polyval(coord, coeffs, degree_dim="degree"):
    x = coord.data

    deg_coord = coeffs[degree_dim]
    N = int(deg_coord.max()) + 1

    lhs = xr.DataArray(
        np.stack([x ** (N - 1 - i) for i in range(N)], axis=1),
        dims=(coord.name, degree_dim),
        coords={coord.name: coord, degree_dim: np.arange(deg_coord.max() + 1)[::-1]},
    )
    return (lhs * coeffs).sum(degree_dim)


# Function to detrend
# Modified from source: https://gist.github.com/rabernat/1ea82bb067c3273a6166d1b1f77d490f
def detrend_dim(da, dim, deg=1):
    """detrend along a single dimension."""
    # calculate polynomial coefficients
    p = da.polyfit(dim=dim, deg=deg, skipna=False)
    # first create a chunked version of the "ocean_time" dimension
    chunked_dim = xr.DataArray(
        dask.array.from_array(da[dim].data, chunks=da.chunksizes[dim]),
        dims=dim,
        name=dim,
    )
    fit = polyval(chunked_dim, p.polyfit_coefficients)
    # evaluate trend
    # remove the trend
    return da - fit


detrended = detrend_dim(da, dim="ocean_time")
detrended
<xarray.DataArray (ocean_time: 2193, eta_rho: 489, xi_rho: 655)>
dask.array<sub, shape=(2193, 489, 655), dtype=float64, chunksize=(8, 489, 655), chunktype=numpy.ndarray>
Coordinates:
  * ocean_time  (ocean_time) int64 0 1 2 3 4 5 ... 2187 2188 2189 2190 2191 2192
  * eta_rho     (eta_rho) int64 0 1 2 3 4 5 6 7 ... 482 483 484 485 486 487 488
  * xi_rho      (xi_rho) int64 0 1 2 3 4 5 6 7 ... 648 649 650 651 652 653 654
detrended.compute()
Hide code cell source
Image("../images/detrend-dask-dashboard-4.png")
../../../_images/1ea683a7d495e745bde909755b6575089b8ad19076742b6cd8a826438b295bbe.png

It works!

Summary#

In this notebook, I explore the detrending problem a little bit to find what other challenges affect detrending workflows.

  1. We use simple techniques like tracking chunk sizes and number of tasks to isolate a few problems.

  2. Fixing those doesn’t improve the calculation, so we then examine the dask graph for a smaller artificial problem that replicates the main issues.

  3. This lets us trace the issue back to a serious regression in dask.

  4. That issue has now been fixed.

  5. And things work!

Future work#

  1. Some improvements to xarray.polyval are needed so that it can handle a chunked coordinate as input.

  2. Dask could also detect that a giant 5GB chunk is constructed by this line in polyval:

    (lhs * coeffs).sum(degree_dim)
    

    and automatically avoid this problem by rechunking lhs prior to multiplying.