Skip to content

Commit 5ac3df7

Browse files
jerryzh168facebook-github-bot
authored andcommitted
Minor fix and turn off fold_convbn (#27403)
Summary: Pull Request resolved: #27403 In fold_convbn pass, we need to recompute the parameter(weight, bias) for conv, update the attribute of conv and update the access of bias in conv because if the original conv have no bias, the `self.bias` access will be inline and replaced by Constant node `None = prim::Constant()`, we need to update this to use `GetAttr[name="bias"]` to make this work. But there is also some work going on the handle constants, so we'll fix this pass after that is done. Test Plan: . Imported from OSS Differential Revision: D18182918 fbshipit-source-id: bba510bc41ab58e0eb76f7b77335b6e3ffe2862d
1 parent d690521 commit 5ac3df7

File tree

3 files changed

+20
-3
lines changed

3 files changed

+20
-3
lines changed

test/test_jit.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1301,6 +1301,11 @@ def test_quant_fusion(self):
13011301
FileCheck().run(input_str, graph)
13021302

13031303
@_tmp_donotuse_dont_inline_everything
1304+
@unittest.skip("Temporarily turn off fold_convbn tests until \
1305+
constants are handled properly, this test should not be passing \
1306+
because bias is not handled properly, the reason is passes is because the \
1307+
parameters of bn are initialized to default values and the recomputed bias \
1308+
for conv is zero, which is equivalent to no bias")
13041309
def test_foldbn_trivial(self):
13051310
# Test trivial case
13061311
class TestModule(torch.nn.Module):
@@ -1338,6 +1343,8 @@ def forward(self, x):
13381343
self.assertAlmostEqual(eager(x), scripted(x), delta=1e-5)
13391344

13401345
@_tmp_donotuse_dont_inline_everything
1346+
@unittest.skip("Temporarily turn off fold_convbn tests until \
1347+
constants are handled properly")
13411348
def test_foldbn_trivial_nobias(self):
13421349
# Test trivial case
13431350
class TestModule(torch.nn.Module):
@@ -1375,6 +1382,8 @@ def forward(self, x):
13751382
self.assertAlmostEqual(eager(x), scripted(x), delta=1e-5)
13761383

13771384
@_tmp_donotuse_dont_inline_everything
1385+
@unittest.skip("Temporarily turn off fold_convbn tests until \
1386+
constants are handled properly")
13781387
def test_foldbn_in_submodule(self):
13791388
# Test that we find Conv-BN patterns in submodules
13801389
class SubModule(torch.nn.Module):

torch/csrc/jit/passes/quantization.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -883,8 +883,12 @@ graph(%self, %x):
883883
GRAPH_UPDATE("Deleting ", *matched_bn);
884884

885885
auto new_w_b = computeUpdatedConvWeightAndBias(params);
886-
params.conv_w.set_data(std::get<0>(new_w_b));
887-
params.conv_b.set_data(std::get<1>(new_w_b));
886+
conv_submodule.set_parameter("weight", std::get<0>(new_w_b));
887+
if (conv_submodule.find_parameter("bias")) {
888+
conv_submodule.set_parameter("bias", std::get<1>(new_w_b));
889+
} else {
890+
conv_submodule.register_parameter("bias", std::get<1>(new_w_b), false);
891+
}
888892
}
889893

890894
// Perform planned rewritings

torch/quantization/_quantize_script.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,11 @@ def quantize_script(model, qconfig_dict, run_fn, run_args, inplace=False):
132132
if not inplace:
133133
model = model.copy()
134134
scripted_qconfig_dict = {k: script_qconfig(v) for k, v in qconfig_dict.items()}
135-
torch._C._jit_pass_fold_convbn(model._c)
135+
# We are not going to run fold_convbn pass right now
136+
# since it is not able to work correctly, we will
137+
# revisit after constants is properly handled in
138+
# JIT
139+
# torch._C._jit_pass_fold_convbn(model._c)
136140
prepare_script(model, scripted_qconfig_dict, True)
137141
run_fn(model._c._get_method('forward'), *run_args)
138142
# When we mutating graph we didn't create a new ClassType

0 commit comments

Comments
 (0)