Skip to content

Redesign InPlaceAccumulator op#11842

Merged
pengwa merged 4 commits intotraining_dev/on_device_pocfrom
aibhanda/accumulator
Jun 24, 2022
Merged

Redesign InPlaceAccumulator op#11842
pengwa merged 4 commits intotraining_dev/on_device_pocfrom
aibhanda/accumulator

Conversation

@ashbhandare
Copy link
Contributor

This PR makes the output buffer of InPlaceAccumulatorV2 op optional, and introduces an additional optional input 'overwrite'.
Also added op tests and onnxblock changes to use the new op

@lgtm-com
Copy link

lgtm-com bot commented Jun 14, 2022

This pull request introduces 1 alert when merging 184150c into fb88efb - view on LGTM.com

new alerts:

  • 1 for Commented-out code

memcpy(accumulation_buffer_data, updated_data, new_value->SizeInBytes());
} else {
// Copy from Add CPU kernel
ProcessBroadcastSpanFuncs funcs{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious if it is possible we reuse same ProcessBroadcastSpanFuncs instance for both v1 and v2 kernels.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe but for that I can think of one way by holding a static ProcessBroadcastSpanFuncs which will be used by both, but not sure if it is worth it. let me know if any other better way

@@ -132,9 +132,9 @@ Status Module::TrainStep(const std::vector<OrtValue>& inputs, std::vector<OrtVal
feeds.insert(feeds.end(), weights_.begin(), weights_.end());
feeds.insert(feeds.end(), gradients_.begin(), gradients_.end());
// TODO: consider maintaining this as ortvalue instead of bool
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it is easy to implement the TODO, shall we do it in this PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the question is if we should do it, is it better to hold an ortvalue and update underlying buffer by unwrapping it every step, or to wrap it every step

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I did not realize this. How about hold two ORTValues directly., one is True, one is False. Never mind, let's refine it later.


test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCpuExecutionProvider});
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we also add >1D input data cases?

@ashbhandare ashbhandare force-pushed the aibhanda/accumulator branch from 184150c to 5a23f40 Compare June 16, 2022 22:55
@lgtm-com
Copy link

lgtm-com bot commented Jun 17, 2022

This pull request introduces 1 alert when merging 5a23f40 into f63e28c - view on LGTM.com

new alerts:

  • 1 for Commented-out code

std::vector<std::vector<int64_t>> x_shapes = {
{4, 3, 2}, {4, 3, 2}, {4, 3, 2}, {4, 3, 2}, {4, 3, 2}, {4, 3, 2}, {4, 3, 2}, {4, 3, 2},
{4, 3, 2},
{4, 3, 2},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we'd better avoid making this kind of change to make it easier for us when merge back to master branch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we'll be cleaning up anyway at merge time, let's handle it then? will make sure to not make such changes further

@baijumeswani
Copy link
Contributor

We also need to update the GetEvalModeOutputCount and GetTrainModeOutputCount to not include the bool tensors as outputs, right?

@ashbhandare
Copy link
Contributor Author

We also need to update the GetEvalModeOutputCount and GetTrainModeOutputCount to not include the bool tensors as outputs, right?

This is already done for GetTrainModeOutputCount. The Eval model should only have user outputs anyway so no change required

@ashbhandare ashbhandare force-pushed the aibhanda/accumulator branch from f9358ba to 594a0e4 Compare June 24, 2022 00:54
@pengwa pengwa merged commit c2fd5cc into training_dev/on_device_poc Jun 24, 2022
@pengwa pengwa deleted the aibhanda/accumulator branch June 24, 2022 09:11
@baijumeswani baijumeswani restored the aibhanda/accumulator branch July 5, 2022 17:50
@baijumeswani baijumeswani deleted the aibhanda/accumulator branch July 5, 2022 18:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants