Skip to content
44 changes: 42 additions & 2 deletions torch/_masked/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def _apply_docstring_templates(func):
"""Decorator that applies docstring templates to function docstring
and returns the function instance.
"""

docstring_templates = dict(
reduction_signature='''\
{function_name}(input, dim, *, keepdim=False, dtype=None, mask=None) -> Tensor''',
Expand Down Expand Up @@ -106,7 +105,8 @@ def _apply_docstring_templates(func):
sum='sum',
prod='product',
amax='maximum',
amin='minimum')[func.__name__],
amin='minimum',
mean='mean')[func.__name__],
'identity_uint8': _reduction_identity(func.__name__, torch.tensor(0, dtype=torch.uint8)),
'identity_int32': _reduction_identity(func.__name__, torch.tensor(0, dtype=torch.int32)),
'identity_float32': _reduction_identity(func.__name__, torch.tensor(0, dtype=torch.float32)),
Expand Down Expand Up @@ -172,6 +172,13 @@ def _reduction_identity(op_name: str, input: Tensor):
return torch.tensor(torch.inf, dtype=dtype, device=device)
elif torch.is_signed(input) or dtype == torch.uint8:
return torch.tensor(torch.iinfo(dtype).max, dtype=dtype, device=device)
elif op_name == 'mean':
# Strictly speaking, the identity value of the mean operation
# is the mean of the input. Since the mean value depends on
# the dim argument and it may be a non-scalar tensor, we
# consider the identity value of the mean operation ambiguous.
# Moreover, the mean value of empty input is undefined.
return None
raise NotImplementedError(f'identity of {op_name} on {dtype} input')


Expand Down Expand Up @@ -332,3 +339,36 @@ def amin(input: Tensor,
return torch.amin(mask_input, dim_, bool(keepdim)).to(dtype=dtype)
else:
raise ValueError(f'masked amin expects strided tensor (got {input.layout} tensor)')


@_apply_docstring_templates
def mean(input: Tensor,
dim: DimOrDims = None,
*,
keepdim: Optional[bool] = False,
dtype: Optional[DType] = None,
mask: Optional[Tensor] = None) -> Tensor:
"""\
{reduction_signature}

{reduction_descr}

By definition, the identity value of a mean operation is the mean
value of the tensor. If all elements of the input tensor along given
dimension(s) :attr:`dim` are masked-out, the identity value of the
mean is undefined. Due to this ambiguity, the elements of output
tensor with strided layout, that correspond to fully masked-out
elements, have ``nan`` values.

{reduction_args}

{reduction_example}"""
if dtype is None:
dtype = input.dtype
if input.layout == torch.strided:
inmask = _input_mask(input, mask=mask)
count = sum(inmask.new_ones(input.shape, dtype=torch.int64), dim, keepdim=keepdim, mask=inmask)
total = sum(input, dim, keepdim=keepdim, dtype=dtype, mask=inmask)
return total / count
else:
raise ValueError(f'masked sum expects strided tensor (got {input.layout} tensor)')
18 changes: 18 additions & 0 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -11851,6 +11851,24 @@ def ref_pairwise_distance(input1, input2):
sample_inputs_func=sample_inputs_masked_reduction,
gradcheck_wrapper=gradcheck_wrapper_masked_operation
),
ReductionOpInfo(
'_masked.mean',
ref=reference_reduction_numpy(np.mean) if np.lib.NumpyVersion(np.__version__) >= '1.20.2' else None,
method_variant=None,
nan_policy='propagate',
supports_out=False,
promotes_int_to_float=True,
dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16, torch.bool),
skips=(
# FIXME: sum reduces all dimensions when dim=[]
DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'),
DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'),
# RuntimeError: undefined value tensor
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
),
sample_inputs_func=sample_inputs_masked_reduction,
gradcheck_wrapper=gradcheck_wrapper_masked_operation
),
OpInfo(
"nn.functional.ctc_loss",
ref=_NOTHING,
Expand Down