Skip to content

Commit 79de128

Browse files
authored
fix: support setting with histograms (#1036)
Signed-off-by: Henry Schreiner <[email protected]>
1 parent 622e0fe commit 79de128

File tree

5 files changed

+137
-40
lines changed

5 files changed

+137
-40
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,6 @@ minversion = "6.0"
145145
junit_family = "xunit2"
146146
addopts = [
147147
"-ra",
148-
"--showlocals",
149148
"--strict-markers",
150149
"--strict-config",
151150
"--import-mode=importlib",

src/boost_histogram/axis/__init__.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,16 @@ def __iter__(
189189
) -> Iterator[float] | Iterator[str] | Iterator[tuple[float, float]]:
190190
return self._ax.__iter__() # type: ignore[no-any-return]
191191

192+
def _process_callable(self, value: AxCallOrInt | None, *, default: int) -> int:
193+
"""
194+
This processes a callable in start or stop. None gets replaced by default.
195+
"""
196+
if value is None:
197+
return default
198+
if callable(value):
199+
return value(self)
200+
return value
201+
192202
def _process_loc(
193203
self, start: AxCallOrInt | None, stop: AxCallOrInt | None
194204
) -> tuple[int, int]:
@@ -201,18 +211,15 @@ def _process_loc(
201211
is turned off if underflow is not None.
202212
"""
203213

204-
def _process_internal(item: AxCallOrInt | None, default: int) -> int:
205-
return default if item is None else item(self) if callable(item) else item
206-
207214
underflow = -1 if self._ax.traits_underflow else 0
208215
overflow = 1 if self._ax.traits_overflow else 0
209216

210217
# Non-ordered axes only use flow if integrating from None to None
211218
if not self._ax.traits_ordered and not (start is None and stop is None):
212219
overflow = 0
213220

214-
begin = _process_internal(start, underflow)
215-
end = _process_internal(stop, len(self) + overflow)
221+
begin = self._process_callable(start, default=underflow)
222+
end = self._process_callable(stop, default=len(self) + overflow)
216223

217224
return begin, end
218225

src/boost_histogram/histogram.py

Lines changed: 53 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,37 +1128,33 @@ def __setitem__(self, index: IndexingExpr, value: ArrayLike | Accumulator) -> No
11281128
If an array is given that does not match, if it does match the
11291129
with-overflow size, it fills that.
11301130
1131-
PLANNED (not yet supported):
1132-
11331131
h[a:] = h2
11341132
11351133
If another histogram is given, that must either match with or without
11361134
overflow, where the overflow bins must be overflow bins (that is,
11371135
you cannot set a histogram's flow bins from another histogram that
1138-
is 2 larger). Bin edges must be a close match, as well. If you don't
1139-
want this level of type safety, just use ``h[...] = h2.view()``.
1136+
is 2 larger). If you don't want this level of type safety, just use
1137+
``h[...] = h2.view()``.
11401138
"""
11411139
indexes = self._compute_commonindex(index)
11421140

1143-
if isinstance(value, Histogram):
1144-
raise TypeError("Not supported yet")
1145-
1146-
value = np.asarray(value)
1141+
in_array = np.asarray(value)
11471142
view: Any = self.view(flow=True)
11481143

11491144
value_shape: tuple[int, ...]
1145+
11501146
# Support raw arrays for accumulators, the final dimension is the constructor values
11511147
if (
1152-
value.ndim > 0
1148+
in_array.ndim > 0
11531149
and len(view.dtype) > 0
1154-
and len(value.dtype) == 0
1155-
and len(view.dtype) == value.shape[-1]
1150+
and len(in_array.dtype) == 0
1151+
and len(view.dtype) == in_array.shape[-1]
11561152
):
1157-
value_shape = value.shape[:-1]
1158-
value_ndim = value.ndim - 1
1153+
value_shape = in_array.shape[:-1]
1154+
value_ndim = in_array.ndim - 1
11591155
else:
1160-
value_shape = value.shape
1161-
value_ndim = value.ndim
1156+
value_shape = in_array.shape
1157+
value_ndim = in_array.ndim
11621158

11631159
# NumPy does not broadcast partial slices, but we would need
11641160
# to allow it (because we do allow broadcasting up dimensions)
@@ -1174,43 +1170,67 @@ def __setitem__(self, index: IndexingExpr, value: ArrayLike | Accumulator) -> No
11741170
has_overflow = self.axes[n].traits.overflow
11751171

11761172
if isinstance(request, slice):
1173+
# This ensures that callable start/stop are handled
1174+
start, stop = self.axes[n]._process_loc(request.start, request.stop)
1175+
11771176
# Only consider underflow/overflow if the endpoints are not given
1178-
use_underflow = has_underflow and request.start is None
1179-
use_overflow = has_overflow and request.stop is None
1177+
use_underflow = has_underflow and start < 0
1178+
use_overflow = has_overflow and stop > len(self.axes[n])
1179+
1180+
# If the input is a histogram, we need to exactly match underflow/overflow
1181+
if isinstance(value, Histogram):
1182+
in_underflow = value.axes[n].traits.underflow
1183+
in_overflow = value.axes[n].traits.overflow
1184+
1185+
if use_underflow != in_underflow or use_overflow != in_overflow:
1186+
msg = (
1187+
f"Cannot set histogram with underflow={in_underflow} and overflow={in_overflow} "
1188+
f"to a histogram slice with underflow={use_underflow} and overflow={use_overflow}"
1189+
)
1190+
raise ValueError(msg)
1191+
1192+
# Convert to non-flow coordinates
1193+
start_real = start + 1 if has_underflow else start
1194+
stop_real = stop + 1 if has_underflow else stop
11801195

1181-
# Make the limits explicit since we may need to shift them
1182-
start = 0 if request.start is None else request.start
1183-
stop = len(self.axes[n]) if request.stop is None else request.stop
1184-
request_len = stop - start
1196+
# This is the total requested length without flow bins
1197+
request_len = min(stop, len(self.axes[n])) - max(start, 0)
11851198

11861199
# If set to a scalar, then treat it like broadcasting without flow bins
11871200
# Normal requests here too
1188-
if value_ndim == 0 or request_len == value_shape[value_n]:
1189-
start += has_underflow
1190-
stop += has_underflow
1201+
# Also single element broadcasting
1202+
if (
1203+
value_ndim == 0
1204+
or request_len == value_shape[value_n]
1205+
or value_shape[value_n] == 1
1206+
):
1207+
start_real += 1 if start < 0 else 0
1208+
stop_real -= 1 if stop > len(self.axes[n]) else 0
11911209

11921210
# Expanded setting
11931211
elif request_len + use_underflow + use_overflow == value_shape[value_n]:
1194-
start += has_underflow and not use_underflow
1195-
stop += has_underflow + (has_overflow and use_overflow)
1196-
1197-
# Single element broadcasting
1198-
elif value_shape[value_n] == 1:
1199-
start += has_underflow
1200-
stop += has_underflow
1212+
pass
12011213

12021214
else:
12031215
msg = f"Mismatched shapes in dimension {n}"
12041216
msg += f", {value_shape[n]} != {request_len}"
12051217
if use_underflow or use_overflow:
12061218
msg += f" or {request_len + use_underflow + use_overflow}"
12071219
raise ValueError(msg)
1208-
indexes[n] = slice(start, stop, request.step)
1220+
logger.debug(
1221+
"__setitem__: axis %i, start: %i (actual %i), stop: %i (actual %i)",
1222+
n,
1223+
start,
1224+
start_real,
1225+
stop,
1226+
stop_real,
1227+
)
1228+
indexes[n] = slice(start_real, stop_real, request.step)
12091229
value_n += 1
12101230
else:
12111231
indexes[n] = request + has_underflow
12121232

1213-
view[tuple(indexes)] = value
1233+
view[tuple(indexes)] = in_array
12141234

12151235
def project(self, *args: int) -> Self | float | Accumulator:
12161236
"""

tests/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ foreach(TEST_FILE IN LISTS BOOST_HIST_PY_TESTS)
2121
get_filename_component(TEST_NAME "${TEST_FILE}" NAME_WE)
2222
add_test(
2323
NAME ${TEST_NAME}
24-
COMMAND ${Python_EXECUTABLE} -m pytest "${TEST_FILE}" --rootdir=.
24+
COMMAND ${Python_EXECUTABLE} -m pytest "${TEST_FILE}" --rootdir=. --showlocals
2525
WORKING_DIRECTORY "${PROJECT_BINARY_DIR}")
2626
set_tests_properties(${TEST_NAME} PROPERTIES SKIP_RETURN_CODE 5)
2727
endforeach()

tests/test_histogram_indexing.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,18 @@ def test_set_range_with_scalar():
295295
assert h[5] == 0
296296

297297

298+
def test_set_range_with_scalar_callable():
299+
h = bh.Histogram(bh.axis.Integer(0, 10))
300+
h[2:len] = 42
301+
302+
assert h[1] == 0
303+
assert h[2] == 42
304+
assert h[3] == 42
305+
assert h[4] == 42
306+
assert h[5] == 42
307+
assert h[bh.overflow] == 0
308+
309+
298310
def test_set_all_with_scalar():
299311
h = bh.Histogram(bh.axis.Integer(0, 10))
300312
h[:] = 42
@@ -483,3 +495,62 @@ def test_large_index():
483495
)
484496
assert h.axes[0].value(6) == 99_999_001
485497
assert h.axes[0].index(99_999_001) == 6
498+
499+
500+
def test_scaling_slice():
501+
h = bh.Histogram(bh.axis.Regular(3, 0, 3), bh.axis.StrCategory(["a", "b"]))
502+
h.fill([1, 1, 2], "a")
503+
h.fill([0], "b")
504+
505+
h[:, bh.loc("a")] *= 2
506+
507+
assert h[1, 0] == approx(4)
508+
assert h[2, 0] == approx(2)
509+
assert h[0, 1] == approx(1)
510+
511+
512+
def test_scaling_slice_weight():
513+
h = bh.Histogram(
514+
bh.axis.Regular(3, 0, 3),
515+
bh.axis.StrCategory(["a", "b"]),
516+
storage=bh.storage.Weight(),
517+
)
518+
h.fill([1, 1, 2], "a")
519+
h.fill([0], "b")
520+
521+
h[:, bh.loc("a")] *= 2
522+
523+
assert h[1, 0].value == approx(4)
524+
assert h[2, 0].value == approx(2)
525+
assert h[0, 1].value == approx(1)
526+
527+
528+
def test_setting_histogram_mismatch():
529+
h = bh.Histogram(bh.axis.Regular(10, 0, 10), bh.axis.Integer(0, 20))
530+
531+
h[:, 0] = bh.Histogram(bh.axis.Regular(10, 0, 10))
532+
h[0:, 0] = bh.Histogram(bh.axis.Regular(10, 0, 10, underflow=False))
533+
h[:len, 0] = bh.Histogram(bh.axis.Regular(10, 0, 10, overflow=False))
534+
h[0:len, 0] = bh.Histogram(
535+
bh.axis.Regular(10, 0, 10, underflow=False, overflow=False)
536+
)
537+
with pytest.raises(ValueError, match="Cannot set histogram with underflow"):
538+
h[0:, 0] = bh.Histogram(bh.axis.Regular(10, 0, 10))
539+
with pytest.raises(ValueError, match="Cannot set histogram with underflow"):
540+
h[:len, 0] = bh.Histogram(bh.axis.Regular(10, 0, 10))
541+
with pytest.raises(ValueError, match="Cannot set histogram with underflow"):
542+
h[:, 0] = bh.Histogram(bh.axis.Regular(10, 0, 10, underflow=False))
543+
with pytest.raises(ValueError, match="Cannot set histogram with underflow"):
544+
h[:, 0] = bh.Histogram(bh.axis.Regular(10, 0, 10, overflow=False))
545+
with pytest.raises(ValueError, match="Cannot set histogram with underflow"):
546+
h[:, 0] = bh.Histogram(
547+
bh.axis.Regular(10, 0, 10, underflow=False, overflow=False)
548+
)
549+
with pytest.raises(ValueError, match="Cannot set histogram with underflow"):
550+
h[0:, 0] = bh.Histogram(
551+
bh.axis.Regular(10, 0, 10, underflow=False, overflow=False)
552+
)
553+
with pytest.raises(ValueError, match="Cannot set histogram with underflow"):
554+
h[:len, 0] = bh.Histogram(
555+
bh.axis.Regular(10, 0, 10, underflow=False, overflow=False)
556+
)

0 commit comments

Comments
 (0)