Skip to content

Add the pytorch implementation of the OpenAI GeLU approximation #21344

@jlamypoirier

Description

@jlamypoirier

Feature request

Add support for the pytorch implementation of OpenAI's approximation of the GeLU function, added in pytorch 1.12. This implementation is equivalent to gelu_new or gelu_fast but much faster. It can come as a separate activation function, for example gelu_new_python, to avoid distrupting existing models.

Motivation

Many transformer models use OpenAI's approximation (tanh) for the GeLU, through the activation function gelu_new or gelu_fast. These implementations are extremely slow (despite their name) because they consist of multiple operations/kernels (8 and 9 respectively).

Since version 1.12, pytorch supports a single-kernel, C/cuda implementation through the argument approximate='tanh' ( https://pytorch.org/docs/stable/generated/torch.nn.GELU.html). This implementation is 6-10x faster than what currently exists in transformers, and is numerically equal up to rounding errors.

When benchmarking the inference speed of the SantaCoder models, I found that using the pytorch implementation allowed for an end-to-end speedup of ~15-20%.

I also benchmarked the speed and accuracy using the following code (on a A100-80GB):


import time
import torch
from transformers.activations import NewGELUActivation, FastGELUActivation

dtype=torch.float32
eps=torch.finfo(dtype).eps

x=torch.empty([2**30], device="cuda", dtype=dtype).normal_()
torch.cuda.synchronize()
t0=time.perf_counter()

y0=torch.nn.functional.gelu(x, approximate="tanh")
torch.cuda.synchronize()
t1=time.perf_counter()

y1=NewGELUActivation()(x)
torch.cuda.synchronize()
t2=time.perf_counter()

y2=FastGELUActivation()(x)
torch.cuda.synchronize()
t3=time.perf_counter()

y3=torch.nn.functional.gelu(x)
torch.cuda.synchronize()
t4=time.perf_counter()

print(f"Torch tanh: {1000*(t1-t0):.3f} ms")
print(f"New: {1000*(t2-t1):.3f} ms")
print(f"Fast: {1000*(t3-t2):.3f} ms")
print(f"Torch orig: {1000*(t4-t3):.3f} ms")

print(f"Torch tanh vs new: {(y1-y0).float().std().cpu().item()/eps:.3f}")
print(f"Torch tanh vs fast: {(y2-y0).float().std().cpu().item()/eps:.3f}")
print(f"New vs fast: {(y2-y1).float().std().cpu().item()/eps:.3f}")
print(f"Torch tanh vs torch orig: {(y3-y0).float().std().cpu().item()/eps:.3f}")

With output

Torch tanh: 4.921 ms
New: 43.253 ms
Fast: 50.269 ms
Torch orig: 4.989 ms
Torch tanh vs new: 0.042
Torch tanh vs fast: 0.147
New vs fast: 0.147
Torch tanh vs torch orig: 971.960

I.e., the tanh version of torch matches the fast and new gelu within epsilon while being 8.8x/10.2x faster, but is different from the original version

With dtype=torch.float16:

Torch tanh: 3.342 ms
New: 22.667 ms
Fast: 26.104 ms
Torch orig: 3.395 ms
Torch tanh vs new: 0.244
Torch tanh vs fast: 0.243
New vs fast: 0.143
Torch tanh vs torch orig: 0.216

I.e., it's 6.8x/7.8x faster, and the implementation doesn't matters because rounding errors dominate.

On cpu (float32), size 2**28 (268M):

Torch tanh: 182.575 ms
New: 1683.934 ms
Fast: 1925.547 ms
Torch orig: 141.410 ms
Torch tanh vs new: 0.043
Torch tanh vs fast: 0.144
New vs fast: 0.144
Torch tanh vs torch orig: 971.852

I.e., same accuracy and speedup (9.2x/10.5x faster)

Your contribution

Opened a draft PR (#21345)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions