Skip to content

Commit 9a01cf1

Browse files
committed
allow align_to_coord with non-overlapping traces
1 parent 5326c3e commit 9a01cf1

File tree

2 files changed

+41
-8
lines changed

2 files changed

+41
-8
lines changed

dascore/proc/align.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def _get_aligned_coords(patch, dim_name, meta):
189189
return patch.coords.update(**{dim_name: new_coord})
190190

191191

192-
def _get_shift_indices(coord_vals, dim, reverse, samples):
192+
def _get_shift_indices(coord_vals, dim, reverse, samples, mode):
193193
"""
194194
Get the indices for shifting.
195195
@@ -200,11 +200,19 @@ def _get_shift_indices(coord_vals, dim, reverse, samples):
200200
# Therefore, we need to insure positive values go into get_next_index.
201201
abs_vals = np.abs(coord_vals)
202202
sign_vals = np.sign(coord_vals)
203-
inds_abs = dim.get_next_index(
204-
abs_vals,
205-
samples=samples,
206-
relative=False if samples else True,
207-
)
203+
try:
204+
inds_abs = dim.get_next_index(
205+
abs_vals,
206+
samples=samples,
207+
relative=False if samples else True,
208+
allow_out_of_bounds=mode == "full",
209+
)
210+
except ValueError:
211+
msg = (
212+
f"Trace shift with align_to_coord results in some traces with no "
213+
f"overlaps. This is only possible with mode = 'full' not {mode}."
214+
)
215+
raise ParameterError(msg)
208216
# Reverse index if needed. This way the reference stays the same.
209217
inds = inds_abs * (sign_vals.astype(np.int64) * (-1 if reverse else 1))
210218
return inds
@@ -263,7 +271,8 @@ def align_to_coord(
263271
Determines the output shape of the patch. Options are:
264272
"full" - Regardless of shift, all original data are preserved.
265273
This can result in patches with many fill values along the
266-
aligned dimension.
274+
aligned dimension. It also allows cases where some traces are
275+
do not overlap at all.
267276
"same" - The patch will retain its shape, however, only one trace
268277
(and traces that weren't shifted) will remain complete. Parts
269278
of shifted traces will be discarded.
@@ -370,7 +379,7 @@ def align_to_coord(
370379
coord_dims = patch.coords.dim_map[coord_name]
371380
coord_axes = tuple(patch.dims.index(x) for x in coord_dims)
372381
# Get the metadata about shift and the indices for shifting.
373-
inds = _get_shift_indices(coord_vals, dim, reverse, samples)
382+
inds = _get_shift_indices(coord_vals, dim, reverse, samples, mode)
374383
meta = _calculate_shift_info(mode, inds, dim, dim_axis, patch.shape)
375384
# Apply shifts to data
376385
shifted_data = _apply_shifts_to_data(

tests/test_proc/test_align.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,17 @@ def patch_with_mixed_shifts(simple_patch_for_alignment):
113113
return patch.update_coords(shift_time_mixed=("distance", mixed_shifts))
114114

115115

116+
@pytest.fixture(scope="class")
117+
def patch_non_overlapping_shifts(random_patch):
118+
"""Patch with shifts that cause some traces to not overlap."""
119+
sub = (
120+
random_patch.select(distance=slice(0, 3), samples=True)
121+
.select(time=slice(0, 10), samples=True)
122+
.update_coords(shift=("distance", np.array([0, 8, 16])))
123+
)
124+
return sub
125+
126+
116127
class TestAlignToCoordValidation:
117128
"""Tests for align_to_coord validation logic."""
118129

@@ -375,3 +386,16 @@ def test_reverse_coordinate_accuracy(self, patch_with_known_shifts):
375386
assert np.isclose(
376387
reversed_patch.get_coord("time").step, original_time_coord.step
377388
)
389+
390+
def test_align_no_overlap(self, patch_non_overlapping_shifts):
391+
"""Test that a patch with shifts that cause traces to not overlap."""
392+
patch = patch_non_overlapping_shifts
393+
msg = "some traces with no overlaps"
394+
# Mode = valid and same should fail.
395+
with pytest.raises(ParameterError, match=msg):
396+
patch.align_to_coord(time="shift", mode="valid")
397+
with pytest.raises(ParameterError, match=msg):
398+
patch.align_to_coord(time="shift", mode="same")
399+
# But full should work.
400+
new = patch.align_to_coord(time="shift", mode="full")
401+
assert new.dropna("time").data.size == 0

0 commit comments

Comments
 (0)