Skip to content

Commit ebdbecb

Browse files
authored
add type hints to patch functions, fix febus utils and trapz (#590)
* add type hints to patch functions * review * debug profile * try fix benchmarks * continue debug febus block time * more unpacking * fix trapz issue * remove trapz reference * global trap function
1 parent 1ebc98f commit ebdbecb

File tree

14 files changed

+67
-36
lines changed

14 files changed

+67
-36
lines changed

.github/workflows/profile.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,4 @@ jobs:
3333
with:
3434
mode: instrumentation
3535
run: ./.github/test_code.sh profile
36-
token: ${{ secrets.CODSPEED_TOKEN }} # Optional for public repos
36+
token: ${{ secrets.CODSPEED_TOKEN }}

dascore/io/febus/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@ def _get_block_time(feb):
3333
# Some files have this set. We haven't yet seen any files where this
3434
# values exists and is wrong, so we trust it (for now). This is probably
3535
# much faster than reading the whole time vector.
36-
br = feb.zone.attrs.get("BlockRate", 0) / 1_000
36+
br = _maybe_unpack(feb.zone.attrs.get("BlockRate", 0) / 1_000)
3737
if br > 0:
38-
return 1 / br
38+
return float(1 / br)
3939
# Otherwise we have to try to use the time vector. Here be dragons.
4040
time_shape = feb.source["time"].shape
4141
# Not sure why but time has the shape of [1, n] for some files and just
@@ -54,7 +54,7 @@ def _get_block_time(feb):
5454
# After removing outliers, the mean seems to work better than the median
5555
# for the test files we have. There is still a concerning amount of
5656
# variability.
57-
return np.mean(d_time)
57+
return float(_maybe_unpack(np.mean(d_time)))
5858

5959

6060
@cache

dascore/proc/coords.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323

2424
@patch_function()
25-
def snap_coords(patch: PatchType, *coords, reverse=False):
25+
def snap_coords(patch: PatchType, *coords, reverse: bool = False) -> PatchType:
2626
"""
2727
Snap coordinates to evenly spaced samples.
2828
@@ -59,7 +59,7 @@ def snap_coords(patch: PatchType, *coords, reverse=False):
5959

6060

6161
@patch_function()
62-
def sort_coords(patch: PatchType, *coords, reverse=False):
62+
def sort_coords(patch: PatchType, *coords, reverse: bool = False) -> PatchType:
6363
"""
6464
Sort one or more coordinates.
6565

dascore/proc/correlate.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ def _get_source_fft(patch, dim, source, source_axis, samples):
3535

3636

3737
@patch_function()
38-
def correlate_shift(patch, dim, undo_weighting=True):
38+
def correlate_shift(
39+
patch: PatchType, dim: str, undo_weighting: bool = True
40+
) -> PatchType:
3941
"""
4042
Apply a shift to the patch data to undo correlation in frequency domain.
4143
@@ -87,7 +89,7 @@ def correlate_shift(patch, dim, undo_weighting=True):
8789
@patch_function()
8890
def correlate(
8991
patch: PatchType,
90-
samples=False,
92+
samples: bool = False,
9193
lag=None,
9294
**kwargs,
9395
) -> PatchType:

dascore/proc/detrend.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,18 @@
22

33
from __future__ import annotations
44

5+
from typing import Literal
6+
57
from scipy.signal import detrend as scipy_detrend
68

79
from dascore.constants import PatchType
810
from dascore.utils.patch import patch_function
911

1012

1113
@patch_function()
12-
def detrend(patch: PatchType, dim, type="linear") -> PatchType:
14+
def detrend(
15+
patch: PatchType, dim: str, type: Literal["linear", "constant"] = "linear"
16+
) -> PatchType:
1317
"""
1418
Perform detrending along a given dimension (distance or time) of a patch.
1519

dascore/proc/filter.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,9 @@ def _get_sos(sr, filt_min, filt_max, corners):
9595

9696

9797
@patch_function()
98-
def pass_filter(patch: PatchType, corners=4, zerophase=True, **kwargs) -> PatchType:
98+
def pass_filter(
99+
patch: PatchType, corners: int = 4, zerophase: bool = True, **kwargs
100+
) -> PatchType:
99101
"""
100102
Apply a Butterworth pass filter (bandpass, highpass, or lowpass).
101103
@@ -146,7 +148,9 @@ def pass_filter(patch: PatchType, corners=4, zerophase=True, **kwargs) -> PatchT
146148

147149

148150
@patch_function()
149-
def sobel_filter(patch: PatchType, dim: str, mode="reflect", cval=0.0) -> PatchType:
151+
def sobel_filter(
152+
patch: PatchType, dim: str, mode: str = "reflect", cval: float | int = 0.0
153+
) -> PatchType:
150154
"""
151155
Apply a Sobel filter.
152156
@@ -200,7 +204,11 @@ def _create_size_and_axes(patch, kwargs, samples):
200204
@patch_function()
201205
@compose_docstring(sample_explanation=samples_arg_description)
202206
def median_filter(
203-
patch: PatchType, samples=False, mode="reflect", cval=0.0, **kwargs
207+
patch: PatchType,
208+
samples: bool = False,
209+
mode: str = "reflect",
210+
cval: float = 0.0,
211+
**kwargs,
204212
) -> PatchType:
205213
"""
206214
Apply 2-D median filter.
@@ -251,7 +259,7 @@ def median_filter(
251259

252260

253261
@patch_function()
254-
def notch_filter(patch: PatchType, q, **kwargs) -> PatchType:
262+
def notch_filter(patch: PatchType, q: float, **kwargs) -> PatchType:
255263
"""
256264
Apply a second-order IIR notch digital filter on patch's data.
257265
@@ -320,7 +328,12 @@ def notch_filter(patch: PatchType, q, **kwargs) -> PatchType:
320328
@patch_function()
321329
@compose_docstring(sample_explanation=samples_arg_description)
322330
def savgol_filter(
323-
patch: PatchType, polyorder, samples=False, mode="interp", cval=0.0, **kwargs
331+
patch: PatchType,
332+
polyorder: int,
333+
samples: bool = False,
334+
mode: str = "interp",
335+
cval: float = 0.0,
336+
**kwargs,
324337
) -> PatchType:
325338
"""
326339
Applies Savgol filter along spenfied dimensions.
@@ -382,7 +395,12 @@ def savgol_filter(
382395
@patch_function()
383396
@compose_docstring(sample_explanation=samples_arg_description)
384397
def gaussian_filter(
385-
patch: PatchType, samples=False, mode="reflect", cval=0.0, truncate=4.0, **kwargs
398+
patch: PatchType,
399+
samples: bool = False,
400+
mode: str = "reflect",
401+
cval: float = 0.0,
402+
truncate: float = 4.0,
403+
**kwargs,
386404
) -> PatchType:
387405
"""
388406
Applies a Gaussian filter along specified dimensions.

dascore/proc/hampel.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,10 @@ def hampel_filter(
9999
patch: PatchType,
100100
*,
101101
threshold: float = 10.0,
102-
samples=False,
103-
approximate=True,
102+
samples: bool = False,
103+
approximate: bool = True,
104104
**kwargs,
105-
):
105+
) -> PatchType:
106106
"""
107107
A Hampel filter implementation useful for removing spikes in data.
108108

dascore/proc/resample.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def _apply_scipy_decimation(patch, factor, ftype, axis):
4040
def decimate(
4141
patch: PatchType,
4242
filter_type: Literal["iir", "fir", None] = "iir",
43-
copy=True,
43+
copy: bool = True,
4444
**kwargs,
4545
) -> PatchType:
4646
"""
@@ -155,7 +155,11 @@ def interpolate(patch: PatchType, kind: str | int = "linear", **kwargs) -> Patch
155155

