Stream: dask

Topic: parallelizing via dask.delayed


view this post on Zulip Michael Levy (Apr 29 2022 at 22:37):

I have a class that has a .run(nsteps) method and that method creates / populates self._ds (an xarray Dataset). So large_obj._ds does not exist when you initialize large_obj, but it gets created when you run large_obj.run(nsteps).

Class objects have datasets with an X dimension, and instead of just having a single object where X=81000, I've constructed a list of 9 objects each with X=9000. I was thinking that I could just do

tmp_list=[]
for smaller_obj in list_of_objs:
    tmp_list.append(dask.delayed(smaller_obj.run)(nsteps)
dask.compute(*tmp_list)

I'm running into two issues:

  1. The dask.compute() call is returning a CancelledError; the entire traceback is in the dask module, so if something in my code is causing this cancellation I don't know where it is
  2. If I run dask.compute(tmp_list[0]) then the task completes in ~1/9 the time of running large_obj.run(nsteps), but list_of_objs[0]._ds does not exist

At this point, I'm more concerned about the second issue than the first -- I thought I was just making a list of tasks for dask to distribute over the cluster, but that's not quite what is happening because otherwise I'd be able to access the ._ds Dataset.

I wondered if maybe a copy was happening with smaller_obj, so I also tried

tmp_list=[]
for n in range(len(list_of_objs)):
    tmp_list.append(dask.delayed(list_of_objs[n].run)(nsteps)
dask.compute(*tmp_list)

But ran into the same issue. Does anyone here have experience with running something like list_of_objs[:].run(nstep) in parallel?

view this post on Zulip Deepak Cherian (Apr 30 2022 at 03:02):

For (2) you probably have to save the return value from dask.compute. If you're using distributed, it communicates over the network interface and ends up copying objects. So the thing you send is not the thing you get back.

For (1) I would try changing the dask version either up or down. Things are. moving. quite rapidly, so it could just be a bug.

view this post on Zulip Michael Levy (May 02 2022 at 14:54):

I updated to the latest dask (I think I went from a Dec 2021 release to Apr 2022), and the problem persisted. I tried to make a simple example that mimicked the structure of my code, but it ran fine... some it seems likely the issue is in code I wrote. Bummer.

view this post on Zulip Michael Levy (May 02 2022 at 15:05):

Although I realize I was running my test case on a (casper) login node and the actual notebook on a compute node -- if I start on the login node for the actual code, I get StreamClosedError instead of the CancelledError; in that case, it does look like the jobs are being sent out to the cluster, and the error is thrown when the tasks return results

view this post on Zulip Deepak Cherian (May 02 2022 at 16:31):

Can you run a small subset of your problem in serial?

with dask.config.set(scheduler="sync"):
    dask.compute(...)

This is just a for-loop over tasks and lets you check for errors. Since dask.compute(tmp_list[0]) succeeded maybe this won't help but perhaps the problem is not in the first task but some other one?

view this post on Zulip Michael Levy (May 02 2022 at 16:33):

@Deepak Cherian I can try that, but currently I'm running with

for n in range(len(feisty_drivers)):
    tmp_output.append(dask.delayed(feisty_drivers[n].run)(nsteps,return_ds=True))
results = dask.persist(*tmp_output)
for n, tmp in enumerate(results):
    feisty_drivers[n]._ds = tmp.compute()

and getting multiple tasks sent to the schedule at once. It seems weird to combine the dask.persist and .compute() calls, but I'm getting ~3x speed-up when running on a single node with 9 workers, and I can run 5 years without a problem. I did run into an issue with a longer run -- when I tried a 9 year run, I started seeing

distributed.scheduler - ERROR - Couldn't gather keys

view this post on Zulip Michael Levy (May 02 2022 at 16:54):

Yup, calling dask.compute(*tmp_output) with dask.config.set(scheduler="sync") ran fine (and even dumped output from my .run() function to stdout instead of losing it to the void / log files that I never look at)

view this post on Zulip Deepak Cherian (May 02 2022 at 17:02):

and does results = dask.persist(*tmp_output) succeed? or is that when the errors start? Either way it looks like something weird in distributed.

view this post on Zulip Michael Levy (May 02 2022 at 17:09):

results = dask.persist(*tmp_output) succeeds, but it returns a list of Delayed objects; I know the Delayed objects are all lists of Datasets, but I've been using the feisty_drivers[n]._ds = tmp.compute() to actually access the lists. Is there a better option on that end? tmp.load()?

view this post on Zulip Deepak Cherian (May 02 2022 at 19:05):

So is feisty_drivers[n]._ds = tmp.compute() where it fails?

view this post on Zulip Michael Levy (May 02 2022 at 19:18):

Sorry, it got confusing since I was testing a few different things out while trying to keep this thread up-to-date. Hopefully this summary will make sense... consider

def methodA():
    for n in range(len(feisty_drivers)):
        tmp_output.append(dask.delayed(feisty_drivers[n].run)(nsteps,return_ds=True))
    results = dask.compute(*tmp_output)
    for n in range(len(results)):
        feisty_drivers[n]._ds = results[n]

def methodB():
    for n in range(len(feisty_drivers)):
        tmp_output.append(dask.delayed(feisty_drivers[n].run)(nsteps,return_ds=True))
    results = dask.persist(*tmp_output)
    for n, tmp in enumerate(results):
        feisty_drivers[n]._ds = tmp.compute()

view this post on Zulip Michael Levy (May 02 2022 at 19:19):

and yes, the error is always in the results = dask.compute(*tmp_output) or feisty_drivers[n]._ds = tmp.compute() line; dask.persist() seems to be okay.

view this post on Zulip Matt Long (May 03 2022 at 12:24):

Not sure I follow everything, but perhaps wrapping the driver in a function would help?

def run_feisty(obj, **kwargs):
  obj.run(...)
  return obj

view this post on Zulip Matt Long (May 03 2022 at 12:24):

then run dask.delayed on that function?

view this post on Zulip Matt Long (May 03 2022 at 12:51):

Or do you need to scatter first?

objs_delayed = [client.scatter(obj) for obj in objs]

@dask.delayed
def run_feisty(obj, **kwargs):
  obj.run(...)
  return obj

view this post on Zulip Michael Levy (May 03 2022 at 16:04):

@Matt Long one idea that came out of talking with @Deepak Cherian was to try to use xr.map_blocks -- map_blocks() will pass numpy arrays through even if they aren't dask-ified, so it will actually turn into a much cleaner interface. This is still a work in progress, but I think I'll just need to add a parallel argument to .run() that will chunk the forcing / domain / state_t datasets (and maybe refactor _solve() some)

view this post on Zulip Michael Levy (May 03 2022 at 16:05):

Using dask inself._ds is preferable to splitting that dataset into several small ones and then trying to concatenate them on the back end (in some of my tests, the xr.concat() was eating into a significant portion of the time savings from parallelization)


Last updated: May 16 2025 at 17:14 UTC