mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Enable XPU for test_autograd_function.py (#160309)
# Description Fixes #114850, we will port dynamo tests to Intel GPU We could enable Intel GPU with following methods and try the best to keep the original code styles: # Changes 1. Get device type from get_devtype() method. 2. Replace the requires_cuda_and_triton with requires_gpu. 3. Add HAS_XPU_AND_TRITON into the scope. # Notify Pull Request resolved: https://github.com/pytorch/pytorch/pull/160309 Approved by: https://github.com/guangyey, https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
8eee08d227
commit
6ea8376f84
@ -8,13 +8,14 @@ import torch
|
||||
import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
import torch._dynamo.utils
|
||||
from torch.testing._internal.triton_utils import (
|
||||
HAS_CUDA_AND_TRITON,
|
||||
requires_cuda_and_triton,
|
||||
from torch.testing._internal.triton_utils import HAS_GPU, requires_gpu
|
||||
|
||||
|
||||
device_type = (
|
||||
acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu"
|
||||
)
|
||||
|
||||
|
||||
if HAS_CUDA_AND_TRITON:
|
||||
if HAS_GPU:
|
||||
import triton
|
||||
|
||||
from torch.testing._internal.triton_utils import add_kernel
|
||||
@ -507,13 +508,13 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
|
||||
|
||||
class MyMM(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@torch.amp.custom_fwd(device_type="cuda")
|
||||
@torch.amp.custom_fwd(device_type=device_type)
|
||||
def forward(ctx, a, b):
|
||||
ctx.save_for_backward(a, b)
|
||||
return a.mm(b)
|
||||
|
||||
@staticmethod
|
||||
@torch.amp.custom_bwd(device_type="cuda")
|
||||
@torch.amp.custom_bwd(device_type=device_type)
|
||||
def backward(ctx, grad):
|
||||
a, b = ctx.saved_tensors
|
||||
return grad.mm(b.t()), a.t().mm(grad)
|
||||
@ -1476,7 +1477,7 @@ class GraphModule(torch.nn.Module):
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
self.assertEqual(cnt.op_count, 1)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@requires_gpu
|
||||
def test_triton_kernel_basic(self):
|
||||
class Add(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@ -1500,14 +1501,14 @@ class GraphModule(torch.nn.Module):
|
||||
z = Add.apply(x, y)
|
||||
return z
|
||||
|
||||
x = torch.randn(10, device="cuda", requires_grad=True)
|
||||
y = torch.randn(10, device="cuda", requires_grad=True)
|
||||
x = torch.randn(10, device=device_type, requires_grad=True)
|
||||
y = torch.randn(10, device=device_type, requires_grad=True)
|
||||
z = f(x, y)
|
||||
loss = z.sum()
|
||||
loss.backward()
|
||||
self.assertEqual(x + y, z)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@requires_gpu
|
||||
def test_triton_kernel_multiple_out(self):
|
||||
class Add(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@ -1535,8 +1536,8 @@ class GraphModule(torch.nn.Module):
|
||||
z = Add.apply(x, y)
|
||||
return z
|
||||
|
||||
x = torch.randn(10, device="cuda", requires_grad=True)
|
||||
y = torch.randn(10, device="cuda", requires_grad=True)
|
||||
x = torch.randn(10, device=device_type, requires_grad=True)
|
||||
y = torch.randn(10, device=device_type, requires_grad=True)
|
||||
z, _ = f(x, y)
|
||||
loss = z.sum()
|
||||
loss.backward()
|
||||
|
Reference in New Issue
Block a user