Skip to content

Commit f5617b0

Browse files
kshitij12345facebook-github-bot
authored andcommitted
[testing] Add Opinfo for torch.frac and minor fixes (#52660)
Summary: Reference : #42515 Pull Request resolved: #52660 Reviewed By: ailzhang Differential Revision: D26618151 Pulled By: mruberry fbshipit-source-id: cf0df38e46f44d3afff6e0015af5a840c661aa0e
1 parent 312b297 commit f5617b0

File tree

3 files changed

+18
-10
lines changed

3 files changed

+18
-10
lines changed

test/test_torch.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7030,7 +7030,6 @@ def tmp(dtype, device):
70307030
('eig', 'with_eigvec', _new_t((10, 10)), lambda t, d: [True],
70317031
1e-5, 1e-5, 1e-5, _float_types_no_half, _cpu_types, False, [skipCUDAIfNoMagma, onlyOnCPUAndCUDA]),
70327032
('sign', '', _small_3d, lambda t, d: []),
7033-
('frac', '', _small_3d, lambda t, d: [], 1e-5, 1e-2, 1e-5, _float_types, [torch.bfloat16]),
70347033
]
70357034

70367035
# Creates and decorates a generic test and adds it to the class.

test/test_unary_ufuncs.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,10 @@
5252
-math.pi + .00001, math.pi - .00001,
5353
-math.pi, math.pi,
5454
-math.pi - .00001, math.pi + .00001)
55-
_large_float_vals = (-501, 501,
56-
-1001.2, 1001.2,
57-
-13437.7, 13437.7,
58-
-4988429.2, 4988429.2,
59-
-1e20, 1e20)
55+
_large_float16_vals = (-501, 501,
56+
-1001.2, 1001.2,
57+
-13437.7, 13437.7)
58+
_large_float_vals = _large_float16_vals + (-4988429.2, 4988429.2, -1e20, 1e20)
6059
_float_extremals = (float('inf'), float('-inf'), float('nan'))
6160
_medium_length = 812
6261
_large_size = (1029, 917)
@@ -146,7 +145,11 @@ def generate_numeric_tensors_hard(device, dtype, *,
146145
return ()
147146

148147
if dtype.is_floating_point:
149-
vals = _large_float_vals
148+
if dtype is torch.float16:
149+
# float16 has smaller range.
150+
vals = _large_float16_vals
151+
else:
152+
vals = _large_float_vals
150153
elif dtype.is_complex:
151154
vals = tuple(complex(x, y) for x, y in chain(product(_large_float_vals, _large_float_vals),
152155
product(_float_vals, _large_float_vals),
@@ -1726,7 +1729,6 @@ def _medium_2d(dtype, device):
17261729

17271730
# TODO: all these should be replaced with OpInfos
17281731
torch_op_tests = [
1729-
_TorchMathTestMeta('frac', reffn='fmod', refargs=lambda x: (x.numpy(), 1)),
17301732
_TorchMathTestMeta('polygamma', args=[0], substr='_0', reffn='polygamma',
17311733
refargs=lambda x: (0, x.numpy()), input_fn=_generate_gamma_input, inputargs=[False],
17321734
ref_backend='scipy'),

torch/testing/_internal/common_methods_invocations.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1730,6 +1730,15 @@ def sample_inputs_masked_select(op_info, device, dtype, requires_grad):
17301730
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half),
17311731
sample_inputs_func=sample_inputs_diag,
17321732
test_inplace_grad=False),
1733+
UnaryUfuncInfo('frac',
1734+
ref=lambda x: np.modf(x)[0],
1735+
dtypes=floating_types_and(torch.bfloat16, torch.float16),
1736+
dtypesIfCPU=floating_types_and(torch.bfloat16, torch.float16),
1737+
dtypesIfCUDA=floating_types_and(torch.float16),
1738+
assert_autodiffed=True,
1739+
# Reference for disabling extremals
1740+
# https://github.com/pytorch/pytorch/issues/51948
1741+
handles_extremals=False),
17331742
SpectralFuncInfo('fft.fft',
17341743
aten_name='fft_fft',
17351744
ref=np.fft.fft,
@@ -2967,8 +2976,6 @@ def method_tests():
29672976
('rsqrt', uniform_scalar(1e-2, requires_grad=True), NO_ARGS, 'scalar', (True,)),
29682977
('rsqrt', torch.rand(S, S, S, dtype=torch.cfloat) + 1e-2, NO_ARGS, 'complex', (True,)),
29692978
('rsqrt', uniform_scalar(1e-2 * (1 + 1j), requires_grad=True), NO_ARGS, 'complex_scalar', (True,)),
2970-
('frac', (S, S, S), NO_ARGS, '', (True,)),
2971-
('frac', (), NO_ARGS, 'scalar', (True,)),
29722979
('fmod', (S, S, S), (1.5,), '', (True,)),
29732980
('fmod', (), (1.5,), 'scalar', (True,)),
29742981
('fmod', (S, S, S), (non_differentiable(torch.rand(S, S, S) + 1.5),), 'tensor'),

0 commit comments

Comments
 (0)