Skip to content

Commit d75879e

Browse files
committed
add support for percent in taper and coord select
1 parent 5edff1a commit d75879e

File tree

8 files changed

+117
-53
lines changed

8 files changed

+117
-53
lines changed

dascore/compat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from contextlib import suppress
1111

1212
import numpy as np
13-
from numpy import floor, interp # NOQA
13+
from numpy import floor, interp, ndarray # NOQA
1414
from numpy.random import RandomState
1515
from rich.progress import Progress # NOQA
1616
from scipy.interpolate import interp1d # NOQA

dascore/core/coords.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
get_factor_and_unit,
3131
get_quantity,
3232
get_quantity_str,
33+
percent,
3334
)
3435
from dascore.utils.display import get_nice_text
3536
from dascore.utils.docs import compose_docstring
@@ -505,11 +506,26 @@ def simplify_units(self) -> Self:
505506
_, unit = get_factor_and_unit(self.units, simplify=True)
506507
return self.convert_units(unit)
507508

508-
def coord_range(self):
509-
"""Return a scaler value for the coordinate (e.g., number of seconds)."""
510-
if not self.evenly_sampled:
509+
def coord_range(self, exact: bool = True):
510+
"""
511+
Return a scaler value for the coordinate range (e.g., number of seconds).
512+
513+
Parameters
514+
----------
515+
exact
516+
If true, only exact ranges are returned. This accounts for
517+
spacing at the end of each sample. Consequently, exact is only
518+
possible for evenly sampled coords. If false, just disregard
519+
this if coord isnt't evenly sampled.
520+
521+
"""
522+
if not self.evenly_sampled and exact:
511523
raise CoordError("coord_range has to be called on an evenly sampled data.")
512-
return self.max() - self.min() + self.step
524+
step = getattr(self, "step", None)
525+
coord_range = self.max() - self.min()
526+
if step is not None:
527+
coord_range += step
528+
return coord_range
513529

514530
@abc.abstractmethod
515531
def sort(self, reverse=False) -> tuple[BaseCoord, slice | ArrayLike]:
@@ -599,7 +615,12 @@ def _get_compatible_value(self, value, relative=False):
599615
"""
600616
# strip units and v
601617
if hasattr(value, "units"):
602-
value = convert_units(value.magnitude, self.units, value.units)
618+
mag, unit = value.magnitude, value.units
619+
if unit == percent:
620+
value = (mag / 100.0) * self.coord_range(exact=False)
621+
relative = True
622+
else:
623+
value = convert_units(value.magnitude, self.units, value.units)
603624
# if null or ... just return None
604625
if not is_array(value) and (pd.isnull(value) or value is Ellipsis):
605626
return None

dascore/proc/basic.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def _fast_attr_update(self, attrs):
108108
return self.__class__(self.data, **out)
109109

110110

111-
def equals(self: PatchType, other: Any, only_required_attrs=True) -> bool:
111+
def equals(self: PatchType, other: Any, only_required_attrs=True, close=False) -> bool:
112112
"""
113113
Determine if the current patch equals another.
114114
@@ -119,6 +119,9 @@ def equals(self: PatchType, other: Any, only_required_attrs=True) -> bool:
119119
only_required_attrs
120120
If True, only compare required attributes. This helps avoid issues
121121
with comparing histories or custom attrs of patches, for example.
122+
close
123+
If True, the data can be "close" using np.allclose, otherwise
124+
all data must be equal.
122125
"""
123126
# different types are not equal
124127
if not isinstance(other, type(self)):
@@ -145,8 +148,12 @@ def equals(self: PatchType, other: Any, only_required_attrs=True) -> bool:
145148
}
146149
if not_equal:
147150
return False
148-
149-
return np.equal(self.data, other.data).all()
151+
# Test data equality or proximity.
152+
if close and not np.allclose(self.data, other.data):
153+
return False
154+
elif not close and not np.equal(self.data, other.data).all():
155+
return False
156+
return True
150157

151158

152159
def update(

dascore/proc/taper.py

Lines changed: 4 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -7,31 +7,16 @@
77
from operator import add
88

99
import numpy as np
10-
from scipy.signal import windows # the best operating system?
1110

1211
from dascore.constants import PatchType
1312
from dascore.exceptions import ParameterError
1413
from dascore.units import Quantity
1514
from dascore.utils.docs import compose_docstring
1615
from dascore.utils.misc import broadcast_for_index
1716
from dascore.utils.patch import get_dim_axis_value, patch_function
17+
from dascore.utils.signal import WINDOW_FUNCTIONS, _get_window_function
1818
from dascore.utils.time import to_float
1919

20-
TAPER_FUNCTIONS = dict(
21-
barthann=windows.barthann,
22-
bartlett=windows.bartlett,
23-
blackman=windows.blackman,
24-
blackmanharris=windows.blackmanharris,
25-
bohman=windows.bohman,
26-
hamming=windows.hamming,
27-
hann=windows.hann,
28-
cos=windows.hann,
29-
nuttall=windows.nuttall,
30-
parzen=windows.parzen,
31-
triang=windows.triang,
32-
ramp=windows.triang,
33-
)
34-
3520

3621
def _get_taper_slices(patch, kwargs):
3722
"""Get slice for start/end of patch."""
@@ -42,7 +27,7 @@ def _get_taper_slices(patch, kwargs):
4227
start, stop = value[0], value[1]
4328
else:
4429
start, stop = value, value
45-
dur = coord.max() - coord.min()
30+
dur = coord.coord_range(exact=False)
4631
# either let units pass through or multiply by d_len
4732
clses = (Quantity, np.timedelta64)
4833
start = start if isinstance(start, clses) or start is None else start * dur
@@ -53,19 +38,6 @@ def _get_taper_slices(patch, kwargs):
5338
return axis, (start, stop), inds_1, inds_2
5439

5540

56-
def _get_window_function(window_type):
57-
"""Get the window function to use for taper."""
58-
# get taper function or raise if it isn't known.
59-
if window_type not in TAPER_FUNCTIONS:
60-
msg = (
61-
f"'{window_type}' is not a known window type. "
62-
f"Options are: {sorted(TAPER_FUNCTIONS)}"
63-
)
64-
raise ParameterError(msg)
65-
func = TAPER_FUNCTIONS[window_type]
66-
return func
67-
68-
6941
def _validate_windows(samps, start_slice, end_slice, shape, axis):
7042
"""Validate the windows don't overlap or exceed dim len."""
7143
max_len = shape[axis]
@@ -87,7 +59,7 @@ def _validate_windows(samps, start_slice, end_slice, shape, axis):
8759

