Skip to content

Add support for gradient clipping#11697

Merged
baijumeswani merged 9 commits intotraining_dev/on_device_pocfrom
bmeswani/adamw_grad_clipping
Jun 22, 2022
Merged

Add support for gradient clipping#11697
baijumeswani merged 9 commits intotraining_dev/on_device_pocfrom
bmeswani/adamw_grad_clipping

Conversation

@baijumeswani
Copy link
Contributor

@baijumeswani baijumeswani commented Jun 1, 2022

This pull request adds support for gradient clipping and also integrates with the AdamWOptimizer changes introduced in #11506.

To generate the onnx model for the AdamW optimizer, users can simply do

adamw = onnxblock.optim.AdamW()
with onnxblock.onnx_model() as accessor:
    output_names = adamw(training_model.parameters())

and to add gradient clipping:

adamw = onnxblock.optim.AdamW(clip_grad=onnxblock.optim.ClipGradNorm(2.5))
with onnxblock.onnx_model() as accessor:
    output_names = adamw(training_model.parameters())

The gradient clipping looks like:
image

@baijumeswani baijumeswani force-pushed the bmeswani/adamw_grad_clipping branch from ab805bd to 74b1bc6 Compare June 1, 2022 06:14
@lgtm-com
Copy link

lgtm-com bot commented Jun 1, 2022

This pull request fixes 3 alerts when merging 74b1bc6 into 1c316d0 - view on LGTM.com

fixed alerts:

  • 2 for Unused global variable
  • 1 for Unused import

@baijumeswani baijumeswani force-pushed the bmeswani/adamw_grad_clipping branch from 74b1bc6 to 57f6dfb Compare June 1, 2022 16:30
@lgtm-com
Copy link

lgtm-com bot commented Jun 1, 2022

This pull request fixes 3 alerts when merging 57f6dfb into 1c316d0 - view on LGTM.com

fixed alerts:

  • 2 for Unused global variable
  • 1 for Unused import

@lgtm-com
Copy link

lgtm-com bot commented Jun 1, 2022

This pull request fixes 3 alerts when merging 79c49e8 into 1c316d0 - view on LGTM.com

fixed alerts:

  • 2 for Unused global variable
  • 1 for Unused import

@pengwa
Copy link
Contributor

pengwa commented Jun 2, 2022

SequenceConstruct currently in ORT will bring copies for every tensors in its inputs. Here in optimizer graph, we have all grad as inputs, so have to construct a sequence then feed it into optimizer.

Another way I would think is, we can pass in the sequence of grads, params, momentums directly as inputs of optimizer graph. This is feasible we manage the sequence in Step()

@lgtm-com
Copy link

lgtm-com bot commented Jun 6, 2022

This pull request fixes 3 alerts when merging d5b31c9 into 1c316d0 - view on LGTM.com

fixed alerts:

  • 2 for Unused global variable
  • 1 for Unused import

@baijumeswani baijumeswani force-pushed the bmeswani/adamw_grad_clipping branch from d5b31c9 to cc36334 Compare June 7, 2022 05:21
@lgtm-com
Copy link

lgtm-com bot commented Jun 7, 2022

This pull request fixes 3 alerts when merging cc36334 into 1c316d0 - view on LGTM.com

fixed alerts:

  • 2 for Unused global variable
  • 1 for Unused import

@baijumeswani baijumeswani force-pushed the bmeswani/adamw_grad_clipping branch from cc36334 to 9eb06b0 Compare June 8, 2022 21:59
@lgtm-com
Copy link

lgtm-com bot commented Jun 8, 2022

This pull request fixes 3 alerts when merging 9eb06b0 into 1c316d0 - view on LGTM.com

fixed alerts:

  • 2 for Unused global variable
  • 1 for Unused import

@baijumeswani baijumeswani force-pushed the bmeswani/adamw_grad_clipping branch from 6d54c92 to ae60521 Compare June 9, 2022 04:54
@lgtm-com
Copy link

lgtm-com bot commented Jun 9, 2022

This pull request fixes 3 alerts when merging ae60521 into 540935a - view on LGTM.com

fixed alerts:

  • 2 for Unused global variable
  • 1 for Unused import

@lgtm-com
Copy link

lgtm-com bot commented Jun 10, 2022

This pull request fixes 3 alerts when merging 26e7df2 into 540935a - view on LGTM.com

fixed alerts:

  • 2 for Unused global variable
  • 1 for Unused import

