mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146109 Approved by: https://github.com/desertfire
62 lines
1.9 KiB
Python
62 lines
1.9 KiB
Python
# Owner(s): ["module: inductor"]
|
|
|
|
import torch
|
|
from torch._inductor.test_case import TestCase
|
|
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU, requires_gpu
|
|
|
|
|
|
class TestTritonSyntacticallyValid(TestCase):
|
|
@requires_gpu()
|
|
def test_triton_sqrt(self):
|
|
# https://github.com/pytorch/pytorch/issues/142328
|
|
import math
|
|
|
|
import torch.nn as nn
|
|
|
|
def newtonschulz5(G, steps: int, eps=1e-7):
|
|
assert len(G.shape) == 2
|
|
a, b, c = (3.4445, -4.7750, 2.0315)
|
|
X = G.to(
|
|
torch.bfloat16
|
|
if torch.cuda.is_bf16_supported(including_emulation=False)
|
|
else torch.float16
|
|
)
|
|
X /= X.norm() + eps # ensure top singular value <= 1
|
|
if G.size(0) > G.size(1):
|
|
X = X.T
|
|
for _ in range(steps):
|
|
A = X @ X.T
|
|
B = b * A + c * A @ A
|
|
X = a * X + B @ X
|
|
if G.size(0) > G.size(1):
|
|
X = X.T
|
|
return X
|
|
|
|
@torch.compile(backend="inductor")
|
|
def scaled_newton_schulz(G, steps: int):
|
|
shape = G.shape
|
|
dtype = G.dtype
|
|
G = G.reshape(shape[0], -1)
|
|
G = newtonschulz5(G, steps)
|
|
G = G.reshape(shape).type(dtype)
|
|
G = G * math.sqrt(max(1, shape[0] / G[0].numel()))
|
|
return G
|
|
|
|
model = nn.Sequential(
|
|
nn.Linear(16, 16, bias=False),
|
|
nn.Linear(16, 32, bias=False),
|
|
).to(device=torch.device(GPU_TYPE))
|
|
|
|
loss = model(torch.randn(4, 16, device=torch.device(GPU_TYPE))).sum()
|
|
loss.backward()
|
|
|
|
scaled_newton_schulz(model[0].weight.grad, 6)
|
|
scaled_newton_schulz(model[1].weight.grad, 6)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._inductor.test_case import run_tests
|
|
|
|
if HAS_GPU:
|
|
run_tests()
|