[Kernel C API] Implementation of variable ops RFC.#49717
[Kernel C API] Implementation of variable ops RFC.#49717copybara-service[bot] merged 11 commits intotensorflow:masterfrom
Conversation
tensorflow/c/kernels.cc
Outdated
| return tf_tensor; | ||
| } | ||
|
|
||
| tensorflow::Status EnsureSparseVariableAccess(TF_OpKernelContext* ctx, |
There was a problem hiding this comment.
I am moving these to c_api_experimental file (ran into some build issues). I wanted to get started with PR, to get feedback.
tensorflow/c/kernels.cc
Outdated
| var->tensor()->shape(), &tmp, attr)); | ||
| tensorflow::Status s; | ||
| TF_Tensor *tf_tmp = TF_TensorFromTensor(tmp, &s); | ||
| TF_Tensor *tf_tensor = TF_TensorFromTensor(*var->tensor(), &s); |
There was a problem hiding this comment.
I think at the end of the function, TF_DeleteTensor() is needed, TF_TensorFromTensor will new a TF_Tensor struct and will cause memory leak if TF_DeleteTensor() is not invoked.
/ Non-static for testing.
TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, Status* status) {
*status = tensorflow::Status::OK();
if (!src.IsInitialized()) {
*status = FailedPrecondition(
"attempt to use a tensor with an uninitialized value");
return nullptr;
}
if (src.NumElements() == 0) {
return EmptyTensor(static_cast<TF_DataType>(src.dtype()), src.shape());
}
if (src.dtype() == tensorflow::DT_RESOURCE) {
if (src.shape().dims() != 0) {
*status = InvalidArgument(
"Unexpected non-scalar DT_RESOURCE tensor seen (shape: ",
src.shape().DebugString(),
"). Please file a bug at "
"https://github.com/tensorflow/tensorflow/issues/new, "
"ideally with a "
"short code snippet that reproduces this error.");
return nullptr;
}
const string str =
src.scalar<tensorflow::ResourceHandle>()().SerializeAsString();
TF_Tensor* t = TF_AllocateTensor(TF_RESOURCE, {}, 0, str.size());
std::memcpy(TF_TensorData(t), str.c_str(), str.size());
return t;
}
Tensor tensor;
if (!tensor.CopyFrom(src, src.shape())) {
return nullptr;
}
return new TF_Tensor{new tensorflow::TensorInterface(std::move(tensor))};
}
tensorflow/c/kernels.cc
Outdated
| context->allocate_temp(tensor->dtype(), tensor->shape(), &tmp, attr)); | ||
| tensorflow::Status s; | ||
| TF_Tensor *tf_tmp = TF_TensorFromTensor(tmp, &s); | ||
| TF_Tensor *tf_tensor = TF_TensorFromTensor(*tensor, &s); |
There was a problem hiding this comment.
Do we need to check the two status ?
There was a problem hiding this comment.
Maybe we can return the status back to caller of the PrepareUpdateVariable().
fc79d57 to
cc6a3f1
Compare
tensorflow/c/kernels.h
Outdated
| // &total_size)). | ||
| TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrTensorShape( | ||
| TF_OpKernelConstruction* ctx, | ||
| const char* attr_name, int64_t* values, |
There was a problem hiding this comment.
nit: Rename values to dims and max_vals to num_dims to be consistent with TF_GraphGetTensorShape.
tensorflow/c/kernels.cc
Outdated
| return tf_tensor; | ||
| } | ||
|
|
||
| tensorflow::Status EnsureSparseVariableAccess(TF_OpKernelContext* ctx, |
There was a problem hiding this comment.
Is it possible to share code with the existing impl for EnsureSparseVariableAccess in training_op_helpers.h? Same comment for other helpers below.
There was a problem hiding this comment.
Thanks @saxenasaurabh for the review. I have used the helper functions where possible like LookupResource. With EnsureSparseVariableAccess , there is Device dependency which we are passing in as Copy functors. I do agree there is duplication of code which can be avoided. One possibility is to refactor the core helper functions to remove this dependency and we can adopt it here. Maybe we can do it as a followup cleanup as it will require invasive changes in the core which can break things. What do you think ?
There was a problem hiding this comment.
Sounds good. Please try to clean this is up as a follow-up. That would give test coverage for free as well.
There was a problem hiding this comment.
Makes sense, will do.
tensorflow/c/kernels.cc
Outdated
| tf_tensor = TF_TensorFromTensor(*var->tensor(), &s); | ||
| TF_Tensor *tf_tmp = TF_TensorFromTensor(tmp, &s); | ||
| TF_Tensor *tf_tensor = TF_TensorFromTensor(*var->tensor(), &s); | ||
| copyFunc(ctx, tf_tensor, tf_tmp); |
There was a problem hiding this comment.
We are expecting plugin to call TF_DeleteTensor in copyFunc. This would be more in line with rest of the TF, where we release the tensors in the Compute.
tensorflow/c/kernels.h
Outdated
| TF_OpKernelContext* ctx, | ||
| int input, | ||
| bool lock_held, | ||
| bool isVariantType, |
There was a problem hiding this comment.
Do we need to expose isVariantType in the API or can that be inferred from the tensor? I believe this is equivalent to TF_TensorType(tensor) == TF_VARIANT.
There was a problem hiding this comment.
@saxenasaurabh thanks! (sorry for the delay), you are right , we can probably skip the isVariantType in the API. Currently in our metal-plugin release we are using this API. If its not too much of a concern, can we keep this way?
| std::set<tensorflow::string> colocation_constraints; | ||
| }; | ||
|
|
||
| struct TF_VariableInputLockHolder { |
There was a problem hiding this comment.
Could this be in c_api_experimental as well?
The implementation for the Variable Ops RFC.
https://github.com/tensorflow/community/blob/master/rfcs/20210504-kernel-extension-variable-ops.md
@penpornk , @reedwm , @saxenasaurabh , @jzhoulon