mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Use `WritableTempFile` on Windows, reference to: https://github.com/pytorch/pytorch/pull/159342 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159738 Approved by: https://github.com/angelayi, https://github.com/Skylion007
440 lines
17 KiB
Python
440 lines
17 KiB
Python
# Owner(s): ["module: functorch"]
|
|
import json
|
|
import zipfile
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import torch._dynamo
|
|
import torch._functorch
|
|
import torch._inductor
|
|
import torch._inductor.decomposition
|
|
from torch._higher_order_ops.torchbind import CallTorchBind, enable_torchbind_tracing
|
|
from torch._inductor import aot_compile, ir
|
|
from torch._inductor.codecache import WritableTempFile
|
|
from torch._inductor.package import package_aoti
|
|
from torch._inductor.test_case import run_tests, TestCase
|
|
from torch.testing._internal.inductor_utils import GPU_TYPE, requires_gpu
|
|
from torch.testing._internal.torchbind_impls import (
|
|
_empty_tensor_queue,
|
|
init_torchbind_implementations,
|
|
)
|
|
|
|
|
|
class TestTorchbind(TestCase):
|
|
def setUp(self):
|
|
super().setUp()
|
|
init_torchbind_implementations()
|
|
|
|
def get_dummy_exported_model(self):
|
|
"""
|
|
Returns the ExportedProgram, example inputs, and result from calling the
|
|
eager model with those inputs
|
|
"""
|
|
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x + 1
|
|
|
|
m = M()
|
|
inputs = (torch.ones(2, 3),)
|
|
orig_res = m(*inputs)
|
|
|
|
ep = torch.export.export(m, inputs, strict=False)
|
|
|
|
return ep, inputs, orig_res, m
|
|
|
|
def get_exported_model(self):
|
|
"""
|
|
Returns the ExportedProgram, example inputs, and result from calling the
|
|
eager model with those inputs
|
|
"""
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
|
|
self.b = torch.randn(2, 3)
|
|
|
|
def forward(self, x):
|
|
x = x + self.b
|
|
a = torch.ops._TorchScriptTesting.takes_foo_tuple_return(self.attr, x)
|
|
y = a[0] + a[1]
|
|
b = torch.ops._TorchScriptTesting.takes_foo(self.attr, y)
|
|
c = self.attr.add_tensor(x)
|
|
return x + b + c
|
|
|
|
m = M()
|
|
inputs = (torch.ones(2, 3),)
|
|
orig_res = m(*inputs)
|
|
|
|
# We can't directly torch.compile because dynamo doesn't trace ScriptObjects yet
|
|
with enable_torchbind_tracing():
|
|
ep = torch.export.export(m, inputs, strict=False)
|
|
|
|
return ep, inputs, orig_res, m
|
|
|
|
def test_torchbind_inductor(self):
|
|
ep, inputs, orig_res, _ = self.get_exported_model()
|
|
compiled = torch._inductor.compile(ep.module(), inputs)
|
|
|
|
new_res = compiled(*inputs)
|
|
self.assertTrue(torch.allclose(orig_res, new_res))
|
|
|
|
def test_torchbind_compile_symint(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.attr = torch.classes._TorchScriptTesting._Foo(2, 3)
|
|
|
|
def forward(self, x):
|
|
a = torch.ops._TorchScriptTesting.takes_foo_tensor_return(self.attr, x)
|
|
return a
|
|
|
|
m = M()
|
|
inputs = (torch.ones(2, 3),)
|
|
orig_res = m(*inputs)
|
|
new_res = torch.compile(m, backend="inductor")(*inputs)
|
|
self.assertTrue(torch.allclose(orig_res, new_res))
|
|
|
|
def test_torchbind_compile(self):
|
|
_, inputs, orig_res, mod = self.get_exported_model()
|
|
new_res = torch.compile(mod, backend="inductor")(*inputs)
|
|
self.assertTrue(torch.allclose(orig_res, new_res))
|
|
|
|
def test_torchbind_get_buf_bytes(self):
|
|
a = torch.classes._TorchScriptTesting._Foo(10, 20)
|
|
buffer = ir.TorchBindObject(name="a", value=a)
|
|
size = buffer.get_buf_bytes()
|
|
self.assertEqual(size, 0)
|
|
|
|
t = torch.randn(2, 3)
|
|
b = torch.classes._TorchScriptTesting._ContainsTensor(t)
|
|
buffer = ir.TorchBindObject(name="b", value=b)
|
|
size = buffer.get_buf_bytes()
|
|
self.assertEqual(size, 2 * 3 * 4)
|
|
|
|
q = _empty_tensor_queue()
|
|
buffer = ir.TorchBindObject(name="q", value=q)
|
|
size = buffer.get_buf_bytes()
|
|
self.assertEqual(size, 0)
|
|
|
|
q.push(torch.ones(2, 3))
|
|
size = buffer.get_buf_bytes()
|
|
self.assertEqual(size, 2 * 3 * 4)
|
|
|
|
def test_torchbind_hop_schema(self):
|
|
foo = torch.classes._TorchScriptTesting._Foo(10, 20)
|
|
foo_ir = ir.TorchBindObject(name="foo", value=foo)
|
|
schema = CallTorchBind.schema(foo_ir, "add")
|
|
self.assertEqual(
|
|
str(schema),
|
|
"call_torchbind(__torch__.torch.classes._TorchScriptTesting._Foo _0, str method, int _1) -> int _0",
|
|
)
|
|
|
|
def test_torchbind_config_not_generated(self):
|
|
# custom_objs_config.json should not be generated when its empty
|
|
ep, inputs, _, _ = self.get_dummy_exported_model()
|
|
aoti_files = aot_compile(
|
|
ep.module(), inputs, options={"aot_inductor.package": True}
|
|
)
|
|
for file in aoti_files:
|
|
self.assertTrue(not file.endswith("/custom_objs_config.json"))
|
|
|
|
def test_torchbind_hop_schema_no_input(self):
|
|
q = _empty_tensor_queue()
|
|
q_ir = ir.TorchBindObject(name="q", value=q)
|
|
schema = CallTorchBind.schema(q_ir, "pop")
|
|
self.assertEqual(
|
|
str(schema),
|
|
"call_torchbind(__torch__.torch.classes._TorchScriptTesting._TensorQueue _0, str method) -> Tensor _0",
|
|
)
|
|
|
|
def test_torchbind_hop_schema_no_output(self):
|
|
q = _empty_tensor_queue()
|
|
q_ir = ir.TorchBindObject(name="q", value=q)
|
|
schema = CallTorchBind.schema(q_ir, "push")
|
|
self.assertEqual(
|
|
str(schema),
|
|
"call_torchbind(__torch__.torch.classes._TorchScriptTesting._TensorQueue _0, str method, Tensor _1) -> NoneType _0",
|
|
)
|
|
|
|
def test_torchbind_aot_compile(self):
|
|
ep, inputs, _, _ = self.get_exported_model()
|
|
aoti_files = aot_compile(
|
|
ep.module(), inputs, options={"aot_inductor.package": True}
|
|
)
|
|
|
|
custom_objs_config = None
|
|
custom_obj_0 = None
|
|
extern_json = None
|
|
for file in aoti_files:
|
|
if file.endswith("/custom_objs_config.json"):
|
|
custom_objs_config = file
|
|
elif file.endswith("/custom_obj_0"):
|
|
custom_obj_0 = file
|
|
elif file.endswith(".json") and "metadata" not in file:
|
|
extern_json = file
|
|
|
|
self.assertIsNotNone(custom_objs_config)
|
|
self.assertIsNotNone(custom_obj_0)
|
|
self.assertIsNotNone(extern_json)
|
|
|
|
with open(custom_objs_config) as file:
|
|
data = json.load(file)
|
|
self.assertEqual(data, {"_torchbind_obj0": "custom_obj_0"})
|
|
|
|
with open(extern_json) as file:
|
|
data = json.load(file)
|
|
self.assertEqual(
|
|
data,
|
|
{
|
|
"nodes": [
|
|
{
|
|
"name": "buf3",
|
|
"node": {
|
|
"target": "_TorchScriptTesting::takes_foo_tuple_return",
|
|
"inputs": [
|
|
{
|
|
"name": "foo",
|
|
"arg": {
|
|
"as_custom_obj": {
|
|
"name": "_torchbind_obj0",
|
|
"class_fqn": "__torch__.torch.classes._TorchScriptTesting._Foo",
|
|
}
|
|
},
|
|
"kind": 1,
|
|
},
|
|
{
|
|
"name": "x",
|
|
"arg": {"as_tensor": {"name": "buf2"}},
|
|
"kind": 1,
|
|
},
|
|
],
|
|
"outputs": [
|
|
{"as_tensor": {"name": "buf4"}},
|
|
{"as_tensor": {"name": "buf5"}},
|
|
],
|
|
"metadata": {},
|
|
"is_hop_single_tensor_return": None,
|
|
},
|
|
},
|
|
{
|
|
"name": "buf7",
|
|
"node": {
|
|
"target": "_TorchScriptTesting::takes_foo",
|
|
"inputs": [
|
|
{
|
|
"name": "foo",
|
|
"arg": {
|
|
"as_custom_obj": {
|
|
"name": "_torchbind_obj0",
|
|
"class_fqn": "__torch__.torch.classes._TorchScriptTesting._Foo",
|
|
}
|
|
},
|
|
"kind": 1,
|
|
},
|
|
{
|
|
"name": "x",
|
|
"arg": {"as_tensor": {"name": "buf6"}},
|
|
"kind": 1,
|
|
},
|
|
],
|
|
"outputs": [{"as_tensor": {"name": "buf8"}}],
|
|
"metadata": {},
|
|
"is_hop_single_tensor_return": None,
|
|
},
|
|
},
|
|
{
|
|
"name": "buf9",
|
|
"node": {
|
|
"target": "call_torchbind",
|
|
"inputs": [
|
|
{
|
|
"name": "_0",
|
|
"arg": {
|
|
"as_custom_obj": {
|
|
"name": "_torchbind_obj0",
|
|
"class_fqn": "__torch__.torch.classes._TorchScriptTesting._Foo",
|
|
}
|
|
},
|
|
"kind": 1,
|
|
},
|
|
{
|
|
"name": "method",
|
|
"arg": {"as_string": "add_tensor"},
|
|
"kind": 1,
|
|
},
|
|
{
|
|
"name": "_1",
|
|
"arg": {"as_tensor": {"name": "buf2"}},
|
|
"kind": 1,
|
|
},
|
|
],
|
|
"outputs": [{"as_tensor": {"name": "buf10"}}],
|
|
"metadata": {},
|
|
"is_hop_single_tensor_return": None,
|
|
},
|
|
},
|
|
]
|
|
},
|
|
)
|
|
|
|
# Test that the files are packaged
|
|
with WritableTempFile(suffix=".pt2") as f:
|
|
package_path = package_aoti(f.name, aoti_files)
|
|
|
|
with zipfile.ZipFile(package_path, "r") as zip_ref:
|
|
all_files = zip_ref.namelist()
|
|
base_folder = all_files[0].split("/")[0]
|
|
tmp_path_model = Path(base_folder) / "data" / "aotinductor" / "model"
|
|
tmp_path_constants = Path(base_folder) / "data" / "constants"
|
|
|
|
self.assertTrue(
|
|
str(tmp_path_model / "custom_objs_config.json") in all_files
|
|
)
|
|
self.assertTrue(str(tmp_path_constants / "custom_obj_0") in all_files)
|
|
|
|
def test_torchbind_aoti(self):
|
|
ep, inputs, orig_res, _ = self.get_exported_model()
|
|
pt2_path = torch._inductor.aoti_compile_and_package(ep)
|
|
optimized = torch._inductor.aoti_load_package(pt2_path)
|
|
result = optimized(*inputs)
|
|
self.assertEqual(result, orig_res)
|
|
|
|
@torch._inductor.config.patch("aot_inductor.use_runtime_constant_folding", True)
|
|
def test_torchbind_aot_compile_constant_folding(self):
|
|
ep, inputs, orig_res, _ = self.get_exported_model()
|
|
pt2_path = torch._inductor.aoti_compile_and_package(ep)
|
|
optimized = torch._inductor.aoti_load_package(pt2_path)
|
|
result = optimized(*inputs)
|
|
self.assertEqual(result, orig_res)
|
|
|
|
def test_torchbind_list_return_aot_compile(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
|
|
|
|
def forward(self, x):
|
|
a = torch.ops._TorchScriptTesting.takes_foo_list_return(self.attr, x)
|
|
y = a[0] + a[1] + a[2]
|
|
b = torch.ops._TorchScriptTesting.takes_foo(self.attr, y)
|
|
return x + b
|
|
|
|
m = M()
|
|
inputs = (torch.ones(2, 3),)
|
|
orig_res = m(*inputs)
|
|
|
|
# We can't directly torch.compile because dynamo doesn't trace ScriptObjects yet
|
|
with enable_torchbind_tracing():
|
|
ep = torch.export.export(m, inputs, strict=False)
|
|
|
|
pt2_path = torch._inductor.aoti_compile_and_package(ep)
|
|
optimized = torch._inductor.aoti_load_package(pt2_path)
|
|
result = optimized(*inputs)
|
|
self.assertEqual(result, orig_res)
|
|
|
|
def test_torchbind_queue(self):
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self, tq) -> None:
|
|
super().__init__()
|
|
self.tq = tq
|
|
|
|
def forward(self, x):
|
|
self.tq.push(x.cos())
|
|
self.tq.push(x.sin())
|
|
# TODO: int return type in fallback kernel not support yet
|
|
x_cos = self.tq.pop() # + self.tq.size()
|
|
x_sin = self.tq.pop() # - self.tq.size()
|
|
return x_sin, x_cos
|
|
|
|
inputs = (torch.randn(3, 2),)
|
|
|
|
q = _empty_tensor_queue()
|
|
m = Foo(q)
|
|
orig_res = m(*inputs)
|
|
|
|
q2 = _empty_tensor_queue()
|
|
m2 = Foo(q2)
|
|
|
|
# We can't directly torch.compile because dynamo doesn't trace ScriptObjects yet
|
|
with enable_torchbind_tracing():
|
|
ep = torch.export.export(m2, inputs, strict=False)
|
|
|
|
pt2_path = torch._inductor.aoti_compile_and_package(ep)
|
|
optimized = torch._inductor.aoti_load_package(pt2_path)
|
|
result = optimized(*inputs)
|
|
self.assertEqual(result, orig_res)
|
|
|
|
@requires_gpu()
|
|
@torch._dynamo.config.patch("capture_dynamic_output_shape_ops", True)
|
|
@torch._inductor.config.patch("graph_partition", True)
|
|
def test_torchbind_compile_gpu_op_symint_graph_partition(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.attr = torch.classes._TorchScriptTesting._Foo(2, 3)
|
|
|
|
def forward(self, x):
|
|
a = torch.ops._TorchScriptTesting.takes_foo_tensor_return(self.attr, x)
|
|
a_cuda = a.to(device=GPU_TYPE)
|
|
return a_cuda + 1
|
|
|
|
m = M()
|
|
inputs = (torch.ones(2, 3),)
|
|
orig_res = m(*inputs)
|
|
new_res = torch.compile(m, backend="inductor")(*inputs)
|
|
self.assertTrue(torch.allclose(orig_res, new_res))
|
|
|
|
def test_torchbind_input_aot_compile(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x, y):
|
|
a = torch.ops._TorchScriptTesting.takes_foo_list_return(x, y)
|
|
return a
|
|
|
|
m = M()
|
|
inputs = (torch.classes._TorchScriptTesting._Foo(10, 20), torch.ones(2, 3))
|
|
|
|
# We can't directly torch.compile because dynamo doesn't trace ScriptObjects yet
|
|
with enable_torchbind_tracing():
|
|
ep = torch.export.export(m, inputs, strict=False)
|
|
|
|
from torch._dynamo.exc import UserError
|
|
|
|
with self.assertRaisesRegex(
|
|
UserError,
|
|
expected_regex="TorchBind object inputs are not supported in AOTInductor",
|
|
):
|
|
aot_compile(ep.module(), inputs, options={"aot_inductor.package": True})
|
|
|
|
def test_aoti_torchbind_name_collision(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self._torchbind_obj0 = torch.classes._TorchScriptTesting._Foo(2, 3)
|
|
|
|
def forward(self, x):
|
|
a = self._torchbind_obj0.add_tensor(x)
|
|
torchbind = torch.classes._TorchScriptTesting._Foo(4, 5)
|
|
b = torchbind.add_tensor(x)
|
|
return a + b
|
|
|
|
m = M()
|
|
inputs = (torch.ones(2, 3),)
|
|
orig_res = m(*inputs)
|
|
|
|
with enable_torchbind_tracing():
|
|
ep = torch.export.export(m, inputs, strict=False)
|
|
|
|
pt2_path = torch._inductor.aoti_compile_and_package(ep)
|
|
optimized = torch._inductor.aoti_load_package(pt2_path)
|
|
result = optimized(*inputs)
|
|
self.assertEqual(result, orig_res)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|