-
Notifications
You must be signed in to change notification settings - Fork 6.7k
Making mixed precision work with all optimizers #7654
Conversation
936e2bd to
5becd3d
Compare
python/mxnet/optimizer.py
Outdated
| """ | ||
| weight_master_copy = None | ||
| if self.multi_precision and weight.dtype == numpy.float16: | ||
| weight_master_copy = array(weight, ctx=weight.context, dtype=numpy.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weight.astype(float32)
python/mxnet/optimizer.py
Outdated
| # Wrapper for mixed precision | ||
| weight_master_copy = state[0] | ||
| original_state = state[1] | ||
| grad32 = array(grad, ctx=grad.context, dtype=numpy.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
grad.astype
python/mxnet/optimizer.py
Outdated
| original_state = state[1] | ||
| grad32 = array(grad, ctx=grad.context, dtype=numpy.float32) | ||
| self.update(index, weight_master_copy, grad32, original_state) | ||
| weight[:] = weight_master_copy.astype(weight.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nd.cast(weight, dtype=weight.dtype, out=weight) to avoid a copy
python/mxnet/optimizer.py
Outdated
| The state associated with the weight. | ||
| """ | ||
|
|
||
| def create_mp_state(self, index, weight): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use full name create_state_multi_precision
python/mxnet/optimizer.py
Outdated
| """ | ||
|
|
||
| def create_mp_state(self, index, weight): | ||
| """Creates auxiliary state for a given weight, including FP32 master |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
including FP32 master copy if necessary.
->
including fp32 high precision copy if original weight is fp16
python/mxnet/optimizer.py
Outdated
| """ | ||
| raise NotImplementedError() | ||
|
|
||
| def update_mp(self, index, weight, grad, state): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
update_multi_precision
python/mxnet/optimizer.py
Outdated
| return momentum | ||
|
|
||
| def update(self, index, weight, grad, state): | ||
| def create_state(self, index, weight): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There should be a _create_state_impl function that both create_mp_state and create_state use. And create_state should keep the original behavior (always multi_precision=False)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think _create_state_impl is necessary, since this is basically create_state.
| rg_options = [{}, {'rescale_grad': 0.14}, {'rescale_grad': 0.8}] | ||
| wd_options = [{}, {'wd': 0.03}, {'wd': 0.05}, {'wd': 0.07}] | ||
| mp_options = [{}, {'multi_precision': False}, {'multi_precision': True}] | ||
| for dtype in [np.float16, np.float32]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is exploding exponentially...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And that is fine - previously there was a list of test cases where it was really hard to make sure that all possible combinations of parameters are tested (and all possible combinations should be tested). This design guarantees that.
6bd2368 to
df3c841
Compare
df3c841 to
9cab44a
Compare
|
I think you need to merge in master |
|
I did (and it fixed a lot of errors) but I still got an unrelated segfault :-(. |
|
@piiswrong Passed! |
|
Thanks |
* Making mixed precision work with all optimizers * Restart CI * Restart CI
* Making mixed precision work with all optimizers * Restart CI * Restart CI
* Making mixed precision work with all optimizers * Restart CI * Restart CI
No description provided.