From 776bc5a6f75a7b98199aee22db4e940d7a497c8f Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 6 Dec 2025 14:44:06 +0100 Subject: [PATCH 01/35] use cumsum from flox --- xarray/core/_aggregations.py | 43 +++++++++++++++--- xarray/core/groupby.py | 88 +++++++++++++++++++++++++++++------- 2 files changed, 108 insertions(+), 23 deletions(-) diff --git a/xarray/core/_aggregations.py b/xarray/core/_aggregations.py index adc064840de..75e11d41e1c 100644 --- a/xarray/core/_aggregations.py +++ b/xarray/core/_aggregations.py @@ -6647,6 +6647,13 @@ def _flox_reduce( ) -> DataArray: raise NotImplementedError() + def _flox_scan( + self, + dim: Dims, + **kwargs: Any, + ) -> DataArray: + raise NotImplementedError() + def count( self, dim: Dims = None, @@ -7904,13 +7911,35 @@ def cumsum( * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) tuple[Hashable, ...]: + parsed_dim: tuple[Hashable, ...] + if isinstance(dim, str): + parsed_dim = (dim,) + elif dim is None: + parsed_dim_list = list() + # preserve order + for dim_ in itertools.chain( + *(grouper.codes.dims for grouper in self.groupers) + ): + if dim_ not in parsed_dim_list: + parsed_dim_list.append(dim_) + parsed_dim = tuple(parsed_dim_list) + elif dim is ...: + parsed_dim = tuple(obj.dims) + else: + parsed_dim = tuple(dim) + + return parsed_dim + def _flox_reduce( self, dim: Dims, @@ -1088,22 +1108,7 @@ def _flox_reduce( # set explicitly to avoid unnecessarily accumulating count kwargs["min_count"] = 0 - parsed_dim: tuple[Hashable, ...] - if isinstance(dim, str): - parsed_dim = (dim,) - elif dim is None: - parsed_dim_list = list() - # preserve order - for dim_ in itertools.chain( - *(grouper.codes.dims for grouper in self.groupers) - ): - if dim_ not in parsed_dim_list: - parsed_dim_list.append(dim_) - parsed_dim = tuple(parsed_dim_list) - elif dim is ...: - parsed_dim = tuple(obj.dims) - else: - parsed_dim = tuple(dim) + parsed_dim = self._parse_dim(dim) # Do this so we raise the same error message whether flox is present or not. # Better to control it here than in flox. @@ -1202,6 +1207,57 @@ def _flox_reduce( return result + def _flox_scan( + self, + dim: Dims, + *, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> DataArray: + + from flox import groupby_scan + + # def groupby_scan( + # array: np.ndarray | DaskArray, + # *by: T_By, + # func: T_Scan, + # expected_groups: T_ExpectedGroupsOpt = None, + # axis: int | tuple[int] = -1, + # dtype: np.typing.DTypeLike = None, + # method: T_MethodOpt = None, + # engine: T_EngineOpt = None, + # ) -> np.ndarray | DaskArray: + + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=True) + + parsed_dim = self._parse_dim(dim) + obj = self._original_obj + codes = tuple(g.codes for g in self.groupers) + a = 2 + g = groupby_scan( + obj.data, + *codes, + func=kwargs["func"], + expected_groups=None, + axis=obj.get_axis_num(parsed_dim), + dtype=None, + method=None, + engine=None, + ) + + return obj.copy(data=g) + + # xarray_reduce( + # obj.drop_vars(non_numeric.keys()), + # *codes, + # dim=parsed_dim, + # expected_groups=expected_groups, + # isbin=False, + # keep_attrs=keep_attrs, + # **kwargs, + # ) + def fillna(self, value: Any) -> T_Xarray: """Fill missing values in this object by group. From ae276323902ca46ccb16c35b87d941c3ad57df0e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 6 Dec 2025 13:44:51 +0000 Subject: [PATCH 02/35] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/core/groupby.py | 1 - 1 file changed, 1 deletion(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 4f80cdba3b7..1b6a726adff 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -1214,7 +1214,6 @@ def _flox_scan( keep_attrs: bool | None = None, **kwargs: Any, ) -> DataArray: - from flox import groupby_scan # def groupby_scan( From a5f93265db48cee749212af73198bfe1d2f5f1b5 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 6 Dec 2025 14:47:26 +0100 Subject: [PATCH 03/35] Update groupby.py --- xarray/core/groupby.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 1b6a726adff..45b0e7906bf 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -1042,7 +1042,7 @@ def _parse_dim(self, dim: Dims) -> tuple[Hashable, ...]: parsed_dim_list.append(dim_) parsed_dim = tuple(parsed_dim_list) elif dim is ...: - parsed_dim = tuple(obj.dims) + parsed_dim = tuple(self._original_obj.dims) else: parsed_dim = tuple(dim) From 50ccca4ad3b415ab9a0e749b6ae95da8b5c3d04f Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 6 Dec 2025 14:48:52 +0100 Subject: [PATCH 04/35] Update groupby.py --- xarray/core/groupby.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 45b0e7906bf..9daff837dcf 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -252,9 +252,7 @@ def to_array(self) -> DataArray: T_Group = Union["T_DataArray", _DummyGroup] -def _ensure_1d( - group: T_Group, obj: T_DataWithCoords -) -> tuple[ +def _ensure_1d(group: T_Group, obj: T_DataWithCoords) -> tuple[ T_Group, T_DataWithCoords, Hashable | None, @@ -1233,7 +1231,7 @@ def _flox_scan( parsed_dim = self._parse_dim(dim) obj = self._original_obj codes = tuple(g.codes for g in self.groupers) - a = 2 + g = groupby_scan( obj.data, *codes, From f55531ef0ac4de3a523623f44366ed2eff470a1f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 6 Dec 2025 13:49:56 +0000 Subject: [PATCH 05/35] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/core/groupby.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 9daff837dcf..806ca838ec3 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -252,7 +252,9 @@ def to_array(self) -> DataArray: T_Group = Union["T_DataArray", _DummyGroup] -def _ensure_1d(group: T_Group, obj: T_DataWithCoords) -> tuple[ +def _ensure_1d( + group: T_Group, obj: T_DataWithCoords +) -> tuple[ T_Group, T_DataWithCoords, Hashable | None, From 06ac3724cd709f73ea497746a2a0931b41205b8e Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 6 Dec 2025 15:00:49 +0100 Subject: [PATCH 06/35] Update groupby.py --- xarray/core/groupby.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 9daff837dcf..c9aefcc1cde 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -1209,41 +1209,40 @@ def _flox_scan( self, dim: Dims, *, + func: str, keep_attrs: bool | None = None, + skipna: bool | None = None **kwargs: Any, ) -> DataArray: from flox import groupby_scan - # def groupby_scan( - # array: np.ndarray | DaskArray, - # *by: T_By, - # func: T_Scan, - # expected_groups: T_ExpectedGroupsOpt = None, - # axis: int | tuple[int] = -1, - # dtype: np.typing.DTypeLike = None, - # method: T_MethodOpt = None, - # engine: T_EngineOpt = None, - # ) -> np.ndarray | DaskArray: + obj = self._original_obj + + if skipna or ( + skipna is None and isinstance(func, str) and obj.dtype.kind in "cfO" + ): + if "nan" not in func and func not in ["all", "any", "count"]: + func = f"nan{func}" if keep_attrs is None: keep_attrs = _get_keep_attrs(default=True) parsed_dim = self._parse_dim(dim) - obj = self._original_obj codes = tuple(g.codes for g in self.groupers) g = groupby_scan( obj.data, *codes, - func=kwargs["func"], + func=func, expected_groups=None, axis=obj.get_axis_num(parsed_dim), dtype=None, method=None, engine=None, ) + result = obj.copy(data=g) - return obj.copy(data=g) + return result # xarray_reduce( # obj.drop_vars(non_numeric.keys()), From dd475368ed5b5cb5d89f28ba418b4729bf88f66c Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 6 Dec 2025 15:02:30 +0100 Subject: [PATCH 07/35] Update groupby.py --- xarray/core/groupby.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index cad7f210e6e..c38a1b592d7 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -252,9 +252,7 @@ def to_array(self) -> DataArray: T_Group = Union["T_DataArray", _DummyGroup] -def _ensure_1d( - group: T_Group, obj: T_DataWithCoords -) -> tuple[ +def _ensure_1d(group: T_Group, obj: T_DataWithCoords) -> tuple[ T_Group, T_DataWithCoords, Hashable | None, @@ -1213,7 +1211,7 @@ def _flox_scan( *, func: str, keep_attrs: bool | None = None, - skipna: bool | None = None + skipna: bool | None = None, **kwargs: Any, ) -> DataArray: from flox import groupby_scan From e867f12c880a156bbb740a2424dc21a1046476e6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 6 Dec 2025 14:03:01 +0000 Subject: [PATCH 08/35] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/core/groupby.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index c38a1b592d7..3a8f5dfdb28 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -252,7 +252,9 @@ def to_array(self) -> DataArray: T_Group = Union["T_DataArray", _DummyGroup] -def _ensure_1d(group: T_Group, obj: T_DataWithCoords) -> tuple[ +def _ensure_1d( + group: T_Group, obj: T_DataWithCoords +) -> tuple[ T_Group, T_DataWithCoords, Hashable | None, From 88e0ebc31f1da827c094702fd0e337746cc75263 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 6 Dec 2025 15:20:12 +0100 Subject: [PATCH 09/35] Update groupby.py --- xarray/core/groupby.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 3a8f5dfdb28..ec67ad4cb13 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -252,9 +252,7 @@ def to_array(self) -> DataArray: T_Group = Union["T_DataArray", _DummyGroup] -def _ensure_1d( - group: T_Group, obj: T_DataWithCoords -) -> tuple[ +def _ensure_1d(group: T_Group, obj: T_DataWithCoords) -> tuple[ T_Group, T_DataWithCoords, Hashable | None, @@ -1230,14 +1228,16 @@ def _flox_scan( keep_attrs = _get_keep_attrs(default=True) parsed_dim = self._parse_dim(dim) - codes = tuple(g.codes for g in self.groupers) + axis_ = obj.get_axis_num(parsed_dim) + axis = (axis_,) if isinstance(axis_, int) else axis_ + codes = tuple(g.codes for g in self.groupers) g = groupby_scan( obj.data, *codes, func=func, expected_groups=None, - axis=obj.get_axis_num(parsed_dim), + axis=axis, dtype=None, method=None, engine=None, From 181d4a38f51e411e86ffff32561d5bef9fc99ccf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 6 Dec 2025 14:20:37 +0000 Subject: [PATCH 10/35] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/core/groupby.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index ec67ad4cb13..a4fb8a6359e 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -252,7 +252,9 @@ def to_array(self) -> DataArray: T_Group = Union["T_DataArray", _DummyGroup] -def _ensure_1d(group: T_Group, obj: T_DataWithCoords) -> tuple[ +def _ensure_1d( + group: T_Group, obj: T_DataWithCoords +) -> tuple[ T_Group, T_DataWithCoords, Hashable | None, From a82ec398c8bec3c17b07fadcc3ae407bc8f27168 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 6 Dec 2025 15:44:38 +0100 Subject: [PATCH 11/35] use apply_ufunc for dataset and dataarray handling --- xarray/core/_aggregations.py | 8 ------ xarray/core/groupby.py | 53 ++++++++++++++++++++++++++++-------- 2 files changed, 41 insertions(+), 20 deletions(-) diff --git a/xarray/core/_aggregations.py b/xarray/core/_aggregations.py index 75e11d41e1c..aa4bec25a11 100644 --- a/xarray/core/_aggregations.py +++ b/xarray/core/_aggregations.py @@ -7911,14 +7911,6 @@ def cumsum( * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 labels (time) Date: Sat, 6 Dec 2025 14:45:22 +0000 Subject: [PATCH 12/35] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/core/groupby.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index cf34ef7feac..603c52fa4c9 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -13,11 +13,11 @@ from packaging.version import Version from xarray.computation import ops +from xarray.computation.apply_ufunc import apply_ufunc from xarray.computation.arithmetic import ( DataArrayGroupbyArithmetic, DatasetGroupbyArithmetic, ) -from xarray.computation.apply_ufunc import apply_ufunc from xarray.core import dtypes, duck_array_ops, nputils from xarray.core._aggregations import ( DataArrayGroupByAggregations, From d8d0eaa2c23e7f7db54e67ecb49093b450a40fd9 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 6 Dec 2025 16:21:11 +0100 Subject: [PATCH 13/35] Update groupby.py --- xarray/core/groupby.py | 48 +++++++++++------------------------------- 1 file changed, 12 insertions(+), 36 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index cf34ef7feac..345845da163 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -253,9 +253,7 @@ def to_array(self) -> DataArray: T_Group = Union["T_DataArray", _DummyGroup] -def _ensure_1d( - group: T_Group, obj: T_DataWithCoords -) -> tuple[ +def _ensure_1d(group: T_Group, obj: T_DataWithCoords) -> tuple[ T_Group, T_DataWithCoords, Hashable | None, @@ -1221,36 +1219,23 @@ def _flox_scan( obj = self._original_obj - if skipna or ( - skipna is None and isinstance(func, str) and obj.dtype.kind in "cfO" - ): - if "nan" not in func and func not in ["all", "any", "count"]: - func = f"nan{func}" - - # if keep_attrs is None: - # keep_attrs = _get_keep_attrs(default=True) - parsed_dim = self._parse_dim(dim) axis_ = obj.get_axis_num(parsed_dim) axis = (axis_,) if isinstance(axis_, int) else axis_ codes = tuple(g.codes for g in self.groupers) - # g = groupby_scan( - # obj.data, - # *codes, - # func=func, - # expected_groups=None, - # axis=axis, - # dtype=None, - # method=None, - # engine=None, - # ) - # result = obj.copy(data=g) - - # return result + + def wrapper(array, *by, func: str, skipna: bool | None, **kwargs): + if skipna or ( + skipna is None and isinstance(func, str) and obj.dtype.kind in "cfO" + ): + if "nan" not in func: + func = f"nan{func}" + + return groupby_scan(array, *codes, func=func, **kwargs) actual = apply_ufunc( - groupby_scan, + wrapper, obj, *codes, # input_core_dims=input_core_dims, @@ -1267,6 +1252,7 @@ def _flox_scan( ), kwargs=dict( func=func, + skipna=skipna, expected_groups=None, axis=axis, dtype=None, @@ -1277,16 +1263,6 @@ def _flox_scan( return actual - # xarray_reduce( - # obj.drop_vars(non_numeric.keys()), - # *codes, - # dim=parsed_dim, - # expected_groups=expected_groups, - # isbin=False, - # keep_attrs=keep_attrs, - # **kwargs, - # ) - def fillna(self, value: Any) -> T_Xarray: """Fill missing values in this object by group. From 33d136079021896ce9ec48818475e95750424faa Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 6 Dec 2025 15:21:45 +0000 Subject: [PATCH 14/35] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/core/groupby.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index a56e37e63f9..755c06d0e9f 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -253,7 +253,9 @@ def to_array(self) -> DataArray: T_Group = Union["T_DataArray", _DummyGroup] -def _ensure_1d(group: T_Group, obj: T_DataWithCoords) -> tuple[ +def _ensure_1d( + group: T_Group, obj: T_DataWithCoords +) -> tuple[ T_Group, T_DataWithCoords, Hashable | None, From c97ae98bb4e0443cfa4063d80b2e673952e713c8 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 6 Dec 2025 16:36:09 +0100 Subject: [PATCH 15/35] sync protocols with each other --- xarray/core/_aggregations.py | 4 ++++ xarray/core/groupby.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/xarray/core/_aggregations.py b/xarray/core/_aggregations.py index aa4bec25a11..0f4c4a9dd49 100644 --- a/xarray/core/_aggregations.py +++ b/xarray/core/_aggregations.py @@ -6650,6 +6650,10 @@ def _flox_reduce( def _flox_scan( self, dim: Dims, + *, + func: str, + skipna: bool | None = None, + keep_attrs: bool | None = None, **kwargs: Any, ) -> DataArray: raise NotImplementedError() diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index a56e37e63f9..85bfb957e9d 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -1211,8 +1211,8 @@ def _flox_scan( dim: Dims, *, func: str, - keep_attrs: bool | None = None, skipna: bool | None = None, + keep_attrs: bool | None = None, **kwargs: Any, ) -> DataArray: from flox import groupby_scan From 84f9b4430addb927b32573c088ddb25311fb1844 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 6 Dec 2025 16:46:01 +0100 Subject: [PATCH 16/35] typing --- xarray/core/groupby.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 4aefe5fb8f0..8576cefe1ca 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -253,9 +253,7 @@ def to_array(self) -> DataArray: T_Group = Union["T_DataArray", _DummyGroup] -def _ensure_1d( - group: T_Group, obj: T_DataWithCoords -) -> tuple[ +def _ensure_1d(group: T_Group, obj: T_DataWithCoords) -> tuple[ T_Group, T_DataWithCoords, Hashable | None, @@ -1223,14 +1221,12 @@ def _flox_scan( parsed_dim = self._parse_dim(dim) - axis_ = obj.get_axis_num(parsed_dim) - axis = (axis_,) if isinstance(axis_, int) else axis_ + axis = obj.get_axis_num(parsed_dim) + # axis = (axis_,) if isinstance(axis_, int) else axis_ codes = tuple(g.codes for g in self.groupers) def wrapper(array, *by, func: str, skipna: bool | None, **kwargs): - if skipna or ( - skipna is None and isinstance(func, str) and obj.dtype.kind in "cfO" - ): + if skipna or (skipna is None and obj.dtype.kind in "cfO"): if "nan" not in func: func = f"nan{func}" From 297887709d81fd70cb7d76ccb45456012ea79406 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 6 Dec 2025 15:46:59 +0000 Subject: [PATCH 17/35] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/core/groupby.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 8576cefe1ca..b82c60dca81 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -253,7 +253,9 @@ def to_array(self) -> DataArray: T_Group = Union["T_DataArray", _DummyGroup] -def _ensure_1d(group: T_Group, obj: T_DataWithCoords) -> tuple[ +def _ensure_1d( + group: T_Group, obj: T_DataWithCoords +) -> tuple[ T_Group, T_DataWithCoords, Hashable | None, From 0a9adee78419c3853ef19436d69cf56a21130487 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 6 Dec 2025 19:07:45 +0100 Subject: [PATCH 18/35] add dataset and version requirement --- xarray/core/_aggregations.py | 42 ++++++++++++++++++++++------ xarray/util/generate_aggregations.py | 21 ++++++++++++-- 2 files changed, 52 insertions(+), 11 deletions(-) diff --git a/xarray/core/_aggregations.py b/xarray/core/_aggregations.py index 0f4c4a9dd49..13ee955b3fd 100644 --- a/xarray/core/_aggregations.py +++ b/xarray/core/_aggregations.py @@ -3655,6 +3655,17 @@ def _flox_reduce( ) -> Dataset: raise NotImplementedError() + def _flox_scan( + self, + dim: Dims, + *, + func: str, + skipna: bool | None = None, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> DataArray: + raise NotImplementedError() + def count( self, dim: Dims = None, @@ -5015,14 +5026,28 @@ def cumsum( Data variables: da (time) float64 48B 1.0 2.0 3.0 3.0 4.0 nan """ - return self.reduce( - duck_array_ops.cumsum, - dim=dim, - skipna=skipna, - numeric_only=True, - keep_attrs=keep_attrs, - **kwargs, - ) + if ( + flox_available + and OPTIONS["use_flox"] + and module_available("flox", minversion="0.10.5") + and contains_only_chunked_or_numpy(self._obj) + ): + return self._flox_scan( + func="cumsum", + dim=dim, + skipna=skipna, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.cumsum, + dim=dim, + skipna=skipna, + keep_attrs=keep_attrs, + **kwargs, + ) def cumprod( self, @@ -7918,6 +7943,7 @@ def cumsum( if ( flox_available and OPTIONS["use_flox"] + and module_available("flox", minversion="0.10.5") and contains_only_chunked_or_numpy(self._obj) ): return self._flox_scan( diff --git a/xarray/util/generate_aggregations.py b/xarray/util/generate_aggregations.py index e386b96f63d..8c6bc34cfb8 100644 --- a/xarray/util/generate_aggregations.py +++ b/xarray/util/generate_aggregations.py @@ -15,7 +15,7 @@ import textwrap from dataclasses import dataclass, field -from typing import NamedTuple +from typing import NamedTuple, Literal MODULE_PREAMBLE = '''\ """Mixin classes with reduction operations.""" @@ -132,6 +132,17 @@ def _flox_reduce( dim: Dims, **kwargs: Any, ) -> {obj}: + raise NotImplementedError() + + def _flox_scan( + self, + dim: Dims, + *, + func: str, + skipna: bool | None = None, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> DataArray: raise NotImplementedError()""" TEMPLATE_REDUCTION_SIGNATURE = ''' @@ -284,6 +295,7 @@ def __init__( see_also_methods=(), min_flox_version=None, additional_notes="", + flox_aggregation_type: Literal["reduce", "scan"] = "reduce", ): self.name = name self.extra_kwargs = extra_kwargs @@ -292,6 +304,7 @@ def __init__( self.see_also_methods = see_also_methods self.min_flox_version = min_flox_version self.additional_notes = additional_notes + self.flox_aggregation_type = flox_aggregation_type if bool_reduce: self.array_method = f"array_{name}" self.np_example_array = ( @@ -444,7 +457,7 @@ def generate_code(self, method, has_keep_attrs): # median isn't enabled yet, because it would break if a single group was present in multiple # chunks. The non-flox code path will just rechunk every group to a single chunk and execute the median - method_is_not_flox_supported = method.name in ("median", "cumsum", "cumprod") + method_is_not_flox_supported = method.name in ("median", "cumprod") if method_is_not_flox_supported: indent = 12 else: @@ -476,7 +489,7 @@ def generate_code(self, method, has_keep_attrs): + f""" and contains_only_chunked_or_numpy(self._obj) ): - return self._flox_reduce( + return self._flox_{method.flox_aggregation_type}( func="{method.name}", dim=dim,{extra_kwargs} # fill_value=fill_value, @@ -537,6 +550,8 @@ def generate_code(self, method, has_keep_attrs): numeric_only=True, see_also_methods=("cumulative",), additional_notes=_CUM_NOTES, + min_flox_version="0.10.5", + flox_aggregation_type="scan", ), Method( "cumprod", From c056d1f8b0ead31241e2f7094073891ce7e41e93 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 6 Dec 2025 18:08:23 +0000 Subject: [PATCH 19/35] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/util/generate_aggregations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/util/generate_aggregations.py b/xarray/util/generate_aggregations.py index 8c6bc34cfb8..efe55a5eae5 100644 --- a/xarray/util/generate_aggregations.py +++ b/xarray/util/generate_aggregations.py @@ -15,7 +15,7 @@ import textwrap from dataclasses import dataclass, field -from typing import NamedTuple, Literal +from typing import Literal, NamedTuple MODULE_PREAMBLE = '''\ """Mixin classes with reduction operations.""" From d4873b992c1561b967dbcde1d09584095293044e Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 6 Dec 2025 19:37:25 +0100 Subject: [PATCH 20/35] Update _aggregations.py --- xarray/core/_aggregations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/_aggregations.py b/xarray/core/_aggregations.py index 13ee955b3fd..e7fb84cde01 100644 --- a/xarray/core/_aggregations.py +++ b/xarray/core/_aggregations.py @@ -3663,7 +3663,7 @@ def _flox_scan( skipna: bool | None = None, keep_attrs: bool | None = None, **kwargs: Any, - ) -> DataArray: + ) -> Dataset: raise NotImplementedError() def count( From 21cbde201ab6edd4596b96318a2075770803023d Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 6 Dec 2025 21:11:06 +0100 Subject: [PATCH 21/35] Update xarray/core/groupby.py Co-authored-by: Deepak Cherian --- xarray/core/groupby.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index b82c60dca81..8a44e271bb6 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -1223,7 +1223,7 @@ def _flox_scan( parsed_dim = self._parse_dim(dim) - axis = obj.get_axis_num(parsed_dim) + axis = obj.transpose(..., *parsed_dim) # axis = (axis_,) if isinstance(axis_, int) else axis_ codes = tuple(g.codes for g in self.groupers) From 4aebc4739cef22d009721cd0fe733bde4eaddf6a Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 6 Dec 2025 21:14:51 +0100 Subject: [PATCH 22/35] Update groupby.py --- xarray/core/groupby.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 8a44e271bb6..92e8d728775 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -1228,7 +1228,7 @@ def _flox_scan( codes = tuple(g.codes for g in self.groupers) def wrapper(array, *by, func: str, skipna: bool | None, **kwargs): - if skipna or (skipna is None and obj.dtype.kind in "cfO"): + if skipna or (skipna is None and array.dtype.kind in "cfO"): if "nan" not in func: func = f"nan{func}" From f4cab24f4f132101b6407d6981cc9efed4a17c21 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 6 Dec 2025 21:23:32 +0100 Subject: [PATCH 23/35] Update groupby.py --- xarray/core/groupby.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 92e8d728775..a6524125acc 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -1216,7 +1216,7 @@ def _flox_scan( skipna: bool | None = None, keep_attrs: bool | None = None, **kwargs: Any, - ) -> DataArray: + ) -> T_Xarray: from flox import groupby_scan obj = self._original_obj From 23d9d50cd79e9211ab50fdfbf390f32308938799 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 6 Dec 2025 22:11:25 +0100 Subject: [PATCH 24/35] Update groupby.py --- xarray/core/groupby.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index a6524125acc..f58a69d88c7 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -1220,10 +1220,11 @@ def _flox_scan( from flox import groupby_scan obj = self._original_obj - parsed_dim = self._parse_dim(dim) - axis = obj.transpose(..., *parsed_dim) + obj = obj.transpose(..., *parsed_dim) + axis = range(-len(parsed_dim), 0) + # axis = (axis_,) if isinstance(axis_, int) else axis_ codes = tuple(g.codes for g in self.groupers) From 9b64db2cf3ba19d0578e68e0009d651f6edfc1de Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 6 Dec 2025 22:11:31 +0100 Subject: [PATCH 25/35] Update generate_aggregations.py --- xarray/util/generate_aggregations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/util/generate_aggregations.py b/xarray/util/generate_aggregations.py index efe55a5eae5..a6b6d1fb25a 100644 --- a/xarray/util/generate_aggregations.py +++ b/xarray/util/generate_aggregations.py @@ -142,7 +142,7 @@ def _flox_scan( skipna: bool | None = None, keep_attrs: bool | None = None, **kwargs: Any, - ) -> DataArray: + ) -> {obj}: raise NotImplementedError()""" TEMPLATE_REDUCTION_SIGNATURE = ''' From 928b158e181a3f7c973c104e7faca169cc00a3db Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 7 Dec 2025 01:17:04 +0100 Subject: [PATCH 26/35] Renove workaround in test --- xarray/tests/test_groupby.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index e96d8b6828b..f2dbc0a4ce5 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -2568,7 +2568,8 @@ def test_groupby_cumsum() -> None: ) # TODO: Remove drop_vars when GH6528 is fixed # when Dataset.cumsum propagates indexes, and the group variable? - assert_identical(expected.drop_vars(["x", "group_id"]), actual) + # assert_identical(expected.drop_vars(["x", "group_id"]), actual) + assert_identical(expected, actual) actual = ds.foo.groupby("group_id").cumsum(dim="x") expected.coords["group_id"] = ds.group_id From 130f98e9e7657405216db232e0519e557e535e82 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 7 Dec 2025 12:40:30 +0100 Subject: [PATCH 27/35] Update _aggregations.py --- xarray/core/_aggregations.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/xarray/core/_aggregations.py b/xarray/core/_aggregations.py index e7fb84cde01..28cfc2897bb 100644 --- a/xarray/core/_aggregations.py +++ b/xarray/core/_aggregations.py @@ -5011,9 +5011,11 @@ def cumsum( da (time) float64 48B 1.0 2.0 3.0 0.0 2.0 nan >>> ds.groupby("labels").cumsum() - Size: 48B + Size: 120B Dimensions: (time: 6) - Dimensions without coordinates: time + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) Date: Sun, 7 Dec 2025 12:43:04 +0100 Subject: [PATCH 28/35] Update _aggregations.py --- xarray/core/_aggregations.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/xarray/core/_aggregations.py b/xarray/core/_aggregations.py index 28cfc2897bb..ad742366077 100644 --- a/xarray/core/_aggregations.py +++ b/xarray/core/_aggregations.py @@ -5022,9 +5022,11 @@ def cumsum( Use ``skipna`` to control whether NaNs are ignored. >>> ds.groupby("labels").cumsum(skipna=False) - Size: 48B + Size: 120B Dimensions: (time: 6) - Dimensions without coordinates: time + Coordinates: + * time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) Date: Sun, 7 Dec 2025 12:56:57 +0100 Subject: [PATCH 29/35] Update test_groupby.py --- xarray/tests/test_groupby.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index f2dbc0a4ce5..d79eca3925e 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -16,6 +16,8 @@ from xarray import DataArray, Dataset, Variable, date_range from xarray.core.groupby import _consolidate_slices from xarray.core.types import InterpOptions, ResampleCompatible +from xarray.core.utils import module_available + from xarray.groupers import ( BinGrouper, EncodedGroups, @@ -2566,10 +2568,13 @@ def test_groupby_cumsum() -> None: "group_id": ds.group_id, }, ) - # TODO: Remove drop_vars when GH6528 is fixed - # when Dataset.cumsum propagates indexes, and the group variable? - # assert_identical(expected.drop_vars(["x", "group_id"]), actual) - assert_identical(expected, actual) + + if xr.get_options()["use_flox"] and module_available("flox", minversion="0.10.5"): + assert_identical(expected, actual) + else: + # TODO: Remove drop_vars when GH6528 is fixed + # when Dataset.cumsum propagates indexes, and the group variable? + assert_identical(expected.drop_vars(["x", "group_id"]), actual) actual = ds.foo.groupby("group_id").cumsum(dim="x") expected.coords["group_id"] = ds.group_id From 3bc8dc7b2a69d94d89f5231f10dbcce54b56aa05 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 7 Dec 2025 11:59:03 +0000 Subject: [PATCH 30/35] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/tests/test_groupby.py | 1 - 1 file changed, 1 deletion(-) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index d79eca3925e..89abaa38568 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -17,7 +17,6 @@ from xarray.core.groupby import _consolidate_slices from xarray.core.types import InterpOptions, ResampleCompatible from xarray.core.utils import module_available - from xarray.groupers import ( BinGrouper, EncodedGroups, From ec8ffd6fd9182b721df9febe65e5219d604122c9 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 7 Dec 2025 14:01:37 +0100 Subject: [PATCH 31/35] clean ups --- xarray/core/groupby.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index f58a69d88c7..c70da2a16ea 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -1219,13 +1219,9 @@ def _flox_scan( ) -> T_Xarray: from flox import groupby_scan - obj = self._original_obj parsed_dim = self._parse_dim(dim) - - obj = obj.transpose(..., *parsed_dim) + obj = self._original_obj.transpose(..., *parsed_dim) axis = range(-len(parsed_dim), 0) - - # axis = (axis_,) if isinstance(axis_, int) else axis_ codes = tuple(g.codes for g in self.groupers) def wrapper(array, *by, func: str, skipna: bool | None, **kwargs): From 07a4d351ee558727d1574a208fae4414fba36f61 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 9 Dec 2025 00:24:47 +0100 Subject: [PATCH 32/35] Add expected groups, add options --- xarray/core/groupby.py | 36 ++++++++++++++++-------------------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index c70da2a16ea..839800d1fda 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -1047,6 +1047,13 @@ def _parse_dim(self, dim: Dims) -> tuple[Hashable, ...]: else: parsed_dim = tuple(dim) + d not in grouper.codes.dims and d not in self._original_obj.dims + for d in parsed_dim + ): + # TODO: Not a helpful error, it's a sanity check that dim actually exist + # either in self.groupers or self._original_obj + raise ValueError(f"cannot reduce over dimensions {dim}.") + return parsed_dim def _flox_reduce( @@ -1111,14 +1118,6 @@ def _flox_reduce( parsed_dim = self._parse_dim(dim) - # Do this so we raise the same error message whether flox is present or not. - # Better to control it here than in flox. - for grouper in self.groupers: - if any( - d not in grouper.codes.dims and d not in obj.dims for d in parsed_dim - ): - raise ValueError(f"cannot reduce over dimensions {dim}.") - has_missing_groups = ( self.encoded.unique_coord.size != self.encoded.full_index.size ) @@ -1224,6 +1223,11 @@ def _flox_scan( axis = range(-len(parsed_dim), 0) codes = tuple(g.codes for g in self.groupers) + # pass RangeIndex as a hint to flox that `by` is already factorized + expected_groups = tuple( + pd.RangeIndex(len(grouper)) for grouper in self.groupers + ) + def wrapper(array, *by, func: str, skipna: bool | None, **kwargs): if skipna or (skipna is None and array.dtype.kind in "cfO"): if "nan" not in func: @@ -1235,26 +1239,18 @@ def wrapper(array, *by, func: str, skipna: bool | None, **kwargs): wrapper, obj, *codes, - # input_core_dims=input_core_dims, - # for xarray's test_groupby_duplicate_coordinate_labels - # exclude_dims=set(dim_tuple), - # output_core_dims=[output_core_dims], dask="allowed", - # dask_gufunc_kwargs=dict( - # output_sizes=output_sizes, - # output_dtypes=[dtype] if dtype is not None else None, - # ), keep_attrs=( _get_keep_attrs(default=True) if keep_attrs is None else keep_attrs ), kwargs=dict( func=func, skipna=skipna, - expected_groups=None, + expected_groups=expected_groups, axis=axis, - dtype=None, - method=None, - engine=None, + dtype=kwargs.get("dtype", None), + method=kwargs.get("method", None), + engine=kwargs.get("engine", None), ), ) From d0f7ed212e3c98e166fb33c3a531c5bd29ba1b98 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 9 Dec 2025 00:25:06 +0100 Subject: [PATCH 33/35] Update groupby.py --- xarray/core/groupby.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 839800d1fda..bdfb6eb3c0e 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -1047,6 +1047,10 @@ def _parse_dim(self, dim: Dims) -> tuple[Hashable, ...]: else: parsed_dim = tuple(dim) + # Do this so we raise the same error message whether flox is present or not. + # Better to control it here than in flox. + for grouper in self.groupers: + if any( d not in grouper.codes.dims and d not in self._original_obj.dims for d in parsed_dim ): From 098be30fd47cd5e691ef7d0a4119a959408e3ec2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 8 Dec 2025 23:25:38 +0000 Subject: [PATCH 34/35] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/core/groupby.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index bdfb6eb3c0e..47841d973f5 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -1252,9 +1252,9 @@ def wrapper(array, *by, func: str, skipna: bool | None, **kwargs): skipna=skipna, expected_groups=expected_groups, axis=axis, - dtype=kwargs.get("dtype", None), - method=kwargs.get("method", None), - engine=kwargs.get("engine", None), + dtype=kwargs.get("dtype"), + method=kwargs.get("method"), + engine=kwargs.get("engine"), ), ) From 87d5f77a4d231f52f313891427a21df5db0099c8 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 9 Dec 2025 00:34:27 +0100 Subject: [PATCH 35/35] expeced_groups not supported in groupby_scan --- xarray/core/groupby.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index bdfb6eb3c0e..0d9ffb88f50 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -1227,11 +1227,6 @@ def _flox_scan( axis = range(-len(parsed_dim), 0) codes = tuple(g.codes for g in self.groupers) - # pass RangeIndex as a hint to flox that `by` is already factorized - expected_groups = tuple( - pd.RangeIndex(len(grouper)) for grouper in self.groupers - ) - def wrapper(array, *by, func: str, skipna: bool | None, **kwargs): if skipna or (skipna is None and array.dtype.kind in "cfO"): if "nan" not in func: @@ -1250,7 +1245,7 @@ def wrapper(array, *by, func: str, skipna: bool | None, **kwargs): kwargs=dict( func=func, skipna=skipna, - expected_groups=expected_groups, + expected_groups=None, # TODO: Should be same as _flox_reduce? axis=axis, dtype=kwargs.get("dtype", None), method=kwargs.get("method", None),