Skip to content

Commit bfb7fef

Browse files
fix numpy.dot error if the first parameter is a pint quantity and the second parameter isn't a pint quantity (#2214)
* fix dot nonquantity * fix test * _dimensionless_if_needed * lint --------- Co-authored-by: unknown <[email protected]> Co-authored-by: oxygen-dioxide <[email protected]>
1 parent 6e0d038 commit bfb7fef

File tree

2 files changed

+34
-8
lines changed

2 files changed

+34
-8
lines changed

pint/facets/numpy/numpy_func.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -788,6 +788,25 @@ def _correlate(a, v, mode="valid", **kwargs):
788788
return a.units._REGISTRY.Quantity(ret, units)
789789

790790

791+
def _dimensionless_if_needed(*args):
792+
registry = None
793+
for arg in args:
794+
if _is_quantity(arg):
795+
registry = arg.units._REGISTRY
796+
break
797+
if registry is None:
798+
raise ValueError(
799+
"At least one argument must be a Quantity to determine the registry."
800+
)
801+
new_args = []
802+
for arg in args:
803+
if _is_quantity(arg):
804+
new_args.append(arg)
805+
else:
806+
new_args.append(registry.Quantity(arg, "dimensionless"))
807+
return new_args
808+
809+
791810
def implement_mul_func(func):
792811
# If NumPy is not available, do not attempt implement that which does not exist
793812
if np is None:
@@ -797,15 +816,12 @@ def implement_mul_func(func):
797816

798817
@implements(func_str, "function")
799818
def implementation(a, b, **kwargs):
819+
a, b = _dimensionless_if_needed(a, b)
800820
a = _base_unit_if_needed(a)
801-
units = a.units
802-
if hasattr(b, "units"):
803-
b = _base_unit_if_needed(b)
804-
units *= b.units
805-
b = b._magnitude
806-
807-
mag = func(a._magnitude, b, **kwargs)
808-
return a.units._REGISTRY.Quantity(mag, units)
821+
b = _base_unit_if_needed(b)
822+
units = a.units * b.units
823+
mag = func(a._magnitude, b._magnitude, **kwargs)
824+
return mag * units
809825

810826

811827
for func_str in ("cross", "dot"):

pint/testsuite/test_numpy.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,16 @@ def test_dot_numpy_func(self):
470470
3 * self.ureg.m,
471471
)
472472

473+
@helpers.requires_array_function_protocol()
474+
def test_dot_nonquantity(self):
475+
a = np.array([0, 0, 1, 0])
476+
b = self.Q_(self.q.ravel(), "m")
477+
expected = 3 * self.ureg.m
478+
result1 = np.dot(a, b)
479+
helpers.assert_quantity_equal(result1, expected)
480+
result2 = np.dot(b, a)
481+
helpers.assert_quantity_equal(result2, expected)
482+
473483
@helpers.requires_array_function_protocol()
474484
def test_einsum(self):
475485
a = np.arange(25).reshape(5, 5) * self.ureg.m

0 commit comments

Comments
 (0)