Skip to content

Commit 53bad2d

Browse files
committed
Add an arithmetic_compat option to xr.set_options, which determines how non-index coordinates of the same name are compared for potential conflicts when performing binary operations.
Previously the behaviour has been fixed at compat='minimal', this allows any of the compat options to be set. For now the default is still 'minimal', but this sets up arithmetic_compat to be migrated to 'override' alongside the other defaults for compat, as part of use_new_combine_kwarg_defaults. It warns about behaviour which will change under this migration. A couple of tests which were relying on or testing the arithmetic_compat='minimal' behaviour have been updated to avoid the warning, either by setting arithmetic_compat explicitly or by avoiding a coordinate clash.
1 parent ad92a16 commit 53bad2d

File tree

10 files changed

+179
-44
lines changed

10 files changed

+179
-44
lines changed

doc/whats-new.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ New Features
1717
- :py:func:`combine_nested` now support :py:class:`DataTree` objects
1818
(:pull:`10849`).
1919
By `Stephan Hoyer <https://github.com/shoyer>`_.
20+
- :py:func:`set_options` now supports an `arithmetic_compat` option which determines how non-index coordinates
21+
of the same name are compared for potential conflicts when performing binary operations. The default for it is
22+
`arithmetic_compat='minimal'` which matches the existing behaviour, but it is slated to change to 'override'
23+
in future alongside the other defaults for `compat`, see below.
2024

2125
Breaking Changes
2226
~~~~~~~~~~~~~~~~
@@ -25,6 +29,10 @@ Breaking Changes
2529
Deprecations
2630
~~~~~~~~~~~~
2731

32+
- The default for the `arithmetic_compat` option will change from 'minimal' to 'override' in a future version,
33+
consistent with the existing deprecation cycle which will change the default for `compat` to 'override' elsewhere.
34+
We log a warning on behaviour which will change when this default changes.
35+
2836

2937
Bug Fixes
3038
~~~~~~~~~

xarray/core/coordinates.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
assert_no_index_corrupted,
2222
create_default_index_implicit,
2323
)
24-
from xarray.core.types import DataVars, ErrorOptions, Self, T_DataArray, T_Xarray
24+
from xarray.core.types import CompatOptions, DataVars, ErrorOptions, Self, T_DataArray, T_Xarray
2525
from xarray.core.utils import (
2626
Frozen,
2727
ReprObject,
@@ -31,6 +31,7 @@
3131
from xarray.core.variable import Variable, as_variable, calculate_dimensions
3232
from xarray.structure.alignment import Aligner
3333
from xarray.structure.merge import merge_coordinates_without_align, merge_coords
34+
from xarray.util.deprecation_helpers import CombineKwargDefault
3435

3536
if TYPE_CHECKING:
3637
from xarray.core.common import DataWithCoords
@@ -499,18 +500,18 @@ def _drop_coords(self, coord_names):
499500
# redirect to DatasetCoordinates._drop_coords
500501
self._data.coords._drop_coords(coord_names)
501502

502-
def _merge_raw(self, other, reflexive):
503+
def _merge_raw(self, other, reflexive, compat: CompatOptions | CombineKwargDefault):
503504
"""For use with binary arithmetic."""
504505
if other is None:
505506
variables = dict(self.variables)
506507
indexes = dict(self.xindexes)
507508
else:
508509
coord_list = [self, other] if not reflexive else [other, self]
509-
variables, indexes = merge_coordinates_without_align(coord_list)
510+
variables, indexes = merge_coordinates_without_align(coord_list, compat=compat)
510511
return variables, indexes
511512

512513
@contextmanager
513-
def _merge_inplace(self, other):
514+
def _merge_inplace(self, other, compat: CompatOptions | CombineKwargDefault):
514515
"""For use with in-place binary arithmetic."""
515516
if other is None:
516517
yield
@@ -523,12 +524,16 @@ def _merge_inplace(self, other):
523524
if k not in self.xindexes
524525
}
525526
variables, indexes = merge_coordinates_without_align(
526-
[self, other], prioritized
527+
[self, other], prioritized, compat=compat
527528
)
528529
yield
529530
self._update_coords(variables, indexes)
530531

