Implement Tanh Gelu Approximation (#61439)

Summary:
1. Implements https://github.com/pytorch/pytorch/issues/39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - https://github.com/pytorch/xla/pull/3039

Pull Request resolved: https://github.com/pytorch/pytorch/pull/61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
This commit is contained in:
Ryan Spring
2022-02-13 19:32:11 -08:00
committed by PyTorch MergeBot
parent 64faa043f7
commit 4f8b986e28
51 changed files with 825 additions and 270 deletions

View File

@ -2860,13 +2860,23 @@ TEST_F(ModulesTest, GLU) {
}
TEST_F(ModulesTest, GELU) {
GELU model;
GELU model(GELUOptions().approximate("none"));
const auto x = torch::linspace(-3.0, 3.0, 100);
const auto y_exp = x * 0.5 * (1.0 + torch::erf(x / std::sqrt(2.0)));
const auto y = model(x);
ASSERT_TRUE(torch::allclose(y, y_exp, 1.4e-06, 1e-05));
}
TEST_F(ModulesTest, TanhGELU) {
GELU model(GELUOptions().approximate("tanh"));
const auto x = torch::linspace(-3.0, 3.0, 100);
const auto inner = std::sqrt(2 / M_PI) * (x + 0.044715 * x.pow(3.0));
const auto y_exp = 0.5 * x * (1.0 + inner.tanh());
const auto y = model(x);
ASSERT_TRUE(torch::allclose(y, y_exp, 1.4e-06, 1e-05));
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST_F(ModulesTest, Mish) {
Mish model;
auto x = torch::randn(100) * 10;