Skip to content

[java] Sparse tensor support#10653

Merged
yuslepukhin merged 18 commits intomicrosoft:mainfrom
Craigacp:sparse-tensor
Nov 22, 2022
Merged

[java] Sparse tensor support#10653
yuslepukhin merged 18 commits intomicrosoft:mainfrom
Craigacp:sparse-tensor

Conversation

@Craigacp
Copy link
Contributor

@Craigacp Craigacp commented Feb 24, 2022

Description:

Adds support for creating and receiving sparse tensors in the ORT Java API.

CSRC and COO tensors as inputs are tested, but there is no op which accepts a block sparse tensor to test. COO tensors are tested as outputs, but there is no op which emits a CSRC or block sparse tensor to test.

Motivation and Context

  • Why is this change required? What problem does it solve? Request to expose ORT sparse tensor support in Java.

cc @yuslepukhin

* <p>Sparse tensors support a variety of formats, and the {@link #getValue} method returns a
* different static inner class representing each type.
*/
public final class OnnxSparseTensor extends OnnxTensorLike {
Copy link
Member

@yuslepukhin yuslepukhin Feb 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OnnxSparseTensor

Just a remark. Onnx is a standard. Onnxruntime is an implemenation. Technically, we deal with a Onnxruntime implementation of tensors. #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Onnx... classes map onto concepts in ONNX, and the Ort... classes map onto ONNX Runtime specific concepts. I admit the line is a bit fuzzy here because ONNX doesn't define sparse tensor formats, but it does define generic sparse tensors as a thing.

}

/** Abstract base class for Java sparse tensors */
private abstract static class BaseSparseTensor<T extends Buffer> implements SparseTensor<T> {
Copy link
Member

@yuslepukhin yuslepukhin Feb 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BaseSparseTensor

Just curious, why do we need both base class (which is not used fully IMHO) and the interface? Can base class serve as an interface?
It seems we get a huge hierarchy just to review.
OnnxValue -> OnnxTensorLike -> OnnxSparseTensor (essentially OnnxValue anyway) then that contains SparseTensor interface -> BaseSparsetensor -> FormatSpecific tensors which include common things that should be in a base class.
Then OnnxSparseTensor contains two things separate, one is a SparseTensor and a native handle live its own life. The issue I see here as a coder, is that I have an OnnxSparseTensor AutoClosable. I can get SparseTensor handle and the OnnxSparseTensor can disappear from underneath. Whereas if SpraseTensor was AutoClosable it would never had a chance of happening.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The interface + abstract base class is a fairly standard idiom in Java, though I guess it doesn't offer much flexibility here given I intend to seal the interface so it only has the provided implementations. I can refactor it into a public abstract class.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SparseTensor and OnnxSparseTensor represent different objects with different uses. The SparseTensor is purely Java side, it holds no references to ORT objects or data and is used as the way to prepare data in Java for supplying into ORT, or for working on the output if a sparse tensor is produced from a run call. SparseTensor and its implementations could easily live outside the OnnxSparseTensor class, and if Java had a similar package to scipy which provided sparse tensors it wouldn't exist at all we'd just depend on the external definitions. OnnxSparseTensor holds the references to the ORT objects and memory, but the data only be accessed from Java by making a SparseTensor and copying the data out because there aren't any direct accessors into the ORT sparse tensor object. The OnnxSparseTensor has a reference to the buffers that came from the SparseTensor to make sure they live longer than the OnnxSparseTensor, though I'm not actually sure now if it's copying the memory into the OrtValue or referencing it.

OnnxTensorLike is a made up class which exists because both OnnxTensor and OnnxSparseTensor can be inputs to a run call, but OnnxSequence and OnnxMap cannot. So I need something in the type hierarchy that lets me distinguish them statically.

Copy link
Contributor Author

@Craigacp Craigacp Feb 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ooops, you're right, I'm missing the copy in OnnxSparseTensor.getDataBuffer(), OnnxSparseTensor.getIndicesBuffer() and OnnxSparseTensor.getInnerIndicesBuffer(). I'll fix that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've fixed the get methods so they copy the buffer. Once this has merged in I plan to add methods to both OnnxSparseTensor and OnnxTensor which let users copy the buffer into one they supply, and hopefully that will let me refactor and reduce the get logic a little (by pushing it into OnnxTensorLike).

}

OrtTensorTypeAndShapeInfo* info;
checkOrtStatus(jniEnv,api,api->GetSparseTensorIndicesTypeShape((OrtValue*) handle, indicesFormat, &info));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checkOrtStatus

Please, remind how checkOrtStatus works. It does record Java exception, but otherwise is happily continuing calling native API. Should this some how stop on error since it may be feeding incorrect values?

Copy link
Contributor Author

@Craigacp Craigacp Feb 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the JNI code needs revising as the exception isn't necessarily propagated until the next call back into the JVM. So checkOrtStatus needs to be modified to return something that can be checked and all the places where it is called need to be modified with extra control flow to allow an early return, potentially several layers deep. I've not done that in this PR as it's going to be a large and confusing change which I'd like to keep separate from this one.

I think in most cases that a failure in an ORT call will cause the remainder of the calls in that method to also fail (e.g. by returning a null pointer), and so in practice not much else happens in that method when a call fails. However that doesn't guarantee that we free any memory allocated before the call failed, which will make the abort control flow much trickier.

Copy link
Member

@yuslepukhin yuslepukhin Feb 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we are going to SEGSEV if OrtTensorTypeAndShapeInfo is some stray pointer that was not properly populated because something errored out. The only thing we can do is to check for nullptr which we do not do since the pre-req is a valid instance. And if it is not nullptr then we can only assume it is valid. Same applies to all other APIs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I'll work on a separate PR for the native fixes.

@yuslepukhin
Copy link
Member

One question. Does this code still copy memory for input/output? We should not copy at least for input and if so I believe there is a way to prevent that on output as well.

@Craigacp
Copy link
Contributor Author

It does not copy on input if the buffers are direct. If they aren't direct then they need to be copied into direct ones, which are passed in to the OrtValue. Then the OnnxSparseTensor holds a reference to the buffers so they don't get garbage collected.

For outputs we could wrap the bare pointers in direct byte buffers, but if those buffers are exposed out of the OnnxSparseTensor then they need to not live past the lifetime of the sparse tensor, and that's tricky to enforce. Currently they copy the direct buffer into another buffer before wrapping the copies in the COOTensor/CSRCTensor/BlockSparseTensor.

@yuslepukhin
Copy link
Member

yuslepukhin commented Feb 28, 2022

Well, we are doing just find in C#, holding to the native pointers received from ORT and destroy then when IDisposable is invoked, similar to AutoClosable.
If that does not work, here is idea. If output shapes are known, pre-allocate output buffers, create OrtValues and pass them as outputs with corresponding name mappings. You still need to hold on to OrtValues pointers and destroy them on AutoClose.


In reply to: 1054582742

@Craigacp
Copy link
Contributor Author

Craigacp commented Feb 28, 2022

The direct byte buffer abstraction in Java is different from the native pointer abstraction in C#, and I think that there is not a good solution here for outputs that has a zero copy output path and that also exposes the buffers to users. We could alternatively expose direct access into the OrtValue by implementing chunks of the buffer API on top of OnnxSparseTensor (or in private lifetime limited objects inside OnnxSparseTensor), but that's a substantial development effort. If you want to use a sparse tensor output from one ORT session and pass it into another ORT session then that's zero copy, it's only if you want to get the values back out into Java that it triggers a copy. When the Java foreign memory API is completed (and ORT is updated to use that minimum Java version, which might involve forking the Java API from the Android one) then there will be better options here for managing the lifetimes of these chunks of memory.

Pre-allocation could work, but that would require a rewrite of all the output processing in both the Java and the native code. I'd missed the update to Run which switched the output pointer over so it was _Inout_, that seems to have changed after the Java API was initially integrated. Presumably it throws an exception if one of the supplied OrtValues is the wrong size or shape?


In reply to: 1054714133

@yuslepukhin
Copy link
Member

/azp run Linux CPU CI Pipeline, Linux CPU Minimal Build E2E CI Pipeline, Linux CPU x64 NoContribops CI Pipeline, Linux GPU CI Pipeline, Linux GPU TensorRT CI Pipeline, Linux Nuphar CI Pipeline, Linux OpenVINO CI Pipeline, MacOS CI Pipeline

@yuslepukhin
Copy link
Member

/azp run MacOS NoContribops CI Pipeline, Windows CPU CI Pipeline, Windows GPU CI Pipeline, Windows GPU TensorRT CI Pipeline, Windows WebAssembly CI Pipeline, orttraining-amd-gpu-ci-pipeline, orttraining-linux-ci-pipeline, orttraining-linux-gpu-ci-pipeline, orttraining-ortmodule-distributed

@azure-pipelines
Copy link

Azure Pipelines successfully started running 7 pipeline(s).

1 similar comment
@azure-pipelines
Copy link

Azure Pipelines successfully started running 7 pipeline(s).

@Craigacp
Copy link
Contributor Author

Craigacp commented Mar 1, 2022

I fixed the build error. Clang must do different exhaustiveness checking on enums than gcc or msvc as the latter two complained I wasn't assigning a value in a switch which covered all the enum cases. I added a default branch which lines up with ORT_SPARSE_UNDEFINED which should make it happy. Plus I rebased after the merge of #10670.

@yuslepukhin
Copy link
Member

yuslepukhin commented Mar 1, 2022

Pre-allocation could work, but that would require a rewrite of all the output processing in both the Java and the native code. I'd missed the update to Run which switched the output pointer over so it was _Inout_, that seems to have changed after the Java API was initially integrated. Presumably it throws an exception if one of the supplied OrtValues is the wrong size or shape?

Yes, it throws and it does it soon. The API docs and the parameter annotations still need work. This feature was available for a long time. Let's fix error reporting first and then we can work on copy elimination. Both are very important. Not so much right now for SparseTensors, but for dense Tensors in the first place.

@yuslepukhin
Copy link
Member

/azp run Linux CPU CI Pipeline, Linux CPU Minimal Build E2E CI Pipeline, Linux CPU x64 NoContribops CI Pipeline, Linux GPU CI Pipeline, Linux GPU TensorRT CI Pipeline, Linux Nuphar CI Pipeline, Linux OpenVINO CI Pipeline, MacOS CI Pipeline

@yuslepukhin
Copy link
Member

/azp run MacOS NoContribops CI Pipeline, Windows CPU CI Pipeline, Windows GPU CI Pipeline, Windows GPU TensorRT CI Pipeline, Windows WebAssembly CI Pipeline, orttraining-amd-gpu-ci-pipeline, orttraining-linux-ci-pipeline, orttraining-linux-gpu-ci-pipeline, orttraining-ortmodule-distributed

@azure-pipelines
Copy link

Azure Pipelines successfully started running 7 pipeline(s).

1 similar comment
@azure-pipelines
Copy link

Azure Pipelines successfully started running 7 pipeline(s).

@yuslepukhin
Copy link
Member

/azp run onnxruntime-binary-size-checks-ci-pipeline, onnxruntime-python-checks-ci-pipeline, ONNX Runtime Web CI Pipeline

@azure-pipelines
Copy link

Azure Pipelines successfully started running 3 pipeline(s).

yuslepukhin
yuslepukhin previously approved these changes Mar 2, 2022
@yuslepukhin yuslepukhin dismissed their stale review March 10, 2022 19:22

We would like to postpone it

@Craigacp
Copy link
Contributor Author

Can we get this integrated now the 1.11 release has happened? I can rebase it on master if necessary, then it'll be easier to work on the native binding improvements that we discussed.

@yuslepukhin
Copy link
Member

Can we get this integrated now the 1.11 release has happened? I can rebase it on master if necessary, then it'll be easier to work on the native binding improvements that we discussed.

I do not mind we rebase it and merge it.

What is your plan with regards to error reporting work?

@Craigacp
Copy link
Contributor Author

Can we get this integrated now the 1.11 release has happened? I can rebase it on master if necessary, then it'll be easier to work on the native binding improvements that we discussed.

I do not mind we rebase it and merge it.

What is your plan with regards to error reporting work?

Once this has been merged in I'll start making smaller PRs for the different JNI files so it's not as unpleasant to review. The initial one will be bigger because it'll have the error handling changes in OrtJniUtil.c and things called from the session run method. All the rest should be smaller and simpler.

@Craigacp
Copy link
Contributor Author

Done.

@yuslepukhin
Copy link
Member

/azp run MacOS CI Pipeline, Windows CPU CI Pipeline, Windows GPU CI Pipeline, Windows GPU TensorRT CI Pipeline, ONNX Runtime Web CI Pipeline, onnxruntime-python-checks-ci-pipeline

@yuslepukhin
Copy link
Member

/azp run Linux CPU CI Pipeline, Linux CPU Minimal Build E2E CI Pipeline, Linux GPU CI Pipeline, Linux GPU TensorRT CI Pipeline, Linux Nuphar CI Pipeline, Linux OpenVINO CI Pipeline

@azure-pipelines
Copy link

Azure Pipelines successfully started running 6 pipeline(s).

@azure-pipelines
Copy link

Azure Pipelines successfully started running 5 pipeline(s).

@Craigacp
Copy link
Contributor Author

I wish clang had the same processing code for deciding definite assignment through switch as gcc and msvc. It's definitely right, in that the variable was always assigned, but someone needs to explain that to gcc. I've fixed the things gcc complained about, so you'll need to kick off the pipelines again.

@yuslepukhin
Copy link
Member

/azp run MacOS CI Pipeline, Windows CPU CI Pipeline, Windows GPU CI Pipeline, Windows GPU TensorRT CI Pipeline, ONNX Runtime Web CI Pipeline, onnxruntime-python-checks-ci-pipeline

@yuslepukhin
Copy link
Member

/azp run Linux CPU CI Pipeline, Linux CPU Minimal Build E2E CI Pipeline, Linux GPU CI Pipeline, Linux GPU TensorRT CI Pipeline, Linux Nuphar CI Pipeline, Linux OpenVINO CI Pipeline

@yuslepukhin
Copy link
Member

/azp run orttraining-amd-gpu-ci-pipeline, orttraining-linux-ci-pipeline, orttraining-linux-gpu-ci-pipeline, orttraining-ortmodule-distributed

@azure-pipelines
Copy link

Azure Pipelines successfully started running 6 pipeline(s).

@azure-pipelines
Copy link

Azure Pipelines successfully started running 5 pipeline(s).

@azure-pipelines
Copy link

Azure Pipelines successfully started running 4 pipeline(s).

@yuslepukhin
Copy link
Member

/azp run onnxruntime-binary-size-checks-ci-pipeline

@azure-pipelines
Copy link

Azure Pipelines successfully started running 1 pipeline(s).

@yuslepukhin yuslepukhin merged commit dd2c031 into microsoft:main Nov 22, 2022
@Craigacp Craigacp deleted the sparse-tensor branch November 29, 2022 14:55
simon-moo pushed a commit to simon-moo/onnxruntime that referenced this pull request Dec 26, 2022
**Description**:

Adds support for creating and receiving sparse tensors in the ORT Java
API.

CSRC and COO tensors as inputs are tested, but there is no op which
accepts a block sparse tensor to test. COO tensors are tested as
outputs, but there is no op which emits a CSRC or block sparse tensor to
test.

**Motivation and Context**
- Why is this change required? What problem does it solve? Request to
expose ORT sparse tensor support in Java.

cc @yuslepukhin
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants