Skip to content

Commit b123328

Browse files
committed
2 parents c5c5e5e + 49c7503 commit b123328

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+727
-144
lines changed

.github/workflows/run_min_dep_tests.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@ on:
1212
- '**.py'
1313
- '.github/workflows/run_min_dep_tests.yml'
1414

15+
env:
16+
# Ensure matplotlib doesn't try to show figures in CI
17+
MPLBACKEND: Agg
18+
QT_QPA_PLATFORM: offscreen
19+
1520
# Cancel previous runs when this one starts.
1621
concurrency:
1722
group: TestCodeMinDeps-${{ github.event.pull_request.number || github.run_id }}

.github/workflows/runtests.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ on:
1515
env:
1616
# used to manually trigger cache reset. Just increment if needed.
1717
CACHE_NUMBER: 1
18+
# Ensure matplotlib doesn't try to show figures in CI
19+
MPLBACKEND: Agg
20+
QT_QPA_PLATFORM: offscreen
1821

1922
# Cancel previous runs when this one starts.
2023
concurrency:

dascore/constants.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,3 +201,14 @@ def map(self, func, iterables, **kwargs):
201201
"last": partial(np.take, indices=-1),
202202
}
203203
)
204+
205+
DIM_REDUCE_DOCS = """
206+
dim_reduce
207+
How to reduce the dimensional coordinate associated with the
208+
aggregated axis. Can be the name of any valid aggregator, a callable,
209+
"empty" (the default) which returns a length 1 partial coord, or
210+
"squeeze" which drops the coordinate. For dimensions with datetime
211+
or timedelta datatypes, if the operation fails it will automatically
212+
be applied to the coordinates converted to floats then the output
213+
converted back to the appropriate time type.
214+
"""

dascore/core/coordmanager.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ def _get_dimensional_sorts(coords2sort):
370370
dim = self.dim_map[name][0]
371371
new_coord, indexer = coord.sort(reverse=reverse)
372372
new_coords[name] = new_coord
373-
indexes.append(self._get_indexer(self.dims.index(dim), indexer))
373+
indexes.append(self._get_indexer(self.get_axis(dim), indexer))
374374
# also sort related coords.
375375
_sort_related(name, dim, indexer, new_coords)
376376
return new_coords, tuple(indexes)
@@ -721,6 +721,23 @@ def __rich__(self) -> str:
721721
out.append(text)
722722
return Text.assemble(*out)
723723

724+
def get_axis(self: Self, dim: str) -> int:
725+
"""
726+
Get the axis corresponding to a Patch dimension. Raise error if not found.
727+
728+
Parameters
729+
----------
730+
self
731+
The Patch object.
732+
dim
733+
The dimension name.
734+
"""
735+
try:
736+
return self.dims.index(dim)
737+
except (ValueError, IndexError):
738+
msg = f"Patch has no dimension: {dim}. Its dimensions are: {self.dims}"
739+
raise CoordError(msg)
740+
724741
def __str__(self):
725742
return str(self.__rich__())
726743

