mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Implement Tanh Gelu Approximation (#61439)
Summary: 1. Implements https://github.com/pytorch/pytorch/issues/39853 2. Adds approximate boolean flag to Gelu 3. Enables Tanh Gelu approximation 4. Adds double backward support for Gelu 5. Enable Tanh Gelu in NvFuser ``` def gelu(x, approximate : str = 'none'): if approximate == 'tanh': # sqrt(2/pi) = 0.7978845608028654 return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0)))) else: return x * normcdf(x) ``` Linking XLA PR - https://github.com/pytorch/xla/pull/3039 Pull Request resolved: https://github.com/pytorch/pytorch/pull/61439 Reviewed By: VitalyFedyunin Differential Revision: D33894937 Pulled By: jbschlosser fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851 (cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
This commit is contained in:
committed by
PyTorch MergeBot
parent
64faa043f7
commit
4f8b986e28
@ -2860,13 +2860,23 @@ TEST_F(ModulesTest, GLU) {
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, GELU) {
|
||||
GELU model;
|
||||
GELU model(GELUOptions().approximate("none"));
|
||||
const auto x = torch::linspace(-3.0, 3.0, 100);
|
||||
const auto y_exp = x * 0.5 * (1.0 + torch::erf(x / std::sqrt(2.0)));
|
||||
const auto y = model(x);
|
||||
ASSERT_TRUE(torch::allclose(y, y_exp, 1.4e-06, 1e-05));
|
||||
}
|
||||
|
||||
TEST_F(ModulesTest, TanhGELU) {
|
||||
GELU model(GELUOptions().approximate("tanh"));
|
||||
const auto x = torch::linspace(-3.0, 3.0, 100);
|
||||
const auto inner = std::sqrt(2 / M_PI) * (x + 0.044715 * x.pow(3.0));
|
||||
const auto y_exp = 0.5 * x * (1.0 + inner.tanh());
|
||||
const auto y = model(x);
|
||||
ASSERT_TRUE(torch::allclose(y, y_exp, 1.4e-06, 1e-05));
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST_F(ModulesTest, Mish) {
|
||||
Mish model;
|
||||
auto x = torch::randn(100) * 10;
|
||||
|
Reference in New Issue
Block a user