156156
@patch_function()
157157
def resample(
158-
patch: PatchType, window=None, interp_kind="linear", samples=False, **kwargs
158+
patch: PatchType,
159+
window=None,
160+
interp_kind: str = "linear",
161+
samples: bool = False,
162+
**kwargs,
159163
) -> PatchType:
160164
"""
161165
Resample along a single dimension using Fourier Method and interpolation.

dascore/proc/wiener.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
def wiener_filter(
1616
patch: PatchType,
1717
*,
18-
noise=None,
19-
samples=False,
18+
noise: float | None = None,
19+
samples: bool = False,
2020
**kwargs,
21-
):
21+
) -> PatchType:
2222
"""
2323
Apply a Wiener filter to reduce noise in the patch data.
2424

dascore/transform/integrate.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
patch_function,
1717
)
1818

19+
_TRAP_FUNC = getattr(np, "trapezoid" if hasattr(np, "trapezoid") else "trapz")
20+
1921

2022
def _quasi_mean(array):
2123
"""Get a quasi mean value from an array. Works with datetimes."""
@@ -45,12 +47,11 @@ def _get_new_coords_and_array(patch, array, dims):
4547
ndims = len(patch.shape)
4648
for dxs_or_val, ax in zip(dxs_or_vals, axes):
4749
# Numpy 2/3 compat code
48-
trap = getattr(np, "trapezoid", getattr(np, "trapz"))
4950
indexer = broadcast_for_index(ndims, ax, None, fill=slice(None))
5051
if is_array(dxs_or_val):
51-
array = trap(array, x=dxs_or_val, axis=ax)[indexer]
52+
array = _TRAP_FUNC(array, x=dxs_or_val, axis=ax)[indexer]
5253
else:
53-
array = trap(array, dx=dxs_or_val, axis=ax)[indexer]
54+
array = _TRAP_FUNC(array, dx=dxs_or_val, axis=ax)[indexer]
5455
array, coords = _get_new_coords_and_array(patch, array, dims)
5556
return array, coords
5657

0 commit comments

Comments
 (0)