Skip to content

Commit 9a41f44

Browse files
BowenBaofacebook-github-bot
authored andcommitted
Improve ONNX Loop export (#20445)
Summary: ~~This is work in progress due to its dependency on multiple pending PRs.~~ - [x] ONNX: Relax constraint on subgraph input/output type & shape check. onnx/onnx#2009 - [x] PyTorch: Add infra to test_pytorch_onnx_caffe2.py to test ScriptModule models. #20256 This PR should partially resolve #17531. However, ideally we shouldn't need to put cast(and reshape) node to help the conversion for loop condition. - Added cast node for condition values before entering loop node. The ONNX spec only accepts Bool type, while in PyTorch if the condition value is an output from other node it could potentially have any integral type. - Tidying up the exported ONNX loop subgraph input type & shape. According to ONNX spec, input "M" is exported as 0-d scalar tensor with type int64. input "Cond" is exported as incomplete tensor of type Bool without shape information. This is because through out the iteration, the rank of condition value is dynamic, either 0-d or 1-d, as long as it holds a single value. Pull Request resolved: #20445 Differential Revision: D15534188 Pulled By: houseroad fbshipit-source-id: d174e778529def05ee666afeee4b8fb27786e320
1 parent 4980b8b commit 9a41f44

File tree

3 files changed

+150
-9
lines changed

3 files changed

+150
-9
lines changed

test/onnx/test_pytorch_onnx_caffe2.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1747,6 +1747,94 @@ def forward(self, x):
17471747
x = torch.randn(1, 2, 3)
17481748
self.run_model_test(DropoutModel(), train=False, input=x, batch_size=BATCH_SIZE)
17491749

1750+
def test_while(self):
1751+
class WhileModel(torch.jit.ScriptModule):
1752+
@torch.jit.script_method
1753+
def forward(self, x):
1754+
a = 0
1755+
while a < 4:
1756+
a += 1
1757+
return x + a
1758+
1759+
model = WhileModel()
1760+
inputs = torch.zeros(1, 2, 3, dtype=torch.long)
1761+
outputs = model(inputs)
1762+
self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE,
1763+
example_outputs=(outputs,))
1764+
1765+
def test_while_cond(self):
1766+
class WhileModel(torch.jit.ScriptModule):
1767+
@torch.jit.script_method
1768+
def forward(self, x, a):
1769+
b = (a < 4)
1770+
while b:
1771+
a += b.to(torch.long)
1772+
b = (a < 4)
1773+
return x + a
1774+
1775+
model = WhileModel()
1776+
x = torch.zeros(1, 2, 3, dtype=torch.long)
1777+
a = torch.tensor([0], dtype=torch.long)
1778+
outputs = model(x, a)
1779+
self.run_model_test(model, train=False, input=(x, a), batch_size=BATCH_SIZE,
1780+
example_outputs=(outputs,))
1781+
1782+
def test_loop(self):
1783+
class LoopModel(torch.jit.ScriptModule):
1784+
@torch.jit.script_method
1785+
def forward(self, x):
1786+
for i in range(5):
1787+
x = x + i
1788+
return x
1789+
1790+
model = LoopModel()
1791+
inputs = torch.zeros(1, 2, 3, dtype=torch.long)
1792+
outputs = model(inputs)
1793+
self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE,
1794+
example_outputs=(outputs,))
1795+
1796+
def test_dynamic_loop(self):
1797+
class LoopModel(torch.jit.ScriptModule):
1798+
@torch.jit.script_method
1799+
def forward(self, x):
1800+
for i in range(x.size(2)):
1801+
x = x + i
1802+
return x
1803+
1804+
model = LoopModel()
1805+
inputs = torch.zeros(1, 2, 3, dtype=torch.long)
1806+
outputs = model(inputs)
1807+
self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE,
1808+
example_outputs=(outputs,))
1809+
1810+
def test_nested_loops(self):
1811+
class NestedLoopsModel(torch.jit.ScriptModule):
1812+
@torch.jit.script_method
1813+
def forward(self, x):
1814+
for i in range(5):
1815+
a = 0
1816+
while a < 4:
1817+
a += 1
1818+
for j in range(a):
1819+
x = x + j
1820+
x = x + a
1821+
return x
1822+
1823+
model = NestedLoopsModel()
1824+
inputs = torch.zeros(1, 2, 3, dtype=torch.long)
1825+
outputs = model(inputs)
1826+
self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE,
1827+
example_outputs=(outputs,))
1828+
1829+
def test_select(self):
1830+
class SelectModel(torch.nn.Module):
1831+
def forward(self, x):
1832+
return torch.select(x, 0, 1)
1833+
1834+
model = SelectModel()
1835+
inputs = torch.randn(3, 2, 1)
1836+
self.run_model_test(model, train=False, input=(inputs, ), batch_size=BATCH_SIZE)
1837+
17501838
# a bit of metaprogramming to set up all the rnn tests
17511839

17521840

