Skip to content

Commit f911361

Browse files
benjaminglass1pytorchmergebot
authored andcommitted
1 parent 08db735 commit f911361

File tree

2 files changed

+2
-34
lines changed

2 files changed

+2
-34
lines changed

test/test_maskedtensor.py

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
# Owner(s): ["module: masked operators"]
22

33
import torch
4-
import unittest
54
from torch.testing._internal.common_utils import (
6-
decorateIf,
75
TestCase,
86
run_tests,
97
make_tensor,
@@ -957,37 +955,6 @@ def test_unary_core(self, device, dtype, op, layout):
957955

958956
@ops(mt_binary_ufuncs, allowed_dtypes=MASKEDTENSOR_FLOAT_TYPES) # type: ignore[arg-type]
959957
@parametrize("layout", [torch.strided, torch.sparse_coo, torch.sparse_csr])
960-
# FIXME:
961-
# Result is just wrong; production logic should be fixed
962-
@decorateIf(
963-
unittest.expectedFailure,
964-
lambda params: (
965-
params["op"].name == "add" and
966-
params["dtype"] in [torch.float16, torch.float32] and
967-
params["device"] == "cpu" and
968-
params["layout"] == torch.sparse_csr
969-
)
970-
)
971-
# Result is just wrong; production logic should be fixed
972-
@decorateIf(
973-
unittest.expectedFailure,
974-
lambda params: (
975-
params["op"].name == "sub" and
976-
params["dtype"] in [torch.float16, torch.float32] and
977-
params["device"] == "cpu" and
978-
params["layout"] == torch.sparse_csr
979-
)
980-
)
981-
# Result is just wrong; production logic should be fixed
982-
@decorateIf(
983-
unittest.expectedFailure,
984-
lambda params: (
985-
params["op"].name == "eq" and
986-
params["dtype"] == torch.float64 and
987-
params["device"] == "cpu" and
988-
params["layout"] == torch.sparse_csr
989-
)
990-
)
991958
def test_binary_core(self, device, dtype, op, layout):
992959
self._test_unary_binary_equality(device, dtype, op, layout)
993960

torch/masked/maskedtensor/binary.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,10 @@ def _binary_helper(fn, args, kwargs, inplace):
139139

140140
crow = data_args[0].crow_indices()
141141
col = data_args[0].col_indices()
142+
size = data_args[0].size()
142143
data_args[0] = data_args[0].values()
143144
v = fn(*data_args)
144-
result_data = torch.sparse_csr_tensor(crow, col, v)
145+
result_data = torch.sparse_csr_tensor(crow, col, v, size)
145146

146147
else:
147148
result_data = fn(*data_args)

0 commit comments

Comments
 (0)