[API compatibility] add the param name for paddle.Tensor.copy_#74768
Conversation
|
你的PR提交成功,感谢你对开源项目的贡献! |
| "sure you are on the right way. " | ||
| "The expected arguments as follow: (" | ||
| "other, non_blocking)")); | ||
| PADDLE_ENFORCE_EQ( |
| common::errors::PreconditionNotMet( | ||
| "Must provide the `other: Tensor` params for paddle.Tensor.copy_")); | ||
|
|
||
| paddle::Tensor& src_tensor = CastPyArg2Tensor(other, 0); |
| EAGER_TRY | ||
| paddle::Tensor& src_tensor = CastPyArg2Tensor(PyTuple_GET_ITEM(args, 0), 0); | ||
| PyObject* other = nullptr; | ||
| bool blocking = false; |
There was a problem hiding this comment.
之前使用paddle并指定blocking的代码会挂掉,这个要兼容的话 就是在blocking后面加1个non_blocking参数,只有torch的non_blocking未指定关键字时被误当做blocking,这个影响应比较小。
other, blocking, non_blocking
There was a problem hiding this comment.
嗯?是指.copy_(a, blocking=True)这个调用会报错吗,如果是的话应该是因为之前的copy_里面只对args解析了,没有对kwargs进行解析所以会报错。如果是.copy_(a, True, non_blocking=True),这种情况,以non_blocking为主吗?
| nullptr}; | ||
| bool flag = PyArg_ParseTupleAndKeywords( | ||
| args, kwargs, "|Obb", kwlist, &other_tensor, &blocking, &non_blocking); | ||
| blocking = !blocking || non_blocking ? false : true; |
There was a problem hiding this comment.
这个blocking设置的逻辑是,默认采用blocking的方式执行,只blocking参数为False或者non_blocking设置为True时,就使用non_blocking的方式执行
|
/re-run all-failed |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## develop #74768 +/- ##
==========================================
Coverage ? 0
==========================================
Files ? 0
Lines ? 0
Branches ? 0
==========================================
Hits ? 0
Misses ? 0
Partials ? 0 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
PR Category
User Experience
PR Types
New features
Description
对齐paddle.Tensor.copy_与torch.Tensor.copy_。
当前paddle.Tensor.copy_已经存在,并且在功能上与torch.Tensor.copy_对齐,但是在paddle内是基于pybind实现的方法,在参数处理上没有处理kwargs的逻辑并且没有返回值(torch里有返回值),本PR进行了补充与完善。
在PaConvert下自测,paddle.Tensor.copy_与torch.Tenso.copy_对齐。
pcard-71500