mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Using device-agnostic autocast api (#136613)
- using torch.autocast(device_str="cuda") instead of torch.cuda.amp.autocast() - using torch.autocast(device_str="cpu") instead of torch.cpu.amp.autocast() Pull Request resolved: https://github.com/pytorch/pytorch/pull/136613 Approved by: https://github.com/shink, https://github.com/cyyever, https://github.com/kwen2501
This commit is contained in:
@ -533,11 +533,11 @@ class TestAutocast(JitTestCase):
|
||||
return torch.mm(x, y)
|
||||
|
||||
def t_cuda_amp_autocast(x, y):
|
||||
with torch.cuda.amp.autocast():
|
||||
with torch.autocast(device_type="cuda"):
|
||||
return torch.mm(x, y)
|
||||
|
||||
def t_cpu_amp_autocast(x, y):
|
||||
with torch.cpu.amp.autocast():
|
||||
with torch.autocast(device_type="cpu"):
|
||||
return torch.mm(x, y)
|
||||
|
||||
x = torch.randn(5, 5, device="cuda", dtype=torch.float32)
|
||||
@ -658,7 +658,7 @@ class TestAutocast(JitTestCase):
|
||||
impl: Iface
|
||||
|
||||
def forward(self, x, y):
|
||||
with torch.cuda.amp.autocast():
|
||||
with torch.autocast(device_type="cuda"):
|
||||
a = torch.mm(x, y)
|
||||
b = self.impl.forward(a, x)
|
||||
return b
|
||||
@ -671,7 +671,7 @@ class TestAutocast(JitTestCase):
|
||||
y = torch.rand([2, 2])
|
||||
|
||||
# make sure this doesn't throw an error
|
||||
with torch.cuda.amp.autocast():
|
||||
with torch.autocast(device_type="cuda"):
|
||||
ans = scripted_thing1.forward(x, y)
|
||||
self.assertEqual(torch.mm(torch.mm(x, y), x), ans)
|
||||
|
||||
@ -683,7 +683,7 @@ class TestAutocast(JitTestCase):
|
||||
def test_jit_freeze_autocast_basic(self):
|
||||
class TestModule(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
with torch.cuda.amp.autocast():
|
||||
with torch.autocast(device_type="cuda"):
|
||||
return torch.mm(x, y)
|
||||
|
||||
x = torch.rand((3, 4), dtype=torch.float).cuda()
|
||||
@ -710,7 +710,7 @@ class TestAutocast(JitTestCase):
|
||||
self.x = torch.rand((3, 4), dtype=torch.float).cuda()
|
||||
|
||||
def forward(self, y):
|
||||
with torch.cuda.amp.autocast():
|
||||
with torch.autocast(device_type="cuda"):
|
||||
return torch.mm(self.x, y)
|
||||
|
||||
y = torch.rand((4, 5), dtype=torch.float).cuda()
|
||||
@ -729,7 +729,7 @@ class TestAutocast(JitTestCase):
|
||||
@unittest.skipIf(TEST_CUDA, "CPU-only test")
|
||||
def test_jit_autocast_softmax_cpu(self):
|
||||
def fn(x):
|
||||
with torch.cpu.amp.autocast():
|
||||
with torch.autocast(device_type="cpu"):
|
||||
return torch.nn.functional.softmax(x, dim=0)
|
||||
|
||||
fn_s = torch.jit.script(fn)
|
||||
@ -742,7 +742,7 @@ class TestAutocast(JitTestCase):
|
||||
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
||||
def test_jit_autocast_softmax_gpu(self):
|
||||
def fn(x):
|
||||
with torch.cuda.amp.autocast():
|
||||
with torch.autocast(device_type="cuda"):
|
||||
return torch.nn.functional.softmax(x, dim=0)
|
||||
|
||||
fn_s = torch.jit.script(fn)
|
||||
@ -759,7 +759,7 @@ class TestAutocast(JitTestCase):
|
||||
|
||||
inp = torch.rand([10, 10], dtype=torch.float)
|
||||
foo._set_ignore_amp(True)
|
||||
with torch.cpu.amp.autocast():
|
||||
with torch.autocast(device_type="cpu"):
|
||||
foo(inp)
|
||||
foo(inp)
|
||||
|
||||
@ -797,7 +797,7 @@ class TestJitTraceAutocast(JitTestCase):
|
||||
def test_generate_autocast_jit_trace_model(self):
|
||||
def test_generate_autocast_jit_trace_model(model, x):
|
||||
model.eval()
|
||||
with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad():
|
||||
with torch.autocast(device_type="cpu", cache_enabled=False), torch.no_grad():
|
||||
traced_model = torch.jit.trace(model, x)
|
||||
traced_model = torch.jit.freeze(traced_model)
|
||||
for i in range(self.models.__len__()):
|
||||
@ -806,12 +806,12 @@ class TestJitTraceAutocast(JitTestCase):
|
||||
def test_nchw_autocast_jit_trace_model(self):
|
||||
def test_nchw_autocast_jit_trace_model(model, x):
|
||||
model.eval()
|
||||
with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad():
|
||||
with torch.autocast(device_type="cpu", cache_enabled=False), torch.no_grad():
|
||||
traced_model = torch.jit.trace(model, x)
|
||||
traced_model = torch.jit.freeze(traced_model)
|
||||
with torch.no_grad():
|
||||
y = traced_model(x.clone())
|
||||
with torch.cpu.amp.autocast(), torch.no_grad():
|
||||
with torch.autocast(device_type="cpu"), torch.no_grad():
|
||||
y2 = model(x.clone())
|
||||
torch.testing.assert_close(y.double(), y2.double(), rtol=1e-03, atol=1e-03)
|
||||
for i in range(self.models.__len__()):
|
||||
@ -821,12 +821,12 @@ class TestJitTraceAutocast(JitTestCase):
|
||||
def test_nhwc_autocast_jit_trace_model(model, x):
|
||||
model = model.to(memory_format=torch.channels_last)
|
||||
model.eval()
|
||||
with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad():
|
||||
with torch.autocast(device_type="cpu", cache_enabled=False), torch.no_grad():
|
||||
traced_model = torch.jit.trace(model, x.to(memory_format=torch.channels_last))
|
||||
traced_model = torch.jit.freeze(traced_model)
|
||||
with torch.no_grad():
|
||||
y = traced_model(x.clone().to(memory_format=torch.channels_last))
|
||||
with torch.cpu.amp.autocast(), torch.no_grad():
|
||||
with torch.autocast(device_type="cpu"), torch.no_grad():
|
||||
y2 = model(x.clone().to(memory_format=torch.channels_last))
|
||||
torch.testing.assert_close(y.double(), y2.double(), rtol=1e-03, atol=1e-03)
|
||||
for i in range(self.models.__len__()):
|
||||
@ -845,7 +845,7 @@ class TestJitTraceAutocast(JitTestCase):
|
||||
# To avoid the fusion group from TE, we will disable the fuser here.
|
||||
for jit_freeze_or_not in [False, True]:
|
||||
test_model = TestModel().eval()
|
||||
with torch.cpu.amp.autocast(cache_enabled=False, dtype=torch.bfloat16), torch.no_grad():
|
||||
with torch.autocast(device_type="cpu", cache_enabled=False, dtype=torch.bfloat16), torch.no_grad():
|
||||
a = torch.rand(24, 128, 128)
|
||||
b = torch.rand(24, 128, 128, dtype=torch.bfloat16)
|
||||
c = test_model(a, b)
|
||||
@ -869,10 +869,10 @@ class TestJitTraceAutocast(JitTestCase):
|
||||
fn_s = torch.jit.script(fn)
|
||||
|
||||
x = torch.rand((4, 4)) - 0.5
|
||||
with torch.cpu.amp.autocast():
|
||||
with torch.autocast(device_type="cpu"):
|
||||
self.assertEqual(fn_s(x), fn(x))
|
||||
|
||||
with torch.cpu.amp.autocast(enabled=True):
|
||||
with torch.autocast(device_type="cpu", enabled=True):
|
||||
self.assertEqual(fn_s(x), fn(x))
|
||||
|
||||
self.assertTrue(any("is_autocast_cpu_enabled" in x.kind() for x in fn_s.graph.nodes()))
|
||||
@ -888,10 +888,10 @@ class TestJitTraceAutocast(JitTestCase):
|
||||
fn_s = torch.jit.script(fn)
|
||||
|
||||
x = torch.rand((4, 4)) - 0.5
|
||||
with torch.cpu.amp.autocast():
|
||||
with torch.autocast(device_type="cpu"):
|
||||
self.assertEqual(fn_s(x), fn(x))
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=True):
|
||||
with torch.autocast(device_type="cuda", enabled=True):
|
||||
self.assertEqual(fn_s(x), fn(x))
|
||||
|
||||
self.assertTrue(any("is_autocast_enabled" in x.kind() for x in fn_s.graph.nodes()))
|
||||
@ -904,7 +904,7 @@ class TestJitTraceAutocast(JitTestCase):
|
||||
y = True
|
||||
else:
|
||||
y = False
|
||||
with torch.cuda.amp.autocast(enabled=True):
|
||||
with torch.autocast(device_type="cuda", enabled=True):
|
||||
z = x.relu()
|
||||
return y, z
|
||||
|
||||
@ -926,10 +926,10 @@ class TestJitTraceAutocast(JitTestCase):
|
||||
def fn(x, y) -> Tuple[torch.Tensor, bool, torch.Tensor, bool, torch.Tensor, bool]:
|
||||
b1 = torch.is_autocast_cpu_enabled()
|
||||
v1 = torch.mm(x, y)
|
||||
with torch.cpu.amp.autocast(enabled=True):
|
||||
with torch.autocast(device_type="cpu", enabled=True):
|
||||
b2 = torch.is_autocast_cpu_enabled()
|
||||
v2 = torch.mm(x, y)
|
||||
with torch.cpu.amp.autocast(enabled=False):
|
||||
with torch.autocast(device_type="cpu", enabled=False):
|
||||
b3 = torch.is_autocast_cpu_enabled()
|
||||
v3 = torch.mm(x, y)
|
||||
return (v1, b1, v2, b2, v3, b3)
|
||||
@ -946,11 +946,11 @@ class TestJitTraceAutocast(JitTestCase):
|
||||
|
||||
fn_s = torch.jit.script(fn)
|
||||
|
||||
with torch.cpu.amp.autocast(enabled=False):
|
||||
with torch.autocast(device_type="cpu", enabled=False):
|
||||
check_fn_results(fn(x, y))
|
||||
check_fn_results(fn_s(x, y))
|
||||
|
||||
with torch.cpu.amp.autocast(enabled=True):
|
||||
with torch.autocast(device_type="cpu", enabled=True):
|
||||
check_fn_results(fn(x, y))
|
||||
check_fn_results(fn_s(x, y))
|
||||
|
||||
|
Reference in New Issue
Block a user