Skip to content

Commit 67da5af

Browse files
bertmaherfacebook-github-bot
authored andcommitted
[te] Fix bugs with shift operators (#49271)
Summary: Pull Request resolved: #49271 Two things: 1. These throw exceptions in their constructor, which causes a segfault (*), so move the exceptions to ::make. 2. They technically support FP types but the rules are complicated so let's not bother. (*) The reason for the segfault: all Exprs including these inherit from KernelScopedObject, whose constructor adds the object to a list for destruction at the end of the containing KernelArena's lifetime. But if the derived-class constructor throws, the object is deleted even though it's still in the KernelArena's list. So when the KernelArena is itself deleted, it double-frees the pointer and dies. I've also fixed And, Or, and Xor in this diff. ghstack-source-id: 118594998 Test Plan: `buck test //caffe2/test:jit` Differential Revision: D25512052 fbshipit-source-id: f3ca16f208c427cd3d740e8971302d8d504240fb
1 parent 39a10fb commit 67da5af

File tree

4 files changed

+50
-54
lines changed

4 files changed

+50
-54
lines changed

test/test_jit_fuser_te.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,9 @@ def apply(fn):
477477
binary_ops = [
478478
operator.__and__,
479479
operator.__or__,
480-
operator.__xor__
480+
operator.__xor__,
481+
operator.__lshift__,
482+
operator.__rshift__,
481483
]
482484
devices = self.devices
483485
for dtype, op, device in product(self.int_dtypes, binary_ops, devices):
@@ -1292,11 +1294,6 @@ def apply(fn):
12921294
torch.lt,
12931295
torch.fmod,
12941296
torch.remainder,
1295-
1296-
# FIXME: segfaults on CPU backend
1297-
# operator.__rshift__,
1298-
# operator.__lshift__,
1299-
13001297
lambda x, y: y.type_as(x),
13011298
]
13021299
fp_only = [
@@ -1343,10 +1340,6 @@ def apply_with_scalar(fn, scalar):
13431340
torch.ge,
13441341
torch.lt,
13451342
torch.gt,
1346-
1347-
# FIXME: segfaults on CPU backend
1348-
# operator.__rshift__,
1349-
# operator.__lshift__,
13501343
]
13511344
devices = self.devices
13521345
# Maybe we should split this into separate tests to speed it up by

torch/csrc/jit/passes/tensorexpr_fuser.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -737,6 +737,12 @@ class TensorExprFuser {
737737
"aten::remainder.Scalar(Tensor self, Scalar other) -> Tensor",
738738
"aten::remainder.Tensor(Tensor self, Tensor other) -> Tensor",
739739
};
740+
static const OperatorSet int_only_operator_set{
741+
"aten::__lshift__.Scalar(Tensor self, Scalar other) -> Tensor",
742+
"aten::__lshift__.Tensor(Tensor self, Tensor other) -> Tensor",
743+
"aten::__rshift__.Scalar(Tensor self, Scalar other) -> Tensor",
744+
"aten::__rshift__.Tensor(Tensor self, Tensor other) -> Tensor",
745+
};
740746
// clang-format on
741747

742748
for (const Value* v : node->inputs()) {
@@ -759,11 +765,20 @@ class TensorExprFuser {
759765
if (node->isMemberOf(float_only_operator_set) && !isFloatingType(*st)) {
760766
return false;
761767
}
768+
769+
// These operators have complicated casting rules for floats.
770+
if (node->isMemberOf(int_only_operator_set) && isFloatingType(*st)) {
771+
return false;
772+
}
762773
} else if (node->isMemberOf(float_only_operator_set)) {
763774
// Check scalar operands of float-only ops.
764775
if (!v->type()->cast<FloatType>()) {
765776
return false;
766777
}
778+
} else if (node->isMemberOf(int_only_operator_set)) {
779+
if (!v->type()->cast<IntType>()) {
780+
return false;
781+
}
767782
}
768783
}
769784

