Skip to content

Commit 69af3cf

Browse files
authored
fix: serialization was broken for N>1D complex storage histograms (#1043)
Signed-off-by: Henry Schreiner <[email protected]>
1 parent d83aea0 commit 69af3cf

File tree

3 files changed

+31
-9
lines changed

3 files changed

+31
-9
lines changed

src/boost_histogram/histogram.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1250,7 +1250,7 @@ def __setitem__(self, index: IndexingExpr, value: ArrayLike | Accumulator) -> No
12501250
pass
12511251

12521252
else:
1253-
msg = f"Mismatched shapes in dimension {n}"
1253+
msg = f"Mismatched shapes {value_shape} in dimension {n}"
12541254
msg += f", {value_shape[n]} != {request_len}"
12551255
if use_underflow or use_overflow:
12561256
msg += f" or {request_len + use_underflow + use_overflow}"

src/boost_histogram/serialization/_storage.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,9 @@ def _data_from_dict(data: dict[str, Any], /) -> np.typing.NDArray[Any]:
108108
if storage_type in {"int", "double"}:
109109
return data["values"]
110110
if storage_type == "weighted":
111-
return np.stack([data["values"], data["variances"]]).T
111+
return np.stack([data["values"], data["variances"]], axis=-1)
112112
if storage_type == "mean":
113-
return np.stack(
114-
[data["counts"], data["values"], data["variances"]],
115-
).T
113+
return np.stack([data["counts"], data["values"], data["variances"]], axis=-1)
116114
if storage_type == "weighted_mean":
117115
return np.stack(
118116
[
@@ -121,6 +119,7 @@ def _data_from_dict(data: dict[str, Any], /) -> np.typing.NDArray[Any]:
121119
data["values"],
122120
data["variances"],
123121
],
124-
).T
122+
axis=-1,
123+
)
125124

126125
raise TypeError(f"Unsupported storage type: {storage_type}")

tests/test_serialization_uhi.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,6 @@ def test_round_trip_weighted() -> None:
151151
data = to_uhi(h)
152152
h2 = from_uhi(data)
153153

154-
print(h.view())
155-
print(h2.view())
156-
157154
assert pytest.approx(np.array(h.axes[0])) == np.array(h2.axes[0])
158155
assert np.asarray(h) == pytest.approx(h2)
159156

@@ -285,3 +282,29 @@ def test_remove_writer_info() -> None:
285282
"writer_info": {"boost-histogram": {"foo": "bar"}},
286283
}
287284
assert remove_writer_info(d, library="c") == d
285+
286+
287+
def test_convert_weight() -> None:
288+
h = bh.Histogram(
289+
bh.axis.Regular(3, 13, 10, __dict__={"name": "x"}),
290+
bh.axis.StrCategory(["one", "two"]),
291+
storage=bh.storage.Weight(),
292+
)
293+
data = h._to_uhi_()
294+
h2 = bh.Histogram(data)
295+
296+
assert h == h2
297+
298+
299+
def test_convert_weightmean() -> None:
300+
h = bh.Histogram(
301+
bh.axis.Regular(12, 0, 1),
302+
bh.axis.StrCategory(["a", "b", "c", "d", "e", "f", "g"]),
303+
bh.axis.Boolean(),
304+
bh.axis.Integer(1, 18),
305+
storage=bh.storage.WeightedMean(),
306+
)
307+
data = h._to_uhi_()
308+
h2 = bh.Histogram(data)
309+
310+
assert h.axes == h2.axes

0 commit comments

Comments
 (0)