mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
using new device-agnostic api instead of old api like torch.cpu or torch.cuda (#134448)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134448 Approved by: https://github.com/guangyey, https://github.com/shink, https://github.com/albanD
This commit is contained in:
@ -45,9 +45,9 @@ class TestAutocastCPU(TestCase):
|
||||
if add_kwargs is None:
|
||||
add_kwargs = {}
|
||||
|
||||
self.assertFalse(torch.is_autocast_cpu_enabled())
|
||||
with torch.cpu.amp.autocast(dtype=amp_dtype):
|
||||
self.assertTrue(torch.is_autocast_cpu_enabled())
|
||||
self.assertFalse(torch.is_autocast_enabled(device_type="cpu"))
|
||||
with torch.amp.autocast(device_type="cpu", dtype=amp_dtype):
|
||||
self.assertTrue(torch.is_autocast_enabled(device_type="cpu"))
|
||||
out_type = out_type if out_type is not None else run_as_type
|
||||
output = output_method = None
|
||||
|
||||
@ -94,8 +94,8 @@ class TestAutocastCPU(TestCase):
|
||||
# Compare numerics to Python-side "autocasting" that (we expect) does the same thing
|
||||
# as the C++-side autocasting, and should be bitwise accurate.
|
||||
output_to_compare = output if output is not None else output_method
|
||||
with torch.cpu.amp.autocast(enabled=False):
|
||||
self.assertFalse(torch.is_autocast_cpu_enabled())
|
||||
with torch.amp.autocast(device_type="cpu", enabled=False):
|
||||
self.assertFalse(torch.is_autocast_enabled(device_type="cpu"))
|
||||
|
||||
if module is not None and hasattr(module, op):
|
||||
control = getattr(module, op)(
|
||||
@ -108,8 +108,8 @@ class TestAutocastCPU(TestCase):
|
||||
self.assertTrue(type(output_to_compare) == type(control))
|
||||
comparison = compare(output_to_compare, control)
|
||||
self.assertTrue(comparison, f"torch.{op} result did not match control")
|
||||
self.assertTrue(torch.is_autocast_cpu_enabled())
|
||||
self.assertFalse(torch.is_autocast_cpu_enabled())
|
||||
self.assertTrue(torch.is_autocast_enabled(device_type="cpu"))
|
||||
self.assertFalse(torch.is_autocast_enabled(device_type="cpu"))
|
||||
|
||||
def args_maybe_kwargs(self, op_with_args):
|
||||
if len(op_with_args) == 2:
|
||||
@ -237,7 +237,7 @@ class TestAutocastCPU(TestCase):
|
||||
m(x, (hx, cx))
|
||||
|
||||
# Should be able to run the below case with autocast
|
||||
with torch.cpu.amp.autocast():
|
||||
with torch.amp.autocast(device_type="cpu"):
|
||||
m(x, (hx, cx))
|
||||
|
||||
def test_autocast_disabled_with_fp32_dtype(self):
|
||||
@ -249,7 +249,7 @@ class TestAutocastCPU(TestCase):
|
||||
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
|
||||
with torch.amp.autocast(device_type="cpu"):
|
||||
generic_autocast_output = getattr(torch, op)(*args, **maybe_kwargs)
|
||||
with torch.cpu.amp.autocast():
|
||||
with torch.amp.autocast(device_type="cpu"):
|
||||
cpu_autocast_output = getattr(torch, op)(*args, **maybe_kwargs)
|
||||
self.assertEqual(generic_autocast_output, cpu_autocast_output)
|
||||
|
||||
@ -346,8 +346,8 @@ class TestAutocastGPU(TestCase):
|
||||
|
||||
class TestTorchAutocast(TestCase):
|
||||
def test_autocast_fast_dtype(self):
|
||||
gpu_fast_dtype = torch.get_autocast_gpu_dtype()
|
||||
cpu_fast_dtype = torch.get_autocast_cpu_dtype()
|
||||
gpu_fast_dtype = torch.get_autocast_dtype(device_type="cuda")
|
||||
cpu_fast_dtype = torch.get_autocast_dtype(device_type="cpu")
|
||||
self.assertEqual(gpu_fast_dtype, torch.half)
|
||||
self.assertEqual(cpu_fast_dtype, torch.bfloat16)
|
||||
|
||||
|
Reference in New Issue
Block a user