Skip to content

Commit c6ea042

Browse files
authored
Merge pull request atmtools#143 from gerritholl/fix-uada-ufunc
Make UADA work with __array_ufunc__
2 parents abd2367 + 11d65f9 commit c6ea042

File tree

1 file changed

+30
-0
lines changed

1 file changed

+30
-0
lines changed

typhon/physics/units/tools.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ class UnitsAwareDataArray(xarray.DataArray):
1818
"""Like xarray.DataArray, but transfers units
1919
"""
2020

21+
# need to keep both __array_wrap__ and __array_ufunc__. Although the
22+
# former supersedes the latter, xarrays methods explicitly call the
23+
# former sometimes.
2124
def __array_wrap__(self, obj, context=None):
2225
new_var = super().__array_wrap__(obj, context)
2326
if self.attrs.get("units"):
@@ -56,6 +59,33 @@ def _apply_rbinary_op_to_units(self, func, other, x):
5659
ureg.Quantity(1, self.attrs["units"]),).u)
5760
return x
5861

62+
def __array_ufunc__(self, ufunc, method, *args, **kwargs):
63+
new_var = super().__array_ufunc__(ufunc, method, *args, **kwargs)
64+
# make sure we're still UADA
65+
new_var = self.__class__(new_var)
66+
if self.attrs.get("units"):
67+
if method == "__call__":
68+
q = ufunc(ureg.Quantity(1, self.attrs.get("units")))
69+
try:
70+
u = q.u
71+
except AttributeError:
72+
if (ureg(self.attrs["units"]).dimensionless or
73+
new_var.dtype.kind == "b"):
74+
# expected, see https://github.com/hgrecco/pint/issues/482
75+
u = ureg.dimensionless
76+
else:
77+
raise
78+
# for exp and log, values are not set correctly. I'm
79+
# not sure why. Perhaps related to
80+
# https://github.com/hgrecco/pint/issues/493
81+
new_var.values = ufunc(ureg.Quantity(self.values, self.units))
82+
new_var.attrs["units"] = str(u)
83+
else: # unary operators? always retain units?
84+
raise NotImplementedError("Not implented")
85+
new_var.attrs["units"] = str(self.attrs.get("units"))
86+
87+
return new_var
88+
5989
# pow is different because resulting unit depends on argument, not on
6090
# unit of argument
6191
def __pow__(self, other):

0 commit comments

Comments
 (0)