Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/fluid/framework/new_executor/pir_interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ PirInterpreter::PirInterpreter(const platform::Place& place,

std::stringstream ss;
ss << this
<< std::chrono::system_clock::to_time_t(std::chrono::system_clock::now());
<< std::chrono::high_resolution_clock::now().time_since_epoch().count();
BuildScope(*ir_block_, ss.str(), value_exe_info_.get());
}

Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/pybind/eager_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ extern PyTypeObject* p_string_tensor_type;

extern PyTypeObject* g_framework_scope_pytype;
extern PyTypeObject* g_ir_opresult_pytype;
extern PyTypeObject* g_ir_value_pytype;
extern PyTypeObject* g_vartype_pytype;
extern PyTypeObject* g_data_type_pytype;
extern PyTypeObject* g_place_pytype;
Expand Down Expand Up @@ -1521,6 +1522,8 @@ pir::Value CastPyArg2Value(PyObject* obj,
size_t arg_pos) {
if (PyObject_TypeCheck(obj, g_ir_opresult_pytype)) {
return ::pybind11::handle(obj).cast<pir::OpResult>();
} else if (PyObject_TypeCheck(obj, g_ir_value_pytype)) {
return ::pybind11::handle(obj).cast<pir::Value>();
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ namespace paddle {
namespace pybind {

PyTypeObject *g_ir_opresult_pytype = nullptr;
PyTypeObject *g_ir_value_pytype = nullptr;

void BindOpsAPI(pybind11::module *module);

Expand Down Expand Up @@ -410,6 +411,7 @@ void BindValue(py::module *m) {
when build network.

)DOC");
g_ir_value_pytype = reinterpret_cast<PyTypeObject *>(value.ptr());
value
.def(
"get_defining_op",
Expand Down
10 changes: 6 additions & 4 deletions python/paddle/base/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import numpy as np

from ..pir import OpResult, translate_to_new_ir
from ..pir import OpResult, Value, translate_to_new_ir
from . import compiler, core, framework, get_flags, set_flags, unique_name
from .data_feeder import convert_dtype
from .framework import (
Expand Down Expand Up @@ -513,7 +513,7 @@ def _add_pir_fetch_ops(program, fetch_list, fetch_var_name):
with paddle.static.program_guard(program):
for i, fetch_input in enumerate(fetch_list):
assert isinstance(
fetch_input, OpResult
fetch_input, (OpResult, Value)
), f"Wrong type for fetch_list[{i}]: {type(fetch_input)}"
paddle._pir_ops.fetch(fetch_input, fetch_var_name + str(i), i)

Expand Down Expand Up @@ -1956,7 +1956,9 @@ def _run_inference(self, exe, feed):
return exe.run(feed)

def _check_fetch_list(self, fetch_list):
is_fetch_var = lambda var: isinstance(var, (Variable, str, OpResult))
is_fetch_var = lambda var: isinstance(
var, (Variable, str, OpResult, Value)
)
is_tuple_list = lambda var: isinstance(var, (tuple, list))

if fetch_list is None:
Expand All @@ -1982,7 +1984,7 @@ def _check_fetch_list(self, fetch_list):
res.append(var)
else:
raise TypeError(
"Require fetch_list[{}] 's type shall be one of (Variable, str), but received {}.".format(
"Require fetch_list[{}] 's type shall be one of (OpResult, str), but received {}.".format(
i, type(var).__name__
)
)
Expand Down
58 changes: 55 additions & 3 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2037,6 +2037,7 @@ def split(x, num_or_sections, axis=0, name=None):
>>> print(out2.shape)
[3, 3, 5]
"""

input = x
dim = axis
if in_dynamic_mode():
Expand All @@ -2061,15 +2062,32 @@ def split(x, num_or_sections, axis=0, name=None):
else:
return _C_ops.split(input, num_or_sections, dim)
elif in_pir_mode():
if isinstance(dim, paddle.pir.OpResult):
dim.stop_gradient = True
if isinstance(dim, int):
assert len(input.shape) + dim >= 0, "(rank(x) + axis) must >= 0"
dim = (len(input.shape) + dim) if dim < 0 else dim

input_shape = input.shape
if isinstance(num_or_sections, int):
dim = dim if dim >= 0 else dim + len(input.shape)
assert num_or_sections > 0, 'num_or_sections must be than 0.'
if isinstance(dim, int) and input_shape[dim] > 0:
assert input_shape[dim] % num_or_sections == 0, (
"The input's size along the split dimension "
"must be evenly divisible by Attr(num_or_sections). "
"But %d is not evenly divisible by %d. "
% (num_or_sections, input_shape[dim])
)
return _C_ops.split_with_num(input, num_or_sections, dim)
else:
dim = dim if dim >= 0 else dim + len(input.shape)
if isinstance(dim, int) and input_shape[dim] > 0:
assert (
len(num_or_sections) <= input_shape[dim]
), 'len(num_or_sections) must not be more than input.shape[dim].'
if paddle.utils._contain_var(num_or_sections):
num_or_sections = paddle.utils.get_int_tensor_list(
num_or_sections
)
return _C_ops.split(input, num_or_sections, dim)

else:
Expand Down Expand Up @@ -3603,7 +3621,21 @@ def expand(x, shape, name=None):
[[1, 2, 3],
[1, 2, 3]])
"""
if in_dynamic_or_pir_mode():
if in_dynamic_mode():
return _C_ops.expand(x, shape)
elif in_pir_mode():
if convert_dtype(x.dtype) == 'bool' and not x.stop_gradient:
raise ValueError(
"When the data type of input 'x' for expand is bool, "
"you must set its stop_gradient to be False by "
"some_var.stop_gradient = True, supporting "
"some_var as the input."
)
if isinstance(shape, paddle.pir.OpResult):
shape.stop_gradient = True
elif isinstance(shape, (list, tuple)):
if paddle.utils._contain_var(shape):
shape = paddle.utils._convert_to_tensor_list(shape)
return _C_ops.expand(x, shape)
else:
if isinstance(shape, Variable):
Expand Down Expand Up @@ -3798,6 +3830,26 @@ def get_attr_shape(list_shape):
)
return out
elif in_pir_mode():
check_variable_and_dtype(
x,
'x',
[
'float16',
'float32',
'float64',
'int8',
'uint8',
'int16',
'int32',
'int64',
'bool',
'uint16',
],
'reshape',
)
check_type(
shape, 'shape', (list, tuple, paddle.pir.OpResult), 'reshape'
)
if isinstance(shape, (list, tuple)):
if paddle.utils._contain_var(shape):
new_shape = paddle.utils._convert_to_tensor_list(shape)
Expand Down
23 changes: 21 additions & 2 deletions python/paddle/tensor/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,9 +794,28 @@ def uniform(shape, dtype=None, min=-1.0, max=1.0, seed=0, name=None):
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)

if in_dynamic_or_pir_mode():
if in_dynamic_mode():
shape = paddle.utils.convert_shape_to_list(shape)
if in_pir_mode() and paddle.utils._contain_var(shape):
return _C_ops.uniform(
shape,
dtype,
float(min),
float(max),
seed,
_current_expected_place(),
)
elif in_pir_mode():
check_type(
shape, 'shape', (list, tuple, paddle.pir.OpResult), 'uniform/rand'
)
check_dtype(dtype, 'dtype', supported_dtypes, 'uniform/rand')
check_type(
min, 'min', (float, int, paddle.pir.OpResult), 'uniform/rand'
)
check_type(
max, 'max', (float, int, paddle.pir.OpResult), 'uniform/rand'
)
if paddle.utils._contain_var(shape):
shape = paddle.utils.get_int_tensor_list(
shape, _current_expected_place()
)
Expand Down
147 changes: 80 additions & 67 deletions test/legacy_test/test_concat_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import paddle
from paddle import base
from paddle.base import Program, core, program_guard
from paddle.pir_utils import test_with_pir_api


class TestConcatOp(OpTest):
Expand Down Expand Up @@ -591,76 +592,88 @@ def test_input_same_dtype():


class TestConcatAPI(unittest.TestCase):
@test_with_pir_api
def test_base_api(self):
paddle.enable_static()
x_1 = paddle.static.data(
shape=[None, 1, 4, 5], dtype='int32', name='x_1'
)
paddle.concat([x_1, x_1], 0)

input_2 = np.random.random([2, 1, 4, 5]).astype("int32")
input_3 = np.random.random([2, 2, 4, 5]).astype("int32")
x_2 = paddle.static.data(shape=[2, 1, 4, 5], dtype='int32', name='x_2')
x_3 = paddle.static.data(shape=[2, 2, 4, 5], dtype='int32', name='x_3')
positive_1_int32 = paddle.tensor.fill_constant([1], "int32", 1)
positive_1_int64 = paddle.tensor.fill_constant([1], "int64", 1)
out_1 = paddle.concat([x_2, x_3], axis=1)
out_2 = paddle.concat([x_2, x_3], axis=positive_1_int32)
out_3 = paddle.concat([x_2, x_3], axis=positive_1_int64)

exe = base.Executor(place=base.CPUPlace())
[res_1, res_2, res_3] = exe.run(
base.default_main_program(),
feed={"x_1": input_2, "x_2": input_2, "x_3": input_3},
fetch_list=[out_1, out_2, out_3],
)
np.testing.assert_array_equal(
res_1, np.concatenate((input_2, input_3), axis=1)
)
np.testing.assert_array_equal(
res_2, np.concatenate((input_2, input_3), axis=1)
)
np.testing.assert_array_equal(
res_3, np.concatenate((input_2, input_3), axis=1)
)
with paddle.static.program_guard(paddle.static.Program()):
x_1 = paddle.static.data(
shape=[None, 1, 4, 5], dtype='int32', name='x_1'
)
paddle.concat([x_1, x_1], 0)

input_2 = np.random.random([2, 1, 4, 5]).astype("int32")
input_3 = np.random.random([2, 2, 4, 5]).astype("int32")
x_2 = paddle.static.data(
shape=[2, 1, 4, 5], dtype='int32', name='x_2'
)
x_3 = paddle.static.data(
shape=[2, 2, 4, 5], dtype='int32', name='x_3'
)
positive_1_int32 = paddle.tensor.fill_constant([1], "int32", 1)
positive_1_int64 = paddle.tensor.fill_constant([1], "int64", 1)
out_1 = paddle.concat([x_2, x_3], axis=1)
out_2 = paddle.concat([x_2, x_3], axis=positive_1_int32)
out_3 = paddle.concat([x_2, x_3], axis=positive_1_int64)

exe = base.Executor(place=base.CPUPlace())
[res_1, res_2, res_3] = exe.run(
paddle.static.default_main_program(),
feed={"x_1": input_2, "x_2": input_2, "x_3": input_3},
fetch_list=[out_1, out_2, out_3],
)
np.testing.assert_array_equal(
res_1, np.concatenate((input_2, input_3), axis=1)
)
np.testing.assert_array_equal(
res_2, np.concatenate((input_2, input_3), axis=1)
)
np.testing.assert_array_equal(
res_3, np.concatenate((input_2, input_3), axis=1)
)

@test_with_pir_api
def test_api(self):
paddle.enable_static()
x_1 = paddle.static.data(
shape=[None, 1, 4, 5], dtype='int32', name='x_1'
)
paddle.concat([x_1, x_1], 0)

input_2 = np.random.random([2, 1, 4, 5]).astype("int32")
input_3 = np.random.random([2, 2, 4, 5]).astype("int32")
x_2 = paddle.static.data(shape=[2, 1, 4, 5], dtype='int32', name='x_2')
x_3 = paddle.static.data(shape=[2, 2, 4, 5], dtype='int32', name='x_3')
positive_1_int32 = paddle.tensor.fill_constant([1], "int32", 1)
positive_1_int64 = paddle.tensor.fill_constant([1], "int64", 1)
negative_int64 = paddle.tensor.fill_constant([1], "int64", -3)
out_1 = paddle.concat(x=[x_2, x_3], axis=1)
out_2 = paddle.concat(x=[x_2, x_3], axis=positive_1_int32)
out_3 = paddle.concat(x=[x_2, x_3], axis=positive_1_int64)
out_4 = paddle.concat(x=[x_2, x_3], axis=negative_int64)

exe = paddle.static.Executor(place=paddle.CPUPlace())
[res_1, res_2, res_3, res_4] = exe.run(
paddle.static.default_main_program(),
feed={"x_1": input_2, "x_2": input_2, "x_3": input_3},
fetch_list=[out_1, out_2, out_3, out_4],
)
np.testing.assert_array_equal(
res_1, np.concatenate((input_2, input_3), axis=1)
)
np.testing.assert_array_equal(
res_2, np.concatenate((input_2, input_3), axis=1)
)
np.testing.assert_array_equal(
res_3, np.concatenate((input_2, input_3), axis=1)
)
np.testing.assert_array_equal(
res_4, np.concatenate((input_2, input_3), axis=1)
)
with paddle.static.program_guard(paddle.static.Program()):
x_1 = paddle.static.data(
shape=[None, 1, 4, 5], dtype='int32', name='x_1'
)
paddle.concat([x_1, x_1], 0)

input_2 = np.random.random([2, 1, 4, 5]).astype("int32")
input_3 = np.random.random([2, 2, 4, 5]).astype("int32")
x_2 = paddle.static.data(
shape=[2, 1, 4, 5], dtype='int32', name='x_2'
)
x_3 = paddle.static.data(
shape=[2, 2, 4, 5], dtype='int32', name='x_3'
)
positive_1_int32 = paddle.tensor.fill_constant([1], "int32", 1)
positive_1_int64 = paddle.tensor.fill_constant([1], "int64", 1)
negative_int64 = paddle.tensor.fill_constant([1], "int64", -3)
out_1 = paddle.concat(x=[x_2, x_3], axis=1)
out_2 = paddle.concat(x=[x_2, x_3], axis=positive_1_int32)
out_3 = paddle.concat(x=[x_2, x_3], axis=positive_1_int64)
out_4 = paddle.concat(x=[x_2, x_3], axis=negative_int64)

exe = paddle.static.Executor(place=paddle.CPUPlace())
[res_1, res_2, res_3, res_4] = exe.run(
paddle.static.default_main_program(),
feed={"x_1": input_2, "x_2": input_2, "x_3": input_3},
fetch_list=[out_1, out_2, out_3, out_4],
)
np.testing.assert_array_equal(
res_1, np.concatenate((input_2, input_3), axis=1)
)
np.testing.assert_array_equal(
res_2, np.concatenate((input_2, input_3), axis=1)
)
np.testing.assert_array_equal(
res_3, np.concatenate((input_2, input_3), axis=1)
)
np.testing.assert_array_equal(
res_4, np.concatenate((input_2, input_3), axis=1)
)

def test_imperative(self):
in1 = np.array([[1, 2, 3], [4, 5, 6]])
Expand Down Expand Up @@ -729,8 +742,8 @@ def setUp(self):
def set_program(self, use_base_api):
paddle.enable_static()
if use_base_api:
self.program = base.Program()
with base.program_guard(self.program):
self.program = paddle.static.Program()
with paddle.static.program_guard(self.program):
input = paddle.assign(self.x)
tensor_array = paddle.tensor.create_array(dtype='float32')
zero = paddle.tensor.fill_constant(
Expand Down
Loading