Skip to content

Commit 4242dc7

Browse files
authored
allow chunking to drop non-dimensional coords (#532)
* allow chunking to drop non-dimensional coords * rabbit review
1 parent 0151905 commit 4242dc7

File tree

4 files changed

+85
-5
lines changed

4 files changed

+85
-5
lines changed

dascore/utils/coordmanager.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@ def merge_coord_managers(
1818
coord_managers: Sequence[dc.CoordManager],
1919
dim: str,
2020
snap_tolerance: float | None = None,
21+
drop_conflicting: bool = False,
2122
) -> dc.CoordManager:
2223
"""
23-
Merger coordinate managers along a specified dimension.
24+
Merge coordinate managers along a specified dimension.
2425
2526
Parameters
2627
----------
@@ -34,6 +35,9 @@ def merge_coord_managers(
3435
start/end to be joined together. If they don't meet this requirement
3536
an [CoordMergeError](`dascore.exceptions.CoordMergeError`) is raised.
3637
If None, no checks are performed.
38+
drop_conflicting
39+
If True, drop conflicting (non-dimensional) coordinates, otherwise
40+
raise an exception if they occur.
3741
"""
3842

3943
def _get_dims(managers):
@@ -68,9 +72,15 @@ def _get_non_merge_coords(managers, non_merger_names):
6872
dims = managers[0].dim_map[coord_name]
6973
out[coord_name] = (dims, first)
7074
continue
75+
# Simply skip conflicting
76+
elif drop_conflicting:
77+
# These are non dimensional coords
78+
if not any(coord_name in x.dims for x in managers):
79+
continue
7180
msg = (
7281
f"Non merging coordinates {coord_name} are not equal. "
73-
"Coordinate managers cannot be merged."
82+
"Coordinate managers cannot be merged. Try using "
83+
"spool.chunk with conflict='drop'."
7484
)
7585
raise CoordMergeError(msg)
7686
return out

dascore/utils/patch.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -392,9 +392,11 @@ def _maybe_step(df, dim):
392392
return get_middle_value(col)
393393
return None
394394

395-
def _get_new_coord(df, merge_dim, coords):
395+
def _get_new_coord(df, merge_dim, coords, drop_conflicting=False):
396396
"""Get new coordinates, also validate anticipated sampling."""
397-
new_coord = merge_coord_managers(coords, dim=merge_dim)
397+
new_coord = merge_coord_managers(
398+
coords, dim=merge_dim, drop_conflicting=drop_conflicting
399+
)
398400
expected_step = _maybe_step(df, merge_dim)
399401
if not pd.isnull(expected_step):
400402
new_coord = new_coord.snap(merge_dim)[0]
@@ -416,7 +418,10 @@ def _get_new_coord(df, merge_dim, coords):
416418
coords = [x.coords for x in patches]
417419
attrs = [x.attrs for x in patches]
418420
new_data = np.concatenate(datas, axis=axis)
419-
new_coord = _get_new_coord(df, merge_dim, coords)
421+
# Determine if conflicting non-dimensional coords should be dropped.
422+
conf = merge_kwargs.get("conflicts", None)
423+
drop_conf_coords = True if conf in {"drop", "keep_first"} else False
424+
new_coord = _get_new_coord(df, merge_dim, coords, drop_conf_coords)
420425
coord = new_coord.coord_map[merge_dim] if merge_dim in dims else None
421426
new_attrs = combine_patch_attrs(attrs, merge_dim, coord=coord, **merge_kwargs)
422427
patch = dc.Patch(data=new_data, coords=new_coord, attrs=new_attrs, dims=dims)

tests/test_core/test_patch_chunk.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,21 @@ def adjacent_spool_different_attrs(self, adjacent_spool_no_overlap):
270270
# since
271271
return dc.spool(out)
272272

273+
@pytest.fixture(scope="class")
274+
def patches_conflicting_private_coord(self, random_patch):
275+
"""Create two patches that have conflicting private coords."""
276+
dist_ax = random_patch.dims.index("distance")
277+
rand = np.random.RandomState(42)
278+
c1 = rand.random(random_patch.shape[dist_ax])
279+
c2 = rand.random(c1.shape)
280+
281+
time = random_patch.get_coord("time")
282+
p1 = random_patch.update_coords(_bad_coord=("distance", c1))
283+
p2 = random_patch.update_coords(
284+
_bad_coord=("distance", c2), time=time + time.coord_range()
285+
)
286+
return p1, p2
287+
273288
def test_merge_unequal_other(self, distance_adjacent):
274289
"""When distance values are not equal time shouldn't be merge-able."""
275290
with pytest.raises(CoordMergeError):
@@ -462,3 +477,24 @@ def test_chunk_patches_with_non_coord(self, random_patch):
462477
chunked = spool.chunk(time=None)
463478
# Since the time dims are NaN, this can't work.
464479
assert not len(chunked)
480+
481+
def test_merge_with_conflicting_private_coords(
482+
self,
483+
patches_conflicting_private_coord,
484+
):
485+
"""
486+
Private coords that conflict should be dropped and not block merge
487+
when conflict="drop".
488+
489+
Otherwise they should raise.
490+
"""
491+
p1, p2 = patches_conflicting_private_coord
492+
merged_spool = dc.spool([p1, p2]).chunk(time=None, conflict="drop")
493+
merge_patch = merged_spool[0]
494+
assert len(merged_spool) == 1
495+
# Since the private coords conflicted, they should have been dropped.
496+
coord_names = list(merge_patch.coords.coord_map)
497+
assert not any([x.startswith("_") for x in coord_names])
498+
# Without conflict drop this should raise.
499+
with pytest.raises(CoordMergeError, match="conflict"):
500+
dc.spool([p1, p2]).chunk(time=None)[0]

tests/test_utils/test_coordmanager_utils.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from __future__ import annotations
66

7+
import numpy as np
78
import pytest
89

910
from dascore.core.coords import CoordArray, CoordMonotonicArray, CoordRange
@@ -23,6 +24,21 @@ def _get_offset_coord_manager(self, cm, from_max=True, **kwargs):
2324
new, _ = cm.update_from_attrs({attr_name: start + value})
2425
return new
2526

27+
@pytest.fixture(scope="class")
28+
def conflicting_non_dim_coords(self, cm_basic):
29+
"""Get two coord managers with conflicting non-dimensional coordinates."""
30+
dist_ax = cm_basic.dims.index("distance")
31+
rand = np.random.RandomState(42)
32+
c1 = rand.random(cm_basic.shape[dist_ax])
33+
c2 = rand.random(c1.shape)
34+
35+
time = cm_basic.get_coord("time")
36+
cm1 = cm_basic.update_coords(_bad_coord=("distance", c1))
37+
cm2 = cm1.update_coords(
38+
_bad_coord=("distance", c2), time=time + time.coord_range()
39+
)
40+
return cm1, cm2
41+
2642
def test_merge_simple(self, cm_basic):
2743
"""Ensure we can merge simple, contiguous, coordinates together."""
2844
cm1 = cm_basic
@@ -136,3 +152,16 @@ def test_slightly_different_dt(self, cm_dt_small_diff):
136152
cm = cm_dt_small_diff
137153
coord = cm.coord_map["time"]
138154
assert coord.sorted
155+
156+
def test_conflicting_non_dimensional_coords(self, conflicting_non_dim_coords):
157+
"""
158+
Ensure conflicting non-dimensional coords can be merged if drop_conflict=True,
159+
Otherwise raise.
160+
"""
161+
c1, c2 = conflicting_non_dim_coords
162+
163+
out = merge_coord_managers([c1, c2], dim="time", drop_conflicting=True)
164+
assert not any([x.startswith("_") for x in out.coord_map])
165+
166+
with pytest.raises(CoordMergeError, match="cannot be merged"):
167+
merge_coord_managers([c1, c2], dim="time", drop_conflicting=False)

0 commit comments

Comments
 (0)