Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 369b66d

Browse files
ZhennanQinzhreshold
authored andcommitted
Improve cached_op performance for static mode (#14785)
* Fix cached_op * try to fix ci * Fix CI * Fix ci
1 parent 5dd9fa2 commit 369b66d

File tree

4 files changed

+32
-21
lines changed

4 files changed

+32
-21
lines changed

src/executor/attach_op_execs_pass.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ class FComputeExExecutor : public OpExecutor {
261261
ExecType exec_type_;
262262
};
263263

264-
void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i) {
264+
void CreateOpExecs(const Graph& g, OpExecVector* p_ret, OpStateVector* p_state, size_t i) {
265265
using nnvm::DTypeVector;
266266
using mxnet::ShapeVector;
267267
using nnvm::FMutateInputs;
@@ -302,6 +302,10 @@ void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i) {
302302

303303
OpStatePtr state = fcreate_op_state[op](
304304
inode.source->attrs, vctx[i], ishape, itype);
305+
if (p_state) {
306+
CHECK_GT(p_state->size(), i);
307+
p_state->at(i) = state;
308+
}
305309
FStatefulComputeEx fcompute_ex = common::GetFCompute<FStatefulComputeEx>(
306310
op, "FStatefulComputeEx", vctx[i]);
307311
// FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx
@@ -359,7 +363,7 @@ Graph AttachOpExecs(Graph g) {
359363
const auto& idx = g.indexed_graph();
360364
OpExecVector ret(idx.num_nodes());
361365
for (size_t i = 0; i < idx.num_nodes(); ++i) {
362-
CreateOpExecs(g, &ret, i);
366+
CreateOpExecs(g, &ret, nullptr, i);
363367
}
364368
g.attrs["op_execs"] = std::make_shared<nnvm::any>(ret);
365369
return g;

src/executor/exec_pass.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,12 @@ class OpExecutor {
9898
*/
9999
using OpExecVector = std::vector<std::shared_ptr<OpExecutor> >;
100100

101+
/*!
102+
* \brief per node vector of operator states.
103+
* \note stored under attribute "op_states"
104+
*/
105+
using OpStateVector = std::vector<OpStatePtr>;
106+
101107
/*!
102108
* \brief per node context vector
103109
* \node stored under "context"
@@ -115,9 +121,10 @@ using DevMaskVector = std::vector<int>;
115121
*
116122
* \param g input graph
117123
* \param p_ret OpExecVector for input and output
124+
* \param p_state OpStateVector if it has.
118125
* \param i the id of the node
119126
*/
120-
void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i);
127+
void CreateOpExecs(const Graph& g, OpExecVector* p_ret, OpStateVector* p_state, size_t i);
121128
/*!
122129
* \brief Attach OpExecutor to the graph attributes.
123130
*

src/imperative/cached_op.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ bool CachedOp::CheckDynamicShapeExists(const Context& default_ctx,
285285
CheckAndInferShape(&g, std::move(shape_inputs), true,
286286
{0, 0}, {0, 0},
287287
&contain_dynamic_shape);
288-
if (erase_result) {
288+
if (contain_dynamic_shape && erase_result) {
289289
g.attrs.erase("shape");
290290
g.attrs.erase("shape_inputs");
291291
}
@@ -603,7 +603,7 @@ void CachedOp::StaticInitExec(
603603
}
604604
} else {
605605
for (size_t i = start_nid; i < end_nid; ++i) {
606-
exec::CreateOpExecs(g, &state.execs, i);
606+
exec::CreateOpExecs(g, &state.execs, &state.op_states, i);
607607
}
608608
exec::AttachOpResources(g, state.execs, start_nid, end_nid);
609609

@@ -705,8 +705,10 @@ void CachedOp::StaticRunOps(
705705
arg_shapes.emplace_back(ndinput->shape());
706706
arg_dtypes.emplace_back(ndinput->dtype());
707707
}
708-
state.op_states[i] = createop[node.source->op()](
709-
node.source->attrs, default_ctx, arg_shapes, arg_dtypes);
708+
if (!state.op_states[i]) {
709+
state.op_states[i] =
710+
createop[node.source->op()](node.source->attrs, default_ctx, arg_shapes, arg_dtypes);
711+
}
710712
Imperative::Get()->InvokeOp(
711713
default_ctx, node.source->attrs, ndinputs, ndoutputs, req,
712714
dispatch_mode, state.op_states[i]);

src/imperative/imperative_utils.h

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -595,23 +595,21 @@ inline bool CheckAndInferShape(nnvm::Graph* p_g, mxnet::ShapeVector&& shapes,
595595
*contain_unknown = false;
596596
}
597597
nnvm::Graph& g = *p_g;
598-
if (use_inputs) {
599-
if (g.attrs.count("shape_inputs") &&
600-
g.GetAttr<mxnet::ShapeVector>("shape_inputs") == shapes) return true;
601-
} else if (g.attrs.count("shape")) {
598+
if (g.attrs.count("shape")) {
602599
const auto& prev_shapes = g.GetAttr<mxnet::ShapeVector>("shape");
603-
CHECK_EQ(prev_shapes.size(), shapes.size());
604-
bool match = true;
605-
for (size_t i = 0; i < shapes.size(); ++i) {
606-
if (i == entry_range.first) {
607-
i = entry_range.second;
608-
if (i >= shapes.size()) break;
600+
if (prev_shapes.size() == shapes.size()) {
601+
bool match = true;
602+
for (size_t i = 0; i < shapes.size(); ++i) {
603+
if (i == entry_range.first) {
604+
i = entry_range.second;
605+
if (i >= shapes.size()) break;
606+
}
607+
if (shapes[i] == prev_shapes[i]) continue;
608+
match = false;
609+
break;
609610
}
610-
if (shapes[i] == prev_shapes[i]) continue;
611-
match = false;
612-
break;
611+
if (match) return true;
613612
}
614-
if (match) return true;
615613
}
616614
g.attrs.erase("shape");
617615
g.attrs.erase("shape_inputs");

0 commit comments

Comments
 (0)