-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🚀 Feature
Tweak Java API before 1.3 release
Motivation
We can update this later, but it's easier to do it now.
Module
No proposed changes.
IValue
Proposed change: consolidate factory methods.
These are currently named by the types the construct: bool, float64, etc. I propose consolidating these into:
IValue.fromfor bool, double64, long64, string, tensorIValue.listFromfor boolList, doubleList, longList, tensorListIValue.tupleFromIValue.dictLongKeyFromIValue.dictStringKeyFromIValue.optionalNull
Benefits:
- Fewer method names to remember.
- Feels more "torchy" to me.
Drawbacks:
- Unnecessary unverloading can cause confusion.
- Some method names get longer (like "tupleFrom").
Alternatives:
- Could keep "tupleFrom" named "tuple". Same for distLongKeyFrom.
- optionalNull could be replaced by
IValue.from(). listFromcould be replaced by more overloads offrom, but I think it would be confusing, especially if we continued to allow varargs (IValue.from(1, 2)suddenly becomes a different type if you remove the second arg).- Unfortunately, I don't think it is possible to unify dictLongKey and dictStringKey while keeping type safety.
Proposed non-change: getters.
These are currently getBool, getTensor, etc. The C++ API uses "toBool", "toTensor", etc.
Benefits:
- Getters are a very strong convention in Java.
- Kotlin allows these to be written as
myValue.tensor, which is convenient.
Drawbacks:
- Divergence from C++ API.
Tensor
Proposed change: Consolidate factory methods.
These are newFloat32Tensor(shape, data), newInt64Tensor(shape, data). I propose consolidating all of these except for uint into Tensor.fromBlob(data, shape)
Benefits:
- Matches C++ API.
Drawbacks:
- The current API matches numpy.ndarray, which puts shape before data.
- Unnecessary overloading can cause confusion.
Alternatives:
- Just leave it as-is?
Proposed change: UInt8 factory method
Java doesn't have unsigned types, so we can't use overloading to distinguish int8 from uint8. I propose creating Tensor.fromBlobUnsigned for uint8.
Benefits:
- Pretty simple.
Drawbacks:
- Doesn't match the C++ API, specifies the dtype as a last parameter. (TBH, I think this is a minor deficiency in the C++ API, which could infer the dtype from the pointer type.)
Alternatives:
- Specify dtype as third argument?
Proposed non-change: data accessors
These are t.getDataAsFloatArray(), etc. They're verbose, but they all have to have distinct names because of the distinct return types. In C++, this would be t.data_ptr<float>(), but that's not possible in Java.
I can't think of anything obviously better. The "Array" in the name is because we might expose another format later (like java.nio.Buffer).
Alternatives:
- We could use overloading like
float[] f = t.dataPtr((float)0), which matches the C++ API more closely, but is just gross.
Proposed change: dtypes
Let's make these these an enum instead of ints.
Benefits:
Better type safety.
Drawbacks:
Slightly larger code (though I think ProGuard and Redex can eliminate this).
Proposed change: shape
This is currently a mutable array. Let's change it to a method that returns a fresh copy.
Benefits:
Don't need to worry about people mutating it.
Drawbacks:
Slight runtime overhead.