mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140597 Approved by: https://github.com/ezyang, https://github.com/anijain2305
128 lines
4.1 KiB
Python
128 lines
4.1 KiB
Python
# Owner(s): ["module: inductor"]
|
|
|
|
import sys
|
|
import unittest
|
|
|
|
from torch.testing._internal.common_utils import (
|
|
IS_CI,
|
|
IS_WINDOWS,
|
|
skipIfRocm,
|
|
skipIfXpu,
|
|
)
|
|
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU, requires_gpu
|
|
|
|
|
|
if IS_WINDOWS and IS_CI:
|
|
sys.stderr.write(
|
|
"Windows CI does not have necessary dependencies for test_memory_planning yet\n"
|
|
)
|
|
if __name__ == "__main__":
|
|
sys.exit(0)
|
|
raise unittest.SkipTest("requires sympy/functorch/filelock") # noqa: F821
|
|
|
|
import torch
|
|
from torch._C import FileCheck
|
|
from torch._dynamo.utils import same
|
|
from torch._inductor import config
|
|
from torch._inductor.test_case import run_tests, TestCase
|
|
from torch._inductor.utils import run_and_get_cpp_code
|
|
from torch.export import Dim
|
|
|
|
|
|
@requires_gpu()
|
|
@config.patch(memory_planning=True)
|
|
class TestMemoryPlanning(TestCase):
|
|
device = GPU_TYPE
|
|
|
|
def _generate(self, *, device):
|
|
"""
|
|
Generate a simple test case that has multiple simultaneously-live intermediate tensors.
|
|
"""
|
|
|
|
class Foo(torch.nn.Module):
|
|
def forward(self, x, y, z):
|
|
t0 = x.matmul(y)
|
|
t1 = x.matmul(z)
|
|
t0 = x.transpose(0, 1).matmul(t1)
|
|
t1 = x.matmul(t0)
|
|
return t0.sum() + t1.sum()
|
|
|
|
x = torch.randn((3, 2), device=device)
|
|
y = torch.randn((2, 4), device=device)
|
|
z = torch.randn((2, 3), device=device)
|
|
return (Foo(), (x, y, z))
|
|
|
|
def test_python_wrapper(self):
|
|
f, args = self._generate(device=GPU_TYPE)
|
|
compiled = torch.compile(f, dynamic=True)
|
|
result, code = run_and_get_cpp_code(compiled, *args)
|
|
|
|
FileCheck().check(
|
|
"pool1 = empty_strided_" + GPU_TYPE + "((4*s0*s1 + align(4*s0*s0), ), (1, )"
|
|
).check_next(
|
|
"buf0 = alloc_from_pool(pool1, 0, torch.float32, (s0, s0), (s0, 1))"
|
|
).check(
|
|
"buf1 = alloc_from_pool(pool1, align(4*s0*s0),"
|
|
).run(
|
|
code
|
|
)
|
|
self.assertTrue(same(f(*args), result))
|
|
|
|
def test_cpp_wrapper(self):
|
|
f, args = self._generate(device=GPU_TYPE)
|
|
compiled = torch.compile(f, dynamic=True)
|
|
with config.patch({"cpp_wrapper": True}):
|
|
result, code = run_and_get_cpp_code(compiled, *args)
|
|
|
|
FileCheck().check(
|
|
"aoti_torch__alloc_from_pool(pool1, 0, cached_torch_dtype_float32, 2, int_array_4, int_array_5, &tmp_tensor_handle_1)"
|
|
).check_next("auto buf0 = RAIIAtenTensorHandle(tmp_tensor_handle_1);").check(
|
|
"auto buf1 = RAIIAtenTensorHandle(tmp_tensor_handle_2);"
|
|
).run(
|
|
code
|
|
)
|
|
self.assertTrue(same(f(*args), result))
|
|
|
|
@skipIfRocm(msg="test_aot_inductor doesn't work on ROCm")
|
|
@skipIfXpu(msg="aoti doesn't work on XPU")
|
|
def test_aoti(self):
|
|
try:
|
|
from .test_aot_inductor import AOTIRunnerUtil
|
|
except ImportError:
|
|
from test_aot_inductor import ( # @manual=fbcode//caffe2/test/inductor:test_aot_inductor-library
|
|
AOTIRunnerUtil,
|
|
)
|
|
|
|
f, args = self._generate(device=GPU_TYPE)
|
|
dim0_x = Dim("dim0_x", min=1, max=2048)
|
|
dynamic_shapes = ({0: dim0_x}, None, None)
|
|
result, code = run_and_get_cpp_code(
|
|
lambda: AOTIRunnerUtil.run(GPU_TYPE, f, args, dynamic_shapes=dynamic_shapes)
|
|
)
|
|
|
|
FileCheck().check(
|
|
"int64_t int_array_2[] = {24L + align(12L*s0), };"
|
|
).check_next("int64_t int_array_3[] = {1L, };").check_next(
|
|
"AtenTensorHandle pool1_handle;"
|
|
).check_next(
|
|
"aoti_torch_empty_strided(1, int_array_2, int_array_3,"
|
|
).check_next(
|
|
"RAIIAtenTensorHandle pool1(pool1_handle);"
|
|
).check_next(
|
|
"int64_t int_array_4[] = {s0, 3L};"
|
|
).check_next(
|
|
"int64_t int_array_5[] = {3L, 1L};"
|
|
).check_next(
|
|
"AtenTensorHandle tmp_tensor_handle_1;"
|
|
).check_next(
|
|
"aoti_torch__alloc_from_pool(pool1, 0"
|
|
).run(
|
|
code
|
|
)
|
|
self.assertTrue(same(f(*args), result))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if HAS_GPU:
|
|
run_tests()
|