torch/csrc/jit/tensorexpr/eval.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -422,8 +422,14 @@ class SimpleIREvaluator : public CodeGen, public IRVisitor {
422422

423423
if (expr_type == IRNodeType::kLshift || expr_type == IRNodeType::kRshift) {
424424
switch (lhs_v.dtype().scalar_type()) {
425-
case ScalarType::Int:
426-
value_ = shift_binary_op<int>(lhs_v, rhs_v, expr_type);
425+
#define TYPE_CASE(Type, Name) \
426+
case ScalarType::Name: \
427+
value_ = shift_binary_op<Type>(lhs_v, rhs_v, expr_type); \
428+
break;
429+
AT_FORALL_INT_TYPES(TYPE_CASE);
430+
#undef TYPE_CASE
431+
case ScalarType::Bool:
432+
value_ = shift_binary_op<unsigned char>(lhs_v, rhs_v, expr_type);
427433
break;
428434
default:
429435
throw unsupported_dtype();

torch/csrc/jit/tensorexpr/ir.h

Lines changed: 24 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -179,69 +179,51 @@ class Mod : public BinaryOpNode<Mod> {
179179
: BinaryOpNode(lhs, rhs, IRNodeType::kMod) {}
180180
};
181181

182-
class And : public BinaryOpNode<And> {
182+
template <typename Op>
183+
class BitwiseOpNode : public BinaryOpNode<Op> {
183184
public:
184-
And(const Expr* lhs, const Expr* rhs)
185-
: BinaryOpNode(lhs, rhs, IRNodeType::kAnd) {
186-
if (!lhs->dtype().is_integral()) {
185+
BitwiseOpNode(const Expr* lhs, const Expr* rhs, IRNodeType type)
186+
: BinaryOpNode<Op>(lhs, rhs, type) {}
187+
188+
static ExprHandle make(const ExprHandle& lhs, const ExprHandle& rhs) {
189+
if (!lhs.dtype().is_integral()) {
187190
throw unsupported_dtype();
188191
}
189-
if (lhs->dtype() != rhs->dtype()) {
190-
throw malformed_input("bad dtype in And");
192+
if (lhs.dtype() != rhs.dtype()) {
193+
throw malformed_input("lhs/rhs dtype mismatch");
191194
}
195+
return BinaryOpNode<Op>::make(lhs, rhs);
192196
}
193197
};
194198

195-
class Or : public BinaryOpNode<Or> {
199+
class And : public BitwiseOpNode<And> {
200+
public:
201+
And(const Expr* lhs, const Expr* rhs)
202+
: BitwiseOpNode(lhs, rhs, IRNodeType::kAnd) {}
203+
};
204+
205+
class Or : public BitwiseOpNode<Or> {
196206
public:
197207
Or(const Expr* lhs, const Expr* rhs)
198-
: BinaryOpNode(lhs, rhs, IRNodeType::kOr) {
199-
if (!lhs->dtype().is_integral()) {
200-
throw unsupported_dtype();
201-
}
202-
if (lhs->dtype() != rhs->dtype()) {
203-
throw malformed_input("bad dtype in Or");
204-
}
205-
}
208+
: BitwiseOpNode(lhs, rhs, IRNodeType::kOr) {}
206209
};
207210

208-
class Xor : public BinaryOpNode<Xor> {
211+
class Xor : public BitwiseOpNode<Xor> {
209212
public:
210213
Xor(const Expr* lhs, const Expr* rhs)
211-
: BinaryOpNode(lhs, rhs, IRNodeType::kXor) {
212-
if (!lhs->dtype().is_integral()) {
213-
throw unsupported_dtype();
214-
}
215-
if (lhs->dtype() != rhs->dtype()) {
216-
throw malformed_input("bad dtype in Xor");
217-
}
218-
}
214+
: BitwiseOpNode(lhs, rhs, IRNodeType::kXor) {}
219215
};
220216

221-
class Lshift : public BinaryOpNode<Lshift> {
217+
class Lshift : public BitwiseOpNode<Lshift> {
222218
public:
223219
Lshift(const Expr* lhs, const Expr* rhs)
224-
: BinaryOpNode(lhs, rhs, IRNodeType::kLshift) {
225-
if (lhs->dtype().scalar_type() != ScalarType::Int) {
226-
throw unsupported_dtype();
227-
}
228-
if (lhs->dtype() != rhs->dtype()) {
229-
throw malformed_input("bad dtype in Lshift");
230-
}
231-
}
220+
: BitwiseOpNode(lhs, rhs, IRNodeType::kLshift) {}
232221
};
233222

234-
class Rshift : public BinaryOpNode<Rshift> {
223+
class Rshift : public BitwiseOpNode<Rshift> {
235224
public:
236225
Rshift(const Expr* lhs, const Expr* rhs)
237-
: BinaryOpNode(lhs, rhs, IRNodeType::kRshift) {
238-
if (lhs->dtype().scalar_type() != ScalarType::Int) {
239-
throw unsupported_dtype();
240-
}
241-
if (lhs->dtype() != rhs->dtype()) {
242-
throw malformed_input("bad dtype in Rshift");
243-
}
244-
}
226+
: BitwiseOpNode(lhs, rhs, IRNodeType::kRshift) {}
245227
};
246228

247229
class Max : public BinaryOpNode<Max> {

0 commit comments

Comments
 (0)