Skip to content

Comments

Support split#6

Merged
Honry merged 5 commits intoHonry:stable-diffusionfrom
BruceDai:support_split
May 17, 2023
Merged

Support split#6
Honry merged 5 commits intoHonry:stable-diffusionfrom
BruceDai:support_split

Conversation

@BruceDai
Copy link

@fdwr @Honry PTAL, thanks.

}
ORT_RETURN_IF_NOT(output_array["length"].as<int32_t>() == static_cast<int32_t>(num_outputs), "The size of outputs must be equal to 'num_outputs'.");
} else {
// w/o 'split' input for opset 13
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fdwr From the explanation of 'split' input of opset 13

split: Optional length of each output. Values should be >= 0.Sum of the values must be equal to the dim value at 'axis' specified.

so which is the default 'split'? Is split = [input_shape[axis]] for this case? that's, num_outputs = 1?
But I confused with the sample case of 1d_opset13, here split = [2, 2, 2] not split = [6]

node_input = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).astype(np.float32)

node = onnx.helper.make_node(
    "Split",
    inputs=["input"], # w/o optional 'split' input
    outputs=["output_1", "output_2", "output_3"],
    axis=0,
)

expected_outputs = [
    np.array([1.0, 2.0]).astype(np.float32),
    np.array([3.0, 4.0]).astype(np.float32),
    np.array([5.0, 6.0]).astype(np.float32),
]
expect(
    node,
    inputs=[node_input],
    outputs=expected_outputs,
    name="test_split_equal_parts_1d_opset13",
    opset_imports=[onnx.helper.make_opsetid("", 13)],
)

Any thoughts? Thanks.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me try running 1d_opset13 locally and comparing to the DML EP...

Copy link

@fdwr fdwr May 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, if the split lengths are not passed, then just divide the input size by the output count. See DML EP here:

https://github.com/microsoft/onnxruntime/blob/a7ad859e3ab60bddfcf2fefa96bfcb550f0fc04c/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp#L984-L989

So I'm guessing it would be:

const auto& output_defs = node.OutputDefs();
auto output_count = output_defs.size();
output_array = model_builder.GetBuilder().call<emscripten::val>("split", input, static_cast<int32_t>(output_count), options);

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Please take another look at second new commit.

Comment on lines 39 to 54
const auto& dims = tensor.dims();
if (dims.empty() || dims[0] == 0) {
LOGS(logger, VERBOSE) << "The shape cannot be empty.";
return false;
}
if (dims.size() != 1) {
LOGS(logger, VERBOSE) << "The shape must be 1D.";
return false;
}
if (tensor.data_type() != ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) {
LOGS(logger, VERBOSE) << "The shape element data type must be INT64.";
return false;
}
const int64_t* shape_data = reinterpret_cast<const int64_t*>(unpacked_tensor.data());
shape = std::vector<int64_t>{shape_data, shape_data + dims[0]};
return true;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const auto& dims = tensor.dims();
if (dims.empty() || dims[0] == 0) {
LOGS(logger, VERBOSE) << "The shape cannot be empty.";
return false;
}
if (dims.size() != 1) {
LOGS(logger, VERBOSE) << "The shape must be 1D.";
return false;
}
if (tensor.data_type() != ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) {
LOGS(logger, VERBOSE) << "The shape element data type must be INT64.";
return false;
}
const int64_t* shape_data = reinterpret_cast<const int64_t*>(unpacked_tensor.data());
shape = std::vector<int64_t>{shape_data, shape_data + dims[0]};
return true;
const auto& dims = tensor.dims();
if (dims.size() != 1) {
LOGS(logger, VERBOSE) << "The shape tensor must be 1D.";
return false;
}
int64_t rank = dims[0];
if (tensor.data_type() != ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) {
LOGS(logger, VERBOSE) << "The shape element data type must be INT64.";
return false;
}
const int64_t* shape_data = reinterpret_cast<const int64_t*>(unpacked_tensor.data());
shape.assign(shape_data, shape_data + rank};
return true;
  • If we compare dims size != 1 first, then the empty check is not needed.
  • Dimensions of size 0 are legal in ONNX and represent scalar values. The WebNN EP accepts them too (I have some test cases in SimpleWebNN.html).
  • Can use assign rather than create a new vector.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Please take another look at first new commit.

std::vector<int64_t> split;
const auto& initializers(model_builder.GetInitializerTensors());
const auto& split_tensor = *initializers.at(input_defs[1]->Name());
ORT_RETURN_IF_NOT(GetShapeByTensor(split_tensor, split, logger), "Cannot get split.");
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤔 It's not really reading a tensor "shape" here, but rather a series of split lengths. How about renaming it to something more generic like ReadInt64ArrayFromTensor?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, the DML EP has a similar function here: https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp#L132

The only difference is that it returns the int64 values already casted to int32, which might be useful for you since WebNN takes uint32 shapes, and then every place which calls this function won't need to separately call std::transform.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about renaming it to something more generic like ReadInt64ArrayFromTensor?

