Conversation
| size_t element_size{0}; | ||
| switch (data_type) | ||
| { | ||
| case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: |
| emscripten::val view{emscripten::typed_memory_view(size / element_size, reinterpret_cast<const float*>(dest))}; | ||
| emscripten::val view = emscripten::val::undefined(); | ||
| emscripten::val desc = emscripten::val::object(); | ||
| if (element_size == 2) { |
There was a problem hiding this comment.
It's better to use data type as variable to allocate different types of buffer view.
| #endif | ||
| AddOperand(name, operand); | ||
| mem_persist_buffers_.push_back(std::move(persist_buffer)); | ||
| emscripten::val console = emscripten::val::global("console"); |
|
@huningxin, @fdwr, would you like to take a look at this PR? |
js/common/lib/tensor.ts
Outdated
| string: string; | ||
| bool: boolean; | ||
| float16: never; // hold on before we have a concret solution for float 16 | ||
| float16: number; // hold on before we have a concret solution for float 16 |
There was a problem hiding this comment.
| float16: number; // hold on before we have a concret solution for float 16 | |
| float16: number; // Keep using until we have a concrete solution for float16. |
(minor typo concret)
| auto data_type = tensor.data_type(); | ||
| if (data_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { | ||
| if (data_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 && | ||
| data_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { |
There was a problem hiding this comment.
Tis worth putting this check into a little shared helper, like IsSupportedDataType(), since I see it repeated 4 times.
| break; | ||
| case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: | ||
| wnn_outputs_.set(output, | ||
| emscripten::val::global("Float32Array").new_(static_cast<const int>(num_elements))); |
There was a problem hiding this comment.
Should this be static_cast<const float>( instead of int? Line 138 uses float.
There was a problem hiding this comment.
I think the num_elements should be used as is, because it is the result of casting to size_t and emscripten::val::global("Float32Array").new_(num_elements) should be fine.
There was a problem hiding this comment.
Yeah, that looks better. Granted, I wouldn't be too surprised if Float32Array's constructor actually took a float given all Javascript numbers are evidently float64's anyway 🙃.
| emscripten::val::global("Float32Array").new_(static_cast<const float>(num_elements))); | ||
| break; | ||
| default: | ||
| break; |
There was a problem hiding this comment.
I'd bail here on an unsupported data type rather than silently continuing, probably calling ORT_THROW, unless it's really okay to ignore that output. If not here, then there are some other places that throwing make sense if they fall into the default clause.
|
I am very glad to see this change to support float16. however I don't understand how user can use JS code to deal with f16 input/output. Do they rely on any 3rd party library to convert between a float16 value (represented by The JS API for type When it turns out to user using float16: It's also similar problem for model output - what does a user expect to do with a given Uint16Array in JS? |
@fs-eire: Note this is just into Wanming's private branch for now for demo purposes. I'm not that worried about connecting the data to JS since it can just be an |
|
p.s. Yulong: And once webmachinelearning/webnn#373 |
| switch (data_type) { | ||
| case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: | ||
| wnn_inputs_.set(input, | ||
| emscripten::val::global("Uint16Array").new_(static_cast<const uint16_t>(num_elements))); |
There was a problem hiding this comment.
Does emscripten::val::global("Uint16Array").new_() expect the num_elements in size_t? Because it is already casted in line 129, should it be used as is?
There was a problem hiding this comment.
Actually this won't accept size_t in new_() which will throw binding error, I will cast it to int32_t.
| break; | ||
| case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: | ||
| wnn_inputs_.set(input, | ||
| emscripten::val::global("Float32Array").new_(static_cast<const float>(num_elements))); |
There was a problem hiding this comment.
ditto, no cast needed for num_elements.
There was a problem hiding this comment.
Agreed. If this function takes an element count, then it should be size_t, not float.
| break; | ||
| case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: | ||
| wnn_outputs_.set(output, | ||
| emscripten::val::global("Float32Array").new_(static_cast<const int>(num_elements))); |
There was a problem hiding this comment.
I think the num_elements should be used as is, because it is the result of casting to size_t and emscripten::val::global("Float32Array").new_(num_elements) should be fine.
| switch (data_type) { | ||
| case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: | ||
| wnn_outputs_.set(output, | ||
| emscripten::val::global("Uint16Array").new_(static_cast<const uint16_t>(num_elements))); |
| emscripten::val desc = emscripten::val::object(); | ||
| switch (data_type) { | ||
| case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: | ||
| view = emscripten::val{emscripten::typed_memory_view(size / 2, |
| desc.set("type", emscripten::val("float16")); | ||
| break; | ||
| case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: | ||
| view = emscripten::val{emscripten::typed_memory_view(size / 4, |
Yes, I use a 3rd party library to make the convection. A simple test for fp16 and fp32 can be found in onnxruntime-web-simpletest. @fs-eire |
|
@huningxin, @fdwr, thanks for your review, @zesongw is OOO these two days, I've addressed your comments, PTAL again! thanks! |
### Description
Release OrtEnv before main function returns. Before this change, OrtEnv
is deleted when C/C++ runtime destructs all global variables in ONNX
Runtime's core framework.
The callstack is like this:
```
* frame #0: 0x00007fffee39f5a6 libonnxruntime.so.1.16.0`onnxruntime::Environment::~Environment(this=0x00007fffee39fbf2) at environment.h:20:7
frame #1: 0x00007fffee39f614 libonnxruntime.so.1.16.0`std::default_delete<onnxruntime::Environment>::operator()(this=0x00007ffff4c30e50, __ptr=0x0000000005404b00) const at unique_ptr.h:85:2
frame #2: 0x00007fffee39edca libonnxruntime.so.1.16.0`std::unique_ptr<onnxruntime::Environment, std::default_delete<onnxruntime::Environment>>::~unique_ptr(this=0x5404b00) at unique_ptr.h:361:17
frame #3: 0x00007fffee39e2ab libonnxruntime.so.1.16.0`OrtEnv::~OrtEnv(this=0x00007ffff4c30e50) at ort_env.cc:43:1
frame #4: 0x00007fffee39fa96 libonnxruntime.so.1.16.0`std::default_delete<OrtEnv>::operator()(this=0x00007fffefff8f78, __ptr=0x00007ffff4c30e50) const at unique_ptr.h:85:2
frame #5: 0x00007fffee39f394 libonnxruntime.so.1.16.0`std::unique_ptr<OrtEnv, std::default_delete<OrtEnv>>::~unique_ptr(this=0x7ffff4c30e50) at unique_ptr.h:361:17
frame #6: 0x00007ffff78574b5 libc.so.6`__run_exit_handlers + 261
frame #7: 0x00007ffff7857630 libc.so.6`exit + 32
frame #8: 0x00007ffff783feb7 libc.so.6`__libc_start_call_main + 135
frame #9: 0x00007ffff783ff60 libc.so.6`__libc_start_main@@GLIBC_2.34 + 128
frame #10: 0x0000000000abbdee node`_start + 46
```
After this change, OrtEnv will be deleted before the main function
returns and nodejs is still alive.
### Description Hello we(@lixing-star) are the developers of loongson team. We add 128 (lsx), 256 (lasx) vector optimization code for the loongarch architecture [100% tests passed, 0 tests failed out of 7](https://cloud.a-boat.cn:2021/api/public/dl/6831z1Bi?inline=true) ### Development Environments1 ``` CPU: Loongson-3C5000L uname -a: Linux localhost.localdomain 4.19.190-6.4.lns8.loongarch64 #1 SMP Thu Jul 14 12:08:04 CST 2022 loongarch64 loongarch64 loongarch64 GNU/Linux ``` ### LonngArch Documents - [LoongArch Reference Manual - Volume 1: Basic Architecture: This manual describes the basic part of the LoongArch architecture.](https://loongson.github.io/LoongArch-Documentation/LoongArch-Vol1-EN.html) - [LoongArch ELF psABI: This manual describes the LoongArch ELF psABI.](https://loongson.github.io/LoongArch-Documentation/LoongArch-ELF-ABI-EN.html) - [more](https://loongson.github.io/LoongArch-Documentation/README-EN.html)
### Description Add [Lean Attention](https://arxiv.org/abs/2405.10480) and the integration with MultiHeadAttention operator for LLM in GPU. LeanAttention speeds up self-attention for the token-generation phase (decode-phase) of decoder-only transformer models, especially on long context lengths. - [x] Initial implementation of Lean Attention (by Srikant Bharadwaj) - [x] Integration with MultiHeadAttention operator - [x] Add parity tests - [x] Add benchmark #### Implementation Details (1) Lean Attention is enabled in build for Linux, and disabled for Windows (2) Lean Attention is disabled by default. Need enable it through cuda provider option sdpa_kernel, or use environment variable `ORT_ENABLE_LEAN_ATTENTION=1` (3) It only works for token-generation (sequence_length==1, past_sequence_length > 0). (4) Like flash attention, it only works in Ampere or newer GPU. We can revisit #1 and #2 after comparing with DecoderMaskedMultiHeadAttention and XQA kernels. #### Benchmark ``` cd onnxruntime/test/python/transformers /bin/bash benchmark_mha.sh lean ``` Example outputs in H100: Note that past and present does not share buffer for MHA for now, so we can see low tflops. The relative ratio will change after buffer sharing is enabled. But we expect that the order (kernel A is faster than B) will remain the same after buffer sharing is enabled. Note that common settings `sequence_length=1; causal=True;attn_bias=None;cuda_graph=False` are not shown in the below table. batch_size | past_sequence_length | num_heads | head_size | average_latency | tflops | kernel -- | -- | -- | -- | -- | -- | -- 1 | 512 | 16 | 64 | 0.000059 | 0.0178 | ort:flash 1 | 512 | 16 | 64 | 0.000068 | 0.0155 | ort:efficient 1 | 512 | 16 | 64 | 0.000065 | 0.0161 | ort:math 1 | 512 | 16 | 64 | 0.000060 | 0.0176 | ort:lean 1 | 512 | 32 | 128 | 0.000062 | 0.0674 | ort:flash 1 | 512 | 32 | 128 | 0.000064 | 0.0661 | ort:efficient 1 | 512 | 32 | 128 | 0.000067 | 0.0625 | ort:math 1 | 512 | 32 | 128 | 0.000062 | 0.0678 | ort:lean 1 | 1024 | 16 | 64 | 0.000061 | 0.0345 | ort:flash 1 | 1024 | 16 | 64 | 0.000086 | 0.0244 | ort:efficient 1 | 1024 | 16 | 64 | 0.000065 | 0.0322 | ort:math 1 | 1024 | 16 | 64 | 0.000063 | 0.0332 | ort:lean 1 | 1024 | 32 | 128 | 0.000075 | 0.1125 | ort:flash 1 | 1024 | 32 | 128 | 0.000088 | 0.0951 | ort:efficient 1 | 1024 | 32 | 128 | 0.000079 | 0.1068 | ort:math 1 | 1024 | 32 | 128 | 0.000072 | 0.1171 | ort:lean 1 | 2048 | 16 | 64 | 0.000069 | 0.0606 | ort:flash 1 | 2048 | 16 | 64 | 0.000125 | 0.0336 | ort:efficient 1 | 2048 | 16 | 64 | 0.000064 | 0.0655 | ort:lean 1 | 2048 | 32 | 128 | 0.000098 | 0.1720 | ort:flash 1 | 2048 | 32 | 128 | 0.000132 | 0.1270 | ort:efficient 1 | 2048 | 32 | 128 | 0.000092 | 0.1828 | ort:lean 1 | 4096 | 16 | 64 | 0.000076 | 0.1097 | ort:flash 1 | 4096 | 16 | 64 | 0.000207 | 0.0406 | ort:efficient 1 | 4096 | 16 | 64 | 0.000069 | 0.1209 | ort:lean 1 | 4096 | 32 | 128 | 0.000140 | 0.2394 | ort:flash 1 | 4096 | 32 | 128 | 0.000213 | 0.1575 | ort:efficient 1 | 4096 | 32 | 128 | 0.000139 | 0.2419 | ort:lean 1 | 8192 | 16 | 64 | 0.000104 | 0.1609 | ort:flash 1 | 8192 | 16 | 64 | 0.000392 | 0.0428 | ort:efficient 1 | 8192 | 16 | 64 | 0.000093 | 0.1809 | ort:lean 1 | 8192 | 32 | 128 | 0.000212 | 0.3160 | ort:flash 1 | 8192 | 32 | 128 | 0.000360 | 0.1866 | ort:efficient 1 | 8192 | 32 | 128 | 0.000212 | 0.3162 | ort:lean 1 | 16384 | 16 | 64 | 0.000139 | 0.2410 | ort:flash 1 | 16384 | 16 | 64 | 0.000731 | 0.0459 | ort:efficient 1 | 16384 | 16 | 64 | 0.000136 | 0.2465 | ort:lean 1 | 16384 | 32 | 128 | 0.000361 | 0.3722 | ort:flash 1 | 16384 | 32 | 128 | 0.000667 | 0.2014 | ort:efficient 1 | 16384 | 32 | 128 | 0.000357 | 0.3765 | ort:lean 1 | 32768 | 16 | 64 | 0.000210 | 0.3194 | ort:flash 1 | 32768 | 16 | 64 | 0.001428 | 0.0470 | ort:efficient 1 | 32768 | 16 | 64 | 0.000209 | 0.3211 | ort:lean 1 | 32768 | 32 | 128 | 0.000659 | 0.4074 | ort:flash 1 | 32768 | 32 | 128 | 0.001270 | 0.2114 | ort:efficient 1 | 32768 | 32 | 128 | 0.000651 | 0.4123 | ort:lean 1 | 65536 | 16 | 64 | 0.000355 | 0.3785 | ort:flash 1 | 65536 | 16 | 64 | 0.002736 | 0.0491 | ort:efficient 1 | 65536 | 16 | 64 | 0.000349 | 0.3845 | ort:lean 1 | 65536 | 32 | 128 | 0.001251 | 0.4290 | ort:flash 1 | 65536 | 32 | 128 | 0.002480 | 0.2165 | ort:efficient 1 | 65536 | 32 | 128 | 0.001239 | 0.4333 | ort:lean 4 | 512 | 16 | 64 | 0.000063 | 0.0665 | ort:flash 4 | 512 | 16 | 64 | 0.000069 | 0.0607 | ort:efficient 4 | 512 | 16 | 64 | 0.000066 | 0.0634 | ort:math 4 | 512 | 16 | 64 | 0.000062 | 0.0674 | ort:lean 4 | 512 | 32 | 128 | 0.000100 | 0.1677 | ort:flash 4 | 512 | 32 | 128 | 0.000099 | 0.1703 | ort:efficient 4 | 512 | 32 | 128 | 0.000108 | 0.1557 | ort:math 4 | 512 | 32 | 128 | 0.000092 | 0.1818 | ort:lean 4 | 1024 | 16 | 64 | 0.000077 | 0.1094 | ort:flash 4 | 1024 | 16 | 64 | 0.000099 | 0.0850 | ort:efficient 4 | 1024 | 16 | 64 | 0.000081 | 0.1038 | ort:math 4 | 1024 | 16 | 64 | 0.000072 | 0.1161 | ort:lean 4 | 1024 | 32 | 128 | 0.000143 | 0.2343 | ort:flash 4 | 1024 | 32 | 128 | 0.000137 | 0.2447 | ort:efficient 4 | 1024 | 32 | 128 | 0.000150 | 0.2245 | ort:math 4 | 1024 | 32 | 128 | 0.000135 | 0.2496 | ort:lean 4 | 2048 | 16 | 64 | 0.000096 | 0.1757 | ort:flash 4 | 2048 | 16 | 64 | 0.000156 | 0.1078 | ort:efficient 4 | 2048 | 16 | 64 | 0.000089 | 0.1892 | ort:lean 4 | 2048 | 32 | 128 | 0.000223 | 0.3010 | ort:flash 4 | 2048 | 32 | 128 | 0.000217 | 0.3101 | ort:efficient 4 | 2048 | 32 | 128 | 0.000209 | 0.3209 | ort:lean 4 | 4096 | 16 | 64 | 0.000137 | 0.2448 | ort:flash 4 | 4096 | 16 | 64 | 0.000256 | 0.1312 | ort:efficient 4 | 4096 | 16 | 64 | 0.000133 | 0.2530 | ort:lean 4 | 4096 | 32 | 128 | 0.000389 | 0.3450 | ort:flash 4 | 4096 | 32 | 128 | 0.000376 | 0.3574 | ort:efficient 4 | 4096 | 32 | 128 | 0.000354 | 0.3794 | ort:lean 4 | 8192 | 16 | 64 | 0.000210 | 0.3198 | ort:flash 4 | 8192 | 16 | 64 | 0.000453 | 0.1480 | ort:efficient 4 | 8192 | 16 | 64 | 0.000206 | 0.3260 | ort:lean 4 | 8192 | 32 | 128 | 0.000725 | 0.3705 | ort:flash 4 | 8192 | 32 | 128 | 0.000693 | 0.3874 | ort:efficient 4 | 8192 | 32 | 128 | 0.000653 | 0.4114 | ort:lean 4 | 16384 | 16 | 64 | 0.000355 | 0.3782 | ort:flash 4 | 16384 | 16 | 64 | 0.000849 | 0.1581 | ort:efficient 4 | 16384 | 16 | 64 | 0.000346 | 0.3874 | ort:lean 4 | 16384 | 32 | 128 | 0.001395 | 0.3848 | ort:flash 4 | 16384 | 32 | 128 | 0.001337 | 0.4017 | ort:efficient 4 | 16384 | 32 | 128 | 0.001252 | 0.4288 | ort:lean 4 | 32768 | 16 | 64 | 0.000647 | 0.4146 | ort:flash 4 | 32768 | 16 | 64 | 0.001649 | 0.1628 | ort:efficient 4 | 32768 | 16 | 64 | 0.000639 | 0.4204 | ort:lean 4 | 32768 | 32 | 128 | 0.002721 | 0.3947 | ort:flash 4 | 32768 | 32 | 128 | 0.002601 | 0.4128 | ort:efficient 4 | 32768 | 32 | 128 | 0.002434 | 0.4411 | ort:lean 4 | 65536 | 16 | 64 | 0.001231 | 0.4361 | ort:flash 4 | 65536 | 16 | 64 | 0.003238 | 0.1658 | ort:efficient 4 | 65536 | 16 | 64 | 0.001217 | 0.4412 | ort:lean 4 | 65536 | 32 | 128 | 0.005357 | 0.4009 | ort:flash 4 | 65536 | 32 | 128 | 0.005118 | 0.4196 | ort:efficient 4 | 65536 | 32 | 128 | 0.004781 | 0.4492 | ort:lean 16 | 512 | 16 | 64 | 0.000098 | 0.1724 | ort:flash 16 | 512 | 16 | 64 | 0.000104 | 0.1616 | ort:efficient 16 | 512 | 16 | 64 | 0.000118 | 0.1420 | ort:math 16 | 512 | 16 | 64 | 0.000087 | 0.1926 | ort:lean 16 | 512 | 32 | 128 | 0.000220 | 0.3062 | ort:flash 16 | 512 | 32 | 128 | 0.000208 | 0.3237 | ort:efficient 16 | 512 | 32 | 128 | 0.000237 | 0.2838 | ort:math 16 | 512 | 32 | 128 | 0.000209 | 0.3216 | ort:lean 16 | 1024 | 16 | 64 | 0.000136 | 0.2465 | ort:flash 16 | 1024 | 16 | 64 | 0.000150 | 0.2235 | ort:efficient 16 | 1024 | 16 | 64 | 0.000148 | 0.2266 | ort:math 16 | 1024 | 16 | 64 | 0.000129 | 0.2611 | ort:lean 16 | 1024 | 32 | 128 | 0.000367 | 0.3663 | ort:flash 16 | 1024 | 32 | 128 | 0.000351 | 0.3829 | ort:efficient 16 | 1024 | 32 | 128 | 0.000400 | 0.3357 | ort:math 16 | 1024 | 32 | 128 | 0.000349 | 0.3853 | ort:lean 16 | 2048 | 16 | 64 | 0.000209 | 0.3206 | ort:flash 16 | 2048 | 16 | 64 | 0.000243 | 0.2762 | ort:efficient 16 | 2048 | 16 | 64 | 0.000201 | 0.3338 | ort:lean 16 | 2048 | 32 | 128 | 0.000671 | 0.4002 | ort:flash 16 | 2048 | 32 | 128 | 0.000645 | 0.4163 | ort:efficient 16 | 2048 | 32 | 128 | 0.000642 | 0.4185 | ort:lean 16 | 4096 | 16 | 64 | 0.000360 | 0.3732 | ort:flash 16 | 4096 | 16 | 64 | 0.000425 | 0.3162 | ort:efficient 16 | 4096 | 16 | 64 | 0.000341 | 0.3933 | ort:lean 16 | 4096 | 32 | 128 | 0.001292 | 0.4156 | ort:flash 16 | 4096 | 32 | 128 | 0.001251 | 0.4291 | ort:efficient 16 | 4096 | 32 | 128 | 0.001241 | 0.4327 | ort:lean 16 | 8192 | 16 | 64 | 0.000666 | 0.4030 | ort:flash 16 | 8192 | 16 | 64 | 0.000804 | 0.3339 | ort:efficient 16 | 8192 | 16 | 64 | 0.000627 | 0.4283 | ort:lean 16 | 8192 | 32 | 128 | 0.002541 | 0.4226 | ort:flash 16 | 8192 | 32 | 128 | 0.002454 | 0.4376 | ort:efficient 16 | 8192 | 32 | 128 | 0.002438 | 0.4405 | ort:lean 16 | 16384 | 16 | 64 | 0.001292 | 0.4156 | ort:flash 16 | 16384 | 16 | 64 | 0.001571 | 0.3417 | ort:efficient 16 | 16384 | 16 | 64 | 0.001217 | 0.4411 | ort:lean 16 | 16384 | 32 | 128 | 0.005042 | 0.4260 | ort:flash 16 | 16384 | 32 | 128 | 0.004859 | 0.4420 | ort:efficient 16 | 16384 | 32 | 128 | 0.004827 | 0.4449 | ort:lean 16 | 32768 | 16 | 64 | 0.002537 | 0.4233 | ort:flash 16 | 32768 | 16 | 64 | 0.003103 | 0.3461 | ort:efficient 16 | 32768 | 16 | 64 | 0.002385 | 0.4501 | ort:lean 16 | 32768 | 32 | 128 | 0.009961 | 0.4312 | ort:flash 16 | 32768 | 32 | 128 | 0.009605 | 0.4472 | ort:efficient 16 | 32768 | 32 | 128 | 0.009524 | 0.4510 | ort:lean 16 | 65536 | 16 | 64 | 0.005019 | 0.4279 | ort:flash 16 | 65536 | 16 | 64 | 0.006133 | 0.3502 | ort:efficient 16 | 65536 | 16 | 64 | 0.004703 | 0.4566 | ort:lean 16 | 65536 | 32 | 128 | 0.019746 | 0.4350 | ort:flash 16 | 65536 | 32 | 128 | 0.019027 | 0.4515 | ort:efficient 16 | 65536 | 32 | 128 | 0.018864 | 0.4554 | ort:lean ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
Enable float16: Use Uint16Array in JS to pass float16 to CPP and allocate corresponding memory.