Skip to content

Best practice for applying function over 3D rolling windows Xarray with Dask backend #7505

@GrahamReveley

Description

@GrahamReveley

Apply custom function over 3D rolling window with dask backend taking extremely long time, using very little CPU and a lot of RAM

I'm attempting to run a 3D rolling window operation on a 3D DataArray (time, lat, lon) with a custom function operating on each of the rolling windows.

An example can be seen below:

arr [time:300, lat: 313, lon 400]

chunked 20x20 in lat and lon and -1 in time

arr_rolling = arr.rolling(dim={"lon":3,"lat":3,"tdm_members":num_times},center=True).construct(lon="wlon",lat="wlat",time="wtime")

output = xr.apply_ufunc(
    func,
    arr,
    input_core_dims=[["wlon","wlat","wtime"]],
    output_core_dims=[[]],
    vectorize=True,
    output_dtypes=[float],
    dask="parallelized",
)
)

The function essentially outputs a single value at the central lat, lon point of the rolling window, hence I'm reducing over the window dimensions.

I've tried the rolling().reduce() combination with little success, but could be doing something wrong there.

I'm running this on a dask LocalCluster and I'm seeing very low CPU usage from the workers (<10%) and high disk read/write stats from them also, suggesting a lot of reordering of data. The processing of this array takes a very long time even when running on a large cluster and memory usage and runs more efficiently (and eventually completes) when the input array is a pure numpy array rather than dask.

Eventually the code produces the following error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/tmp/ipykernel_126064/3685643056.py in <module>
----> 1 x = output.compute()

~/.local/lib/python3.8/site-packages/xarray/core/dataarray.py in compute(self, **kwargs)
   1090         """
   1091         new = self.copy(deep=False)
-> 1092         return new.load(**kwargs)
   1093 
   1094     def persist(self: T_DataArray, **kwargs) -> T_DataArray:

~/.local/lib/python3.8/site-packages/xarray/core/dataarray.py in load(self, **kwargs)
   1064         dask.compute
   1065         """
-> 1066         ds = self._to_temp_dataset().load(**kwargs)
   1067         new = self._from_temp_dataset(ds)
   1068         self._variable = new._variable

~/.local/lib/python3.8/site-packages/xarray/core/dataset.py in load(self, **kwargs)
    737 
    738             # evaluate all the dask arrays simultaneously
--> 739             evaluated_data = da.compute(*lazy_data.values(), **kwargs)
    740 
    741             for k, data in zip(lazy_data, evaluated_data):
...
--> 366         raise TypeError(msg, str(x)[:10000])
    367     else:  # pragma: nocover
    368         raise ValueError(f"{on_error=}; expected 'message' or 'raise'")

TypeError: ('Long error message', "('Could not serialize object of type ndarray', '[[[[[[[nan nan nan ...  0.  0.  0.]\\n 
.......

I'd like to leverage dask's parallelisation for this task to speed it up. Has anyone else had this problem and has overcome it? Would love some suggestions! Thanks in advance!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions