Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,12 @@ API changes

- ``astropy.modeling``

- Note: Comparisons of model parameters with array-like values now
yields a Numpy boolean array as one would get with normal Numpy
array comparison. Previously this returned a scalar True or False,
with True only if the comparison was true for all elements compared,
which could lead to confusing circumstances. [#3912]

- Renamed the parameters of ``RotateNative2Celestial`` and
``RotateCelestial2Native`` from ``phi``, ``theta``, ``psi`` to
``lon``, ``lat`` and ``lon_pole``. [#3578]
Expand Down
20 changes: 10 additions & 10 deletions astropy/modeling/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,10 +678,7 @@ def _create_value_wrapper(wrapper, model):

def __array__(self, dtype=None):
# Make np.asarray(self) work a little more straightforwardly
if self._model is None:
return np.array([], dtype=np.float)
else:
return np.asarray(self.value, dtype=dtype)
return np.asarray(self.value, dtype=dtype)

def __nonzero__(self):
if self._model is None:
Expand Down Expand Up @@ -728,22 +725,25 @@ def __rtruediv__(self, val):
return val / self.value

def __eq__(self, val):
return (np.asarray(self) == np.asarray(val)).all()
if self._model is None:
return super(Parameter, self).__eq__(val)

return self.__array__() == val

def __ne__(self, val):
return not (np.asarray(self) == np.asarray(val)).all()
return self.__array__() != val

def __lt__(self, val):
return (np.asarray(self) < np.asarray(val)).all()
return self.__array__() < val

def __gt__(self, val):
return (np.asarray(self) > np.asarray(val)).all()
return self.__array__() > val

def __le__(self, val):
return (np.asarray(self) <= np.asarray(val)).all()
return self.__array__() <= val

def __ge__(self, val):
return (np.asarray(self) >= np.asarray(val)).all()
return self.__array__() >= val

def __neg__(self):
return -self.value
Expand Down
4 changes: 2 additions & 2 deletions astropy/modeling/tests/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def test_poly1d_multiple_sets(self):
p1 = models.Polynomial1D(3, n_models=3)
utils.assert_equal(p1.parameters, [0.0, 0.0, 0.0, 0, 0, 0,
0, 0, 0, 0, 0, 0])
utils.assert_equal(p1.c0, [0, 0, 0])
utils.assert_array_equal(p1.c0, [0, 0, 0])
p1.c0 = [10, 10, 10]
utils.assert_equal(p1.parameters, [10.0, 10.0, 10.0, 0, 0,
0, 0, 0, 0, 0, 0, 0])
Expand Down Expand Up @@ -289,7 +289,7 @@ def test_shift_model_parameters1d(self):
def test_scale_model_parametersnd(self):
sc1 = models.Scale([2, 2])
sc1.factor = [3, 3]
assert sc1.factor == [3, 3]
assert np.all(sc1.factor == [3, 3])
utils.assert_array_equal(sc1.factor.value, [3, 3])

def test_parameters_wrong_shape(self):
Expand Down