531-
def merge(self, other: Mapping[Any, Any] | None) -> Dataset:
532+
def merge(
533+
self,
534+
other: Mapping[Any, Any] | None,
535+
compat: CompatOptions | CombineKwargDefault = "minimal",
536+
) -> Dataset:
532537
"""Merge two sets of coordinates to create a new Dataset
533538
534539
The method implements the logic used for joining coordinates in the
@@ -545,6 +550,9 @@ def merge(self, other: Mapping[Any, Any] | None) -> Dataset:
545550
other : dict-like, optional
546551
A :py:class:`Coordinates` object or any mapping that can be turned
547552
into coordinates.
553+
compat : {"identical", "equals", "broadcast_equals", "no_conflicts",
554+
"override", "minimal"}, default: "minimal"
555+
Compatibility checks to use between coordinate variables.
548556
549557
Returns
550558
-------
@@ -559,7 +567,7 @@ def merge(self, other: Mapping[Any, Any] | None) -> Dataset:
559567
if not isinstance(other, Coordinates):
560568
other = Dataset(coords=other).coords
561569

562-
coords, indexes = merge_coordinates_without_align([self, other])
570+
coords, indexes = merge_coordinates_without_align([self, other], compat=compat)
563571
coord_names = set(coords)
564572
return Dataset._construct_direct(
565573
variables=coords, coord_names=coord_names, indexes=indexes

xarray/core/dataarray.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4899,7 +4899,8 @@ def _binary_op(
48994899
if not reflexive
49004900
else f(other_variable_or_arraylike, self.variable)
49014901
)
4902-
coords, indexes = self.coords._merge_raw(other_coords, reflexive)
4902+
coords, indexes = self.coords._merge_raw(
4903+
other_coords, reflexive, compat=OPTIONS["arithmetic_compat"])
49034904
name = result_name([self, other])
49044905

49054906
return self._replace(variable, coords, name, indexes=indexes)
@@ -4919,7 +4920,7 @@ def _inplace_binary_op(self, other: DaCompatible, f: Callable) -> Self:
49194920
other_coords = getattr(other, "coords", None)
49204921
other_variable = getattr(other, "variable", other)
49214922
try:
4922-
with self.coords._merge_inplace(other_coords):
4923+
with self.coords._merge_inplace(other_coords, compat=OPTIONS["arithmetic_compat"]):
49234924
f(self.variable, other_variable)
49244925
except MergeError as exc:
49254926
raise MergeError(

xarray/core/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7765,7 +7765,7 @@ def apply_over_both(lhs_data_vars, rhs_data_vars, lhs_vars, rhs_vars):
77657765
return type(self)(new_data_vars)
77667766

77677767
other_coords: Coordinates | None = getattr(other, "coords", None)
7768-
ds = self.coords.merge(other_coords)
7768+
ds = self.coords.merge(other_coords, compat=OPTIONS["arithmetic_compat"])
77697769

77707770
if isinstance(other, Dataset):
77717771
new_vars = apply_over_both(

xarray/core/options.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,17 @@
22

33
import warnings
44
from collections.abc import Sequence
5-
from typing import TYPE_CHECKING, Any, Literal, TypedDict
5+
from typing import TYPE_CHECKING, Any, Literal, TypedDict, get_args
66

77
from xarray.core.utils import FrozenDict
8+
from xarray.core.types import CompatOptions
9+
from xarray.util.deprecation_helpers import CombineKwargDefault, _ARITHMETIC_COMPAT_DEFAULT
810

911
if TYPE_CHECKING:
1012
from matplotlib.colors import Colormap
1113

1214
Options = Literal[
15+
"arithmetic_compat",
1316
"arithmetic_join",
1417
"chunk_manager",
1518
"cmap_divergent",
@@ -40,6 +43,7 @@
4043

4144
class T_Options(TypedDict):
4245
arithmetic_broadcast: bool
46+
arithmetic_compat: CompatOptions | CombineKwargDefault
4347
arithmetic_join: Literal["inner", "outer", "left", "right", "exact"]
4448
chunk_manager: str
4549
cmap_divergent: str | Colormap
@@ -70,6 +74,7 @@ class T_Options(TypedDict):
7074

7175
OPTIONS: T_Options = {
7276
"arithmetic_broadcast": True,
77+
"arithmetic_compat": _ARITHMETIC_COMPAT_DEFAULT,
7378
"arithmetic_join": "inner",
7479
"chunk_manager": "dask",
7580
"cmap_divergent": "RdBu_r",
@@ -109,6 +114,7 @@ def _positive_integer(value: Any) -> bool:
109114

110115
_VALIDATORS = {
111116
"arithmetic_broadcast": lambda value: isinstance(value, bool),
117+
"arithmetic_compat": get_args(CompatOptions).__contains__,
112118
"arithmetic_join": _JOIN_OPTIONS.__contains__,
113119
"display_max_children": _positive_integer,
114120
"display_max_rows": _positive_integer,
@@ -178,18 +184,35 @@ class set_options:
178184
179185
Parameters
180186
----------
187+
arithmetic_broadcast: bool, default: True
188+
Whether to perform automatic broadcasting in binary operations.
189+
arithmetic_compat: {"identical", "equals", "broadcast_equals", "no_conflicts",
190+
"override", "minimal"}, default: "minimal"
191+
How to compare non-index coordinates of the same name for potential
192+
conflicts when performing binary operations. (For the alignment of index
193+
coordinates in binary operations, see `arithmetic_join`.)
194+
195+
- "identical": all values, dimensions and attributes of the coordinates
196+
must be the same.
197+
- "equals": all values and dimensions of the coordinates must be the
198+
same.
199+
- "broadcast_equals": all values of the coordinates must be equal after
200+
broadcasting to ensure common dimensions.
201+
- "no_conflicts": only values which are not null in both coordinates
202+
must be equal. The returned coordinate then contains the combination
203+
of all non-null values.
204+
- "override": skip comparing and take the coordinates from the first
205+
operand.
206+
- "minimal": drop conflicting coordinates.
181207
arithmetic_join : {"inner", "outer", "left", "right", "exact"}, default: "inner"
182-
DataArray/Dataset alignment in binary operations:
208+
DataArray/Dataset index alignment in binary operations:
183209
184210
- "outer": use the union of object indexes
185211
- "inner": use the intersection of object indexes
186212
- "left": use indexes from the first object with each dimension
187213
- "right": use indexes from the last object with each dimension
188214
- "exact": instead of aligning, raise `ValueError` when indexes to be
189215
aligned are not equal
190-
- "override": if indexes are of same size, rewrite indexes to be
191-
those of the first object with that dimension. Indexes for the same
192-
dimension must have the same size in all objects.
193216
chunk_manager : str, default: "dask"
194217
Chunk manager to use for chunked array computations when multiple
195218
options are installed.

xarray/structure/merge.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,18 @@ def merge_collected(
332332
FutureWarning,
333333
)
334334
except MergeError:
335+
if isinstance(compat, CombineKwargDefault) and compat == "minimal":
336+
emit_user_level_warning(
337+
compat.warning_message(
338+
f"Here we have dropped the variable {name!r} due to either a "
339+
"failed equality test between variables of this name, or an "
340+
"inability to perform such an equality test. The default in "
341+
"future will be to retain the first instance of variable "
342+
f"{name!r} without attempting an equality check."
343+
),
344+
FutureWarning,
345+
)
346+
335347
if compat != "minimal":
336348
# we need more than "minimal" compatibility (for which
337349
# we drop conflicting coordinates)
@@ -433,6 +445,7 @@ def merge_coordinates_without_align(
433445
prioritized: Mapping[Any, MergeElement] | None = None,
434446
exclude_dims: AbstractSet = frozenset(),
435447
combine_attrs: CombineAttrsOptions = "override",
448+
compat: CompatOptions | CombineKwargDefault = "minimal",
436449
) -> tuple[dict[Hashable, Variable], dict[Hashable, Index]]:
437450
"""Merge variables/indexes from coordinates without automatic alignments.
438451
@@ -457,7 +470,7 @@ def merge_coordinates_without_align(
457470
# TODO: indexes should probably be filtered in collected elements
458471
# before merging them
459472
merged_coords, merged_indexes = merge_collected(
460-
filtered, prioritized, combine_attrs=combine_attrs
473+
filtered, prioritized, compat=compat, combine_attrs=combine_attrs
461474
)
462475
merged_indexes = filter_indexes_from_coords(merged_indexes, set(merged_coords))
463476

xarray/tests/test_dataarray.py

Lines changed: 67 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
DataArray,
2626
Dataset,
2727
IndexVariable,
28+
MergeError,
2829
Variable,
2930
align,
3031
broadcast,
@@ -2477,27 +2478,30 @@ def test_math_with_coords(self) -> None:
24772478
actual = 1 + orig
24782479
assert_identical(expected, actual)
24792480

2480-
actual = orig + orig[0, 0]
2481-
exp_coords = {k: v for k, v in coords.items() if k != "lat"}
2482-
expected = DataArray(
2483-
orig.values + orig.values[0, 0], exp_coords, dims=["x", "y"]
2484-
)
2485-
assert_identical(expected, actual)
2481+
with xr.set_options(arithmetic_compat='minimal'):
2482+
actual = orig + orig[0, 0]
2483+
exp_coords = {k: v for k, v in coords.items() if k != "lat"}
2484+
expected = DataArray(
2485+
orig.values + orig.values[0, 0], exp_coords, dims=["x", "y"]
2486+
)
2487+
assert_identical(expected, actual)
24862488

2487-
actual = orig[0, 0] + orig
2488-
assert_identical(expected, actual)
2489+
actual = orig[0, 0] + orig
2490+
assert_identical(expected, actual)
24892491

2490-
actual = orig[0, 0] + orig[-1, -1]
2491-
expected = DataArray(orig.values[0, 0] + orig.values[-1, -1], {"c": -999})
2492-
assert_identical(expected, actual)
2492+
actual = orig[0, 0] + orig[-1, -1]
2493+
expected = DataArray(
2494+
orig.values[0, 0] + orig.values[-1, -1],
2495+
{"c": -999})
2496+
assert_identical(expected, actual)
24932497

2494-
actual = orig[:, 0] + orig[0, :]
2495-
exp_values = orig[:, 0].values[:, None] + orig[0, :].values[None, :]
2496-
expected = DataArray(exp_values, exp_coords, dims=["x", "y"])
2497-
assert_identical(expected, actual)
2498+
actual = orig[:, 0] + orig[0, :]
2499+
exp_values = orig[:, 0].values[:, None] + orig[0, :].values[None, :]
2500+
expected = DataArray(exp_values, exp_coords, dims=["x", "y"])
2501+
assert_identical(expected, actual)
24982502

2499-
actual = orig[0, :] + orig[:, 0]
2500-
assert_identical(expected.transpose(transpose_coords=True), actual)
2503+
actual = orig[0, :] + orig[:, 0]
2504+
assert_identical(expected.transpose(transpose_coords=True), actual)
25012505

25022506
actual = orig - orig.transpose(transpose_coords=True)
25032507
expected = DataArray(np.zeros((2, 3)), orig.coords)
@@ -2507,14 +2511,53 @@ def test_math_with_coords(self) -> None:
25072511
assert_identical(expected.transpose(transpose_coords=True), actual)
25082512

25092513
alt = DataArray([1, 1], {"x": [-1, -2], "c": "foo", "d": 555}, "x")
2510-
actual = orig + alt
2511-
expected = orig + 1
2512-
expected.coords["d"] = 555
2513-
del expected.coords["c"]
2514-
assert_identical(expected, actual)
25152514

2516-
actual = alt + orig
2517-
assert_identical(expected, actual)
2515+
with xr.set_options(arithmetic_compat='minimal'):
2516+
actual = orig + alt
2517+
expected = orig + 1
2518+
expected.coords["d"] = 555
2519+
del expected.coords["c"]
2520+
assert_identical(expected, actual)
2521+
2522+
actual = alt + orig
2523+
assert_identical(expected, actual)
2524+
2525+
def test_math_with_arithmetic_compat_options(self) -> None:
2526+
# Setting up a clash of non-index coordinate 'foo':
2527+
a = xr.DataArray(
2528+
data=[0, 0, 0],
2529+
dims=["x"],
2530+
coords={
2531+
"x": [1, 2, 3],
2532+
"foo": (["x"], [1.0, 2.0, np.nan]),
2533+
}
2534+
)
2535+
b = xr.DataArray(
2536+
data=[0, 0, 0],
2537+
dims=["x"],
2538+
coords={
2539+
"x": [1, 2, 3],
2540+
"foo": (["x"], [np.nan, 2.0, 3.0]),
2541+
}
2542+
)
2543+
2544+
with xr.set_options(arithmetic_compat="minimal"):
2545+
assert_equal(a + b, a.drop_vars("foo"))
2546+
2547+
with xr.set_options(arithmetic_compat="override"):
2548+
assert_equal(a + b, a)
2549+
assert_equal(b + a, b)
2550+
2551+
with xr.set_options(arithmetic_compat="no_conflicts"):
2552+
expected = a.assign_coords(foo=(["x"], [1.0, 2.0, 3.0]))
2553+
assert_equal(a + b, expected)
2554+
assert_equal(b + a, expected)
2555+
2556+
with xr.set_options(arithmetic_compat="equals"):
2557+
with pytest.raises(MergeError):
2558+
a + b
2559+
with pytest.raises(MergeError):
2560+
b + a
25182561

25192562
def test_index_math(self) -> None:
25202563
orig = DataArray(range(3), dims="x", name="x")

0 commit comments

Comments
 (0)