-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[android] Fix error messages; tensor creation method names with type #26219
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
dzhulgakov
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, more nitpicks of naming - I think it's better to follow pytorch naming convention that java's
| final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); | ||
| final IValue input = | ||
| IValue.tensor(Tensor.newTensor(new long[] {1}, Tensor.allocateByteBuffer(1))); | ||
| IValue.tensor(Tensor.newByteTensor(new long[] {1}, Tensor.allocateByteBuffer(1))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe use uint8 to be consistent with PyTorch's API? (torch.uint8/int8)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
java byte is signed,so byteBuffer for tensor content can be filled with signed bytes.
To support uint8 we need some workaround, I think with method names to force that it is unsigned.
I think that uint8 and int8 naming wil be unnatural for java world...But maybe as we expect users be familiar with pytorch it should be ok.
@dreiss , @ljk53 , what do you think about uint8, int8 naming in android java api ?
| inputTensorData[i] = i; | ||
| } | ||
| final Tensor inputTensor = Tensor.newTensor(inputTensorShape, inputTensorData); | ||
| final Tensor inputTensor = Tensor.newFloatTensor(inputTensorShape, inputTensorData); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similarly - maybe Float32?
| } | ||
|
|
||
| public static Tensor newTensor(long[] shape, double[] data) { | ||
| public static Tensor newDoubleTensor(long[] shape, double[] data) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Float64?
| public String toString() { | ||
| return String.format( | ||
| "Tensor_double64{shape:%s numel:%d}", Arrays.toString(shape), data.capacity()); | ||
| return String.format("Tensor(%s, dtype=torch.double)", Arrays.toString(shape)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.double is an alias for float64:
>>> torch.double
torch.float64
291c54c to
724ca95
Compare
|
@dzhulgakov |
724ca95 to
3c6d86d
Compare
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@IvanKobzarev has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@IvanKobzarev has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
@IvanKobzarev merged this pull request in b07991f. |
Summary: At the moment it includes #26219 changes. That PR is landing at the moment, afterwards this PR will contain only javadocs. Applied all dreiss comments from previous version. Pull Request resolved: #26149 Differential Revision: D17490720 Pulled By: IvanKobzarev fbshipit-source-id: f340dee660d5ffe40c96b43af9312c09f85a000b
After offline discussion with @dzhulgakov :
In future we will introduce creation of byte signed and byte unsigned dtype tensors, but java has only signed byte - we will have to add some separation for it in method names ( java types and tensor types can not be clearly mapped) => Returning type in method names
fixes in error messages
non-static method Tensor.numel()
Change Tensor toString() to be more consistent with python
Update on Sep 16:
Type renaming on java side to uint8, int8, int32, float32, int64, float64