8860

8961
@patch_function()
90-
@compose_docstring(taper_type=sorted(TAPER_FUNCTIONS))
62+
@compose_docstring(taper_type=sorted(WINDOW_FUNCTIONS))
9163
def taper(
9264
patch: PatchType,
9365
window_type: str = "hann",
@@ -219,7 +191,7 @@ def _get_range_envelope(coord, inds, window_type, invert):
219191

220192

221193
@patch_function()
222-
@compose_docstring(taper_type=sorted(TAPER_FUNCTIONS))
194+
@compose_docstring(taper_type=sorted(WINDOW_FUNCTIONS))
223195
def taper_range(
224196
patch: PatchType,
225197
window_type: str = "hann",

dascore/utils/signal.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""
2+
Utilities for signal processing.
3+
"""
4+
5+
from scipy.signal import windows
6+
7+
from dascore.exceptions import ParameterError
8+
9+
WINDOW_FUNCTIONS = dict(
10+
barthann=windows.barthann,
11+
bartlett=windows.bartlett,
12+
blackman=windows.blackman,
13+
blackmanharris=windows.blackmanharris,
14+
bohman=windows.bohman,
15+
hamming=windows.hamming,
16+
hann=windows.hann,
17+
cos=windows.hann,
18+
nuttall=windows.nuttall,
19+
parzen=windows.parzen,
20+
triang=windows.triang,
21+
ramp=windows.triang,
22+
)
23+
24+
25+
def _get_window_function(window_type):
26+
"""Get the window function to use for taper."""
27+
# get taper function or raise if it isn't known.
28+
if window_type not in WINDOW_FUNCTIONS:
29+
msg = (
30+
f"'{window_type}' is not a known window type. "
31+
f"Options are: {sorted(WINDOW_FUNCTIONS)}"
32+
)
33+
raise ParameterError(msg)
34+
func = WINDOW_FUNCTIONS[window_type]
35+
return func

tests/test_core/test_coords.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
get_coord,
2626
)
2727
from dascore.exceptions import CoordError, ParameterError
28-
from dascore.units import get_quantity
28+
from dascore.units import get_quantity, percent
2929
from dascore.utils.misc import all_close, register_func
3030
from dascore.utils.time import dtype_time_like, is_datetime64, is_timedelta64, to_float
3131

@@ -357,9 +357,13 @@ def test_both_values_and_data_raises(self):
357357

358358
def test_coord_range(self, monotonic_float_coord):
359359
"""Ensure that coord_range raises an error for not evenly sampled patches."""
360+
coord = monotonic_float_coord
360361
msg = "has to be called on an evenly sampled"
361362
with pytest.raises(CoordError, match=msg):
362-
monotonic_float_coord.coord_range()
363+
coord.coord_range()
364+
# But when exact=False it should work.
365+
out = coord.coord_range(exact=False)
366+
assert out == (coord.max() - coord.min())
363367

364368
def test_get_coord_datetime(self):
365369
"""Ensure get_coord accepts a datetime object. See #467."""
@@ -745,6 +749,12 @@ def test_select_samples_no_int_raises(self, random_coord):
745749
with pytest.raises(ParameterError, match=expected):
746750
random_coord.select((1, 1.2), samples=True)
747751

752+
def test_percentage(self, coord):
753+
"""Ensure selecting by percentage works."""
754+
out, indexer = coord.select((10 * percent, -20 * percent))
755+
if coord.evenly_sampled:
756+
assert abs((len(out) / len(coord)) - 0.70) < len(coord) / 100.0
757+
748758

749759
class TestOrder:
750760
"""Tests for ordering coordinates."""

tests/test_core/test_patch.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,12 @@ def test_negative_equals(self, random_patch):
401401
assert -random_patch == -random_patch
402402
assert (-random_patch).abs() == random_patch
403403

404+
def test_close(self, random_patch):
405+
"""Test the `close` parameter"""
406+
new = random_patch * 0.999999999999999
407+
assert not new.equals(random_patch, close=False)
408+
assert new.equals(random_patch, close=True)
409+
404410

405411
class TestTranspose:
406412
"""Tests for switching dimensions."""

tests/test_proc/test_taper.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77

88
import dascore as dc
99
from dascore.exceptions import ParameterError
10-
from dascore.proc.taper import TAPER_FUNCTIONS, taper
11-
from dascore.units import m
10+
from dascore.proc.taper import taper
11+
from dascore.units import m, percent
1212
from dascore.utils.misc import broadcast_for_index
13+
from dascore.utils.signal import WINDOW_FUNCTIONS
1314

1415
gen = np.random.default_rng(32)
1516

@@ -21,7 +22,7 @@ def patch_ones(random_patch):
2122
return patch
2223

2324

24-
@pytest.fixture(scope="session", params=sorted(TAPER_FUNCTIONS))
25+
@pytest.fixture(scope="session", params=sorted(WINDOW_FUNCTIONS))
2526
def time_tapered_patch(request, patch_ones):
2627
"""Return a tapered trace."""
2728
# first get a patch with all ones for easy testing
@@ -39,6 +40,14 @@ def _get_start_end_indices(patch, dim):
3940
return inds_start, inds_end
4041

4142

43+
@pytest.fixture(scope="class")
44+
def patch_sorted_time(random_patch):
45+
"""Return a patch with sorted but not evenly spaced time dim."""
46+
times = gen.random(len(random_patch.get_coord("time")))
47+
new_times = dc.to_datetime64(np.sort(times))
48+
return random_patch.update_coords(time=new_times)
49+
50+
4251
class TestTaperBasics:
4352
"""Ensure each taper runs."""
4453

@@ -125,17 +134,21 @@ def test_timedelta_taper(self, random_patch):
125134
patch2 = random_patch.taper(time=time2)
126135
assert patch1 == patch2
127136

137+
def test_percentage_taper(self, patch_ones):
138+
"""Ensure a percentage unit can be used in addition to fraction."""
139+
out1 = patch_ones.taper(time=(0.1, 0.2))
140+
out2 = patch_ones.taper(time=(10 * percent, 20 * percent))
141+
assert out1.equals(out2, close=True)
142+
143+
def test_uneven_time_coord(self, patch_sorted_time):
144+
"""Ensure taper works on patches without even sampling."""
145+
out = patch_sorted_time.taper(time=(0.1, 0.2))
146+
assert isinstance(out, dc.Patch)
147+
128148

129149
class TestTaperRange:
130150
"""Test for tapering a range of values."""
131151

132-
@pytest.fixture(scope="class")
133-
def patch_sorted_time(self, random_patch):
134-
"""Return a patch with sorted but not evenly spaced time dim."""
135-
times = gen.random(len(random_patch.get_coord("time")))
136-
new_times = dc.to_datetime64(np.sort(times))
137-
return random_patch.update_coords(time=new_times)
138-
139152
def test_dims(self, patch_ones):
140153
"""Ensure both dimensions work."""
141154
for ax, dim in enumerate(patch_ones.dims):

0 commit comments

Comments
 (0)