Skip to content

Conversation

@xiaomengy
Copy link
Contributor

Summary:
Add gelu activation forward on CPU in pytorch

Compare to current python implemented version of gelu in BERT model like

def gelu(self, x):
x * 0.5 * (1.0 + torch.erf(x / self.sqrt_two))

The torch.gelu function can reduce the forward time from 333ms to 112ms (with MKL) / 133ms (without MKL) for input size = [64, 128, 56, 56] on a devvm.

Differential Revision: D15400974

@pytorchbot pytorchbot added module: cpu CPU specific problem (e.g., perf, algorithm) module: operators labels May 17, 2019
@soumith soumith requested a review from gchanan May 18, 2019 22:31
Copy link
Contributor

@soumith soumith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks pretty good!

You need to add documentation / stub function in functional.py
For example, see GLU : https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L954-L976

Copy link
Contributor

@soumith soumith May 18, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing out this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

once you made the changes above, gelu will be torch.nn.functional.gelu

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@fmassa
Copy link
Member

fmassa commented May 19, 2019

Could you also add a GPU implementation for it? We try to keep parity between CPU and CUDA as much as possible.

Also, what's the story around the JIT fusion for CPU? It's disabled by default in PyTorch, but the fuser is currently able to handle all the operations in here inside a single fusion group, and seems to bring some significant speedups for CPU (and the same also applies to CUDA)

import torch

# should we enable it by default?
torch._C._jit_override_can_fuse_on_cpu(True)

def gelu(x):
    sqrt_two = 1.4142135623730951
    return x * 0.5 * (1.0 + torch.erf(x / sqrt_two))

@torch.jit.script
def gelu2(x):
    sqrt_two = 1.4142135623730951
    return x * 0.5 * (1.0 + torch.erf(x / sqrt_two))

x = torch.rand(64, 128, 56, 56)

# compile gelu2
gelu2(x)

and gives

%timeit gelu(x)
> 92.6 ms ± 647 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

%timeit gelu2(x)
> 18.5 ms ± 518 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we need to check that the tensors are contiguous before dispatching to the MKL-optimized codepath?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs a comment; I would expect something like:

  1. describe if MKL supports non-contiguous inputs or not.
    2a) if it doesn't, when is it worth it to make the tensor contiguous to do the op? Does it ever or should I just check contiguity in the MKL pass?
    2b) if it does, why don't I just pass in the tensor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will add the cuda impl in the next PR. Thanks for the advice.

Copy link
Member

@fmassa fmassa May 19, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also test non-contiguous tensors? Like

x = torch.rand(50, 50)[:, ::2]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this always coverts the input to be contiguous, although it probably doesn't need to since it's now using TensorIterator.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added test for non-contiguous inputs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this .contiguous() still needed after your latest changes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so, since currently our approach is using for-loop to take advantage of autovectorization for the non-mkl path. Actually currently with or without MKL, the performance is pretty similar.

@pytorchbot pytorchbot added the module: nn Related to torch.nn label May 27, 2019
@xiaomengy
Copy link
Contributor Author

For the current version, the MKL path in the same machine need 109ms while the non-MKL path need 112ms, currently they are quite similar.

Copy link
Contributor

@soumith soumith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ship when tests pass.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's not conflicting, as in you can use TensorIterator, where the specialization for a contiguous block can call v?CdfNorm if MKL is available. It's decoupled from using vec256.h.

But just do that in a follow-up diff, instead of this one, as it's an independent work unit anyways.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where \Phi(x) is the Cumulative Distribution Function for Gaussian Distribution.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Done.

@pytorchbot pytorchbot added the module: docs Related to our documentation, both in docs/ and docblocks label Jun 2, 2019
Summary:
Pull Request resolved: pytorch#20665

Add gelu activation forward on CPU in pytorch

Compare to current python implemented version of gelu in BERT model like

  def gelu(self, x):
      x * 0.5 * (1.0 + torch.erf(x / self.sqrt_two))

The torch.nn.functional.gelu function can reduce the forward time from 333ms to 109ms (with MKL) / 112ms (without MKL) for input size = [64, 128, 56, 56] on a devvm.

Reviewed By: zheng-xq

Differential Revision: D15400974

fbshipit-source-id: 78399123aef803376a2459d487d44557126070ac
@xiaomengy xiaomengy deleted the export-D15400974 branch June 2, 2019 16:23
zdevito pushed a commit to zdevito/ATen that referenced this pull request Jun 2, 2019
Summary:
Pull Request resolved: pytorch/pytorch#20665

Add gelu activation forward on CPU in pytorch

Compare to current python implemented version of gelu in BERT model like

  def gelu(self, x):
      x * 0.5 * (1.0 + torch.erf(x / self.sqrt_two))

The torch.nn.functional.gelu function can reduce the forward time from 333ms to 109ms (with MKL) / 112ms (without MKL) for input size = [64, 128, 56, 56] on a devvm.

Reviewed By: zheng-xq

Differential Revision: D15400974

fbshipit-source-id: f606b43d1dd64e3c42a12c4991411d47551a8121
@facebook-github-bot
Copy link
Contributor

