mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
* Added a cpp loader, AOTIModelPackageLoader, which can load the .pt2, build the .so, and create a runner. The python-facing API is that users can directly call the `run` function, whereas in cpp users can directly access the `runner_` if they are more familiar with that. I couldn't figure out how to bind the `get_runner()` function to python... * Added a new config, `aot_inductor.package_cpp_only` which will **not** package the so. This means that whenever the package is loaded, we will need to build the so. This is turned off by default so that new environments do not need to rebuild their so. The `package_cpp_only` is a feature which torchchat intends to use to provide flexibility to users. * Added a new config, `aot_inductor.metadata` which stores user-provided metadata, serialized to the pt2 as a json file. It also stores the device used when exporting, "cuda" or "cpu", so that during load time, we can use that data to determine which AOTIModelContainerRunner to use. The metadata can be accessed through `loader.get_metadata()`. TODO is to move this metadata to the toplevel `package_aoti` function so that we can remove the metadata as a config. * Separated out `package_aoti` as a standalone function, instead of it automatically being called in inductor. This is to prepare for the case where users will compile multiple models, and want to bundle it in one package. The specific use case is in torchchat, where we want to package the separately-exported encoder and decoder layers. An example of how to use this is in `test_multiple_methods`. * `load_package` will load a singular model, given the model name. * The loader doesn't support windows for now, I think I need to add some more casing to make the build commands work on windows? Differential Revision: [D62329906](https://our.internmc.facebook.com/intern/diff/D62329906) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135374 Approved by: https://github.com/desertfire, https://github.com/malfet
204 lines
6.4 KiB
Python
204 lines
6.4 KiB
Python
# Owner(s): ["module: inductor"]
|
|
import copy
|
|
import sys
|
|
import tempfile
|
|
import unittest
|
|
|
|
from parameterized import parameterized_class
|
|
|
|
import torch
|
|
from torch._inductor.package import AOTICompiledModel, load_package, package_aoti
|
|
from torch._inductor.test_case import TestCase
|
|
from torch.export import Dim
|
|
from torch.testing._internal.common_utils import IS_FBCODE
|
|
from torch.testing._internal.triton_utils import HAS_CUDA
|
|
|
|
|
|
def compile(
|
|
model,
|
|
args,
|
|
kwargs=None,
|
|
*,
|
|
dynamic_shapes=None,
|
|
package_path=None,
|
|
inductor_configs=None,
|
|
) -> AOTICompiledModel:
|
|
ep = torch.export.export(
|
|
model,
|
|
args,
|
|
kwargs,
|
|
dynamic_shapes=dynamic_shapes,
|
|
strict=False,
|
|
)
|
|
package_path = torch._inductor.aoti_compile_and_package(
|
|
ep, args, kwargs, package_path=package_path, inductor_configs=inductor_configs
|
|
) # type: ignore[arg-type]
|
|
loaded = load_package(package_path)
|
|
return loaded
|
|
|
|
|
|
@unittest.skipIf(sys.platform == "darwin", "No CUDA on MacOS")
|
|
@unittest.skipIf(IS_FBCODE, "This is for OSS only")
|
|
@parameterized_class(
|
|
[
|
|
{"device": "cpu", "package_cpp_only": False},
|
|
{"device": "cpu", "package_cpp_only": True},
|
|
]
|
|
+ (
|
|
[
|
|
{"device": "cuda", "package_cpp_only": False},
|
|
{"device": "cuda", "package_cpp_only": True},
|
|
]
|
|
if sys.platform != "darwin"
|
|
else []
|
|
),
|
|
class_name_func=lambda cls, _, params: f"{cls.__name__}{'Cpp' if params['package_cpp_only'] else ''}_{params['device']}",
|
|
)
|
|
class TestAOTInductorPackage(TestCase):
|
|
def check_model(
|
|
self: TestCase,
|
|
model,
|
|
example_inputs,
|
|
inductor_configs=None,
|
|
dynamic_shapes=None,
|
|
disable_constraint_solver=False,
|
|
atol=None,
|
|
rtol=None,
|
|
) -> AOTICompiledModel:
|
|
with torch.no_grad():
|
|
torch.manual_seed(0)
|
|
model = model.to(self.device)
|
|
ref_model = copy.deepcopy(model)
|
|
ref_inputs = copy.deepcopy(example_inputs)
|
|
expected = ref_model(*ref_inputs)
|
|
|
|
inductor_configs = inductor_configs or {}
|
|
inductor_configs["aot_inductor.package_cpp_only"] = self.package_cpp_only
|
|
|
|
torch.manual_seed(0)
|
|
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
|
|
compiled_model = compile(
|
|
model,
|
|
example_inputs,
|
|
dynamic_shapes=dynamic_shapes,
|
|
inductor_configs=inductor_configs,
|
|
package_path=f.name,
|
|
)
|
|
|
|
actual = compiled_model(*example_inputs)
|
|
|
|
self.assertEqual(actual, expected, atol=atol, rtol=rtol)
|
|
return compiled_model
|
|
|
|
def test_add(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return x + y
|
|
|
|
example_inputs = (
|
|
torch.randn(10, 10, device=self.device),
|
|
torch.randn(10, 10, device=self.device),
|
|
)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
def test_linear(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(10, 10)
|
|
|
|
def forward(self, x, y):
|
|
return x + self.linear(y)
|
|
|
|
example_inputs = (
|
|
torch.randn(10, 10, device=self.device),
|
|
torch.randn(10, 10, device=self.device),
|
|
)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
def test_metadata(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(10, 10)
|
|
|
|
def forward(self, x, y):
|
|
return x + self.linear(y)
|
|
|
|
example_inputs = (
|
|
torch.randn(10, 10, device=self.device),
|
|
torch.randn(10, 10, device=self.device),
|
|
)
|
|
metadata = {"dummy": "moo"}
|
|
compiled_model = self.check_model(
|
|
Model(),
|
|
example_inputs,
|
|
inductor_configs={"aot_inductor.metadata": metadata},
|
|
)
|
|
|
|
loaded_metadata = compiled_model.get_metadata() # type: ignore[attr-defined]
|
|
|
|
self.assertEqual(loaded_metadata.get("dummy"), "moo")
|
|
|
|
def test_multiple_methods(self):
|
|
options = {
|
|
"aot_inductor.package": True,
|
|
"aot_inductor.package_cpp_only": self.package_cpp_only,
|
|
}
|
|
|
|
class Model1(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, a, b):
|
|
return torch.cat([a, b], dim=0)
|
|
|
|
b = torch.randn(3, 4, device=self.device)
|
|
dim0_a = Dim("dim0_a", min=1, max=10)
|
|
dim0_b = Dim("dim0_b", min=1, max=20)
|
|
dynamic_shapes = {"a": {0: dim0_a}, "b": {0: dim0_b}}
|
|
example_inputs1 = (
|
|
torch.randn(2, 4, device=self.device),
|
|
torch.randn(3, 4, device=self.device),
|
|
)
|
|
ep1 = torch.export.export(
|
|
Model1(), example_inputs1, dynamic_shapes=dynamic_shapes
|
|
)
|
|
aoti_files1 = torch._inductor.aot_compile(
|
|
ep1.module(), example_inputs1, options=options
|
|
)
|
|
|
|
class Model2(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.device = device
|
|
|
|
def forward(self, x):
|
|
t = torch.tensor(x.size(-1), device=self.device, dtype=torch.float)
|
|
t = torch.sqrt(t * 3)
|
|
return x * t
|
|
|
|
example_inputs2 = (torch.randn(5, 5, device=self.device),)
|
|
ep2 = torch.export.export(Model2(self.device), example_inputs2)
|
|
aoti_files2 = torch._inductor.aot_compile(
|
|
ep2.module(), example_inputs2, options=options
|
|
)
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
|
|
package_path = package_aoti(
|
|
f.name, {"model1": aoti_files1, "model2": aoti_files2}
|
|
)
|
|
loaded1 = load_package(package_path, "model1")
|
|
loaded2 = load_package(package_path, "model2")
|
|
|
|
self.assertEqual(loaded1(*example_inputs1), ep1.module()(*example_inputs1))
|
|
self.assertEqual(loaded2(*example_inputs2), ep2.module()(*example_inputs2))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._inductor.test_case import run_tests
|
|
|
|
# cpp_extension N/A in fbcode
|
|
if HAS_CUDA or sys.platform == "darwin":
|
|
run_tests(needs="filelock")
|