Skip to content

Commit 5acc664

Browse files
Elias Ellisonfacebook-github-bot
authored andcommitted
make magic methods work with casts too (#20654)
Summary: Previous implementation of magic methods extended from BuiltinOperators, but it should be able to work with other sugared values, such as casts. I was also considering making CastValue's and BuiltinOperators's extend from a MagicMethod super class, and having them try to call into the super's before their own call. However, not all Builtin Operators have corresponding magic methods so i did it this way instead (although there are workarounds for that). Pull Request resolved: #20654 Differential Revision: D15434469 Pulled By: eellison fbshipit-source-id: 813fa00bf8b5b9ada46505075ebf984d8eee6aef
1 parent e6f22e1 commit 5acc664

File tree

3 files changed

+112
-40
lines changed

3 files changed

+112
-40
lines changed

test/test_jit.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5611,7 +5611,7 @@ def test_not_cast(x):
56115611
self.checkScript(test_not_cast, (torch.tensor(1),))
56125612
self.checkScript(test_not_cast, (torch.tensor(0),))
56135613

5614-
with self.assertRaisesRegex(RuntimeError, "expected"):
5614+
with self.assertRaisesRegex(RuntimeError, "Could not cast value of type Tuple\[Tensor, Tensor\]"): # noqa: W605
56155615
@torch.jit.script
56165616
def test_mult(x, y):
56175617
return not(x, y)
@@ -5636,7 +5636,7 @@ def test_cast_float(x):
56365636
self.checkScript(test_cast_float, (0.,))
56375637
self.checkScript(test_cast_float, (-1.,))
56385638

5639-
with self.assertRaisesRegex(RuntimeError, "expected a bool, int, float, or Tensor"):
5639+
with self.assertRaisesRegex(RuntimeError, "Could not cast value of type Tuple\[int, int\] to bool"): # noqa: W605
56405640
@torch.jit.script
56415641
def test_bad_conditional(x):
56425642
if (1, 2):
@@ -15605,6 +15605,53 @@ def _xor(): # noqa: E306
1560515605
def test():
1560615606
return Foo(torch.tensor(1)) + Foo(torch.tensor(1))
1560715607

15608+
def test_cast_overloads(self):
15609+
@torch.jit.script
15610+
class Foo(object):
15611+
def __init__(self, val):
15612+
# type: (float) -> None
15613+
self.val = val
15614+
15615+
def __int__(self):
15616+
return int(self.val)
15617+
15618+
def __float__(self):
15619+
return self.val
15620+
15621+
def __bool__(self):
15622+
return bool(self.val)
15623+
15624+
def __str__(self):
15625+
return str(self.val)
15626+
15627+
def test(foo):
15628+
# type: (Foo) -> Tuple[int, float, bool]
15629+
if foo:
15630+
pass
15631+
return int(foo), float(foo), bool(foo)
15632+
15633+
fn = torch.jit.script(test)
15634+
self.assertEqual(fn(Foo(0.5)), test(0.5))
15635+
self.assertEqual(fn(Foo(0.)), test(0.0))
15636+
# str has slightly different formatting
15637+
self.assertTrue("0.5" in (str(Foo(0.5))))
15638+
self.assertTrue("0." in (str(Foo(0.0))))
15639+
15640+
@torch.jit.script
15641+
class BadBool(object):
15642+
def __init__(self):
15643+
pass
15644+
15645+
def __bool__(self):
15646+
return (1, 2)
15647+
15648+
with self.assertRaisesRegex(RuntimeError, "expected a bool expression for condition"):
15649+
@torch.jit.script
15650+
def test():
15651+
if BadBool():
15652+
print(1)
15653+
pass
15654+
1560815655
def test_init_compiled_first(self):
1560915656
@torch.jit.script # noqa: B903
1561015657
class Foo(object):

torch/csrc/jit/script/compiler.cpp

Lines changed: 49 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ namespace torch {
2424
namespace jit {
2525
namespace script {
2626

27-
using SugaredValuePtr = std::shared_ptr<SugaredValue>;
2827
using FunctionTable = std::unordered_map<std::string, Function&>;
2928
using ValueTable = std::unordered_map<std::string, SugaredValuePtr>;
3029
using AttributeMap = std::unordered_map<std::string, Const>;
@@ -145,6 +144,13 @@ static Value* asSimple(const SugaredValuePtr& value) {
145144
}
146145
return nullptr;
147146
}
147+
148+
static std::shared_ptr<MagicMethod> makeMagic(
149+
const std::string& name,
150+
SugaredValuePtr base) {
151+
return std::make_shared<MagicMethod>(name, base);
152+
}
153+
148154
// we consider _N where N is a number, to be a non-meaningful name
149155
// and do not record it as a unique name. This allows python printing to
150156
// be able to export and import more consistently named graphs
@@ -388,17 +394,32 @@ struct Environment {
388394
if (!retval) {
389395
static std::unordered_map<std::string, SugaredValuePtr> globals = {
390396
{"print", std::make_shared<PrintValue>()},
391-
{"float", std::make_shared<CastValue>(FloatType::get(), prim::Float)},
392-
{"int", std::make_shared<CastValue>(IntType::get(), prim::Int)},
393-
{"bool", std::make_shared<CastValue>(BoolType::get(), prim::Bool)},
394-
{"str", std::make_shared<CastValue>(StringType::get(), prim::str)},
397+
{"float",
398+
makeMagic(
399+
"__float__",
400+
std::make_shared<CastValue>(FloatType::get(), prim::Float))},
401+
{"int",
402+
makeMagic(
403+
"__int__",
404+
std::make_shared<CastValue>(IntType::get(), prim::Int))},
405+
{"bool",
406+
makeMagic(
407+
"__bool__",
408+
std::make_shared<CastValue>(BoolType::get(), prim::Bool))},
409+
{"str",
410+
makeMagic(
411+
"__str__",
412+
std::make_shared<CastValue>(StringType::get(), prim::str))},
395413
{"getattr", std::make_shared<GetAttrValue>()},
396414
{"isinstance", std::make_shared<IsInstanceValue>()},
397415
// todo(zach): remove when we can correctly export torch.full via ONNX
398416
// or we have implicit conversion that can convert numbers to tensors
399417
{"_to_tensor",
400418
std::make_shared<CastValue>(TensorType::get(), prim::NumToTensor)},
401-
{"len", std::make_shared<OperatorOverload>(aten::len, "__len__")},
419+
{"len",
420+
makeMagic(
421+
"__len__",
422+
std::make_shared<BuiltinFunction>(aten::len, at::nullopt))},
402423
{"hash", std::make_shared<BuiltinFunction>(aten::hash, at::nullopt)},
403424
{"min", std::make_shared<BuiltinFunction>(prim::min, at::nullopt)},
404425
{"max", std::make_shared<BuiltinFunction>(prim::max, at::nullopt)},
@@ -1117,26 +1138,21 @@ struct to_ir {
11171138

11181139
Value* emitCond(const Expr& cond) {
11191140
Value* v = emitExpr(cond);
1120-
if (!v->type()->isSubtypeOf(BoolType::get())) {
1121-
Value* cast_v = emitBuiltinCall(
1122-
cond.get()->range(),
1123-
*v->owningGraph(),
1124-
prim::Bool,
1125-
c10::nullopt,
1126-
{v},
1127-
{},
1128-
/*required*/ false);
1129-
if (cast_v == nullptr) {
1130-
ErrorReport error(cond);
1131-
error
1132-
<< "expected a bool, int, float, or Tensor expression for condition but found "
1133-
<< v->type()->python_str();
1134-
throw error;
1135-
} else {
1136-
v = cast_v;
1137-
}
1141+
Value* out;
1142+
try {
1143+
auto bool_cast = environment_stack->getSugaredVar("bool", cond.range());
1144+
out = asSimple(bool_cast->call(cond.get()->range(), method, {v}, {}, 0));
1145+
} catch (...) {
1146+
throw ErrorReport(cond.range()) << "Could not cast value of type "
1147+
<< v->type()->python_str() << " to bool";
1148+
}
1149+
// cast value not response for checking output type
1150+
if (!out->type()->isSubtypeOf(BoolType::get())) {
1151+
throw ErrorReport(cond)
1152+
<< "expected a bool expression for condition but found "
1153+
<< out->type()->python_str();
11381154
}
1139-
return v;
1155+
return out;
11401156
}
11411157

11421158
void emitIfElseBlocks(Value* cond_value, const If& stmt) {
@@ -2349,8 +2365,10 @@ struct to_ir {
23492365
const auto& inputs = tree->trees();
23502366
auto named_values = getNamedValues(inputs, /*maybe_unpack=*/false);
23512367
auto neg_val =
2352-
asSimple(OperatorOverload(aten::neg, "__neg__")
2353-
.call(tree->range(), method, named_values, {}, 0));
2368+
asSimple(makeMagic(
2369+
"__neg__",
2370+
std::make_shared<BuiltinFunction>(aten::neg, at::nullopt))
2371+
->call(tree->range(), method, named_values, {}, 0));
23542372

23552373
// constant fold the input if possible
23562374

@@ -2437,8 +2455,10 @@ struct to_ir {
24372455
auto kind = getNodeKind(tree->kind(), inputs.size());
24382456
auto overload = getOperatorOverload(tree->kind(), inputs.size());
24392457
auto named_values = getNamedValues(inputs, /*maybe_unpack=*/false);
2440-
return asSimple(OperatorOverload(kind, overload)
2441-
.call(tree->range(), method, named_values, {}, 0));
2458+
return asSimple(
2459+
makeMagic(
2460+
overload, std::make_shared<BuiltinFunction>(kind, at::nullopt))
2461+
->call(tree->range(), method, named_values, {}, 0));
24422462
}
24432463
case TK_NOT: {
24442464
Value* input = emitCond(Expr(tree->trees()[0]));

torch/csrc/jit/script/sugared_value.h

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,6 @@ struct TORCH_API ClassValue : public SugaredValue {
214214
ClassTypePtr type_;
215215
};
216216

217-
218217
struct FunctionValue : public SugaredValue {
219218
FunctionValue(std::shared_ptr<Function> callee)
220219
: callee_(std::move(callee)) {}
@@ -232,6 +231,7 @@ struct FunctionValue : public SugaredValue {
232231
return std::make_shared<SimpleValue>(
233232
callee_->emit_call(*f.graph(), loc, inputs, attributes));
234233
}
234+
235235
private:
236236
std::shared_ptr<Function> callee_;
237237
};
@@ -262,8 +262,6 @@ struct MethodValue : public SugaredValue {
262262
std::shared_ptr<Function> method_;
263263
};
264264

265-
266-
267265
struct TORCH_API PrintValue : public SugaredValue {
268266
std::string kind() const override {
269267
return "print";
@@ -301,13 +299,19 @@ struct TORCH_API CastValue : public BuiltinFunction {
301299
TypePtr type_;
302300
};
303301

302+
using SugaredValuePtr = std::shared_ptr<SugaredValue>;
303+
304304
// builtins operators and functions that call a method if it exists
305305
// on a class type, like 'len(x)' and 'x + y'
306-
struct TORCH_API OperatorOverload : public BuiltinFunction {
307-
OperatorOverload(c10::Symbol builtin_method, std::string desugared_name)
308-
: BuiltinFunction(builtin_method, c10::nullopt),
306+
struct TORCH_API MagicMethod : public SugaredValue {
307+
MagicMethod(std::string desugared_name, SugaredValuePtr base)
308+
: base_value_(std::move(base)),
309309
desugared_name_(std::move(desugared_name)) {}
310310

311+
std::string kind() const override {
312+
return desugared_name_;
313+
}
314+
311315
std::shared_ptr<SugaredValue> call(
312316
const SourceRange& loc,
313317
Function& m,
@@ -322,17 +326,18 @@ struct TORCH_API OperatorOverload : public BuiltinFunction {
322326
method->emit_call(*m.graph(), loc, inputs, attributes));
323327
} else {
324328
ErrorReport e(loc);
325-
e << "Cannot call builtin operator " << symbol.toDisplayString()
326-
<< " on " << class_ptr->python_str() << " because it does not "
329+
e << "Cannot call " << desugared_name_ << " on "
330+
<< class_ptr->python_str() << " because it does not "
327331
<< " define a " << desugared_name_ << " method";
328332
throw e;
329333
}
330334
}
331335
}
332-
return BuiltinFunction::call(loc, m, inputs, attributes, n_binders);
336+
return base_value_->call(loc, m, inputs, attributes, n_binders);
333337
}
334338

335339
private:
340+
SugaredValuePtr base_value_;
336341
std::string desugared_name_;
337342
};
338343

0 commit comments

Comments
 (0)