dascore/core/coords.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import dascore as dc
2323
from dascore.compat import array, is_array
24-
from dascore.constants import dascore_styles
24+
from dascore.constants import _AGG_FUNCS, DIM_REDUCE_DOCS, dascore_styles
2525
from dascore.exceptions import CoordError, ParameterError
2626
from dascore.units import (
2727
Quantity,
@@ -889,6 +889,46 @@ def change_length(self, length: int) -> Self:
889889
msg = f"Coordinate type {self.__class__} does not implement change_length"
890890
raise NotImplementedError(msg)
891891

892+
@compose_docstring(dim_reduce=DIM_REDUCE_DOCS)
893+
def reduce_coord(self, dim_reduce="empty"):
894+
"""
895+
Get a reduced coordinate.
896+
897+
This is used to get a coordinate after aggregating along a dimension.
898+
899+
Parameters
900+
----------
901+
{dim_reduce}
902+
"""
903+
904+
def _maybe_handle_datatypes(func, data):
905+
"""Maybe handle the complexity of date times here."""
906+
try: # First try function directly
907+
out = func(data)
908+
# Fall back to floats and re-packing.
909+
except (TypeError, ValueError, np.core._exceptions.UFuncTypeError):
910+
float_data = dc.to_float(data)
911+
dfunc = dc.to_datetime64 if is_datetime64(data) else dc.to_timedelta64
912+
out = dfunc(func(float_data))
913+
return np.atleast_1d(out)
914+
915+
if dim_reduce == "empty":
916+
new_coord = self.update(shape=(1,), start=None, stop=None, data=None)
917+
elif dim_reduce == "squeeze":
918+
return None
919+
elif (func := _AGG_FUNCS.get(dim_reduce)) or callable(dim_reduce):
920+
func = dim_reduce if callable(dim_reduce) else func
921+
coord_data = self.data
922+
if dtype_time_like(coord_data):
923+
result = _maybe_handle_datatypes(func, coord_data)
924+
else:
925+
result = func(self.data)
926+
new_coord = self.update(data=result)
927+
else:
928+
msg = "dim_reduce must be 'empty', 'squeeze' or valid aggregator."
929+
raise ParameterError(msg)
930+
return new_coord
931+
892932

893933
class CoordPartial(BaseCoord):
894934
"""

dascore/core/patch.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,8 +253,6 @@ def T(self): # noqa: N802
253253
# This isnt a great name but keeps the numpy tradition.
254254
return self.transpose()
255255

256-
# --- Patch ufuncs
257-
258256
# --- basic patch functionality.
259257

260258
update = dascore.proc.update
@@ -283,6 +281,7 @@ def T(self): # noqa: N802
283281
make_broadcastable_to = dascore.proc.make_broadcastable_to
284282
apply_ufunc = dascore.utils.array.apply_ufunc
285283
get_patch_names = get_patch_names
284+
get_axis = dascore.proc.get_axis
286285

287286
def get_patch_name(self, *args, **kwargs) -> str:
288287
"""
@@ -396,6 +395,9 @@ def iresample(self, *args, **kwargs):
396395
velocity_to_strain_rate_edgeless = transform.velocity_to_strain_rate_edgeless
397396
dispersion_phase_shift = transform.dispersion_phase_shift
398397
tau_p = transform.tau_p
398+
hilbert = transform.hilbert
399+
envelope = transform.envelope
400+
phase_weighted_stack = transform.phase_weighted_stack
399401

400402
# --- Method Namespaces
401403
# Note: these can't be cached_property (from functools) or references

dascore/examples.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,26 @@ def forge_dts():
352352
return dc.spool(path)[0]
353353

354354

355+
@register_func(EXAMPLE_PATCHES, key="nd_patch")
356+
def nd_patch(dim_count=3, coord_lens=10):
357+
"""
358+
Make an N dimensional Patch.
359+
360+
Parameters
361+
----------
362+
dim_count
363+
The number of dimensions.
364+
coord_lens
365+
The length of the coordinates.
366+
"""
367+
ran = np.random.RandomState(42)
368+
dims = tuple(f"dim_{i + 1}" for i in range(dim_count))
369+
coords = {d: np.arange(coord_lens) for d in dims}
370+
shape = tuple(len(coords[d]) for d in dims)
371+
data = ran.randn(*shape)
372+
return dc.Patch(data=data, coords=coords, dims=dims)
373+
374+
355375
@register_func(EXAMPLE_PATCHES, key="ricker_moveout")
356376
def ricker_moveout(
357377
frequency=15,
@@ -392,6 +412,7 @@ def ricker_moveout(
392412

393413
def _ricker(time, delay):
394414
# shift time vector to account for different peak times.
415+
delay = 0 if not np.isfinite(delay) else delay
395416
new_time = time - delay
396417
f = frequency
397418
# get amplitude and exp term of ricker
@@ -407,7 +428,10 @@ def _ricker(time, delay):
407428
# iterate each distance channel and update data
408429
for ind, dist in enumerate(distance):
409430
dist_to_source = np.abs(dist - source_distance)
410-
time_delay = peak_time + (dist_to_source / velocity)
431+
with np.errstate(divide="ignore", invalid="ignore"):
432+
shift = dist_to_source / velocity
433+
actual_shift = shift if np.isfinite(shift) else 0
434+
time_delay = peak_time + actual_shift
411435
data[:, ind] = _ricker(time, time_delay)
412436

413437
coords = {"time": to_timedelta64(time), "distance": distance}
@@ -508,7 +532,7 @@ def delta_patch(
508532
# Get data with ones centered on selected dimensions.
509533
shape = patch.shape
510534
index = tuple(
511-
shape[patch.dims.index(dimension)] // 2 if dimension in used_dims else 0
535+
shape[patch.get_axis(dimension)] // 2 if dimension in used_dims else 0
512536
for dimension in patch.dims
513537
)
514538
data = np.zeros_like(patch.data)

dascore/proc/aggregate.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,26 +6,19 @@
66

77
import numpy as np
88

9-
from dascore.constants import _AGG_FUNCS, PatchType
9+
from dascore.constants import _AGG_FUNCS, DIM_REDUCE_DOCS, PatchType
1010
from dascore.utils.array import _apply_aggregator
1111
from dascore.utils.docs import compose_docstring
1212
from dascore.utils.patch import patch_function
1313

14-
AGG_DOC_STR = """
14+
AGG_DOC_STR = f"""
1515
patch
1616
The input Patch.
1717
dim
1818
The dimension along which aggregations are to be performed.
1919
If None, apply aggregation to all dimensions sequentially.
2020
If a sequence, apply sequentially in order provided.
21-
dim_reduce
22-
How to reduce the dimensional coordinate associated with the
23-
aggregated axis. Can be the name of any valid aggregator, a callable,
24-
"empty" (the default) - which returns and empty coord, or "squeeze"
25-
which drops the coordinate. For dimensions with datetime or timedelta
26-
datatypes, if the operation fails it will automatically be applied
27-
to the coordinates converted to floats then the output converted back
28-
to the appropriate time type.
21+
{DIM_REDUCE_DOCS}
2922
"""
3023

3124
AGG_NOTES = """

dascore/proc/basic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ def normalize(
301301
max - divide each sample by the maximum of the absolute value of the axis.
302302
bit - sample-by-sample normalization (-1/+1)
303303
"""
304-
axis = self.dims.index(dim)
304+
axis = self.get_axis(dim)
305305
data = self.data
306306
if norm in {"l1", "l2"}:
307307
order = int(norm[-1])
@@ -361,7 +361,7 @@ def standardize(
361361
standardized_distance = patch.standardize('distance')
362362
```
363363
"""
364-
axis = self.dims.index(dim)
364+
axis = self.get_axis(dim)
365365
data = self.data
366366
mean = np.mean(data, axis=axis, keepdims=True)
367367
std = np.std(data, axis=axis, keepdims=True)
@@ -413,7 +413,7 @@ def dropna(
413413
>>> # drop all distance labels that have all null values
414414
>>> out = patch.dropna("distance", how="all")
415415
"""
416-
axis = patch.dims.index(dim)
416+
axis = patch.get_axis(dim)
417417
func = np.any if how == "any" else np.all
418418
if include_inf:
419419
to_drop = ~np.isfinite(patch.data)

dascore/proc/coords.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from dascore.core.coords import BaseCoord
1313
from dascore.exceptions import CoordError, ParameterError, PatchError
1414
from dascore.utils.docs import compose_docstring
15-
from dascore.utils.misc import get_parent_code_name
15+
from dascore.utils.misc import get_parent_code_name, iterate
1616
from dascore.utils.patch import patch_function
1717

1818

@@ -661,8 +661,11 @@ def squeeze(self: PatchType, dim=None) -> PatchType:
661661
If None, all length one dimensions are squeezed.
662662
"""
663663
coords = self.coords.squeeze(dim)
664-
axis = None if dim is None else self.coords.dims.index(dim)
665-
data = np.squeeze(self.data, axis=axis)
664+
if dim is None:
665+
axes = None
666+
else:
667+
axes = tuple(self.get_axis(x) for x in iterate(dim))
668+
data = np.squeeze(self.data, axis=axes)
666669
return self.new(data=data, coords=coords)
667670

668671

@@ -735,3 +738,24 @@ def add_distance_to(
735738
new_coords[f"{prefix}_distance"] = (dims, distance)
736739
out = patch.update_coords.func(patch, **new_coords)
737740
return out
741+
742+
743+
def get_axis(self: PatchType, dim: str) -> int:
744+
"""
745+
Get the axis corresponding to a Patch dimension. Raise error if not found.
746+
747+
Parameters
748+
----------
749+
self
750+
The Patch object.
751+
dim
752+
The dimension name.
753+
754+
Examples
755+
--------
756+
>>> import dascore as dc
757+
>>> patch = dc.get_example_patch()
758+
>>> axis = patch.get_axis("time")
759+
>>> assert axis == patch.get_axis("time")
760+
"""
761+
return self.coords.get_axis(dim)

0 commit comments

Comments
 (0)