mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
related commits: - #139706 - #140238 - #140247 - #140253 Pull Request resolved: https://github.com/pytorch/pytorch/pull/140238 Approved by: https://github.com/soulitzer
213 lines
5.6 KiB
Python
213 lines
5.6 KiB
Python
# Owner(s): ["module: cuda graphs"]
|
|
|
|
import functools
|
|
import unittest
|
|
|
|
import torch
|
|
import torch._dynamo
|
|
import torch._dynamo.config
|
|
import torch._dynamo.test_case
|
|
import torch._dynamo.testing
|
|
from torch._dynamo.testing import same
|
|
from torch.testing._internal.common_utils import TEST_CUDA_GRAPH
|
|
|
|
|
|
def composed(*decs):
|
|
def deco(f):
|
|
for dec in reversed(decs):
|
|
f = dec(f)
|
|
return f
|
|
|
|
return deco
|
|
|
|
|
|
def assert_aot_autograd_counter(ok=True):
|
|
def deco(f):
|
|
@functools.wraps(f)
|
|
def wrap(self, *args, **kwargs):
|
|
torch._dynamo.utils.counters.clear()
|
|
r = f(self, *args, **kwargs)
|
|
c_ok = torch._dynamo.utils.counters["aot_autograd"]["ok"]
|
|
c_not_ok = torch._dynamo.utils.counters["aot_autograd"]["not_ok"]
|
|
if ok:
|
|
self.assertGreater(c_ok, 0)
|
|
self.assertEqual(c_not_ok, 0)
|
|
else:
|
|
self.assertEqual(c_ok, 0)
|
|
self.assertGreater(c_not_ok, 0)
|
|
return r
|
|
|
|
return wrap
|
|
|
|
return deco
|
|
|
|
|
|
def patch_all(ok=True):
|
|
return composed(
|
|
torch._dynamo.config.patch(
|
|
verify_correctness=True, automatic_dynamic_shapes=True
|
|
),
|
|
assert_aot_autograd_counter(ok),
|
|
)
|
|
|
|
|
|
N_ITERS = 5
|
|
|
|
|
|
@unittest.skipIf(not torch.cuda.is_available(), "these tests require cuda")
|
|
class TestAotCudagraphs(torch._dynamo.test_case.TestCase):
|
|
@patch_all()
|
|
def test_basic(self):
|
|
def model(x, y):
|
|
return (x + y) * y
|
|
|
|
@torch.compile(backend="cudagraphs")
|
|
def fn(x, y):
|
|
for i in range(N_ITERS):
|
|
loss = model(x, y).sum()
|
|
loss.backward()
|
|
|
|
x = torch.randn(3, device="cuda", requires_grad=True)
|
|
y = torch.randn(3, device="cuda")
|
|
fn(x, y)
|
|
|
|
@patch_all()
|
|
def test_dtoh(self):
|
|
def model(x, y):
|
|
a = x + y
|
|
b = a.cpu() * 3
|
|
return b
|
|
|
|
@torch.compile(backend="cudagraphs")
|
|
def fn(x, y):
|
|
for i in range(N_ITERS):
|
|
loss = model(x, y).sum()
|
|
loss.backward()
|
|
|
|
x = torch.randn(3, device="cuda", requires_grad=True)
|
|
y = torch.randn(3, device="cuda")
|
|
fn(x, y)
|
|
|
|
@patch_all()
|
|
def test_htod(self):
|
|
def model(x, y):
|
|
a = x + y
|
|
return a * 3
|
|
|
|
@torch.compile(backend="cudagraphs")
|
|
def fn(x, y):
|
|
for i in range(N_ITERS):
|
|
loss = model(x, y).sum()
|
|
loss.backward()
|
|
|
|
x = torch.randn(3, device="cuda", requires_grad=True)
|
|
y = torch.randn((), device="cpu")
|
|
fn(x, y)
|
|
|
|
def test_mutate_input(self):
|
|
def model(x, y):
|
|
y.add_(3)
|
|
return x * y
|
|
|
|
@torch.compile(backend="cudagraphs")
|
|
def fn(x, y):
|
|
for i in range(N_ITERS):
|
|
with self.subTest(i):
|
|
y_orig = y.clone()
|
|
loss = model(x, y).sum()
|
|
self.assertTrue(same(y, y_orig + 3))
|
|
loss.backward()
|
|
|
|
x = torch.randn(3, device="cuda", requires_grad=True)
|
|
y = torch.randn(3, device="cuda")
|
|
fn(x, y)
|
|
|
|
@patch_all()
|
|
def test_mutate_constant(self):
|
|
def model(x, y):
|
|
c = torch.tensor(1)
|
|
c.add_(2)
|
|
return x * y * 0 + c
|
|
|
|
@torch.compile(backend="cudagraphs")
|
|
def fn(x, y):
|
|
for i in range(N_ITERS):
|
|
with self.subTest(i):
|
|
loss = model(x, y).sum()
|
|
self.assertTrue(same(loss, torch.tensor(3.0, device="cuda")))
|
|
loss.backward()
|
|
|
|
x = torch.randn(1, device="cuda", requires_grad=True)
|
|
y = torch.randn(1, device="cuda")
|
|
fn(x, y)
|
|
|
|
@patch_all()
|
|
def test_factory(self):
|
|
def model(y):
|
|
x = torch.zeros(3, device="cuda:0")
|
|
x.add_(3)
|
|
return x * y
|
|
|
|
@torch.compile(backend="cudagraphs")
|
|
def fn(y):
|
|
for i in range(N_ITERS):
|
|
with self.subTest(i):
|
|
loss = model(y).sum()
|
|
loss.backward()
|
|
|
|
y = torch.randn(3, device="cuda:0", requires_grad=True)
|
|
fn(y)
|
|
|
|
@patch_all()
|
|
def test_mutated_metadata(self):
|
|
# more tortured example at
|
|
# https://github.com/pytorch/pytorch/issues/81385
|
|
def model(x):
|
|
x = x.clone()
|
|
x.resize_(20)
|
|
x.fill_(2)
|
|
return x
|
|
|
|
@torch.compile(backend="cudagraphs")
|
|
def fn(x):
|
|
for i in range(N_ITERS):
|
|
with self.subTest(i):
|
|
rx = model(x)
|
|
self.assertTrue(same(rx, torch.full((20,), 2.0, device="cuda:0")))
|
|
|
|
x = torch.empty(0, device="cuda:0")
|
|
fn(x)
|
|
|
|
@patch_all()
|
|
def test_dead_fill(self):
|
|
def model(x):
|
|
x = x.clone()
|
|
y = x[0:0]
|
|
x.fill_(2)
|
|
y.fill_(3)
|
|
return x, y
|
|
|
|
@torch.compile(backend="cudagraphs")
|
|
def fn(x):
|
|
for i in range(N_ITERS):
|
|
with self.subTest(i):
|
|
rx, ry = model(x)
|
|
self.assertTrue(same(rx, torch.full((20,), 2.0, device="cuda:0")))
|
|
self.assertTrue(same(ry, torch.empty(0, device="cuda:0")))
|
|
|
|
x = torch.empty(20, device="cuda:0")
|
|
fn(x)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
if not TEST_CUDA_GRAPH:
|
|
if __name__ == "__main__":
|
|
import sys
|
|
|
|
sys.exit(0)
|
|
raise unittest.SkipTest("cuda graph test is skipped")
|
|
|
|
run_tests()
|