Skip to content

Commit b115a76

Browse files
authored
Patch.where (#550)
* add patch.where * fix doctest * mind the rabbit
1 parent 04cd9f8 commit b115a76

File tree

3 files changed

+240
-0
lines changed

3 files changed

+240
-0
lines changed

dascore/core/patch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,7 @@ def iselect(self, *args, **kwargs):
421421
resample = dascore.proc.resample
422422
pad = dascore.proc.pad
423423
roll = dascore.proc.roll
424+
where = dascore.proc.where
424425

425426
@deprecate(
426427
"patch.iresample is deprecated. Please use patch.resample " "with samples=True",

dascore/proc/basic.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from scipy.fft import next_fast_len
1111

1212
import dascore as dc
13+
from dascore.compat import array
1314
from dascore.constants import PatchType
1415
from dascore.core.attrs import PatchAttrs
1516
from dascore.core.coordmanager import CoordManager, get_coord_manager
@@ -19,6 +20,7 @@
1920
from dascore.utils.misc import _get_nullish
2021
from dascore.utils.models import ArrayLike
2122
from dascore.utils.patch import (
23+
align_patch_coords,
2224
get_dim_axis_value,
2325
patch_function,
2426
)
@@ -688,3 +690,66 @@ def roll(patch, samples=False, update_coord=False, **kwargs):
688690
patch = patch.update_coords(**{dim: new_coord})
689691

690692
return patch.new(data=roll_arr)
693+
694+
695+
@patch_function()
696+
def where(
697+
patch: PatchType, cond: ArrayLike | PatchType, other: Any | PatchType = np.nan
698+
) -> PatchType:
699+
"""
700+
Return elements from patch where condition is True, else fill with other.
701+
702+
Parameters
703+
----------
704+
patch
705+
The input patch
706+
cond
707+
Condition array. Should be a boolean array with the same shape as patch data,
708+
or a patch with boolean data that is broadcastable to the patch's shape.
709+
other
710+
Value to use for locations where cond is False. Can be a scalar value,
711+
array, or patch that is broadcastable to the patch's shape. Default is NaN.
712+
713+
Returns
714+
-------
715+
PatchType
716+
A new patch with values from patch where cond is True, and other elsewhere.
717+
718+
Examples
719+
--------
720+
>>> import dascore as dc
721+
>>> import numpy as np
722+
>>> patch = dc.get_example_patch()
723+
>>>
724+
>>> # Where data > 0 fill with original patch values else nan.
725+
>>> condition = patch.data > 0
726+
>>> out = patch.where(condition)
727+
>>>
728+
>>> # Use another patch as condition
729+
>>> threshold = patch.data.mean()
730+
>>> boolean_patch = patch.new(data=(patch.data > threshold))
731+
>>> out = patch.where(boolean_patch, other=0)
732+
>>>
733+
>>> # Replace values below threshold with 0
734+
>>> out = patch.where(patch.data > patch.data.mean(), other=0)
735+
"""
736+
cls = patch.__class__ # Use this so it works with subclasses
737+
# Align patch and cond
738+
if isinstance(cond, cls):
739+
patch, cond = align_patch_coords(patch, cond)
740+
# Align patch and other, may need to re-align cond
741+
if isinstance(other, cls):
742+
patch, other = align_patch_coords(patch, other)
743+
if isinstance(cond, cls):
744+
patch, cond = align_patch_coords(patch, cond)
745+
746+
cond_array, other_array = array(cond), array(other)
747+
748+
# Ensure condition is boolean
749+
if not np.issubdtype(cond_array.dtype, np.bool_):
750+
msg = "Condition must be a boolean array or patch with boolean data"
751+
raise ValueError(msg)
752+
753+
# Use numpy.where to apply condition
754+
new_data = np.where(cond_array, patch.data, other_array)
755+
return patch.new(data=new_data)

tests/test_proc/test_basic.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -598,3 +598,177 @@ def test_dist_coord_roll(self, random_patch):
598598
random_patch.coords.get_array("distance")[0]
599599
== rolled_patch.coords.get_array("distance")[value]
600600
)
601+
602+
603+
class TestWhere:
604+
"""Tests for the where method of Patch."""
605+
606+
def test_where_with_boolean_array(self, random_patch):
607+
"""Test where with a boolean array condition."""
608+
condition = random_patch.data > random_patch.data.mean()
609+
result = random_patch.where(condition)
610+
611+
# Check that the result has the same shape
612+
assert result.shape == random_patch.shape
613+
614+
# Check that values where condition is True are preserved
615+
assert np.allclose(result.data[condition], random_patch.data[condition])
616+
617+
# Check that values where condition is False are NaN
618+
assert np.all(np.isnan(result.data[~condition]))
619+
620+
def test_where_with_other_value(self, random_patch):
621+
"""Test where with a replacement value."""
622+
condition = random_patch.data > 0
623+
other_value = -999
624+
result = random_patch.where(condition, other=other_value)
625+
626+
# Check that values where condition is True are preserved
627+
assert np.allclose(result.data[condition], random_patch.data[condition])
628+
629+
# Check that values where condition is False are replaced
630+
assert np.all(result.data[~condition] == other_value)
631+
632+
def test_where_with_patch_condition(self, random_patch):
633+
"""Test where with another patch as condition."""
634+
boolean_data = (random_patch.data > random_patch.data.mean()).astype(bool)
635+
boolean_patch = random_patch.new(data=boolean_data)
636+
result = random_patch.where(boolean_patch, other=0)
637+
638+
# Check that the result has the same shape
639+
assert result.shape == random_patch.shape
640+
641+
# Check that values where condition is True are preserved
642+
true_mask = boolean_data
643+
assert np.allclose(result.data[true_mask], random_patch.data[true_mask])
644+
645+
# Check that values where condition is False are 0
646+
false_mask = ~boolean_data
647+
assert np.all(result.data[false_mask] == 0)
648+
649+
def test_where_preserves_metadata(self, random_patch):
650+
"""Test that where preserves patch metadata."""
651+
condition = random_patch.data > 0
652+
result = random_patch.where(condition)
653+
654+
# Check that coordinates are preserved
655+
assert result.coords == random_patch.coords
656+
657+
# Check that dimensions are preserved
658+
assert result.dims == random_patch.dims
659+
660+
# Check that attributes are preserved (except history)
661+
assert result.attrs.model_dump(
662+
exclude={"history"}
663+
) == random_patch.attrs.model_dump(exclude={"history"})
664+
665+
def test_where_non_boolean_condition_raises(self, random_patch):
666+
"""Test that non-boolean condition raises ValueError."""
667+
non_boolean_condition = random_patch.data # Not boolean
668+
669+
with pytest.raises(ValueError, match="Condition must be a boolean array"):
670+
random_patch.where(non_boolean_condition)
671+
672+
def test_where_broadcasts_condition(self, random_patch):
673+
"""Test that condition can be broadcast to patch shape."""
674+
# Create a condition that can be broadcast to the full shape
675+
# Create a boolean array that matches the first dimension
676+
condition = np.ones(random_patch.shape[0], dtype=bool)
677+
condition[0] = False # Make first element False
678+
679+
# This should broadcast across the second dimension
680+
result = random_patch.where(condition[:, np.newaxis], other=-1)
681+
assert result.shape == random_patch.shape
682+
683+
# Check that first row is all -1 and others are preserved
684+
assert np.all(result.data[0, :] == -1)
685+
assert np.allclose(result.data[1:, :], random_patch.data[1:, :])
686+
687+
def test_where_with_broadcastable_patch_other(self, random_patch):
688+
"""Test where with a broadcastable patch as other parameter."""
689+
# Get the actual dimensions of the patch to create the right broadcasting
690+
broadcastable_patch1 = random_patch.mean("distance").squeeze()
691+
broadcastable_patch2 = random_patch.mean("time").squeeze()
692+
693+
# Create condition
694+
condition = random_patch.data > random_patch.data.mean()
695+
696+
for castable in [broadcastable_patch1, broadcastable_patch2]:
697+
result = random_patch.where(condition, other=castable)
698+
assert result.shape == random_patch.shape
699+
# Check that values where condition is True are preserved
700+
assert np.allclose(result.data[condition], random_patch.data[condition])
701+
# Check that values where condition is False come from the broadcasted other
702+
false_mask = ~condition
703+
# The exact values depend on how the broadcasting worked
704+
assert np.all(
705+
~np.isnan(result.data[false_mask])
706+
) # Should have valid values
707+
708+
def test_where_with_misaligned_coords(self, random_patch):
709+
"""Test where with condition patch having misaligned coordinates."""
710+
# Create a subset of the original patch with partial overlap
711+
time_coord = random_patch.coords.get_array("time")
712+
# Take only part of the time coordinates to create a partial overlap
713+
partial_time = time_coord[10:20] # Use a subset
714+
715+
# Create a boolean condition patch with partial time coordinates
716+
shifted_patch = random_patch.new(
717+
coords={
718+
"time": partial_time,
719+
"distance": random_patch.coords.get_array("distance"),
720+
},
721+
data=(random_patch.data[:, 10:20] > random_patch.data[:, 10:20].mean()),
722+
)
723+
724+
# This should work with coordinate alignment (union)
725+
result = random_patch.where(shifted_patch, other=0)
726+
727+
# The result should have the union of coordinates and correct shape
728+
assert result is not None
729+
assert isinstance(result.data, np.ndarray)
730+
# After alignment, coords should have the overlapping range
731+
result_time = result.coords.get_array("time")
732+
partial_time_len = len(partial_time)
733+
assert len(result_time) == partial_time_len
734+
735+
def test_where_both_cond_and_other_misaligned(self, random_patch):
736+
"""Test where with both condition and other patches having misaligned coords."""
737+
# Create condition patch with partial time overlap (first part)
738+
time_coord = random_patch.coords.get_array("time")
739+
cond_time = time_coord[5:15] # indices 5-14
740+
741+
condition_patch = random_patch.new(
742+
coords={
743+
"time": cond_time,
744+
"distance": random_patch.coords.get_array("distance"),
745+
},
746+
data=(random_patch.data[:, 5:15] > random_patch.data[:, 5:15].mean()),
747+
)
748+
749+
# Create other patch with different partial time overlap (shifted range)
750+
other_time = time_coord[8:18] # indices 8-17, overlaps with condition
751+
other_patch = random_patch.new(
752+
coords={
753+
"time": other_time,
754+
"distance": random_patch.coords.get_array("distance"),
755+
},
756+
data=random_patch.data[:, 8:18] * 0.5, # Use different values
757+
)
758+
759+
# This should work with coordinate alignment handling both patches
760+
result = random_patch.where(condition_patch, other=other_patch)
761+
762+
# The result should have coordinates that are the intersection of all three
763+
assert result is not None
764+
assert isinstance(result.data, np.ndarray)
765+
766+
# After alignment, the time coordinate should be the intersection
767+
result_time = result.coords.get_array("time")
768+
# The intersection of [5:15], [8:18], and full range should be [8:15]
769+
expected_overlap_len = 7 # indices 8, 9, 10, 11, 12, 13, 14
770+
assert len(result_time) == expected_overlap_len
771+
772+
# Verify the actual time values match the expected overlap
773+
expected_time_values = time_coord[8:15]
774+
assert np.array_equal(result_time, expected_time_values)

0 commit comments

Comments
 (0)