mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Using device-agnostic autocast api (#136613)
- using torch.autocast(device_str="cuda") instead of torch.cuda.amp.autocast() - using torch.autocast(device_str="cpu") instead of torch.cpu.amp.autocast() Pull Request resolved: https://github.com/pytorch/pytorch/pull/136613 Approved by: https://github.com/shink, https://github.com/cyyever, https://github.com/kwen2501
This commit is contained in:
@ -688,7 +688,7 @@ class TestDataParallel(TestCase):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(8, 8)
|
||||
|
||||
@torch.cuda.amp.autocast()
|
||||
@torch.autocast(device_type="cuda")
|
||||
def forward(self, input):
|
||||
return super().forward(input)
|
||||
|
||||
|
@ -1214,7 +1214,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
def gn(*args):
|
||||
return torch.utils.checkpoint.checkpoint(fn, *args, use_reentrant=True)
|
||||
|
||||
with torch.cuda.amp.autocast():
|
||||
with torch.autocast(device_type="cuda"):
|
||||
x = torch.randn(4, 2, 16, 32, device="cuda", requires_grad=True)
|
||||
y = torch.randn(4, 2, 16, 32, device="cuda", requires_grad=True)
|
||||
z = torch.randn(4, 2, 16, 32, device="cuda", requires_grad=True)
|
||||
|
@ -583,7 +583,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
|
||||
a_float32 = torch.rand((8, 8), device="cuda")
|
||||
b_float32 = torch.rand((8, 8), device="cuda")
|
||||
|
||||
with torch.cuda.amp.autocast(dtype=torch.float64):
|
||||
with torch.autocast(device_type="cuda", dtype=torch.float64):
|
||||
c_float64 = torch.mm(a_float32, b_float32)
|
||||
return c_float64
|
||||
|
||||
@ -603,7 +603,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
|
||||
|
||||
def test_is_autocast_cpu_enabled(self):
|
||||
def fn(a_float32, b_float32):
|
||||
with torch.cpu.amp.autocast(dtype=torch.bfloat16):
|
||||
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
|
||||
c_float16 = torch.mm(a_float32, b_float32)
|
||||
if torch.is_autocast_cpu_enabled():
|
||||
c_float16 = c_float16 + 1
|
||||
@ -887,12 +887,12 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
|
||||
def test_autocast_arguments_binding(self):
|
||||
def f1(x):
|
||||
with torch.cuda.amp.autocast(False):
|
||||
with torch.autocast(device_type="cuda", enabled=False):
|
||||
x = torch.sin(x + 1)
|
||||
return x
|
||||
|
||||
def f2(x):
|
||||
with torch.cpu.amp.autocast(False):
|
||||
with torch.autocast(device_type="cpu", enabled=False):
|
||||
x = torch.cos(x + 1)
|
||||
return x
|
||||
|
||||
@ -916,14 +916,14 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
|
||||
return new_fwd
|
||||
|
||||
def autocast_func_cuda(orig_func):
|
||||
@torch.cuda.amp.autocast(dtype=torch.float16)
|
||||
@torch.autocast(device_type="cuda", dtype=torch.float16)
|
||||
def new_fwd(*args, **kwargs):
|
||||
return orig_func(*args, **kwargs)
|
||||
|
||||
return new_fwd
|
||||
|
||||
def autocast_func_cpu(orig_func):
|
||||
@torch.cpu.amp.autocast(dtype=torch.float16)
|
||||
@torch.autocast(device_type="cpu", dtype=torch.float16)
|
||||
def new_fwd(*args, **kwargs):
|
||||
return orig_func(*args, **kwargs)
|
||||
|
||||
|
@ -5089,7 +5089,7 @@ class TestCudaAutocast(TestAutocast):
|
||||
|
||||
dtypes = (torch.float16, torch.bfloat16) if TEST_BF16 else (torch.float16,)
|
||||
for dtype in dtypes:
|
||||
with torch.cuda.amp.autocast(dtype=dtype):
|
||||
with torch.autocast(device_type="cuda", dtype=dtype):
|
||||
output = mymm(x, y)
|
||||
self.assertTrue(output.dtype is dtype)
|
||||
loss = output.sum()
|
||||
|
@ -533,11 +533,11 @@ class TestAutocast(JitTestCase):
|
||||
return torch.mm(x, y)
|
||||
|
||||
def t_cuda_amp_autocast(x, y):
|
||||
with torch.cuda.amp.autocast():
|
||||
with torch.autocast(device_type="cuda"):
|
||||
return torch.mm(x, y)
|
||||
|
||||
def t_cpu_amp_autocast(x, y):
|
||||
with torch.cpu.amp.autocast():
|
||||
with torch.autocast(device_type="cpu"):
|
||||
return torch.mm(x, y)
|
||||
|
||||
x = torch.randn(5, 5, device="cuda", dtype=torch.float32)
|
||||
@ -658,7 +658,7 @@ class TestAutocast(JitTestCase):
|
||||
impl: Iface
|
||||
|
||||
def forward(self, x, y):
|
||||
with torch.cuda.amp.autocast():
|
||||
with torch.autocast(device_type="cuda"):
|
||||
a = torch.mm(x, y)
|
||||
b = self.impl.forward(a, x)
|
||||
return b
|
||||
@ -671,7 +671,7 @@ class TestAutocast(JitTestCase):
|
||||
y = torch.rand([2, 2])
|
||||
|
||||
# make sure this doesn't throw an error
|
||||
with torch.cuda.amp.autocast():
|
||||
with torch.autocast(device_type="cuda"):
|
||||
ans = scripted_thing1.forward(x, y)
|
||||
self.assertEqual(torch.mm(torch.mm(x, y), x), ans)
|
||||
|
||||
@ -683,7 +683,7 @@ class TestAutocast(JitTestCase):
|
||||
def test_jit_freeze_autocast_basic(self):
|
||||
class TestModule(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
with torch.cuda.amp.autocast():
|
||||
with torch.autocast(device_type="cuda"):
|
||||
return torch.mm(x, y)
|
||||
|
||||
x = torch.rand((3, 4), dtype=torch.float).cuda()
|
||||
@ -710,7 +710,7 @@ class TestAutocast(JitTestCase):
|
||||
self.x = torch.rand((3, 4), dtype=torch.float).cuda()
|
||||
|
||||
def forward(self, y):
|
||||
with torch.cuda.amp.autocast():
|
||||
with torch.autocast(device_type="cuda"):
|
||||
return torch.mm(self.x, y)
|
||||
|
||||
y = torch.rand((4, 5), dtype=torch.float).cuda()
|
||||
@ -729,7 +729,7 @@ class TestAutocast(JitTestCase):
|
||||
@unittest.skipIf(TEST_CUDA, "CPU-only test")
|
||||
def test_jit_autocast_softmax_cpu(self):
|
||||
def fn(x):
|
||||
with torch.cpu.amp.autocast():
|
||||
with torch.autocast(device_type="cpu"):
|
||||
return torch.nn.functional.softmax(x, dim=0)
|
||||
|
||||
fn_s = torch.jit.script(fn)
|
||||
@ -742,7 +742,7 @@ class TestAutocast(JitTestCase):
|
||||
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
||||
def test_jit_autocast_softmax_gpu(self):
|
||||
def fn(x):
|
||||
with torch.cuda.amp.autocast():
|
||||
with torch.autocast(device_type="cuda"):
|
||||
return torch.nn.functional.softmax(x, dim=0)
|
||||
|
||||
fn_s = torch.jit.script(fn)
|
||||
@ -759,7 +759,7 @@ class TestAutocast(JitTestCase):
|
||||
|
||||
inp = torch.rand([10, 10], dtype=torch.float)
|
||||
foo._set_ignore_amp(True)
|
||||
with torch.cpu.amp.autocast():
|
||||
with torch.autocast(device_type="cpu"):
|
||||
foo(inp)
|
||||
foo(inp)
|
||||
|
||||
@ -797,7 +797,7 @@ class TestJitTraceAutocast(JitTestCase):
|
||||
def test_generate_autocast_jit_trace_model(self):
|
||||
def test_generate_autocast_jit_trace_model(model, x):
|
||||
model.eval()
|
||||
with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad():
|
||||
with torch.autocast(device_type="cpu", cache_enabled=False), torch.no_grad():
|
||||
traced_model = torch.jit.trace(model, x)
|
||||
traced_model = torch.jit.freeze(traced_model)
|
||||
for i in range(self.models.__len__()):
|
||||
@ -806,12 +806,12 @@ class TestJitTraceAutocast(JitTestCase):
|
||||
def test_nchw_autocast_jit_trace_model(self):
|
||||
def test_nchw_autocast_jit_trace_model(model, x):
|
||||
model.eval()
|
||||
with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad():
|
||||
with torch.autocast(device_type="cpu", cache_enabled=False), torch.no_grad():
|
||||
traced_model = torch.jit.trace(model, x)
|
||||
traced_model = torch.jit.freeze(traced_model)
|
||||
with torch.no_grad():
|
||||
y = traced_model(x.clone())
|
||||
with torch.cpu.amp.autocast(), torch.no_grad():
|
||||
with torch.autocast(device_type="cpu"), torch.no_grad():
|
||||
y2 = model(x.clone())
|
||||
torch.testing.assert_close(y.double(), y2.double(), rtol=1e-03, atol=1e-03)
|
||||
for i in range(self.models.__len__()):
|
||||
@ -821,12 +821,12 @@ class TestJitTraceAutocast(JitTestCase):
|
||||
def test_nhwc_autocast_jit_trace_model(model, x):
|
||||
model = model.to(memory_format=torch.channels_last)
|
||||
model.eval()
|
||||
with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad():
|
||||
with torch.autocast(device_type="cpu", cache_enabled=False), torch.no_grad():
|
||||
traced_model = torch.jit.trace(model, x.to(memory_format=torch.channels_last))
|
||||
traced_model = torch.jit.freeze(traced_model)
|
||||
with torch.no_grad():
|
||||
y = traced_model(x.clone().to(memory_format=torch.channels_last))
|
||||
with torch.cpu.amp.autocast(), torch.no_grad():
|
||||
with torch.autocast(device_type="cpu"), torch.no_grad():
|
||||
y2 = model(x.clone().to(memory_format=torch.channels_last))
|
||||
torch.testing.assert_close(y.double(), y2.double(), rtol=1e-03, atol=1e-03)
|
||||
for i in range(self.models.__len__()):
|
||||
@ -845,7 +845,7 @@ class TestJitTraceAutocast(JitTestCase):
|
||||
# To avoid the fusion group from TE, we will disable the fuser here.
|
||||
for jit_freeze_or_not in [False, True]:
|
||||
test_model = TestModel().eval()
|
||||
with torch.cpu.amp.autocast(cache_enabled=False, dtype=torch.bfloat16), torch.no_grad():
|
||||
with torch.autocast(device_type="cpu", cache_enabled=False, dtype=torch.bfloat16), torch.no_grad():
|
||||
a = torch.rand(24, 128, 128)
|
||||
b = torch.rand(24, 128, 128, dtype=torch.bfloat16)
|
||||
c = test_model(a, b)
|
||||
@ -869,10 +869,10 @@ class TestJitTraceAutocast(JitTestCase):
|
||||
fn_s = torch.jit.script(fn)
|
||||
|
||||
x = torch.rand((4, 4)) - 0.5
|
||||
with torch.cpu.amp.autocast():
|
||||
with torch.autocast(device_type="cpu"):
|
||||
self.assertEqual(fn_s(x), fn(x))
|
||||
|
||||
with torch.cpu.amp.autocast(enabled=True):
|
||||
with torch.autocast(device_type="cpu", enabled=True):
|
||||
self.assertEqual(fn_s(x), fn(x))
|
||||
|
||||
self.assertTrue(any("is_autocast_cpu_enabled" in x.kind() for x in fn_s.graph.nodes()))
|
||||
@ -888,10 +888,10 @@ class TestJitTraceAutocast(JitTestCase):
|
||||
fn_s = torch.jit.script(fn)
|
||||
|
||||
x = torch.rand((4, 4)) - 0.5
|
||||
with torch.cpu.amp.autocast():
|
||||
with torch.autocast(device_type="cpu"):
|
||||
self.assertEqual(fn_s(x), fn(x))
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=True):
|
||||
with torch.autocast(device_type="cuda", enabled=True):
|
||||
self.assertEqual(fn_s(x), fn(x))
|
||||
|
||||
self.assertTrue(any("is_autocast_enabled" in x.kind() for x in fn_s.graph.nodes()))
|
||||
@ -904,7 +904,7 @@ class TestJitTraceAutocast(JitTestCase):
|
||||
y = True
|
||||
else:
|
||||
y = False
|
||||
with torch.cuda.amp.autocast(enabled=True):
|
||||
with torch.autocast(device_type="cuda", enabled=True):
|
||||
z = x.relu()
|
||||
return y, z
|
||||
|
||||
@ -926,10 +926,10 @@ class TestJitTraceAutocast(JitTestCase):
|
||||
def fn(x, y) -> Tuple[torch.Tensor, bool, torch.Tensor, bool, torch.Tensor, bool]:
|
||||
b1 = torch.is_autocast_cpu_enabled()
|
||||
v1 = torch.mm(x, y)
|
||||
with torch.cpu.amp.autocast(enabled=True):
|
||||
with torch.autocast(device_type="cpu", enabled=True):
|
||||
b2 = torch.is_autocast_cpu_enabled()
|
||||
v2 = torch.mm(x, y)
|
||||
with torch.cpu.amp.autocast(enabled=False):
|
||||
with torch.autocast(device_type="cpu", enabled=False):
|
||||
b3 = torch.is_autocast_cpu_enabled()
|
||||
v3 = torch.mm(x, y)
|
||||
return (v1, b1, v2, b2, v3, b3)
|
||||
@ -946,11 +946,11 @@ class TestJitTraceAutocast(JitTestCase):
|
||||
|
||||
fn_s = torch.jit.script(fn)
|
||||
|
||||
with torch.cpu.amp.autocast(enabled=False):
|
||||
with torch.autocast(device_type="cpu", enabled=False):
|
||||
check_fn_results(fn(x, y))
|
||||
check_fn_results(fn_s(x, y))
|
||||
|
||||
with torch.cpu.amp.autocast(enabled=True):
|
||||
with torch.autocast(device_type="cpu", enabled=True):
|
||||
check_fn_results(fn(x, y))
|
||||
check_fn_results(fn_s(x, y))
|
||||
|
||||
|
@ -68,7 +68,7 @@ class JitLlgaTestCase(JitTestCase):
|
||||
with torch.no_grad(), torch._jit_internal._disable_emit_hooks():
|
||||
if dtype == torch.bfloat16:
|
||||
# We rely upon eager-mode AMP support for BF16
|
||||
with torch.cpu.amp.autocast(cache_enabled=False, dtype=torch.bfloat16):
|
||||
with torch.autocast(device_type="cpu", cache_enabled=False, dtype=torch.bfloat16):
|
||||
traced = torch.jit.trace(m, x)
|
||||
if isinstance(m, torch.nn.Module):
|
||||
traced = torch.jit.freeze(traced)
|
||||
|
Reference in New Issue
Block a user