Using device-agnostic autocast api (#136613)

- using torch.autocast(device_str="cuda") instead of torch.cuda.amp.autocast()
- using torch.autocast(device_str="cpu") instead of torch.cpu.amp.autocast()

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136613
Approved by: https://github.com/shink, https://github.com/cyyever, https://github.com/kwen2501
This commit is contained in:
FFFrog
2024-09-25 15:33:08 +08:00
committed by PyTorch MergeBot
parent ad6c70b656
commit e14b58ffbd
6 changed files with 34 additions and 34 deletions

View File

@ -688,7 +688,7 @@ class TestDataParallel(TestCase):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__(8, 8) super().__init__(8, 8)
@torch.cuda.amp.autocast() @torch.autocast(device_type="cuda")
def forward(self, input): def forward(self, input):
return super().forward(input) return super().forward(input)

View File

@ -1214,7 +1214,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
def gn(*args): def gn(*args):
return torch.utils.checkpoint.checkpoint(fn, *args, use_reentrant=True) return torch.utils.checkpoint.checkpoint(fn, *args, use_reentrant=True)
with torch.cuda.amp.autocast(): with torch.autocast(device_type="cuda"):
x = torch.randn(4, 2, 16, 32, device="cuda", requires_grad=True) x = torch.randn(4, 2, 16, 32, device="cuda", requires_grad=True)
y = torch.randn(4, 2, 16, 32, device="cuda", requires_grad=True) y = torch.randn(4, 2, 16, 32, device="cuda", requires_grad=True)
z = torch.randn(4, 2, 16, 32, device="cuda", requires_grad=True) z = torch.randn(4, 2, 16, 32, device="cuda", requires_grad=True)

View File

@ -583,7 +583,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
a_float32 = torch.rand((8, 8), device="cuda") a_float32 = torch.rand((8, 8), device="cuda")
b_float32 = torch.rand((8, 8), device="cuda") b_float32 = torch.rand((8, 8), device="cuda")
with torch.cuda.amp.autocast(dtype=torch.float64): with torch.autocast(device_type="cuda", dtype=torch.float64):
c_float64 = torch.mm(a_float32, b_float32) c_float64 = torch.mm(a_float32, b_float32)
return c_float64 return c_float64
@ -603,7 +603,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
def test_is_autocast_cpu_enabled(self): def test_is_autocast_cpu_enabled(self):
def fn(a_float32, b_float32): def fn(a_float32, b_float32):
with torch.cpu.amp.autocast(dtype=torch.bfloat16): with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
c_float16 = torch.mm(a_float32, b_float32) c_float16 = torch.mm(a_float32, b_float32)
if torch.is_autocast_cpu_enabled(): if torch.is_autocast_cpu_enabled():
c_float16 = c_float16 + 1 c_float16 = c_float16 + 1
@ -887,12 +887,12 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda") @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
def test_autocast_arguments_binding(self): def test_autocast_arguments_binding(self):
def f1(x): def f1(x):
with torch.cuda.amp.autocast(False): with torch.autocast(device_type="cuda", enabled=False):
x = torch.sin(x + 1) x = torch.sin(x + 1)
return x return x
def f2(x): def f2(x):
with torch.cpu.amp.autocast(False): with torch.autocast(device_type="cpu", enabled=False):
x = torch.cos(x + 1) x = torch.cos(x + 1)
return x return x
@ -916,14 +916,14 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
return new_fwd return new_fwd
def autocast_func_cuda(orig_func): def autocast_func_cuda(orig_func):
@torch.cuda.amp.autocast(dtype=torch.float16) @torch.autocast(device_type="cuda", dtype=torch.float16)
def new_fwd(*args, **kwargs): def new_fwd(*args, **kwargs):
return orig_func(*args, **kwargs) return orig_func(*args, **kwargs)
return new_fwd return new_fwd
def autocast_func_cpu(orig_func): def autocast_func_cpu(orig_func):
@torch.cpu.amp.autocast(dtype=torch.float16) @torch.autocast(device_type="cpu", dtype=torch.float16)
def new_fwd(*args, **kwargs): def new_fwd(*args, **kwargs):
return orig_func(*args, **kwargs) return orig_func(*args, **kwargs)

View File

@ -5089,7 +5089,7 @@ class TestCudaAutocast(TestAutocast):
dtypes = (torch.float16, torch.bfloat16) if TEST_BF16 else (torch.float16,) dtypes = (torch.float16, torch.bfloat16) if TEST_BF16 else (torch.float16,)
for dtype in dtypes: for dtype in dtypes:
with torch.cuda.amp.autocast(dtype=dtype): with torch.autocast(device_type="cuda", dtype=dtype):
output = mymm(x, y) output = mymm(x, y)
self.assertTrue(output.dtype is dtype) self.assertTrue(output.dtype is dtype)
loss = output.sum() loss = output.sum()

View File

@ -533,11 +533,11 @@ class TestAutocast(JitTestCase):
return torch.mm(x, y) return torch.mm(x, y)
def t_cuda_amp_autocast(x, y): def t_cuda_amp_autocast(x, y):
with torch.cuda.amp.autocast(): with torch.autocast(device_type="cuda"):
return torch.mm(x, y) return torch.mm(x, y)
def t_cpu_amp_autocast(x, y): def t_cpu_amp_autocast(x, y):
with torch.cpu.amp.autocast(): with torch.autocast(device_type="cpu"):
return torch.mm(x, y) return torch.mm(x, y)
x = torch.randn(5, 5, device="cuda", dtype=torch.float32) x = torch.randn(5, 5, device="cuda", dtype=torch.float32)
@ -658,7 +658,7 @@ class TestAutocast(JitTestCase):
impl: Iface impl: Iface
def forward(self, x, y): def forward(self, x, y):
with torch.cuda.amp.autocast(): with torch.autocast(device_type="cuda"):
a = torch.mm(x, y) a = torch.mm(x, y)
b = self.impl.forward(a, x) b = self.impl.forward(a, x)
return b return b
@ -671,7 +671,7 @@ class TestAutocast(JitTestCase):
y = torch.rand([2, 2]) y = torch.rand([2, 2])
# make sure this doesn't throw an error # make sure this doesn't throw an error
with torch.cuda.amp.autocast(): with torch.autocast(device_type="cuda"):
ans = scripted_thing1.forward(x, y) ans = scripted_thing1.forward(x, y)
self.assertEqual(torch.mm(torch.mm(x, y), x), ans) self.assertEqual(torch.mm(torch.mm(x, y), x), ans)
@ -683,7 +683,7 @@ class TestAutocast(JitTestCase):
def test_jit_freeze_autocast_basic(self): def test_jit_freeze_autocast_basic(self):
class TestModule(torch.nn.Module): class TestModule(torch.nn.Module):
def forward(self, x, y): def forward(self, x, y):
with torch.cuda.amp.autocast(): with torch.autocast(device_type="cuda"):
return torch.mm(x, y) return torch.mm(x, y)
x = torch.rand((3, 4), dtype=torch.float).cuda() x = torch.rand((3, 4), dtype=torch.float).cuda()
@ -710,7 +710,7 @@ class TestAutocast(JitTestCase):
self.x = torch.rand((3, 4), dtype=torch.float).cuda() self.x = torch.rand((3, 4), dtype=torch.float).cuda()
def forward(self, y): def forward(self, y):
with torch.cuda.amp.autocast(): with torch.autocast(device_type="cuda"):
return torch.mm(self.x, y) return torch.mm(self.x, y)
y = torch.rand((4, 5), dtype=torch.float).cuda() y = torch.rand((4, 5), dtype=torch.float).cuda()
@ -729,7 +729,7 @@ class TestAutocast(JitTestCase):
@unittest.skipIf(TEST_CUDA, "CPU-only test") @unittest.skipIf(TEST_CUDA, "CPU-only test")
def test_jit_autocast_softmax_cpu(self): def test_jit_autocast_softmax_cpu(self):
def fn(x): def fn(x):
with torch.cpu.amp.autocast(): with torch.autocast(device_type="cpu"):
return torch.nn.functional.softmax(x, dim=0) return torch.nn.functional.softmax(x, dim=0)
fn_s = torch.jit.script(fn) fn_s = torch.jit.script(fn)
@ -742,7 +742,7 @@ class TestAutocast(JitTestCase):
@unittest.skipIf(not TEST_CUDA, "No cuda") @unittest.skipIf(not TEST_CUDA, "No cuda")
def test_jit_autocast_softmax_gpu(self): def test_jit_autocast_softmax_gpu(self):
def fn(x): def fn(x):
with torch.cuda.amp.autocast(): with torch.autocast(device_type="cuda"):
return torch.nn.functional.softmax(x, dim=0) return torch.nn.functional.softmax(x, dim=0)
fn_s = torch.jit.script(fn) fn_s = torch.jit.script(fn)
@ -759,7 +759,7 @@ class TestAutocast(JitTestCase):
inp = torch.rand([10, 10], dtype=torch.float) inp = torch.rand([10, 10], dtype=torch.float)
foo._set_ignore_amp(True) foo._set_ignore_amp(True)
with torch.cpu.amp.autocast(): with torch.autocast(device_type="cpu"):
foo(inp) foo(inp)
foo(inp) foo(inp)
@ -797,7 +797,7 @@ class TestJitTraceAutocast(JitTestCase):
def test_generate_autocast_jit_trace_model(self): def test_generate_autocast_jit_trace_model(self):
def test_generate_autocast_jit_trace_model(model, x): def test_generate_autocast_jit_trace_model(model, x):
model.eval() model.eval()
with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad(): with torch.autocast(device_type="cpu", cache_enabled=False), torch.no_grad():
traced_model = torch.jit.trace(model, x) traced_model = torch.jit.trace(model, x)
traced_model = torch.jit.freeze(traced_model) traced_model = torch.jit.freeze(traced_model)
for i in range(self.models.__len__()): for i in range(self.models.__len__()):
@ -806,12 +806,12 @@ class TestJitTraceAutocast(JitTestCase):
def test_nchw_autocast_jit_trace_model(self): def test_nchw_autocast_jit_trace_model(self):
def test_nchw_autocast_jit_trace_model(model, x): def test_nchw_autocast_jit_trace_model(model, x):
model.eval() model.eval()
with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad(): with torch.autocast(device_type="cpu", cache_enabled=False), torch.no_grad():
traced_model = torch.jit.trace(model, x) traced_model = torch.jit.trace(model, x)
traced_model = torch.jit.freeze(traced_model) traced_model = torch.jit.freeze(traced_model)
with torch.no_grad(): with torch.no_grad():
y = traced_model(x.clone()) y = traced_model(x.clone())
with torch.cpu.amp.autocast(), torch.no_grad(): with torch.autocast(device_type="cpu"), torch.no_grad():
y2 = model(x.clone()) y2 = model(x.clone())
torch.testing.assert_close(y.double(), y2.double(), rtol=1e-03, atol=1e-03) torch.testing.assert_close(y.double(), y2.double(), rtol=1e-03, atol=1e-03)
for i in range(self.models.__len__()): for i in range(self.models.__len__()):
@ -821,12 +821,12 @@ class TestJitTraceAutocast(JitTestCase):
def test_nhwc_autocast_jit_trace_model(model, x): def test_nhwc_autocast_jit_trace_model(model, x):
model = model.to(memory_format=torch.channels_last) model = model.to(memory_format=torch.channels_last)
model.eval() model.eval()
with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad(): with torch.autocast(device_type="cpu", cache_enabled=False), torch.no_grad():
traced_model = torch.jit.trace(model, x.to(memory_format=torch.channels_last)) traced_model = torch.jit.trace(model, x.to(memory_format=torch.channels_last))
traced_model = torch.jit.freeze(traced_model) traced_model = torch.jit.freeze(traced_model)
with torch.no_grad(): with torch.no_grad():
y = traced_model(x.clone().to(memory_format=torch.channels_last)) y = traced_model(x.clone().to(memory_format=torch.channels_last))
with torch.cpu.amp.autocast(), torch.no_grad(): with torch.autocast(device_type="cpu"), torch.no_grad():
y2 = model(x.clone().to(memory_format=torch.channels_last)) y2 = model(x.clone().to(memory_format=torch.channels_last))
torch.testing.assert_close(y.double(), y2.double(), rtol=1e-03, atol=1e-03) torch.testing.assert_close(y.double(), y2.double(), rtol=1e-03, atol=1e-03)
for i in range(self.models.__len__()): for i in range(self.models.__len__()):
@ -845,7 +845,7 @@ class TestJitTraceAutocast(JitTestCase):
# To avoid the fusion group from TE, we will disable the fuser here. # To avoid the fusion group from TE, we will disable the fuser here.
for jit_freeze_or_not in [False, True]: for jit_freeze_or_not in [False, True]:
test_model = TestModel().eval() test_model = TestModel().eval()
with torch.cpu.amp.autocast(cache_enabled=False, dtype=torch.bfloat16), torch.no_grad(): with torch.autocast(device_type="cpu", cache_enabled=False, dtype=torch.bfloat16), torch.no_grad():
a = torch.rand(24, 128, 128) a = torch.rand(24, 128, 128)
b = torch.rand(24, 128, 128, dtype=torch.bfloat16) b = torch.rand(24, 128, 128, dtype=torch.bfloat16)
c = test_model(a, b) c = test_model(a, b)
@ -869,10 +869,10 @@ class TestJitTraceAutocast(JitTestCase):
fn_s = torch.jit.script(fn) fn_s = torch.jit.script(fn)
x = torch.rand((4, 4)) - 0.5 x = torch.rand((4, 4)) - 0.5
with torch.cpu.amp.autocast(): with torch.autocast(device_type="cpu"):
self.assertEqual(fn_s(x), fn(x)) self.assertEqual(fn_s(x), fn(x))
with torch.cpu.amp.autocast(enabled=True): with torch.autocast(device_type="cpu", enabled=True):
self.assertEqual(fn_s(x), fn(x)) self.assertEqual(fn_s(x), fn(x))
self.assertTrue(any("is_autocast_cpu_enabled" in x.kind() for x in fn_s.graph.nodes())) self.assertTrue(any("is_autocast_cpu_enabled" in x.kind() for x in fn_s.graph.nodes()))
@ -888,10 +888,10 @@ class TestJitTraceAutocast(JitTestCase):
fn_s = torch.jit.script(fn) fn_s = torch.jit.script(fn)
x = torch.rand((4, 4)) - 0.5 x = torch.rand((4, 4)) - 0.5
with torch.cpu.amp.autocast(): with torch.autocast(device_type="cpu"):
self.assertEqual(fn_s(x), fn(x)) self.assertEqual(fn_s(x), fn(x))
with torch.cuda.amp.autocast(enabled=True): with torch.autocast(device_type="cuda", enabled=True):
self.assertEqual(fn_s(x), fn(x)) self.assertEqual(fn_s(x), fn(x))
self.assertTrue(any("is_autocast_enabled" in x.kind() for x in fn_s.graph.nodes())) self.assertTrue(any("is_autocast_enabled" in x.kind() for x in fn_s.graph.nodes()))
@ -904,7 +904,7 @@ class TestJitTraceAutocast(JitTestCase):
y = True y = True
else: else:
y = False y = False
with torch.cuda.amp.autocast(enabled=True): with torch.autocast(device_type="cuda", enabled=True):
z = x.relu() z = x.relu()
return y, z return y, z
@ -926,10 +926,10 @@ class TestJitTraceAutocast(JitTestCase):
def fn(x, y) -> Tuple[torch.Tensor, bool, torch.Tensor, bool, torch.Tensor, bool]: def fn(x, y) -> Tuple[torch.Tensor, bool, torch.Tensor, bool, torch.Tensor, bool]:
b1 = torch.is_autocast_cpu_enabled() b1 = torch.is_autocast_cpu_enabled()
v1 = torch.mm(x, y) v1 = torch.mm(x, y)
with torch.cpu.amp.autocast(enabled=True): with torch.autocast(device_type="cpu", enabled=True):
b2 = torch.is_autocast_cpu_enabled() b2 = torch.is_autocast_cpu_enabled()
v2 = torch.mm(x, y) v2 = torch.mm(x, y)
with torch.cpu.amp.autocast(enabled=False): with torch.autocast(device_type="cpu", enabled=False):
b3 = torch.is_autocast_cpu_enabled() b3 = torch.is_autocast_cpu_enabled()
v3 = torch.mm(x, y) v3 = torch.mm(x, y)
return (v1, b1, v2, b2, v3, b3) return (v1, b1, v2, b2, v3, b3)
@ -946,11 +946,11 @@ class TestJitTraceAutocast(JitTestCase):
fn_s = torch.jit.script(fn) fn_s = torch.jit.script(fn)
with torch.cpu.amp.autocast(enabled=False): with torch.autocast(device_type="cpu", enabled=False):
check_fn_results(fn(x, y)) check_fn_results(fn(x, y))
check_fn_results(fn_s(x, y)) check_fn_results(fn_s(x, y))
with torch.cpu.amp.autocast(enabled=True): with torch.autocast(device_type="cpu", enabled=True):
check_fn_results(fn(x, y)) check_fn_results(fn(x, y))
check_fn_results(fn_s(x, y)) check_fn_results(fn_s(x, y))

View File

@ -68,7 +68,7 @@ class JitLlgaTestCase(JitTestCase):
with torch.no_grad(), torch._jit_internal._disable_emit_hooks(): with torch.no_grad(), torch._jit_internal._disable_emit_hooks():
if dtype == torch.bfloat16: if dtype == torch.bfloat16:
# We rely upon eager-mode AMP support for BF16 # We rely upon eager-mode AMP support for BF16
with torch.cpu.amp.autocast(cache_enabled=False, dtype=torch.bfloat16): with torch.autocast(device_type="cpu", cache_enabled=False, dtype=torch.bfloat16):
traced = torch.jit.trace(m, x) traced = torch.jit.trace(m, x)
if isinstance(m, torch.nn.Module): if isinstance(m, torch.nn.Module):
traced = torch.jit.freeze(traced) traced = torch.jit.freeze(traced)