@@ -106,8 +106,7 @@ def _apply_docstring_templates(func):
106106 sum = 'sum' ,
107107 prod = 'product' ,
108108 amax = 'maximum' ,
109- amin = 'minimum' ,
110- mean = 'mean' )[func .__name__ ],
109+ amin = 'minimum' )[func .__name__ ],
111110 'identity_uint8' : _reduction_identity (func .__name__ , torch .tensor (0 , dtype = torch .uint8 )),
112111 'identity_int32' : _reduction_identity (func .__name__ , torch .tensor (0 , dtype = torch .int32 )),
113112 'identity_float32' : _reduction_identity (func .__name__ , torch .tensor (0 , dtype = torch .float32 )),
@@ -173,13 +172,6 @@ def _reduction_identity(op_name: str, input: Tensor):
173172 return torch .tensor (torch .inf , dtype = dtype , device = device )
174173 elif torch .is_signed (input ) or dtype == torch .uint8 :
175174 return torch .tensor (torch .iinfo (dtype ).max , dtype = dtype , device = device )
176- elif op_name == 'mean' :
177- # Strictly speaking, the identity value of the mean operation
178- # is the mean of the input. Since the mean value depends on
179- # the dim argument and it may be a non-scalar tensor, we
180- # consider the identity value of the mean operation ambiguous.
181- # Moreover, the mean value of empty input is undefined.
182- return None
183175 raise NotImplementedError (f'identity of { op_name } on { dtype } input' )
184176
185177
@@ -340,38 +332,3 @@ def amin(input: Tensor,
340332 return torch .amin (mask_input , dim_ , bool (keepdim )).to (dtype = dtype )
341333 else :
342334 raise ValueError (f'masked amin expects strided tensor (got { input .layout } tensor)' )
343-
344-
345- @_apply_docstring_templates
346- def mean (input : Tensor ,
347- dim : DimOrDims = None ,
348- * ,
349- keepdim : Optional [bool ] = False ,
350- dtype : Optional [DType ] = None ,
351- mask : Optional [Tensor ] = None ) -> Tensor :
352- """\
353- {reduction_signature}
354-
355- {reduction_descr}
356-
357- By definition, the identity value of a mean operation is the mean
358- value of the tensor. If all elements of the input tensor along given
359- dimension(s) :attr:`dim` are masked-out, the identity value of the
360- mean is undefined. Due to this ambiguity, the elements of output
361- tensor with strided layout, that correspond to fully masked-out
362- elements, have ``nan`` values.
363-
364- {reduction_args}
365-
366- {reduction_example}
367-
368- """
369- if dtype is None :
370- dtype = input .dtype
371- if input .layout == torch .strided :
372- inmask = _input_mask (input , mask = mask )
373- count = sum (inmask .new_ones (input .shape , dtype = torch .int64 ), dim , keepdim = keepdim , mask = inmask )
374- total = sum (input , dim , keepdim = keepdim , dtype = dtype , mask = inmask )
375- return total / count
376- else :
377- raise ValueError (f'masked sum expects strided tensor (got { input .layout } tensor)' )
0 commit comments