@@ -3754,10 +3754,33 @@ def _regrid(
37543754 dx = dx .rechunk (chunks )
37553755
37563756 # Define the regridded chunksizes
3757- regridded_chunks = tuple (
3758- (regridded_sizes [i ],) if i in regridded_sizes else c
3759- for i , c in enumerate (dx .chunks )
3760- )
3757+ regridded_chunks = [] # The 'chunks' parameter to `map_blocks`
3758+ new_axis = [] # The 'new_axis' parameter to `map_blocks`
3759+ n = 0
3760+ for i , c in enumerate (dx .chunks ):
3761+ if i in regridded_sizes :
3762+ sizes = regridded_sizes [i ]
3763+ n_sizes = len (sizes )
3764+ regridded_chunks .extend (sizes )
3765+ if n_sizes > 1 :
3766+ new_axis .extend (range (n + 1 , n + n_sizes ))
3767+ n += n_sizes - 1
3768+ else :
3769+ regridded_chunks .extend (c )
3770+
3771+ n += 1
3772+
3773+ if new_axis :
3774+ # Update the axis identifiers.
3775+ #
3776+ # This is necessary when regridding changes the number of
3777+ # data dimensions (e.g. as happens when regridding a mesh
3778+ # topology axis to separate lat and lon axes).
3779+ axes = list (self ._axes )
3780+ for i in new_axis :
3781+ axes .insert (i , new_axis_identifier (tuple (axes )))
3782+
3783+ self ._axes = tuple (axes )
37613784
37623785 # Set the output data type
37633786 if method in ("nearest_dtos" , "nearest_stod" ):
@@ -3790,6 +3813,7 @@ def _regrid(
37903813 weights_dst_mask = weights_dst_mask ,
37913814 ref_src_mask = src_mask ,
37923815 chunks = regridded_chunks ,
3816+ new_axis = new_axis ,
37933817 meta = np .array ((), dtype = dst_dtype ),
37943818 )
37953819
0 commit comments