Skip to content

Commit 38d660d

Browse files
committed
Symbolic function for torch.square
1 parent 5912316 commit 38d660d

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

test/onnx/test_pytorch_onnx_onnxruntime.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1293,6 +1293,14 @@ def forward(self, x):
12931293
dynamic_axes={'input_1': [0, 1, 2],
12941294
'output_1': [0, 1, 2]})
12951295

1296+
def test_square(self):
1297+
class Square(torch.nn.Module):
1298+
def forward(self, x):
1299+
return torch.square(x)
1300+
1301+
x = torch.randn(2, 3, 4)
1302+
self.run_test(Square(), x)
1303+
12961304
@skipIfUnsupportedMinOpsetVersion(9)
12971305
def test_arange_dynamic(self):
12981306
class ArangeModel(torch.nn.Module):

torch/onnx/symbolic_opset9.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,10 @@ def select(g, self, dim, index):
610610
return g.op("Gather", self, index, axis_i=dim)
611611

612612

613+
def square(g, self):
614+
return g.op("Mul", self, self)
615+
616+
613617
def squeeze(g, self, dim=None):
614618
if dim is None:
615619
return g.op("Squeeze", self)

0 commit comments

Comments
 (0)