-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Description
Apply custom function over 3D rolling window with dask backend taking extremely long time, using very little CPU and a lot of RAM
I'm attempting to run a 3D rolling window operation on a 3D DataArray (time, lat, lon) with a custom function operating on each of the rolling windows.
An example can be seen below:
arr [time:300, lat: 313, lon 400]
chunked 20x20 in lat and lon and -1 in time
arr_rolling = arr.rolling(dim={"lon":3,"lat":3,"tdm_members":num_times},center=True).construct(lon="wlon",lat="wlat",time="wtime")
output = xr.apply_ufunc(
func,
arr,
input_core_dims=[["wlon","wlat","wtime"]],
output_core_dims=[[]],
vectorize=True,
output_dtypes=[float],
dask="parallelized",
)
)
The function essentially outputs a single value at the central lat, lon point of the rolling window, hence I'm reducing over the window dimensions.
I've tried the rolling().reduce() combination with little success, but could be doing something wrong there.
I'm running this on a dask LocalCluster and I'm seeing very low CPU usage from the workers (<10%) and high disk read/write stats from them also, suggesting a lot of reordering of data. The processing of this array takes a very long time even when running on a large cluster and memory usage and runs more efficiently (and eventually completes) when the input array is a pure numpy array rather than dask.
Eventually the code produces the following error:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
/tmp/ipykernel_126064/3685643056.py in <module>
----> 1 x = output.compute()
~/.local/lib/python3.8/site-packages/xarray/core/dataarray.py in compute(self, **kwargs)
1090 """
1091 new = self.copy(deep=False)
-> 1092 return new.load(**kwargs)
1093
1094 def persist(self: T_DataArray, **kwargs) -> T_DataArray:
~/.local/lib/python3.8/site-packages/xarray/core/dataarray.py in load(self, **kwargs)
1064 dask.compute
1065 """
-> 1066 ds = self._to_temp_dataset().load(**kwargs)
1067 new = self._from_temp_dataset(ds)
1068 self._variable = new._variable
~/.local/lib/python3.8/site-packages/xarray/core/dataset.py in load(self, **kwargs)
737
738 # evaluate all the dask arrays simultaneously
--> 739 evaluated_data = da.compute(*lazy_data.values(), **kwargs)
740
741 for k, data in zip(lazy_data, evaluated_data):
...
--> 366 raise TypeError(msg, str(x)[:10000])
367 else: # pragma: nocover
368 raise ValueError(f"{on_error=}; expected 'message' or 'raise'")
TypeError: ('Long error message', "('Could not serialize object of type ndarray', '[[[[[[[nan nan nan ... 0. 0. 0.]\\n
.......
I'd like to leverage dask's parallelisation for this task to speed it up. Has anyone else had this problem and has overcome it? Would love some suggestions! Thanks in advance!