Skip to content

Commit 0a3fb45

Browse files
Brennan Vincentfacebook-github-bot
authored andcommitted
allow passing Python built-in types as dtypes (#21215)
Summary: Another simple bit of syntax that NumPy supports and we don't. Support int, float, and bool. ```python >>> torch.randn((2,3), dtype=float) tensor([[-0.1752, -0.3240, -0.6148], [ 0.1861, 1.6472, 0.1687]], dtype=torch.float64) ``` A bit confusingly, Python's "float" actually means double, but nothing we can do about that. Pull Request resolved: #21215 Differential Revision: D15697012 Pulled By: umanwizard fbshipit-source-id: 9a38d960a610b8e67023486b0c9265edd3c22246
1 parent b647804 commit 0a3fb45

File tree

4 files changed

+38
-2
lines changed

4 files changed

+38
-2
lines changed

test/test_torch.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11789,6 +11789,19 @@ def test_T(self):
1178911789
scalar = torch.tensor(5)
1179011790
self.assertEqual(scalar, scalar.T)
1179111791

11792+
def test_python_types(self):
11793+
a1 = torch.randn((1, 2), dtype=torch.float64)
11794+
a2 = torch.randn((1, 2), dtype=float)
11795+
self.assertEqual(a1.dtype, a2.dtype)
11796+
11797+
b1 = torch.arange(10, 20, dtype=torch.int64)
11798+
b2 = torch.arange(10, 20, dtype=int)
11799+
self.assertEqual(b1.dtype, b2.dtype)
11800+
11801+
c1 = torch.tensor([True, False], dtype=torch.bool)
11802+
c2 = torch.tensor([True, False], dtype=bool)
11803+
self.assertEqual(c1.dtype, c2.dtype)
11804+
1179211805
# Functions to test negative dimension wrapping
1179311806
METHOD = 1
1179411807
INPLACE_METHOD = 2

torch/csrc/Dtype.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,15 @@ inline bool THPDtype_Check(PyObject *obj) {
1818
return Py_TYPE(obj) == &THPDtypeType;
1919
}
2020

21+
inline bool THPPythonScalarType_Check(PyObject *obj) {
22+
return obj == (PyObject*)(&PyFloat_Type) ||
23+
obj == (PyObject*)(&PyBool_Type) ||
24+
#if PY_MAJOR_VERSION == 2
25+
obj == (PyObject*)(&PyInt_Type) ||
26+
#endif
27+
obj == (PyObject*)(&PyLong_Type);
28+
}
29+
2130
PyObject * THPDtype_New(at::ScalarType scalar_type, const std::string& name);
2231

2332
void THPDtype_init(PyObject *module);

torch/csrc/utils/python_arg_parser.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ bool FunctionParameter::check(PyObject* obj) {
169169
case ParameterType::BOOL: return PyBool_Check(obj);
170170
case ParameterType::STORAGE: return isStorage(obj);
171171
case ParameterType::PYOBJECT: return true;
172-
case ParameterType::SCALARTYPE: return THPDtype_Check(obj);
172+
case ParameterType::SCALARTYPE: return THPDtype_Check(obj) || THPPythonScalarType_Check(obj);
173173
case ParameterType::LAYOUT: return THPLayout_Check(obj);
174174
case ParameterType::MEMORY_FORMAT: return THPMemoryFormat_Check(obj);
175175
case ParameterType::DEVICE:

torch/csrc/utils/python_arg_parser.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,21 @@ inline at::ScalarType PythonArgs::scalartype(int i) {
337337
return (scalartype == at::ScalarType::Undefined) ?
338338
torch::tensors::get_default_scalar_type() : scalartype;
339339
}
340-
return reinterpret_cast<THPDtype*>(args[i])->scalar_type;
340+
PyObject *obj = args[i];
341+
if (obj == (PyObject*)&PyFloat_Type) {
342+
return at::ScalarType::Double;
343+
}
344+
if (obj == (PyObject*)&PyBool_Type) {
345+
return at::ScalarType::Bool;
346+
}
347+
if (obj == (PyObject*)&PyLong_Type
348+
#if PY_MAJOR_VERSION == 2
349+
|| obj == (PyObject*)&PyInt_Type
350+
#endif
351+
) {
352+
return at::ScalarType::Long;
353+
}
354+
return reinterpret_cast<THPDtype*>(obj)->scalar_type;
341355
}
342356

343357
inline c10::optional<at::ScalarType> PythonArgs::scalartypeOptional(int i) {

0 commit comments

Comments
 (0)