mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo][easy] Support torch.accelerator.current_accelerator (#165734)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165734 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
86ebce1766
commit
c18ddfc572
@ -8101,6 +8101,14 @@ class ReproTestsDevice(torch._dynamo.test_case.TestCase):
|
|||||||
res = gm(x, y)
|
res = gm(x, y)
|
||||||
self.assertEqual(res, ref)
|
self.assertEqual(res, ref)
|
||||||
|
|
||||||
|
def test_current_accelerator(self):
|
||||||
|
@torch.compile(backend="eager", fullgraph=True)
|
||||||
|
def fn(x):
|
||||||
|
torch.accelerator.current_accelerator()
|
||||||
|
return x + 1
|
||||||
|
|
||||||
|
self.assertEqual(fn(torch.ones(3)), torch.ones(3) + 1)
|
||||||
|
|
||||||
|
|
||||||
instantiate_parametrized_tests(ReproTests)
|
instantiate_parametrized_tests(ReproTests)
|
||||||
|
|
||||||
|
@ -146,6 +146,7 @@ REWRITE_OPS_TO_TENSOR_SIZE_METHOD = dict.fromkeys(
|
|||||||
|
|
||||||
constant_fold_functions_need_guards = [
|
constant_fold_functions_need_guards = [
|
||||||
torch.accelerator.current_device_index,
|
torch.accelerator.current_device_index,
|
||||||
|
torch.accelerator.current_accelerator,
|
||||||
torch.cuda.current_device,
|
torch.cuda.current_device,
|
||||||
torch.cuda.is_initialized,
|
torch.cuda.is_initialized,
|
||||||
torch.xpu.current_device,
|
torch.xpu.current_device,
|
||||||
|
Reference in New Issue
Block a user