-
Notifications
You must be signed in to change notification settings - Fork 28
Patch.where #550
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Patch.where #550
Conversation
WalkthroughAdds Changes
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (2)
🚧 Files skipped from review as they are similar to previous changes (1)
🧰 Additional context used🧬 Code graph analysis (1)dascore/proc/basic.py (2)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (16)
🔇 Additional comments (4)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (4)
tests/test_proc/test_basic.py (3)
603-716: Good coverage for core scenarios; add a misaligned-coords case.Consider a test where
cond(orother) is a Patch with shifted/partially overlapping coords so alignment (union) is required; this will catch shape/coord sync bugs.I can draft a test using a time-shifted boolean Patch if helpful.
649-664: Use model_dump to compare attrs (avoid relying on dict).More robust and future-proof than iterating
__dict__.Apply:
- # Check that attributes are preserved (except possibly history) - for key in random_patch.attrs.__dict__: - if key != "history": - assert getattr(result.attrs, key) == getattr(random_patch.attrs, key) + # Check that attributes are preserved (except history) + assert ( + result.attrs.model_dump(exclude={"history"}) + == random_patch.attrs.model_dump(exclude={"history"}) + )
702-716: Also assert replacements match broadcastedotherPatch.Strengthen the test by checking the False-mask against
otherafter broadcasting.Apply:
for castable in [broadcastable_patch1, broadcastable_patch2]: result = random_patch.where(condition, other=castable) assert result.shape == random_patch.shape # Check that values where condition is True are preserved assert np.allclose(result.data[condition], random_patch.data[condition]) + # And where False they equal broadcasted `other` + bcast_other = castable.make_broadcastable_to(random_patch) + false_mask = ~condition + assert np.allclose(result.data[false_mask], bcast_other.data[false_mask])dascore/proc/basic.py (1)
653-669: Fix docstring example typos and clarity.Minor syntax error and variable naming issue in the examples.
Apply:
- >>> # Use another patch as condition - >>> other = patch.data.mean()).astype(bool) - >>> boolean_patch = patch.new(data=(patch.data > other)) + >>> # Use another patch as condition + >>> threshold = patch.data.mean() + >>> boolean_patch = patch.new(data=(patch.data > threshold))
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
dascore/core/patch.py(1 hunks)dascore/proc/basic.py(2 hunks)tests/test_proc/test_basic.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
dascore/proc/basic.py (2)
dascore/utils/patch.py (2)
align_patch_coords(820-870)patch_function(180-286)dascore/core/patch.py (2)
Patch(28-443)data(237-239)
tests/test_proc/test_basic.py (2)
dascore/core/patch.py (6)
data(237-239)shape(242-244)coords(232-234)dims(212-214)attrs(227-229)dtype(252-254)dascore/proc/basic.py (1)
where(630-689)
dascore/core/patch.py (1)
dascore/proc/basic.py (1)
where(630-689)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (17)
- GitHub Check: test_code (windows-latest, 3.11)
- GitHub Check: test_code (windows-latest, 3.12)
- GitHub Check: test_code (macos-latest, 3.13)
- GitHub Check: test_code (ubuntu-latest, 3.13)
- GitHub Check: test_code (windows-latest, 3.10)
- GitHub Check: test_code (windows-latest, 3.13)
- GitHub Check: test_code (macos-latest, 3.12)
- GitHub Check: test_code (macos-latest, 3.10)
- GitHub Check: test_code (ubuntu-latest, 3.12)
- GitHub Check: test_code (ubuntu-latest, 3.11)
- GitHub Check: test_code (ubuntu-latest, 3.10)
- GitHub Check: test_code_min_deps (windows-latest, 3.12)
- GitHub Check: test_code_min_deps (windows-latest, 3.13)
- GitHub Check: test_code_min_deps (ubuntu-latest, 3.12)
- GitHub Check: test_code_min_deps (ubuntu-latest, 3.13)
- GitHub Check: test_code_min_deps (macos-latest, 3.13)
- GitHub Check: test_code_min_deps (macos-latest, 3.12)
🔇 Additional comments (2)
dascore/proc/basic.py (1)
22-25: Import ofalign_patch_coordsis appropriate.dascore/core/patch.py (1)
358-358: ExposePatch.where— package-level re-export confirmed.
dascore/proc/init.py doesfrom .basic import *anddascore/proc/basic.pydefinesdef where(...)(around line 630), so thewhere = dascore.proc.wherealias is valid.
dascore/proc/basic.py
Outdated
| def _get_array(possible_array): | ||
| """Get array from patch or array.""" | ||
| if isinstance(possible_array, dc.Patch): | ||
| _, aligned = align_patch_coords(patch, possible_array) | ||
| out = aligned.data | ||
| else: | ||
| out = possible_array | ||
| return np.asarray(out) | ||
|
|
||
| cond_array = _get_array(cond) | ||
| # Ensure condition is boolean | ||
| if not np.issubdtype(cond_array.dtype, np.bool_): | ||
| msg = "Condition must be a boolean array or patch with boolean data" | ||
| raise ValueError(msg) | ||
|
|
||
| other_array = _get_array(other) | ||
| # Use numpy.where to apply condition | ||
| new_data = np.where(cond_array, patch.data, other_array) | ||
| return patch.new(data=new_data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
❓ Verification inconclusive
Alignment bug: ignoring the aligned base patch can mis-shape arrays.
When cond or other is a Patch with different/partially overlapping coords, align_patch_coords may change the base Patch’s shape (union). Current code discards the aligned base, risking shape mismatch in np.where.
Apply:
@patch_function()
def where(
patch: PatchType, cond: ArrayLike | PatchType, other: Any | PatchType = np.nan
) -> PatchType:
@@
- def _get_array(possible_array):
- """Get array from patch or array."""
- if isinstance(possible_array, dc.Patch):
- _, aligned = align_patch_coords(patch, possible_array)
- out = aligned.data
- else:
- out = possible_array
- return np.asarray(out)
-
- cond_array = _get_array(cond)
+ base = patch
+
+ def _as_array(x):
+ return np.asarray(x)
+
+ # Align cond to base if it's a Patch
+ cond_is_patch = isinstance(cond, dc.Patch)
+ other_is_patch = isinstance(other, dc.Patch)
+
+ if cond_is_patch:
+ base, cond = align_patch_coords(base, cond)
+ cond_array = cond.data
+ else:
+ cond_array = _as_array(cond)
@@
- other_array = _get_array(other)
+ # Align other to (possibly updated) base if it's a Patch
+ if other_is_patch:
+ base, other = align_patch_coords(base, other)
+ other_array = other.data
+ else:
+ other_array = _as_array(other)
+
+ # If base changed after aligning `other`, re-align cond if it was a Patch.
+ if cond_is_patch and cond_array.shape != base.data.shape:
+ base, cond = align_patch_coords(base, cond)
+ cond_array = cond.data
@@
- new_data = np.where(cond_array, patch.data, other_array)
- return patch.new(data=new_data)
+ new_data = np.where(cond_array, base.data, other_array)
+ return base.new(data=new_data)Add a test where cond is a boolean Patch with a shifted coord (e.g., time offset by one step) to ensure shapes/coords are union-aligned and where succeeds.
Alignment bug — update base when aligning cond/other Patches
align_patch_coords can return an updated base with unioned coords; the current code discards that base, which can cause shape mismatches and np.where errors.
@patch_function()
def where(
patch: PatchType, cond: ArrayLike | PatchType, other: Any | PatchType = np.nan
) -> PatchType:
@@
- def _get_array(possible_array):
- """Get array from patch or array."""
- if isinstance(possible_array, dc.Patch):
- _, aligned = align_patch_coords(patch, possible_array)
- out = aligned.data
- else:
- out = possible_array
- return np.asarray(out)
-
- cond_array = _get_array(cond)
+ base = patch
+
+ def _as_array(x):
+ return np.asarray(x)
+
+ # Align cond to base if it's a Patch
+ cond_is_patch = isinstance(cond, dc.Patch)
+ other_is_patch = isinstance(other, dc.Patch)
+
+ if cond_is_patch:
+ base, cond = align_patch_coords(base, cond)
+ cond_array = cond.data
+ else:
+ cond_array = _as_array(cond)
@@
- other_array = _get_array(other)
+ # Align other to (possibly updated) base if it's a Patch
+ if other_is_patch:
+ base, other = align_patch_coords(base, other)
+ other_array = other.data
+ else:
+ other_array = _as_array(other)
+
+ # If base changed after aligning `other`, re-align cond if it was a Patch.
+ if cond_is_patch and cond_array.shape != base.data.shape:
+ base, cond = align_patch_coords(base, cond)
+ cond_array = cond.data
@@
- new_data = np.where(cond_array, patch.data, other_array)
- return patch.new(data=new_data)
+ new_data = np.where(cond_array, base.data, other_array)
+ return base.new(data=new_data)Add a test where cond is a boolean Patch with a shifted coord (e.g., time offset by one step) to ensure union-alignment and that where succeeds.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def _get_array(possible_array): | |
| """Get array from patch or array.""" | |
| if isinstance(possible_array, dc.Patch): | |
| _, aligned = align_patch_coords(patch, possible_array) | |
| out = aligned.data | |
| else: | |
| out = possible_array | |
| return np.asarray(out) | |
| cond_array = _get_array(cond) | |
| # Ensure condition is boolean | |
| if not np.issubdtype(cond_array.dtype, np.bool_): | |
| msg = "Condition must be a boolean array or patch with boolean data" | |
| raise ValueError(msg) | |
| other_array = _get_array(other) | |
| # Use numpy.where to apply condition | |
| new_data = np.where(cond_array, patch.data, other_array) | |
| return patch.new(data=new_data) | |
| base = patch | |
| def _as_array(x): | |
| return np.asarray(x) | |
| # Align cond to base if it's a Patch | |
| cond_is_patch = isinstance(cond, dc.Patch) | |
| other_is_patch = isinstance(other, dc.Patch) | |
| if cond_is_patch: | |
| base, cond = align_patch_coords(base, cond) | |
| cond_array = cond.data | |
| else: | |
| cond_array = _as_array(cond) | |
| # Ensure condition is boolean | |
| if not np.issubdtype(cond_array.dtype, np.bool_): | |
| msg = "Condition must be a boolean array or patch with boolean data" | |
| raise ValueError(msg) | |
| # Align other to (possibly updated) base if it's a Patch | |
| if other_is_patch: | |
| base, other = align_patch_coords(base, other) | |
| other_array = other.data | |
| else: | |
| other_array = _as_array(other) | |
| # If base changed after aligning `other`, re-align cond if it was a Patch. | |
| if cond_is_patch and cond_array.shape != base.data.shape: | |
| base, cond = align_patch_coords(base, cond) | |
| cond_array = cond.data | |
| # Use numpy.where to apply condition | |
| new_data = np.where(cond_array, base.data, other_array) | |
| return base.new(data=new_data) |
|
Cool! it's very useful! |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #550 +/- ##
=======================================
Coverage 99.92% 99.92%
=======================================
Files 126 126
Lines 10455 10472 +17
=======================================
+ Hits 10447 10464 +17
Misses 8 8
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Description
This PR adds the Patch.where method, which behaves like numpy or xarray's where.
Checklist
I have (if applicable):
Summary by CodeRabbit
New Features
Tests