Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit dd4eaf5

Browse files
ZhennanQinpengzhao-intel
authored andcommitted
[MKLDNN]Fix reorder2default (#16602)
* Fix reorder2default Change-Id: I74c87af9535f6264e6d1ea7eaed089a6480a3358 * fix Change-Id: I6d07b43b520a47e7c78bd4b4b6390f5fb95e6957 * Fix Change-Id: Id72f25c34291be4711f55569c6d61467edd6113d * Fix CI Change-Id: I8c33a82555d5ace2d0b682c1e3eefa13f3a44768 * Run CI Change-Id: Ie8a6dab80ef91c0337cafbae4e3db277e0c7ebf7
1 parent bde443e commit dd4eaf5

File tree

1 file changed

+15
-10
lines changed

1 file changed

+15
-10
lines changed

src/ndarray/ndarray.cc

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1635,11 +1635,13 @@ void NDArray::Save(dmlc::Stream *strm) const {
16351635
nd_cpu.WaitToRead();
16361636
save_data = nd_cpu.data();
16371637
} else {
1638+
#if MXNET_USE_MKLDNN == 1
1639+
// For mkldnn, a copy of *this can ensure no write access pending on *this.
1640+
nd_cpu = this->Copy(Context::CPU());
1641+
nd_cpu.WaitToRead();
1642+
#else
16381643
this->WaitToRead();
16391644
nd_cpu = *this;
1640-
#if MXNET_USE_MKLDNN == 1
1641-
if (nd_cpu.IsMKLDNNData())
1642-
nd_cpu = nd_cpu.Reorder2Default();
16431645
#endif
16441646
save_data = nd_cpu.data();
16451647
}
@@ -2024,15 +2026,18 @@ void NDArray::SyncCopyToCPU(void *data, size_t size) const {
20242026
TBlob dst(data, dshape, cpu::kDevMask, this->dtype_, 0); // NOLINT(*)
20252027

20262028
if (this->ctx().dev_mask() == cpu::kDevMask) {
2027-
this->WaitToRead();
2028-
RunContext rctx{this->ctx(), nullptr, nullptr, false};
2029-
NDArray src = *this;
2029+
Engine::Get()->PushAsync(
2030+
[&](RunContext rctx, Engine::CallbackOnComplete on_complete) {
2031+
RunContext ctx{this->ctx(), nullptr, nullptr, false};
2032+
NDArray src = *this;
20302033
#if MXNET_USE_MKLDNN == 1
2031-
if (src.IsMKLDNNData())
2032-
src = this->Reorder2Default();
2034+
src = this->Reorder2Default();
20332035
#endif
2034-
ndarray::Copy<cpu, cpu>(src.data(), &dst,
2035-
Context::CPU(), Context::CPU(), rctx);
2036+
ndarray::Copy<cpu, cpu>(src.data(), &dst, Context::CPU(), Context::CPU(), ctx);
2037+
on_complete();
2038+
},
2039+
this->ctx(), {this->var()}, {}, FnProperty::kNormal, 0, "SyncCopyCPU2CPU");
2040+
this->WaitToWrite();
20362041
} else {
20372042
#if MXNET_USE_CUDA
20382043
Engine::Get()->PushAsync(

0 commit comments

Comments
 (0)