Skip to content

Commit 9b94e78

Browse files
authored
raise error when select on bad patch coord (#556)
* raise error when select on bad patch coord * fix select typo * added coord manager select tests * select on non-dim coords * tests for selecting non-dimensional coord * review
1 parent bfc966f commit 9b94e78

File tree

5 files changed

+372
-20
lines changed

5 files changed

+372
-20
lines changed

dascore/core/coordmanager.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -115,22 +115,31 @@ def _get_indexers_and_new_coords_dict(
115115
):
116116
"""Get reductions for each dimension."""
117117
dim_reductions = {x: slice(None, None) for x in cm.dims}
118-
dimap = cm.dim_map
119118
new_coords = dict(cm._get_dim_array_dict(keep_coord=True))
120119
for coord_name, vals in kwargs.items():
121-
# this is not a selectable coord, just skip.
122-
if coord_name not in cm.coord_map or not len(cm.dim_map[coord_name]):
123-
continue
120+
# All coordinates should exist in coord_map (filtered by
121+
# _get_single_dim_kwarg_list)
122+
assert coord_name in cm.coord_map
124123
coord = cm.coord_map[coord_name]
124+
coord_dims = cm.dim_map[coord_name]
125125
_ensure_1d_coord(coord, coord_name)
126-
dim_name = dimap[coord_name][0]
126+
# Handle non-dimensional coordinates (not tied to any dimension)
127+
if not len(coord_dims):
128+
# Apply operation directly to the non-dimensional coordinate
129+
method = getattr(coord, operation)
130+
new_coord, _ = method(vals, relative=relative, samples=samples)
131+
# Update only this coordinate in new_coords, don't affect array indexing
132+
new_coords[coord_name] = (coord_dims, new_coord)
133+
continue
134+
# Handle dimensional coordinates (tied to exactly one dimension)
135+
dim_name = coord_dims[0]
127136
# different logic if we are using indices or values
128137
method = getattr(coord, operation)
129138
new_coord, reductions = method(vals, relative=relative, samples=samples)
130139
# this handles the case of out-of-bound selections.
131140
# These should be converted to degenerate coords.
132141
dim_reductions[dim_name] = reductions
133-
new_coords[coord_name] = (dimap[coord_name], new_coord)
142+
new_coords[coord_name] = (coord_dims, new_coord)
134143
# update other coords affected by change.
135144
_indirect_coord_updates(cm, dim_name, coord_name, reductions, new_coords)
136145
indexers = tuple(dim_reductions[x] for x in cm.dims)

dascore/proc/coords.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,12 @@
1010

1111
from dascore.constants import PatchType, select_values_description
1212
from dascore.core.coords import BaseCoord
13-
from dascore.exceptions import CoordError, ParameterError, PatchError
13+
from dascore.exceptions import (
14+
CoordError,
15+
ParameterError,
16+
PatchCoordinateError,
17+
PatchError,
18+
)
1419
from dascore.utils.docs import compose_docstring
1520
from dascore.utils.misc import get_parent_code_name, iterate
1621
from dascore.utils.patch import patch_function
@@ -487,7 +492,7 @@ def select(
487492
>>>
488493
>>> # Select only specific values along a dimension
489494
>>> distance = patch.get_array("distance")
490-
>>> new_distance_3 = patch.select(distace=distance[1::2])
495+
>>> new_distance_3 = patch.select(distance=distance[1::2])
491496
492497
Notes
493498
-----
@@ -497,14 +502,23 @@ def select(
497502
See [`Patch.order`](`dascore.Patch.order`).
498503
499504
"""
505+
# Check for and raise on invalid kwargs.
506+
if invalid_coords := set(kwargs) - set(patch.coords.coord_map):
507+
invalid_list = sorted(invalid_coords)
508+
valid_list = sorted(patch.coords.coord_map)
509+
msg = (
510+
f"Coordinate(s) {invalid_list} not found in patch coordinates: {valid_list}"
511+
)
512+
raise PatchCoordinateError(msg)
513+
500514
new_coords, data = patch.coords.select(
501515
**kwargs,
502516
array=patch.data,
503517
relative=relative,
504518
samples=samples,
505519
)
506-
# no slicing was performed, just return original.
507-
if data.shape == patch.data.shape:
520+
# no slicing was performed, just return original if coordinates also unchanged.
521+
if data.shape == patch.data.shape and new_coords == patch.coords:
508522
return patch
509523
if copy:
510524
data = data.copy()

dascore/proc/hampel.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -96,16 +96,12 @@ def hampel_filter(
9696
>>> data[10, 5] = 10 # Add a large spike
9797
>>> patch = patch.update(data=data)
9898
>>>
99-
>>> # Apply hampel filter along time dimension with 1.0 unit window
100-
>>> filtered = patch.hampel_filter(time=1.0, threshold=3.5)
99+
>>> # Apply hampel filter along time dimension with 0.2 unit window
100+
>>> filtered = patch.hampel_filter(time=0.2, threshold=3.5)
101101
>>> assert filtered.data.shape == patch.data.shape
102102
>>> # The spike should be reduced
103103
>>> assert abs(filtered.data[10, 5]) < abs(patch.data[10, 5])
104104
>>>
105-
>>> # Apply filter with a lower threshold for more aggressive filtering
106-
>>> filtered_aggressive = patch.hampel_filter(time=1.0, threshold=2.0)
107-
>>> assert isinstance(filtered_aggressive, dc.Patch)
108-
>>>
109105
>>> # Apply filter along multiple dimensions:
110106
>>> filtered_2d = patch.hampel_filter(time=1.0, distance=5.0, threshold=3.5)
111107
>>> assert filtered_2d.data.shape == patch.data.shape
@@ -115,10 +111,6 @@ def hampel_filter(
115111
... time=3, distance=3, samples=True, threshold=3.5
116112
... )
117113
>>>
118-
>>> # Use separable filtering for faster processing (approximation)
119-
>>> filtered_fast = patch.hampel_filter(
120-
... time=1.0, distance=5.0, threshold=3.5, separable=True
121-
... )
122114
"""
123115
if threshold <= 0 or not np.isfinite(threshold):
124116
msg = "hampel_filter threshold must be finite and greater than zero"

tests/test_core/test_coordmanager.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -621,6 +621,116 @@ def test_select_non_dim_coord(self, cm_basic):
621621
out, _ = new_cm.select(new_dim=(1, 20))
622622
assert new_cm == out
623623

624+
def test_select_non_dim_coord_shortens_coordinate(self, cm_basic):
625+
"""Test that selecting non-dimensional coords shortens only that coordinate."""
626+
# Add a non-dimensional coordinate with numeric values
627+
quality_scores = np.array([0.1, 0.5, 0.9, 0.2, 0.8, 0.3, 0.7])
628+
new_cm = cm_basic.update(quality=(None, quality_scores))
629+
# Select subset of quality scores using array indexing
630+
selected_indices = np.array([1, 3, 5]) # Select indices 1, 3, 5
631+
out, _ = new_cm.select(quality=selected_indices, samples=True)
632+
# Quality coordinate should be shortened
633+
expected_quality = quality_scores[selected_indices]
634+
assert np.array_equal(out.get_array("quality"), expected_quality)
635+
# Dimensional coordinates should be unchanged
636+
assert cm_basic.shape == out.shape
637+
assert np.array_equal(out.get_array("time"), new_cm.get_array("time"))
638+
assert np.array_equal(out.get_array("distance"), new_cm.get_array("distance"))
639+
640+
def test_select_non_dim_coord_with_boolean_mask(self, cm_basic):
641+
"""Test selecting non-dimensional coordinates using boolean arrays."""
642+
# Add a non-dimensional coordinate
643+
values = np.array([10, 20, 30, 40, 50, 60, 70])
644+
new_cm = cm_basic.update(sensor_values=(None, values))
645+
# Create boolean mask
646+
mask = values > 35 # Should select [40, 50, 60, 70]
647+
out, _ = new_cm.select(sensor_values=mask)
648+
# Only the non-dimensional coordinate should be affected
649+
expected_values = values[mask]
650+
assert np.array_equal(out.get_array("sensor_values"), expected_values)
651+
# Dimensional coordinates should remain unchanged
652+
for coord in set(cm_basic.coord_map) - {"sensor_values"}:
653+
assert cm_basic.get_coord(coord) == new_cm.get_coord(coord)
654+
655+
def test_select_multi_dim_coord_raises(self, cm_multidim):
656+
"""
657+
Coords that are associated with more than one dim cannot be selected
658+
because it could ruin the squareness of the patch.
659+
"""
660+
# Non-dim coord associated with one dimension should work.
661+
lat = cm_multidim.get_array("latitude")
662+
lat_mean = np.mean(lat)
663+
out, _ = cm_multidim.select(latitude=(..., lat_mean))
664+
assert isinstance(out, dc.CoordManager)
665+
# Multi-dim coord should raise CoordError
666+
msg = "Only 1 dimensional coordinates"
667+
with pytest.raises(CoordError, match=msg):
668+
cm_multidim.select(quality=(1, 20))
669+
670+
def test_select_coord_tied_to_dimension_affects_others(self, cm_multidim):
671+
"""
672+
Test that selecting a coord tied to a dimension affects other coords
673+
on that dim.
674+
"""
675+
# cm_multidim should have coordinates that share dimensions
676+
# Get a coordinate that's tied to a dimension and has other coords sharing
677+
# that dim
678+
lat = cm_multidim.get_array("latitude")
679+
lat_mean = np.mean(lat)
680+
out, _ = cm_multidim.select(latitude=(..., lat_mean))
681+
# Check that the new lat is what we expect.
682+
new_lat = out.get_array("latitude")
683+
expected = lat[lat <= lat_mean]
684+
assert np.array_equal(new_lat, expected)
685+
# And the other coord associated with that dimension have the same len.
686+
for name, coord in out.coord_map.items():
687+
coord_dims = out.dim_map[name]
688+
# Skip coords not tied to distance dimension
689+
if "distance" not in coord_dims:
690+
continue
691+
axis = coord_dims.index("distance")
692+
assert coord.shape[axis] == len(new_lat)
693+
694+
def test_select_nonexistent_coordinate_ignores_gracefully(self, cm_basic):
695+
"""Test that selecting on a non-existent coordinate is ignored gracefully."""
696+
# This tests line 122 in _get_indexers_and_new_coords_dict
697+
original_shape = cm_basic.shape
698+
out, _ = cm_basic.select(nonexistent_coord=(1, 10))
699+
700+
# Should return unchanged coordinate manager
701+
assert out == cm_basic
702+
assert out.shape == original_shape
703+
704+
# Should work with multiple nonexistent coordinates
705+
out2, _ = cm_basic.select(
706+
fake_coord1=(1, 2), fake_coord2=slice(0, 5), another_fake=(10, 20)
707+
)
708+
assert out2 == cm_basic
709+
assert out2.shape == original_shape
710+
711+
def test_select_mix_valid_invalid_coordinates(self, cm_basic):
712+
"""Test selecting with mix of valid and invalid coordinate names."""
713+
# This also exercises line 122 but with mixed scenarios
714+
time_vals = cm_basic.get_array("time")
715+
subset_time = (time_vals[1], time_vals[-2])
716+
717+
out, _ = cm_basic.select(
718+
time=subset_time, # valid coordinate
719+
nonexistent=(1, 10), # invalid coordinate - should be ignored
720+
fake_dim=slice(0, 5), # another invalid coordinate
721+
)
722+
723+
# Only the valid coordinate selection should have been applied
724+
assert (
725+
out.shape[cm_basic.get_axis("time")]
726+
< cm_basic.shape[cm_basic.get_axis("time")]
727+
)
728+
# Distance should be unchanged since it wasn't selected
729+
assert (
730+
out.shape[cm_basic.get_axis("distance")]
731+
== cm_basic.shape[cm_basic.get_axis("distance")]
732+
)
733+
624734

625735
class TestOrder:
626736
"""Tests for ordering coordinate managers."""

0 commit comments

Comments
 (0)