@@ -598,3 +598,177 @@ def test_dist_coord_roll(self, random_patch):
598598 random_patch .coords .get_array ("distance" )[0 ]
599599 == rolled_patch .coords .get_array ("distance" )[value ]
600600 )
601+
602+
603+ class TestWhere :
604+ """Tests for the where method of Patch."""
605+
606+ def test_where_with_boolean_array (self , random_patch ):
607+ """Test where with a boolean array condition."""
608+ condition = random_patch .data > random_patch .data .mean ()
609+ result = random_patch .where (condition )
610+
611+ # Check that the result has the same shape
612+ assert result .shape == random_patch .shape
613+
614+ # Check that values where condition is True are preserved
615+ assert np .allclose (result .data [condition ], random_patch .data [condition ])
616+
617+ # Check that values where condition is False are NaN
618+ assert np .all (np .isnan (result .data [~ condition ]))
619+
620+ def test_where_with_other_value (self , random_patch ):
621+ """Test where with a replacement value."""
622+ condition = random_patch .data > 0
623+ other_value = - 999
624+ result = random_patch .where (condition , other = other_value )
625+
626+ # Check that values where condition is True are preserved
627+ assert np .allclose (result .data [condition ], random_patch .data [condition ])
628+
629+ # Check that values where condition is False are replaced
630+ assert np .all (result .data [~ condition ] == other_value )
631+
632+ def test_where_with_patch_condition (self , random_patch ):
633+ """Test where with another patch as condition."""
634+ boolean_data = (random_patch .data > random_patch .data .mean ()).astype (bool )
635+ boolean_patch = random_patch .new (data = boolean_data )
636+ result = random_patch .where (boolean_patch , other = 0 )
637+
638+ # Check that the result has the same shape
639+ assert result .shape == random_patch .shape
640+
641+ # Check that values where condition is True are preserved
642+ true_mask = boolean_data
643+ assert np .allclose (result .data [true_mask ], random_patch .data [true_mask ])
644+
645+ # Check that values where condition is False are 0
646+ false_mask = ~ boolean_data
647+ assert np .all (result .data [false_mask ] == 0 )
648+
649+ def test_where_preserves_metadata (self , random_patch ):
650+ """Test that where preserves patch metadata."""
651+ condition = random_patch .data > 0
652+ result = random_patch .where (condition )
653+
654+ # Check that coordinates are preserved
655+ assert result .coords == random_patch .coords
656+
657+ # Check that dimensions are preserved
658+ assert result .dims == random_patch .dims
659+
660+ # Check that attributes are preserved (except history)
661+ assert result .attrs .model_dump (
662+ exclude = {"history" }
663+ ) == random_patch .attrs .model_dump (exclude = {"history" })
664+
665+ def test_where_non_boolean_condition_raises (self , random_patch ):
666+ """Test that non-boolean condition raises ValueError."""
667+ non_boolean_condition = random_patch .data # Not boolean
668+
669+ with pytest .raises (ValueError , match = "Condition must be a boolean array" ):
670+ random_patch .where (non_boolean_condition )
671+
672+ def test_where_broadcasts_condition (self , random_patch ):
673+ """Test that condition can be broadcast to patch shape."""
674+ # Create a condition that can be broadcast to the full shape
675+ # Create a boolean array that matches the first dimension
676+ condition = np .ones (random_patch .shape [0 ], dtype = bool )
677+ condition [0 ] = False # Make first element False
678+
679+ # This should broadcast across the second dimension
680+ result = random_patch .where (condition [:, np .newaxis ], other = - 1 )
681+ assert result .shape == random_patch .shape
682+
683+ # Check that first row is all -1 and others are preserved
684+ assert np .all (result .data [0 , :] == - 1 )
685+ assert np .allclose (result .data [1 :, :], random_patch .data [1 :, :])
686+
687+ def test_where_with_broadcastable_patch_other (self , random_patch ):
688+ """Test where with a broadcastable patch as other parameter."""
689+ # Get the actual dimensions of the patch to create the right broadcasting
690+ broadcastable_patch1 = random_patch .mean ("distance" ).squeeze ()
691+ broadcastable_patch2 = random_patch .mean ("time" ).squeeze ()
692+
693+ # Create condition
694+ condition = random_patch .data > random_patch .data .mean ()
695+
696+ for castable in [broadcastable_patch1 , broadcastable_patch2 ]:
697+ result = random_patch .where (condition , other = castable )
698+ assert result .shape == random_patch .shape
699+ # Check that values where condition is True are preserved
700+ assert np .allclose (result .data [condition ], random_patch .data [condition ])
701+ # Check that values where condition is False come from the broadcasted other
702+ false_mask = ~ condition
703+ # The exact values depend on how the broadcasting worked
704+ assert np .all (
705+ ~ np .isnan (result .data [false_mask ])
706+ ) # Should have valid values
707+
708+ def test_where_with_misaligned_coords (self , random_patch ):
709+ """Test where with condition patch having misaligned coordinates."""
710+ # Create a subset of the original patch with partial overlap
711+ time_coord = random_patch .coords .get_array ("time" )
712+ # Take only part of the time coordinates to create a partial overlap
713+ partial_time = time_coord [10 :20 ] # Use a subset
714+
715+ # Create a boolean condition patch with partial time coordinates
716+ shifted_patch = random_patch .new (
717+ coords = {
718+ "time" : partial_time ,
719+ "distance" : random_patch .coords .get_array ("distance" ),
720+ },
721+ data = (random_patch .data [:, 10 :20 ] > random_patch .data [:, 10 :20 ].mean ()),
722+ )
723+
724+ # This should work with coordinate alignment (union)
725+ result = random_patch .where (shifted_patch , other = 0 )
726+
727+ # The result should have the union of coordinates and correct shape
728+ assert result is not None
729+ assert isinstance (result .data , np .ndarray )
730+ # After alignment, coords should have the overlapping range
731+ result_time = result .coords .get_array ("time" )
732+ partial_time_len = len (partial_time )
733+ assert len (result_time ) == partial_time_len
734+
735+ def test_where_both_cond_and_other_misaligned (self , random_patch ):
736+ """Test where with both condition and other patches having misaligned coords."""
737+ # Create condition patch with partial time overlap (first part)
738+ time_coord = random_patch .coords .get_array ("time" )
739+ cond_time = time_coord [5 :15 ] # indices 5-14
740+
741+ condition_patch = random_patch .new (
742+ coords = {
743+ "time" : cond_time ,
744+ "distance" : random_patch .coords .get_array ("distance" ),
745+ },
746+ data = (random_patch .data [:, 5 :15 ] > random_patch .data [:, 5 :15 ].mean ()),
747+ )
748+
749+ # Create other patch with different partial time overlap (shifted range)
750+ other_time = time_coord [8 :18 ] # indices 8-17, overlaps with condition
751+ other_patch = random_patch .new (
752+ coords = {
753+ "time" : other_time ,
754+ "distance" : random_patch .coords .get_array ("distance" ),
755+ },
756+ data = random_patch .data [:, 8 :18 ] * 0.5 , # Use different values
757+ )
758+
759+ # This should work with coordinate alignment handling both patches
760+ result = random_patch .where (condition_patch , other = other_patch )
761+
762+ # The result should have coordinates that are the intersection of all three
763+ assert result is not None
764+ assert isinstance (result .data , np .ndarray )
765+
766+ # After alignment, the time coordinate should be the intersection
767+ result_time = result .coords .get_array ("time" )
768+ # The intersection of [5:15], [8:18], and full range should be [8:15]
769+ expected_overlap_len = 7 # indices 8, 9, 10, 11, 12, 13, 14
770+ assert len (result_time ) == expected_overlap_len
771+
772+ # Verify the actual time values match the expected overlap
773+ expected_time_values = time_coord [8 :15 ]
774+ assert np .array_equal (result_time , expected_time_values )
0 commit comments