torch/csrc/jit/export.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,8 @@ onnx::TensorProto_DataType ATenTypeToOnnxType(at::ScalarType at_type) {
181181
return onnx::TensorProto_DataType_INT32;
182182
case at::kLong:
183183
return onnx::TensorProto_DataType_INT64;
184+
case at::kBool:
185+
return onnx::TensorProto_DataType_BOOL;
184186
default:
185187
AT_ERROR("unexpected tensor scalar type");
186188
}
@@ -206,19 +208,20 @@ void EncoderBase::EncodeValueInfo(
206208
onnx::ValueInfoProto* v,
207209
const Value* n) {
208210
v->set_name(n->uniqueName());
209-
onnx::TypeProto* t = v->mutable_type();
210-
onnx::TypeProto_Tensor* tensor_type = t->mutable_tensor_type();
211-
212-
onnx::TensorShapeProto* shape = tensor_type->mutable_shape();
213211
if (CompleteTensorTypePtr node_type = n->type()->cast<CompleteTensorType>()) {
212+
onnx::TypeProto* t = v->mutable_type();
213+
onnx::TypeProto_Tensor* tensor_type = t->mutable_tensor_type();
214+
onnx::TensorShapeProto* shape = tensor_type->mutable_shape();
214215
const std::vector<std::int64_t>& sizes = node_type->sizes();
215216
for (size_t i = 0; i < sizes.size(); i++) {
216217
shape->add_dim();
217218
shape->mutable_dim(i)->set_dim_value(sizes[i]);
218219
}
219220
tensor_type->set_elem_type(ATenTypeToOnnxType(node_type->scalarType()));
220-
} else {
221-
tensor_type->set_elem_type(onnx::TensorProto_DataType_UNDEFINED);
221+
} else if (BoolTypePtr node_type = n->type()->cast<BoolType>()) {
222+
onnx::TypeProto* t = v->mutable_type();
223+
onnx::TypeProto_Tensor* tensor_type = t->mutable_tensor_type();
224+
tensor_type->set_elem_type(ATenTypeToOnnxType(at::kBool));
222225
}
223226
}
224227

torch/csrc/jit/passes/onnx/fixup_onnx_loop.cpp

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,62 @@
33
namespace torch {
44
namespace jit {
55

6+
namespace onnx{
7+
using namespace ::c10::onnx;
8+
}
9+
10+
Node* CreateCastToBoolNode(Value* val, Graph* graph) {
11+
Node* cast_node = graph->create(onnx::Cast);
12+
cast_node->addInput(val);
13+
cast_node->i_(attr::to, /*Bool*/9);
14+
return cast_node;
15+
}
16+
17+
Node* InsertCastForCond(Value* cond_val, Graph* graph, Node* consumer_node) {
18+
// prev: cond_val -> consumer_node
19+
// after: cond_val -> cast -> consumer_node
20+
// NOTE: The cast is required because operators like PyTorch Greater/Less
21+
// return tensor in type torch.uint8. However the type for condition
22+
// input in ONNX Loop must be bool.
23+
Node* cast_node = CreateCastToBoolNode(cond_val, graph);
24+
cast_node->insertBefore(consumer_node);
25+
26+
consumer_node->replaceInputWith(cond_val, cast_node->output());
27+
return cast_node;
28+
}
29+
30+
bool IsCondCastRequired(Value* cond_val) {
31+
const auto& type = cond_val->type();
32+
if (type->isSubclass(TypeKind::DimensionedTensorType)) {
33+
return type->expect<DimensionedTensorType>()->scalarType() != c10::kBool;
34+
}
35+
return !type->isSubclass(TypeKind::BoolType);
36+
}
37+
638
void FixupONNXLoops(Block* block) {
739
for (auto* node : block->nodes()) {
840
if (node->kind() == ::c10::onnx::Loop) {
9-
AT_ASSERT(node->blocks().size() == 1);
10-
auto* sub_block = node->blocks()[0];
11-
sub_block->insertInput(1, "cond");
41+
auto* loop_node = node;
42+
auto* graph = loop_node->owningGraph();
43+
44+
// add cast to condition input outside the loop.
45+
Value* cond_val = loop_node->inputs()[1];
46+
if (IsCondCastRequired(cond_val))
47+
InsertCastForCond(cond_val, graph, loop_node);
48+
49+
// Setup Loop input cond and i.
50+
TORCH_INTERNAL_ASSERT(loop_node->blocks().size() == 1);
51+
auto* sub_block = loop_node->blocks()[0];
52+
Value* cond = sub_block->insertInput(1, "cond");
53+
cond->setType(BoolType::create());
54+
55+
Value* i = sub_block->inputs()[0];
56+
i->setType(CompleteTensorType::fromNumberType(IntType::get()));
57+
58+
// add cast to condition input inside the loop.
59+
Value* next_cond_val = sub_block->outputs()[0];
60+
if (IsCondCastRequired(next_cond_val))
61+
InsertCastForCond(next_cond_val, graph, sub_block->return_node());
1262
}
1363
for (Block* block : node->blocks()) {
1464
FixupONNXLoops(block);

0 commit comments

Comments
 (0)