-
Notifications
You must be signed in to change notification settings - Fork 26.3k
added initialization schemes in torch.nn.init #833
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
torch/nn/init.py
Outdated
| return tensor | ||
| else: | ||
| fan_in, _ = _calculate_fan_in_and_fan_out(tensor) | ||
| std = gain * np.sqrt(1.0 / fan_in) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
Darn, looks like the build fails because |
|
@alykhantejani you can keep scipy parts in the tests. Do the following: at the top of test_nn.py add these lines: Then, in the tests where you have scipy references, to the tests, add the annotation:
Then, here: https://github.com/pytorch/pytorch/blob/master/.travis.yml#L21 |
torch/nn/init.py
Outdated
| return tensor.normal_(0, std) | ||
|
|
||
|
|
||
| def kaiming_uniform(tensor, gain=1): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/init.py
Outdated
| return tensor | ||
| else: | ||
| fan_in, _ = _calculate_fan_in_and_fan_out(tensor) | ||
| std = gain * np.sqrt(1.0 / fan_in) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
test/test_nn.py
Outdated
| expected_std = gain * np.sqrt(2.0 / ((tensor_shape[1] + tensor_shape[0]) * receptive_field)) | ||
| assert self._is_normal(input_tensor, 0, expected_std) | ||
|
|
||
| def test_kaiming_unifrom_errors_on_inputs_smaller_than_2d(self): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
@soumith The |
|
hmmm, googling around seems to be a known issue. |
6be9020 to
44b7f9c
Compare
colesbury
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks pretty good. I think the most important thing is to remove the numpy dependency, since PyTorch does not currently require numpy.
torch/nn/init.py
Outdated
| return tensor | ||
| else: | ||
| fan_in, _ = _calculate_fan_in_and_fan_out(tensor) | ||
| std = gain * np.sqrt(1.0 / fan_in) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/init.py
Outdated
| @@ -1,0 +1,240 @@ | |||
| import numpy as np | |||
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/init.py
Outdated
| """Fills the input Tensor or Variable with values according to the method described in "Understanding the difficulty of training | ||
| deep feedforward neural networks" - Glorot, X. and Bengio, Y., using a uniform distribution. | ||
| The resulting tensor will have values sampled from U(-a, a) where a = gain * sqrt(2/(fan_in + fan_out)) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/init.py
Outdated
| num_input_fmaps = tensor.size(1) | ||
| num_output_fmaps = tensor.size(0) | ||
| receptive_field_size = np.prod(tensor.numpy().shape[2:]) | ||
| receptive_field_size = reduce(mul, (tensor.numpy().shape[2:])) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/init.py
Outdated
| raise ValueError("Only tensors with 2 or more dimensions are supported.") | ||
|
|
||
| flattened_shape = (tensor.size(0), int(np.prod(tensor.numpy().shape[1:]))) | ||
| flattened_shape = (tensor.size(0), int(reduce(mul, tensor.numpy().shape[1:]))) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
test/test_nn.py
Outdated
| if rows > cols: | ||
| assert np.allclose(np.dot(flattened_tensor.T, flattened_tensor), np.eye(cols) * gain ** 2, | ||
| atol=1e-6) | ||
| assert torch.dist(torch.mm(flattened_tensor.t(), flattened_tensor), |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/init.py
Outdated
| else: | ||
| num_input_fmaps = tensor.size(1) | ||
| num_output_fmaps = tensor.size(0) | ||
| receptive_field_size = reduce(mul, (tensor.numpy().shape[2:])) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/init.py
Outdated
| if isinstance(tensor, Variable): | ||
| uniform(tensor.data, a=a, b=b) | ||
| return tensor | ||
| else: |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
I've answered most reviews, but I still think there are the following open questions:
|
|
test/test_nn.py
Outdated
|
|
||
| fan_in = input_tensor.size(1) | ||
| if input_tensor.dim() > 2: | ||
| fan_in *= input_tensor[0][0].numel() |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
regarding fan_out, fan_in stabilizes forward activations, fan_out stabilizes gradients on backward, so it depends on architecture which one works better. |
|
String arguments sound good to me. |
|
@szagoruyko @apaszke I've now added a |
|
@pytorchbot add to whitelist |
|
Thank you! |
Co-authored-by: Ryan Spring <[email protected]>
torch.nn.initTestNNInitTestCasetotest_nn.pyOnce merged we can probably close #101