Thanks. Please take another look at first new commit. I optimized it as ReadIntArrayFrom1DTensor().

} else {
if (helper.HasAttr("num_outputs")) {
const int64_t num_outputs = helper.Get("num_outputs", 1);
if (input_shape[axis] % num_outputs == 0) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential division by zero. We should verify num_outputs > 0 and ORT_RETURN_IF_NOT if not.

ORT_RETURN_IF_NOT(output_array["length"].as<int32_t>() == static_cast<int32_t>(1), "The size of outputs must be equal to 1.");
}
}
for (int64_t i = 0; i < output_array["length"].as<int32_t>(); i++) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for (int64_t i = 0; i < output_array["length"].as<int32_t>(); i++) {
for (int64_t i = 0, count = output_array["length"].as<int32_t>(); i < count; i++) {

(minor) Tis nice to assign the count to a temporary, which allows easy inspection in the debugger and avoids re-evualing the keyed lookup and as each time.

LOGS(logger, VERBOSE) << "The split must be a constant initializer.";
return false;
}
// Values should be >= 0.Sum of the values must be equal to the dim value at 'axis' specified.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Values should be >= 0.Sum of the values must be equal to the dim value at 'axis' specified.
// Values should be >= 0. Sum of the values must be equal to the dim value at 'axis' specified.

}
} else {
const auto opset = node.SinceVersion();
if (opset == 18) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (opset == 18) {
if (opset >= 18) {

Just to make it slightly more future proof (because many times new versions of ops are registered that are functionally identical, and they only add more data types like bfloat16 or float8...).

Copy link

@fdwr fdwr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oof, that looks like it took hours. Some comments - hopefully that clarifies your question. Thanks Bruce ☺.

@BruceDai
Copy link
Author

@fdwr I've addressed your comments, please take another look. Thanks.

Comment on lines 63 to 70
if (std::is_same<T, int64_t>::value) {
array.assign(array_data, array_data + rank);
} else if (std::is_same<T, int32_t>::value) {
std::vector<int64_t> raw_array = std::vector<int64_t>{array_data, array_data + rank};
std::transform(raw_array.cbegin(), raw_array.cend(),
std::back_inserter(array),
[](int64_t dim) -> int32_t { return SafeInt<int32_t>(dim); });
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (std::is_same<T, int64_t>::value) {
array.assign(array_data, array_data + rank);
} else if (std::is_same<T, int32_t>::value) {
std::vector<int64_t> raw_array = std::vector<int64_t>{array_data, array_data + rank};
std::transform(raw_array.cbegin(), raw_array.cend(),
std::back_inserter(array),
[](int64_t dim) -> int32_t { return SafeInt<int32_t>(dim); });
}
if constexpr (std::is_same<T, int64_t>::value) {
array.assign(array_data, array_data + rank);
} else if constexpr (std::is_same<T, int32_t>::value) {
std::transform(array_data, array_data + rank,
std::back_inserter(array),
[](int64_t dim) -> int32_t { return SafeInt<int32_t>(dim); });
}
  • Since the template parameter is known at compile time, we should make the if statement constexpr.
  • No need to allocate a temporary vector, since we're just returning the final one.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated with your suggestions. Thanks!

Copy link

@fdwr fdwr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor perf comment. Otherwise looks good to me. Thanks for taking this one from me 😁.

Copy link

@fdwr fdwr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🙌

@Honry Honry merged commit 291e2a1 into Honry:stable-diffusion May 17, 2023
Honry added a commit that referenced this pull request May 17, 2023
This reverts commit 291e2a1.
{"ConvTranspose", "convTranspose2d"},
{"Concat", "concat"},
{"ArgMax", "argMax1"},
{"ArgMin", "argMin1"},
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, I merged too fast, please remove these debug code in a follow-up. Thanks!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed with new PR #8, PTAL, thanks.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ooh, I wondered about these, but didn't realize it was debug code.

Copy link
Owner

@Honry Honry May 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a tip used for fallbacking ops to CPU EP.

Honry pushed a commit that referenced this pull request Aug 28, 2023
### Description
Release OrtEnv before main function returns. Before this change, OrtEnv
is deleted when C/C++ runtime destructs all global variables in ONNX
Runtime's core framework.
The callstack is like this:
```
  * frame #0: 0x00007fffee39f5a6 libonnxruntime.so.1.16.0`onnxruntime::Environment::~Environment(this=0x00007fffee39fbf2) at environment.h:20:7
    frame #1: 0x00007fffee39f614 libonnxruntime.so.1.16.0`std::default_delete<onnxruntime::Environment>::operator()(this=0x00007ffff4c30e50, __ptr=0x0000000005404b00) const at unique_ptr.h:85:2
    frame #2: 0x00007fffee39edca libonnxruntime.so.1.16.0`std::unique_ptr<onnxruntime::Environment, std::default_delete<onnxruntime::Environment>>::~unique_ptr(this=0x5404b00) at unique_ptr.h:361:17
    frame #3: 0x00007fffee39e2ab libonnxruntime.so.1.16.0`OrtEnv::~OrtEnv(this=0x00007ffff4c30e50) at ort_env.cc:43:1
    frame #4: 0x00007fffee39fa96 libonnxruntime.so.1.16.0`std::default_delete<OrtEnv>::operator()(this=0x00007fffefff8f78, __ptr=0x00007ffff4c30e50) const at unique_ptr.h:85:2
    frame #5: 0x00007fffee39f394 libonnxruntime.so.1.16.0`std::unique_ptr<OrtEnv, std::default_delete<OrtEnv>>::~unique_ptr(this=0x7ffff4c30e50) at unique_ptr.h:361:17
    frame #6: 0x00007ffff78574b5 libc.so.6`__run_exit_handlers + 261
    frame #7: 0x00007ffff7857630 libc.so.6`exit + 32
    frame #8: 0x00007ffff783feb7 libc.so.6`__libc_start_call_main + 135
    frame #9: 0x00007ffff783ff60 libc.so.6`__libc_start_main@@GLIBC_2.34 + 128
    frame #10: 0x0000000000abbdee node`_start + 46
```
After this change, OrtEnv will be deleted before the main function
returns and nodejs is still alive.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants