Skip to content

Commit c406bf2

Browse files
nairbvfacebook-github-bot
authored andcommitted
error instead of crashing on attempt to subclass typed tensors (#20283)
Summary: #20052 typed tensors (e.g. torch.FloatTensor) can't be subclassed. Was causing crashes and other errors. Pull Request resolved: #20283 Differential Revision: D15278138 Pulled By: nairbv fbshipit-source-id: 8493eac4d34dfb76b054362bf0acec02146cd0e2
1 parent 1e35ef0 commit c406bf2

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

test/test_torch.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11336,6 +11336,19 @@ def test_c10_layer_norm(self):
1133611336
weight), torch.tensor(bias), 1, epsilon, True)
1133711337
torch.testing.assert_allclose(expected_norm, actual_norm)
1133811338

11339+
def test_subclass_tensors(self):
11340+
# raise an error when trying to subclass FloatTensor
11341+
with self.assertRaisesRegex(TypeError, "type 'torch.FloatTensor' is not an acceptable base type"):
11342+
class Foo1(torch.FloatTensor):
11343+
pass
11344+
11345+
# but allow subclassing Tensor:
11346+
class Foo2(torch.Tensor):
11347+
def foo(self):
11348+
return 5
11349+
f = Foo2()
11350+
self.assertEqual(f.foo(), 5)
11351+
1133911352
# Functions to test negative dimension wrapping
1134011353
METHOD = 1
1134111354
INPLACE_METHOD = 2

torch/csrc/tensor/python_tensor.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,9 @@ static void py_initialize_tensor_type(PyTypeObject& type, const char* name, PyOb
159159
((PyObject*)&type)->ob_refcnt = 1;
160160
((PyObject*)&type)->ob_type = &metaclass;
161161
type.tp_basicsize = sizeof(PyTensorType);
162-
type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
162+
// Subclassing from torch.<ScalarType>Tensor isn't supported.
163+
// (Py_TPFLAGS_BASETYPE omitted). Subclassing torch.Tensor still allowed.
164+
type.tp_flags = Py_TPFLAGS_DEFAULT;
163165
type.tp_name = name;
164166
type.tp_new = Tensor_new;
165167
if (PyType_Ready(&type) < 0) {

0 commit comments

Comments
 (0)