[java] Sparse tensor support#10653
Conversation
| * <p>Sparse tensors support a variety of formats, and the {@link #getValue} method returns a | ||
| * different static inner class representing each type. | ||
| */ | ||
| public final class OnnxSparseTensor extends OnnxTensorLike { |
There was a problem hiding this comment.
The Onnx... classes map onto concepts in ONNX, and the Ort... classes map onto ONNX Runtime specific concepts. I admit the line is a bit fuzzy here because ONNX doesn't define sparse tensor formats, but it does define generic sparse tensors as a thing.
| } | ||
|
|
||
| /** Abstract base class for Java sparse tensors */ | ||
| private abstract static class BaseSparseTensor<T extends Buffer> implements SparseTensor<T> { |
There was a problem hiding this comment.
Just curious, why do we need both base class (which is not used fully IMHO) and the interface? Can base class serve as an interface?
It seems we get a huge hierarchy just to review.
OnnxValue -> OnnxTensorLike -> OnnxSparseTensor (essentially OnnxValue anyway) then that contains SparseTensor interface -> BaseSparsetensor -> FormatSpecific tensors which include common things that should be in a base class.
Then OnnxSparseTensor contains two things separate, one is a SparseTensor and a native handle live its own life. The issue I see here as a coder, is that I have an OnnxSparseTensor AutoClosable. I can get SparseTensor handle and the OnnxSparseTensor can disappear from underneath. Whereas if SpraseTensor was AutoClosable it would never had a chance of happening.
There was a problem hiding this comment.
The interface + abstract base class is a fairly standard idiom in Java, though I guess it doesn't offer much flexibility here given I intend to seal the interface so it only has the provided implementations. I can refactor it into a public abstract class.
There was a problem hiding this comment.
SparseTensor and OnnxSparseTensor represent different objects with different uses. The SparseTensor is purely Java side, it holds no references to ORT objects or data and is used as the way to prepare data in Java for supplying into ORT, or for working on the output if a sparse tensor is produced from a run call. SparseTensor and its implementations could easily live outside the OnnxSparseTensor class, and if Java had a similar package to scipy which provided sparse tensors it wouldn't exist at all we'd just depend on the external definitions. OnnxSparseTensor holds the references to the ORT objects and memory, but the data only be accessed from Java by making a SparseTensor and copying the data out because there aren't any direct accessors into the ORT sparse tensor object. The OnnxSparseTensor has a reference to the buffers that came from the SparseTensor to make sure they live longer than the OnnxSparseTensor, though I'm not actually sure now if it's copying the memory into the OrtValue or referencing it.
OnnxTensorLike is a made up class which exists because both OnnxTensor and OnnxSparseTensor can be inputs to a run call, but OnnxSequence and OnnxMap cannot. So I need something in the type hierarchy that lets me distinguish them statically.
There was a problem hiding this comment.
Ooops, you're right, I'm missing the copy in OnnxSparseTensor.getDataBuffer(), OnnxSparseTensor.getIndicesBuffer() and OnnxSparseTensor.getInnerIndicesBuffer(). I'll fix that.
There was a problem hiding this comment.
I've fixed the get methods so they copy the buffer. Once this has merged in I plan to add methods to both OnnxSparseTensor and OnnxTensor which let users copy the buffer into one they supply, and hopefully that will let me refactor and reduce the get logic a little (by pushing it into OnnxTensorLike).
| } | ||
|
|
||
| OrtTensorTypeAndShapeInfo* info; | ||
| checkOrtStatus(jniEnv,api,api->GetSparseTensorIndicesTypeShape((OrtValue*) handle, indicesFormat, &info)); |
There was a problem hiding this comment.
All the JNI code needs revising as the exception isn't necessarily propagated until the next call back into the JVM. So checkOrtStatus needs to be modified to return something that can be checked and all the places where it is called need to be modified with extra control flow to allow an early return, potentially several layers deep. I've not done that in this PR as it's going to be a large and confusing change which I'd like to keep separate from this one.
I think in most cases that a failure in an ORT call will cause the remainder of the calls in that method to also fail (e.g. by returning a null pointer), and so in practice not much else happens in that method when a call fails. However that doesn't guarantee that we free any memory allocated before the call failed, which will make the abort control flow much trickier.
There was a problem hiding this comment.
I think we are going to SEGSEV if OrtTensorTypeAndShapeInfo is some stray pointer that was not properly populated because something errored out. The only thing we can do is to check for nullptr which we do not do since the pre-req is a valid instance. And if it is not nullptr then we can only assume it is valid. Same applies to all other APIs.
There was a problem hiding this comment.
Ok, I'll work on a separate PR for the native fixes.
|
One question. Does this code still copy memory for input/output? We should not copy at least for input and if so I believe there is a way to prevent that on output as well. |
|
It does not copy on input if the buffers are direct. If they aren't direct then they need to be copied into direct ones, which are passed in to the OrtValue. Then the For outputs we could wrap the bare pointers in direct byte buffers, but if those buffers are exposed out of the |
|
Well, we are doing just find in C#, holding to the native pointers received from ORT and destroy then when IDisposable is invoked, similar to AutoClosable. In reply to: 1054582742 |
|
The direct byte buffer abstraction in Java is different from the native pointer abstraction in C#, and I think that there is not a good solution here for outputs that has a zero copy output path and that also exposes the buffers to users. We could alternatively expose direct access into the OrtValue by implementing chunks of the buffer API on top of Pre-allocation could work, but that would require a rewrite of all the output processing in both the Java and the native code. I'd missed the update to In reply to: 1054714133 |
|
/azp run Linux CPU CI Pipeline, Linux CPU Minimal Build E2E CI Pipeline, Linux CPU x64 NoContribops CI Pipeline, Linux GPU CI Pipeline, Linux GPU TensorRT CI Pipeline, Linux Nuphar CI Pipeline, Linux OpenVINO CI Pipeline, MacOS CI Pipeline |
|
/azp run MacOS NoContribops CI Pipeline, Windows CPU CI Pipeline, Windows GPU CI Pipeline, Windows GPU TensorRT CI Pipeline, Windows WebAssembly CI Pipeline, orttraining-amd-gpu-ci-pipeline, orttraining-linux-ci-pipeline, orttraining-linux-gpu-ci-pipeline, orttraining-ortmodule-distributed |
|
Azure Pipelines successfully started running 7 pipeline(s). |
1 similar comment
|
Azure Pipelines successfully started running 7 pipeline(s). |
|
I fixed the build error. Clang must do different exhaustiveness checking on enums than gcc or msvc as the latter two complained I wasn't assigning a value in a switch which covered all the enum cases. I added a default branch which lines up with |
Yes, it throws and it does it soon. The API docs and the parameter annotations still need work. This feature was available for a long time. Let's fix error reporting first and then we can work on copy elimination. Both are very important. Not so much right now for SparseTensors, but for dense Tensors in the first place. |
|
/azp run Linux CPU CI Pipeline, Linux CPU Minimal Build E2E CI Pipeline, Linux CPU x64 NoContribops CI Pipeline, Linux GPU CI Pipeline, Linux GPU TensorRT CI Pipeline, Linux Nuphar CI Pipeline, Linux OpenVINO CI Pipeline, MacOS CI Pipeline |
|
/azp run MacOS NoContribops CI Pipeline, Windows CPU CI Pipeline, Windows GPU CI Pipeline, Windows GPU TensorRT CI Pipeline, Windows WebAssembly CI Pipeline, orttraining-amd-gpu-ci-pipeline, orttraining-linux-ci-pipeline, orttraining-linux-gpu-ci-pipeline, orttraining-ortmodule-distributed |
|
Azure Pipelines successfully started running 7 pipeline(s). |
1 similar comment
|
Azure Pipelines successfully started running 7 pipeline(s). |
|
/azp run onnxruntime-binary-size-checks-ci-pipeline, onnxruntime-python-checks-ci-pipeline, ONNX Runtime Web CI Pipeline |
|
Azure Pipelines successfully started running 3 pipeline(s). |
|
Can we get this integrated now the 1.11 release has happened? I can rebase it on master if necessary, then it'll be easier to work on the native binding improvements that we discussed. |
I do not mind we rebase it and merge it. What is your plan with regards to error reporting work? |
Once this has been merged in I'll start making smaller PRs for the different JNI files so it's not as unpleasant to review. The initial one will be bigger because it'll have the error handling changes in |
339f78f to
9f8f33b
Compare
|
Done. |
|
/azp run MacOS CI Pipeline, Windows CPU CI Pipeline, Windows GPU CI Pipeline, Windows GPU TensorRT CI Pipeline, ONNX Runtime Web CI Pipeline, onnxruntime-python-checks-ci-pipeline |
|
/azp run Linux CPU CI Pipeline, Linux CPU Minimal Build E2E CI Pipeline, Linux GPU CI Pipeline, Linux GPU TensorRT CI Pipeline, Linux Nuphar CI Pipeline, Linux OpenVINO CI Pipeline |
|
Azure Pipelines successfully started running 6 pipeline(s). |
|
Azure Pipelines successfully started running 5 pipeline(s). |
|
I wish |
|
/azp run MacOS CI Pipeline, Windows CPU CI Pipeline, Windows GPU CI Pipeline, Windows GPU TensorRT CI Pipeline, ONNX Runtime Web CI Pipeline, onnxruntime-python-checks-ci-pipeline |
|
/azp run Linux CPU CI Pipeline, Linux CPU Minimal Build E2E CI Pipeline, Linux GPU CI Pipeline, Linux GPU TensorRT CI Pipeline, Linux Nuphar CI Pipeline, Linux OpenVINO CI Pipeline |
|
/azp run orttraining-amd-gpu-ci-pipeline, orttraining-linux-ci-pipeline, orttraining-linux-gpu-ci-pipeline, orttraining-ortmodule-distributed |
|
Azure Pipelines successfully started running 6 pipeline(s). |
|
Azure Pipelines successfully started running 5 pipeline(s). |
|
Azure Pipelines successfully started running 4 pipeline(s). |
|
/azp run onnxruntime-binary-size-checks-ci-pipeline |
|
Azure Pipelines successfully started running 1 pipeline(s). |
**Description**: Adds support for creating and receiving sparse tensors in the ORT Java API. CSRC and COO tensors as inputs are tested, but there is no op which accepts a block sparse tensor to test. COO tensors are tested as outputs, but there is no op which emits a CSRC or block sparse tensor to test. **Motivation and Context** - Why is this change required? What problem does it solve? Request to expose ORT sparse tensor support in Java. cc @yuslepukhin
Description:
Adds support for creating and receiving sparse tensors in the ORT Java API.
CSRC and COO tensors as inputs are tested, but there is no op which accepts a block sparse tensor to test. COO tensors are tested as outputs, but there is no op which emits a CSRC or block sparse tensor to test.
Motivation and Context
cc @yuslepukhin