mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Port 4 dynamo test files for the intel XPU (#160953)
# Description Fixes #114850, we will port dynamo tests to Intel GPU We could enable Intel GPU with following methods and try the best to keep the original code styles: # Changes 1. Get device type from accelerator method. 2. Replace the requires cuda statement with requires_gpu. 3. Add HAS_XPU_AND_TRITON into the scope. 4. Add several wrapper methods in cuda module into the accelerator. # Notify Pull Request resolved: https://github.com/pytorch/pytorch/pull/160953 Approved by: https://github.com/EikanWang, https://github.com/guangyey, https://github.com/jansel Co-authored-by: Yu, Guangye <106960996+guangyey@users.noreply.github.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
8e48d1ba25
commit
13304401df
@ -386,7 +386,7 @@ class TestCustomBackendAPI(torch._dynamo.test_case.TestCase):
|
||||
self.assertTrue(backend_run)
|
||||
|
||||
|
||||
devices = ["cpu", "cuda", "hpu"]
|
||||
devices = ["cpu", "cuda", "hpu", "xpu"]
|
||||
instantiate_device_type_tests(TestOptimizations, globals(), only_for=devices)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -8,7 +8,12 @@ from torch._dynamo.callback import callback_handler, CallbackArgs, CallbackTrigg
|
||||
from torch._dynamo.test_case import run_tests, TestCase
|
||||
from torch._guards import CompileId
|
||||
from torch.testing._internal.common_utils import TEST_WITH_ROCM
|
||||
from torch.testing._internal.triton_utils import requires_cuda_and_triton
|
||||
from torch.testing._internal.triton_utils import HAS_CUDA_AND_TRITON, requires_gpu
|
||||
|
||||
|
||||
device_type = (
|
||||
acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu"
|
||||
)
|
||||
|
||||
|
||||
class CallbackTests(TestCase):
|
||||
@ -61,7 +66,7 @@ class CallbackTests(TestCase):
|
||||
@unittest.skipIf(
|
||||
TEST_WITH_ROCM, "ROCm outputs a different number of autotuning logs"
|
||||
)
|
||||
@requires_cuda_and_triton
|
||||
@requires_gpu
|
||||
@torch._inductor.config.patch(force_disable_caches=True)
|
||||
def test_triggers(self) -> None:
|
||||
torch._dynamo.reset()
|
||||
@ -91,9 +96,9 @@ class CallbackTests(TestCase):
|
||||
torch._dynamo.graph_break()
|
||||
return self.fc2(temp)
|
||||
|
||||
model = TinyModel().to("cuda")
|
||||
model = TinyModel().to(device_type)
|
||||
compiled_model = torch.compile(model, mode="max-autotune")
|
||||
x = torch.randn(10, 10, device="cuda")
|
||||
x = torch.randn(10, 10, device=device_type)
|
||||
|
||||
loss = compiled_model(x).sum()
|
||||
loss.backward()
|
||||
@ -111,9 +116,13 @@ end=CallbackArgs(callback_trigger=<CallbackTrigger.LAZY_BACKWARD: 2>, compile_id
|
||||
)
|
||||
order.clear()
|
||||
|
||||
if not HAS_CUDA_AND_TRITON:
|
||||
return
|
||||
|
||||
compiled_model.zero_grad()
|
||||
loss = compiled_model(x).sum()
|
||||
loss.backward()
|
||||
|
||||
self.assertExpectedInline(
|
||||
"\n".join(order),
|
||||
"""\
|
||||
|
@ -40,11 +40,16 @@ from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
)
|
||||
from torch.testing._internal.inductor_utils import HAS_GPU
|
||||
|
||||
# Defines all the kernels for tests
|
||||
from torch.testing._internal.triton_utils import * # noqa: F403
|
||||
|
||||
|
||||
device_type = (
|
||||
acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu"
|
||||
)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
d = torch.ones(10, 10)
|
||||
@ -1150,10 +1155,10 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
|
||||
m = a.to(torch.float16)
|
||||
return b.type(m.type())
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
|
||||
@unittest.skipIf(not HAS_GPU, "requires gpu")
|
||||
@make_test
|
||||
def test_tensor_type2(a, b):
|
||||
m = a.to("cuda")
|
||||
m = a.to(device_type)
|
||||
return m + b.type(m.type())
|
||||
|
||||
@make_test
|
||||
@ -4040,7 +4045,7 @@ class GraphModule(torch.nn.Module):
|
||||
def f1():
|
||||
mod1 = torch.get_device_module()
|
||||
mod2 = torch.get_device_module("cpu")
|
||||
mod3 = torch.get_device_module(torch.device("cuda"))
|
||||
mod3 = torch.get_device_module(torch.device(device_type))
|
||||
return mod1, mod2, mod3
|
||||
|
||||
self.assertEqual(f1(), torch.compile(f1, backend="eager", fullgraph=True)())
|
||||
@ -4075,6 +4080,7 @@ class GraphModule(torch.nn.Module):
|
||||
new_device = (
|
||||
"cpu" if torch._C._get_accelerator() == torch.device("cuda") else "cuda"
|
||||
)
|
||||
|
||||
old_get_device_module = torch.get_device_module
|
||||
|
||||
def new_get_device_module(device=None):
|
||||
@ -4721,10 +4727,12 @@ class DefaultsTests(torch._dynamo.test_case.TestCase):
|
||||
opt_fn(x, ys, zs[:1])
|
||||
|
||||
@unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
|
||||
def test_cuda_current_device(self):
|
||||
def test_gpu_current_device(self):
|
||||
def fn(x):
|
||||
y = torch.empty(
|
||||
(2, 3), dtype=torch.float32, device=torch.cuda.current_device()
|
||||
(2, 3),
|
||||
dtype=torch.float32,
|
||||
device=torch.accelerator.current_device_index(),
|
||||
)
|
||||
y.copy_(x)
|
||||
return torch.sin(y + y.device.index)
|
||||
@ -4732,11 +4740,11 @@ class DefaultsTests(torch._dynamo.test_case.TestCase):
|
||||
counter = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch.compile(backend=counter, fullgraph=True)(fn)
|
||||
|
||||
with torch.cuda.device(0):
|
||||
with torch.accelerator.device_index(0):
|
||||
x = torch.randn(2, 3)
|
||||
self.assertEqual(opt_fn(x), fn(x))
|
||||
self.assertEqual(counter.frame_count, 1)
|
||||
with torch.cuda.device(1):
|
||||
with torch.accelerator.device_index(1):
|
||||
self.assertEqual(opt_fn(x), fn(x))
|
||||
self.assertEqual(counter.frame_count, 2)
|
||||
|
||||
|
@ -13293,7 +13293,7 @@ class MiscTestsDevice(torch._inductor.test_case.TestCase):
|
||||
self.assertEqual(out, opt_out)
|
||||
|
||||
@unittest.skipIf(not TEST_MULTIGPU, "need multiple GPU")
|
||||
def test_cuda_set_device(self, device):
|
||||
def test_gpu_set_device(self, device):
|
||||
def fn():
|
||||
a = torch.ones(2, device=device)
|
||||
torch.get_device_module(device).set_device(1)
|
||||
|
Reference in New Issue
Block a user