[dynamo][easy] Support torch.accelerator.current_accelerator (#165734)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165734
Approved by: https://github.com/Skylion007
This commit is contained in:
Animesh Jain
2025-10-17 09:46:53 -07:00
committed by PyTorch MergeBot
parent 86ebce1766
commit c18ddfc572
2 changed files with 9 additions and 0 deletions

View File

@ -8101,6 +8101,14 @@ class ReproTestsDevice(torch._dynamo.test_case.TestCase):
res = gm(x, y) res = gm(x, y)
self.assertEqual(res, ref) self.assertEqual(res, ref)
def test_current_accelerator(self):
@torch.compile(backend="eager", fullgraph=True)
def fn(x):
torch.accelerator.current_accelerator()
return x + 1
self.assertEqual(fn(torch.ones(3)), torch.ones(3) + 1)
instantiate_parametrized_tests(ReproTests) instantiate_parametrized_tests(ReproTests)

View File

@ -146,6 +146,7 @@ REWRITE_OPS_TO_TENSOR_SIZE_METHOD = dict.fromkeys(
constant_fold_functions_need_guards = [ constant_fold_functions_need_guards = [
torch.accelerator.current_device_index, torch.accelerator.current_device_index,
torch.accelerator.current_accelerator,
torch.cuda.current_device, torch.cuda.current_device,
torch.cuda.is_initialized, torch.cuda.is_initialized,
torch.xpu.current_device, torch.xpu.current_device,