mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[dynamo][easy] Support torch.accelerator.current_accelerator (#165734)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165734 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
86ebce1766
commit
c18ddfc572
@ -8101,6 +8101,14 @@ class ReproTestsDevice(torch._dynamo.test_case.TestCase):
|
||||
res = gm(x, y)
|
||||
self.assertEqual(res, ref)
|
||||
|
||||
def test_current_accelerator(self):
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(x):
|
||||
torch.accelerator.current_accelerator()
|
||||
return x + 1
|
||||
|
||||
self.assertEqual(fn(torch.ones(3)), torch.ones(3) + 1)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(ReproTests)
|
||||
|
||||
|
@ -146,6 +146,7 @@ REWRITE_OPS_TO_TENSOR_SIZE_METHOD = dict.fromkeys(
|
||||
|
||||
constant_fold_functions_need_guards = [
|
||||
torch.accelerator.current_device_index,
|
||||
torch.accelerator.current_accelerator,
|
||||
torch.cuda.current_device,
|
||||
torch.cuda.is_initialized,
|
||||
torch.xpu.current_device,
|
||||
|
Reference in New Issue
Block a user