Skip to content

Commit 20f08d2

Browse files
Natalia Gimelsheinfacebook-github-bot
authored andcommitted
Revert D31838513: Strided masked reduction: mean.
Test Plan: revert-hammer Differential Revision: D31838513 Original commit changeset: 54b99ccf9821 fbshipit-source-id: 5480e8482c8770b41579ee085e158572b659c1f5
1 parent 2578de4 commit 20f08d2

File tree

2 files changed

+1
-62
lines changed

2 files changed

+1
-62
lines changed

torch/_masked/__init__.py

Lines changed: 1 addition & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,7 @@ def _apply_docstring_templates(func):
106106
sum='sum',
107107
prod='product',
108108
amax='maximum',
109-
amin='minimum',
110-
mean='mean')[func.__name__],
109+
amin='minimum')[func.__name__],
111110
'identity_uint8': _reduction_identity(func.__name__, torch.tensor(0, dtype=torch.uint8)),
112111
'identity_int32': _reduction_identity(func.__name__, torch.tensor(0, dtype=torch.int32)),
113112
'identity_float32': _reduction_identity(func.__name__, torch.tensor(0, dtype=torch.float32)),
@@ -173,13 +172,6 @@ def _reduction_identity(op_name: str, input: Tensor):
173172
return torch.tensor(torch.inf, dtype=dtype, device=device)
174173
elif torch.is_signed(input) or dtype == torch.uint8:
175174
return torch.tensor(torch.iinfo(dtype).max, dtype=dtype, device=device)
176-
elif op_name == 'mean':
177-
# Strictly speaking, the identity value of the mean operation
178-
# is the mean of the input. Since the mean value depends on
179-
# the dim argument and it may be a non-scalar tensor, we
180-
# consider the identity value of the mean operation ambiguous.
181-
# Moreover, the mean value of empty input is undefined.
182-
return None
183175
raise NotImplementedError(f'identity of {op_name} on {dtype} input')
184176

185177

@@ -340,38 +332,3 @@ def amin(input: Tensor,
340332
return torch.amin(mask_input, dim_, bool(keepdim)).to(dtype=dtype)
341333
else:
342334
raise ValueError(f'masked amin expects strided tensor (got {input.layout} tensor)')
343-
344-
345-
@_apply_docstring_templates
346-
def mean(input: Tensor,
347-
dim: DimOrDims = None,
348-
*,
349-
keepdim: Optional[bool] = False,
350-
dtype: Optional[DType] = None,
351-
mask: Optional[Tensor] = None) -> Tensor:
352-
"""\
353-
{reduction_signature}
354-
355-
{reduction_descr}
356-
357-
By definition, the identity value of a mean operation is the mean
358-
value of the tensor. If all elements of the input tensor along given
359-
dimension(s) :attr:`dim` are masked-out, the identity value of the
360-
mean is undefined. Due to this ambiguity, the elements of output
361-
tensor with strided layout, that correspond to fully masked-out
362-
elements, have ``nan`` values.
363-
364-
{reduction_args}
365-
366-
{reduction_example}
367-
368-
"""
369-
if dtype is None:
370-
dtype = input.dtype
371-
if input.layout == torch.strided:
372-
inmask = _input_mask(input, mask=mask)
373-
count = sum(inmask.new_ones(input.shape, dtype=torch.int64), dim, keepdim=keepdim, mask=inmask)
374-
total = sum(input, dim, keepdim=keepdim, dtype=dtype, mask=inmask)
375-
return total / count
376-
else:
377-
raise ValueError(f'masked sum expects strided tensor (got {input.layout} tensor)')

torch/testing/_internal/common_methods_invocations.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11070,24 +11070,6 @@ def ref_pairwise_distance(input1, input2):
1107011070
sample_inputs_func=sample_inputs_masked_reduction,
1107111071
gradcheck_wrapper=gradcheck_wrapper_masked_operation
1107211072
),
11073-
ReductionOpInfo(
11074-
'_masked.mean',
11075-
ref=reference_reduction_numpy(np.mean) if np.lib.NumpyVersion(np.__version__) >= '1.20.0' else None,
11076-
method_variant=None,
11077-
nan_policy='propagate',
11078-
supports_out=False,
11079-
promotes_int_to_float=True,
11080-
dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16, torch.bool),
11081-
skips=(
11082-
# FIXME: sum reduces all dimensions when dim=[]
11083-
DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'),
11084-
DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'),
11085-
# RuntimeError: undefined value tensor
11086-
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
11087-
),
11088-
sample_inputs_func=sample_inputs_masked_reduction,
11089-
gradcheck_wrapper=gradcheck_wrapper_masked_operation
11090-
),
1109111073
OpInfo(
1109211074
"nn.functional.nll_loss",
1109311075
ref=_NOTHING,

0 commit comments

Comments
 (0)