This pull request has been merged in 93ae040.

namespace {

template <typename T>
void GeluCUDAKernelImplInternal(const Tensor& X, Tensor* Y) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you avoid passing Tensors as Tensor *? It's not standard C++ PyTorch code (is there any other example in the codebase that does this?). Depending on the use case, you can use const Tensor &, Tensor & or Tensor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just want to make things clear that pass by pointer means it is a output variable or it will be changed in the function. Just for readability reason. I can change it to Tensor&.

https://google.github.io/styleguide/cppguide.html#Output_Parameters

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See: https://github.com/pytorch/pytorch/wiki/Writing-Python-in-cpp-(a-manifesto)

Also, if you look at any _out function (translates to python with out= parameter), we use Tensor& already. Although note that this kind of bogus, because reassigning the reference is almost never correct, but you should just follow the convention for now.

But the right way to think about this is you are already passing a (smart) pointer. Passing a pointer to a smart pointer is almost never what you want. And as noted in the link, const is essentially meaningless here, so trying to use static types for readability doesn't really work either (unless you implement ConstTensor).

@BramVanroy
Copy link

Is this implemented in 1.2.0? I can find it in documentation (https://pytorch.org/docs/stable/nn.functional.html) but I can't import it or find it in my installed library.

@cpuhrsch
Copy link
Contributor

@BramVanroy - is this it?

@BramVanroy
Copy link

@BramVanroy - is this it?

Odd. My IDE (PyCharm) underlines gelu in red and says "Cannot find reference 'gelu' in functional.pyi", but when I run the code it seems to import just fine.

from torch.nn.functional import gelu

@BramVanroy
Copy link

BramVanroy commented Sep 28, 2019

PyTorch's current implementation is

def gelu(x):
    return 0.5 * x * (1.0 + torch.erf(x / math.sqrt(2.0)))

rather than

def gelu(x):
    return 0.5 * x * (1 + torch.tanh(math.sqrt(math.pi / 2) * (x + 0.044715 * x ** 3)))

Correct? I've seen both being mentioned, but I'm not sure which one is implemented in PyTorch. IIRC Google's BERT originally uses the former, and OpenAI's GPT the latter.

@xiaomengy
Copy link
Contributor Author

PyTorch implementation is the original definition of GELU which is x * P(X <= x) where X ~ N(0, 1). This one is mathematically equivalent to 0.5 * x * (1.0 + torch.erf(x / math.sqrt(2.0))). The tanh is an approximation of GELU which may lead to better performance as mentioned in https://arxiv.org/pdf/1606.08415.pdf. However, in our test, the performance improvement depends on how tanh function is implemented. So we didn't use that by default.

@BramVanroy
Copy link

@BIT-silence Thank you for the reply. That clarifies things a lot. A final question, is there a reason that gelu doesn't have a Module equivalent that lives in nn? (Like nn.ReLU.)

@xiaomengy
Copy link
Contributor Author

Actually the main reason is we don't have enough time when adding it. We will consider add the Module later.

@BramVanroy
Copy link

Okay, thanks for the information. I wasn't sure whether in general there is a reason when an activation function has a Module equivalent, but it seems that it's mostly a time-constraint.

@soumith
Copy link
Contributor

soumith commented Sep 30, 2019

@BramVanroy generally we only keep functionals for layers which don't have learnable parameters. We used to add layers for all common functions, like nn.ReLU but that's legacy

@BramVanroy
Copy link

@BramVanroy generally we only keep functionals for layers which don't have learnable parameters. We used to add layers for all common functions, like nn.ReLU but that's legacy

That makes sense. However, I do like that printing a Module, gives you a nice overview of its submodules. If the activation is a Module it is included (and clear to the user what it is). If it's a function, it's not included unfortunately - so you're left wondering whether (and where) an activation is taking place. Compare the two snippets below

# as Module
from torch import nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.dense = nn.Linear(512, 1)
        self.activation = nn.ReLU()

    def forward(self, inputs):
        out = self.dense(inputs)
        out = self.activation(out)
        return out


net = Net()
print(net)

Prints

Net(
  (dense): Linear(in_features=512, out_features=1, bias=True)
  (activation): ReLU()
)

But when using functions, you don't see which activations are used nor where.

from torch import nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.dense = nn.Linear(512, 1)
        self.activation = F.relu

    def forward(self, inputs):
        out = self.dense(inputs)
        out = self.activation(out)
        return out


net = Net()
print(net)

Prints

Net(
  (dense): Linear(in_features=512, out_features=1, bias=True)
)

@xiaomengy
Copy link
Contributor Author

The module which is torch.nn.GELU is implemented in #28944

@BramVanroy
Copy link

Great! It'll ship with 1.4 then, I presume?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: cpu CPU specific problem (e.g., perf, algorithm) module: cuda Related to torch.cuda, and CUDA support in general module: docs Related to our documentation, both in docs/ and docblocks module: internals Related to internal abstractions in c10 and ATen module: nn Related to torch.nn

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants