55import xarray as xr
66from xarray .core .groupby import _consolidate_slices
77
8- from . import assert_identical , raises_regex
8+ from . import assert_equal , assert_identical , raises_regex
99
1010
1111def test_consolidate_slices ():
@@ -40,14 +40,14 @@ def test_multi_index_groupby_apply():
4040 {"foo" : (("x" , "y" ), np .random .randn (3 , 4 ))},
4141 {"x" : ["a" , "b" , "c" ], "y" : [1 , 2 , 3 , 4 ]},
4242 )
43- doubled = 2 * ds
44- group_doubled = (
43+ expected = 2 * ds
44+ actual = (
4545 ds .stack (space = ["x" , "y" ])
4646 .groupby ("space" )
4747 .apply (lambda x : 2 * x )
4848 .unstack ("space" )
4949 )
50- assert doubled . equals ( group_doubled )
50+ assert_equal ( expected , actual )
5151
5252
5353def test_multi_index_groupby_sum ():
@@ -58,7 +58,7 @@ def test_multi_index_groupby_sum():
5858 )
5959 expected = ds .sum ("z" )
6060 actual = ds .stack (space = ["x" , "y" ]).groupby ("space" ).sum ("z" ).unstack ("space" )
61- assert expected . equals ( actual )
61+ assert_equal ( expected , actual )
6262
6363
6464def test_groupby_da_datetime ():
@@ -78,15 +78,15 @@ def test_groupby_da_datetime():
7878 expected = xr .DataArray (
7979 [3 , 7 ], coords = dict (reference_date = reference_dates ), dims = "reference_date"
8080 )
81- assert actual . equals (expected )
81+ assert_equal (expected , actual )
8282
8383
8484def test_groupby_duplicate_coordinate_labels ():
8585 # fix for http://stackoverflow.com/questions/38065129
8686 array = xr .DataArray ([1 , 2 , 3 ], [("x" , [1 , 1 , 2 ])])
8787 expected = xr .DataArray ([3 , 3 ], [("x" , [1 , 2 ])])
8888 actual = array .groupby ("x" ).sum ()
89- assert expected . equals ( actual )
89+ assert_equal ( expected , actual )
9090
9191
9292def test_groupby_input_mutation ():
@@ -255,6 +255,59 @@ def test_groupby_repr_datetime(obj):
255255 assert actual == expected
256256
257257
258+ def test_groupby_drops_nans ():
259+ # GH2383
260+ # nan in 2D data variable (requires stacking)
261+ ds = xr .Dataset (
262+ {
263+ "variable" : (("lat" , "lon" , "time" ), np .arange (60.0 ).reshape ((4 , 3 , 5 ))),
264+ "id" : (("lat" , "lon" ), np .arange (12.0 ).reshape ((4 , 3 ))),
265+ },
266+ coords = {"lat" : np .arange (4 ), "lon" : np .arange (3 ), "time" : np .arange (5 )},
267+ )
268+
269+ ds ["id" ].values [0 , 0 ] = np .nan
270+ ds ["id" ].values [3 , 0 ] = np .nan
271+ ds ["id" ].values [- 1 , - 1 ] = np .nan
272+
273+ grouped = ds .groupby (ds .id )
274+
275+ # non reduction operation
276+ expected = ds .copy ()
277+ expected .variable .values [0 , 0 , :] = np .nan
278+ expected .variable .values [- 1 , - 1 , :] = np .nan
279+ expected .variable .values [3 , 0 , :] = np .nan
280+ actual = grouped .apply (lambda x : x ).transpose (* ds .variable .dims )
281+ assert_identical (actual , expected )
282+
283+ # reduction along grouped dimension
284+ actual = grouped .mean ()
285+ stacked = ds .stack ({"xy" : ["lat" , "lon" ]})
286+ expected = (
287+ stacked .variable .where (stacked .id .notnull ()).rename ({"xy" : "id" }).to_dataset ()
288+ )
289+ expected ["id" ] = stacked .id .values
290+ assert_identical (actual , expected .dropna ("id" ).transpose (* actual .dims ))
291+
292+ # reduction operation along a different dimension
293+ actual = grouped .mean ("time" )
294+ expected = ds .mean ("time" ).where (ds .id .notnull ())
295+ assert_identical (actual , expected )
296+
297+ # NaN in non-dimensional coordinate
298+ array = xr .DataArray ([1 , 2 , 3 ], [("x" , [1 , 2 , 3 ])])
299+ array ["x1" ] = ("x" , [1 , 1 , np .nan ])
300+ expected = xr .DataArray (3 , [("x1" , [1 ])])
301+ actual = array .groupby ("x1" ).sum ()
302+ assert_equal (expected , actual )
303+
304+ # test for repeated coordinate labels
305+ array = xr .DataArray ([0 , 1 , 2 , 4 , 3 , 4 ], [("x" , [np .nan , 1 , 1 , np .nan , 2 , np .nan ])])
306+ expected = xr .DataArray ([3 , 3 ], [("x" , [1 , 2 ])])
307+ actual = array .groupby ("x" ).sum ()
308+ assert_equal (expected , actual )
309+
310+
258311def test_groupby_grouping_errors ():
259312 dataset = xr .Dataset ({"foo" : ("x" , [1 , 1 , 1 ])}, {"x" : [1 , 2 , 3 ]})
260313 with raises_regex (ValueError , "None of the data falls within bins with edges" ):
0 commit comments