@baijumeswani baijumeswani force-pushed the bmeswani/adamw_grad_clipping branch from 26e7df2 to e1391e7 Compare June 15, 2022 21:36
@lgtm-com
Copy link

lgtm-com bot commented Jun 15, 2022

This pull request fixes 3 alerts when merging e1391e7 into f63e28c - view on LGTM.com

fixed alerts:

  • 2 for Unused global variable
  • 1 for Unused import

@baijumeswani baijumeswani force-pushed the bmeswani/adamw_grad_clipping branch from e1391e7 to 7657940 Compare June 16, 2022 23:52
@lgtm-com
Copy link

lgtm-com bot commented Jun 17, 2022

This pull request fixes 3 alerts when merging 7657940 into f63e28c - view on LGTM.com

fixed alerts:

  • 2 for Unused global variable
  • 1 for Unused import

const std::vector<std::shared_ptr<onnxruntime::IExecutionProviderFactory>>& provider_factories) {
std::vector<std::shared_ptr<onnxruntime::IExecutionProvider>> execution_providers;
for (const auto& factory : provider_factories) {
execution_providers.emplace_back(std::move(factory->CreateProvider()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we are creating the provider here it means the same instance of the provider will be shared among training, eval and optimizer session right?

For inference scenarios we don't share ep instance among inference sessions but in ortmodule we do... Just wondering do we know any implications of sharing the provider instance?

Copy link
Contributor Author

@baijumeswani baijumeswani Jun 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if there are any guidelines around sharing the provider among different inference sessions. If there is, please let me know.

I think from a API design viewpoint, we should think of TrainingSession as an equivalent to the InferenceSession just in the training world. Extending that further, all components inside the TrainingSession should share the same instance of the provider and by extension the allocator (per provider). But if this is not the expected usage for the provider, we can change it.

@pengwa, @ashbhandare please provide any insight that may be relevant.

@baijumeswani baijumeswani force-pushed the bmeswani/adamw_grad_clipping branch from 7657940 to 83c24a6 Compare June 18, 2022 00:38
@lgtm-com
Copy link

lgtm-com bot commented Jun 18, 2022

This pull request fixes 3 alerts when merging 83c24a6 into a3ec2d6 - view on LGTM.com

fixed alerts:

  • 2 for Unused global variable
  • 1 for Unused import

const auto& tensor_seq = feed.Get<TensorSeq>();
if (tensor_seq.Size() != std::size_t{0}) {
feed_locations[i] = tensor_seq.Get(0).Location().device;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we fix the output too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we should. Let me get this in a follow-up pull request so that this PR can focus specifically on the use case for on device training. Will create a work item and complete in another PR.

: named_parameters_{parameters},
module_{std::make_unique<Module>(model_identifiers.train_model, named_parameters_,
session_options, session_env, providers, model_identifiers.eval_model)},
optimizer_{model_identifiers.optim_model.has_value()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: put int the body?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason to put it in the body as opposed to initialization list?

@lgtm-com
Copy link

lgtm-com bot commented Jun 20, 2022

This pull request fixes 3 alerts when merging 4e0d0e7 into a3ec2d6 - view on LGTM.com

fixed alerts:

  • 2 for Unused global variable
  • 1 for Unused import

tensor_location.mem_type == OrtMemTypeCPUOutput) {
memset(p_tensor->MutableDataRaw(), 0, p_tensor->SizeInBytes());
} else if (tensor_location.device.Type() == OrtDevice::GPU) {
// Use a tensor on cpu and copy it over to the desired device using
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a perf efficient way to do this? Given this is during initialization it may be OK but wondering whether we can use CudaMemset here?

Copy link
Contributor Author

@baijumeswani baijumeswani Jun 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

utils.cc is a part of the onnxruntime_session target. This target is currently not linked against cuda libraries because onnxruntime_session should not care about the providers that are supported. It should be providers agnostic.

We can probably add the target_link_libraries for onnxruntime_session against the cuda libraries to work around this. But this might not be the right solution. Instead, we could add a method in the execution providers for memset that does the memset on the device.

I think this should be done separately in another PR where the focus can be only on this functionality.

Copy link
Contributor

@pengwa pengwa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@baijumeswani baijumeswani merged commit fac8dae into training_dev/on_device_poc Jun 22, 2022
@baijumeswani baijumeswani deleted the bmeswani/adamw_grad_clipping branch June 22, 2022 17:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants