Skip to content

Commit cd3092e

Browse files
authored
Merge pull request #54 from JDAI-CV/fix_onnx2bnn_bn
Refuse the case that bin conv is not followed by a bn, fuse conv bias into bn
2 parents bcb72c3 + 2d29e6c commit cd3092e

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

tools/onnx2bnn/OnnxConverter.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,17 @@ std::vector<std::string> OnnxConverter::Convert(
289289
expected_binary_conv_outputs.end());
290290
if (binary_conv) {
291291
binary_conv_outputs.push_back(node.output(0));
292+
bool precede_bn = false;
293+
for (const auto &node2 : model_proto_.graph().node()) {
294+
if (node2.op_type() == "BatchNormalization" &&
295+
node2.input(0) == node.output(0)) {
296+
precede_bn = true;
297+
break;
298+
}
299+
}
300+
if (!precede_bn) {
301+
throw std::invalid_argument("Binary convolutions should precede BatchNorm");
302+
}
292303
}
293304
AddConv(m(node.input(0)), strides, pads, dilations, group,
294305
ori_weight_name, bias_name, m(node.output(0)), binary_conv);
@@ -556,6 +567,13 @@ void OnnxConverter::CalculateCoeff(const ONNX_NAMESPACE::NodeProto &node,
556567
height *
557568
coeff_a_data[i];
558569
}
570+
if (node2.input_size() == 3) {
571+
const auto &bias = onnx_float_tensors_[node2.input(2)];
572+
573+
FORZ(i, coeff_b_data.size()) {
574+
coeff_b_data[i] += coeff_a_data[i] * bias.data[i];
575+
}
576+
}
559577
}
560578
{
561579
FORZ(i, coeff_a_data.size()) { coeff_a_data[i] *= -2; }

0 commit comments

Comments
 (0)