From cd9ee49a6905dfa80939efb4d3a1556d92ecaa6b Mon Sep 17 00:00:00 2001 From: angelayi Date: Tue, 10 Sep 2024 16:08:07 -0700 Subject: [PATCH] [aoti] Add cpp loader (#135374) * Added a cpp loader, AOTIModelPackageLoader, which can load the .pt2, build the .so, and create a runner. The python-facing API is that users can directly call the `run` function, whereas in cpp users can directly access the `runner_` if they are more familiar with that. I couldn't figure out how to bind the `get_runner()` function to python... * Added a new config, `aot_inductor.package_cpp_only` which will **not** package the so. This means that whenever the package is loaded, we will need to build the so. This is turned off by default so that new environments do not need to rebuild their so. The `package_cpp_only` is a feature which torchchat intends to use to provide flexibility to users. * Added a new config, `aot_inductor.metadata` which stores user-provided metadata, serialized to the pt2 as a json file. It also stores the device used when exporting, "cuda" or "cpu", so that during load time, we can use that data to determine which AOTIModelContainerRunner to use. The metadata can be accessed through `loader.get_metadata()`. TODO is to move this metadata to the toplevel `package_aoti` function so that we can remove the metadata as a config. * Separated out `package_aoti` as a standalone function, instead of it automatically being called in inductor. This is to prepare for the case where users will compile multiple models, and want to bundle it in one package. The specific use case is in torchchat, where we want to package the separately-exported encoder and decoder layers. An example of how to use this is in `test_multiple_methods`. * `load_package` will load a singular model, given the model name. * The loader doesn't support windows for now, I think I need to add some more casing to make the build commands work on windows? Differential Revision: [D62329906](https://our.internmc.facebook.com/intern/diff/D62329906) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135374 Approved by: https://github.com/desertfire, https://github.com/malfet --- .ci/docker/requirements-ci.txt | 5 + .ci/pytorch/win-test.sh | 3 + .../requirements/pip-requirements-macOS.txt | 1 + build_variables.bzl | 2 + setup.py | 1 + test/inductor/test_aot_inductor_package.py | 209 ++++++--- torch/_C/_aoti.pyi | 3 + torch/_inductor/__init__.py | 42 ++ torch/_inductor/codecache.py | 166 +++---- torch/_inductor/compile_fx.py | 1 + torch/_inductor/config.py | 7 + torch/_inductor/package/__init__.py | 2 +- torch/_inductor/package/package.py | 140 +++--- torch/csrc/Module.cpp | 2 + .../aoti_package/model_package_loader.cpp | 411 ++++++++++++++++++ .../aoti_package/model_package_loader.h | 40 ++ torch/csrc/inductor/aoti_package/pybind.cpp | 24 + torch/csrc/inductor/aoti_package/pybind.h | 7 + .../aoti_runner/model_container_runner.h | 2 +- .../model_container_runner_cpu.cpp | 16 + .../model_container_runner_cuda.cpp | 13 + .../csrc/inductor/aoti_torch/shim_common.cpp | 39 +- 22 files changed, 890 insertions(+), 246 deletions(-) create mode 100644 torch/csrc/inductor/aoti_package/model_package_loader.cpp create mode 100644 torch/csrc/inductor/aoti_package/model_package_loader.h create mode 100644 torch/csrc/inductor/aoti_package/pybind.cpp create mode 100644 torch/csrc/inductor/aoti_package/pybind.h diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index 33166eb15187..88ecc0d1a7ed 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -337,3 +337,8 @@ onnxscript==0.1.0.dev20240817 #Description: Required by mypy and test_public_bindings.py when checking torch.onnx._internal #Pinned versions: #test that import: + +parameterized==0.8.1 +#Description: Parameterizes unittests, both the tests themselves and the entire testing class +#Pinned versions: +#test that import: diff --git a/.ci/pytorch/win-test.sh b/.ci/pytorch/win-test.sh index 69f0c88aa307..09b624183c7a 100755 --- a/.ci/pytorch/win-test.sh +++ b/.ci/pytorch/win-test.sh @@ -43,6 +43,9 @@ python -m pip install z3-solver==4.12.2.0 # Install tlparse for test\dynamo\test_structured_trace.py UTs. python -m pip install tlparse==0.3.25 +# Install parameterized +python -m pip install parameterized==0.8.1 + run_tests() { # Run nvidia-smi if available for path in '/c/Program Files/NVIDIA Corporation/NVSMI/nvidia-smi.exe' /c/Windows/System32/nvidia-smi.exe; do diff --git a/.github/requirements/pip-requirements-macOS.txt b/.github/requirements/pip-requirements-macOS.txt index caa831b88231..c72d1b568ca1 100644 --- a/.github/requirements/pip-requirements-macOS.txt +++ b/.github/requirements/pip-requirements-macOS.txt @@ -31,3 +31,4 @@ optree==0.12.1 # NB: test_hparams_* from test_tensorboard is failing with protobuf 5.26.0 in # which the stringify metadata is wrong when escaping double quote protobuf==3.20.2 +parameterized==0.8.1 diff --git a/build_variables.bzl b/build_variables.bzl index 262202329556..8417c1f53a72 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -466,6 +466,7 @@ lazy_tensor_core_python_sources = [ ] inductor_core_resources = [ + "torch/csrc/inductor/aoti_package/model_package_loader.cpp", "torch/csrc/inductor/aoti_runner/model_container_runner.cpp", "torch/csrc/inductor/aoti_runner/model_container_runner_cpu.cpp", "torch/csrc/inductor/aoti_torch/shim_common.cpp", @@ -841,6 +842,7 @@ libtorch_python_core_sources = [ "torch/csrc/fx/node.cpp", "torch/csrc/mps/Module.cpp", "torch/csrc/mtia/Module.cpp", + "torch/csrc/inductor/aoti_package/pybind.cpp", "torch/csrc/inductor/aoti_runner/pybind.cpp", "torch/csrc/inductor/aoti_eager/kernel_holder.cpp", "torch/csrc/inductor/aoti_eager/kernel_meta_info.cpp", diff --git a/setup.py b/setup.py index 2b0cfa99d71d..0531ae1c524c 100644 --- a/setup.py +++ b/setup.py @@ -1324,6 +1324,7 @@ def main(): "include/torch/csrc/distributed/autograd/rpc_messages/*.h", "include/torch/csrc/dynamo/*.h", "include/torch/csrc/inductor/*.h", + "include/torch/csrc/inductor/aoti_package/*.h", "include/torch/csrc/inductor/aoti_runner/*.h", "include/torch/csrc/inductor/aoti_runtime/*.h", "include/torch/csrc/inductor/aoti_torch/*.h", diff --git a/test/inductor/test_aot_inductor_package.py b/test/inductor/test_aot_inductor_package.py index 0e20045cdbfc..490b0e032473 100644 --- a/test/inductor/test_aot_inductor_package.py +++ b/test/inductor/test_aot_inductor_package.py @@ -1,78 +1,95 @@ # Owner(s): ["module: inductor"] import copy import sys +import tempfile import unittest +from parameterized import parameterized_class + import torch -from torch._inductor import config -from torch._inductor.package import load_package +from torch._inductor.package import AOTICompiledModel, load_package, package_aoti from torch._inductor.test_case import TestCase -from torch.testing._internal import common_utils +from torch.export import Dim from torch.testing._internal.common_utils import IS_FBCODE from torch.testing._internal.triton_utils import HAS_CUDA -try: - try: - from .test_torchinductor import copy_tests - except ImportError: - from test_torchinductor import copy_tests -except (unittest.SkipTest, ImportError) as e: - if __name__ == "__main__": - sys.exit(0) - raise - - -def compile(model, example_inputs, dynamic_shapes, options, device): +def compile( + model, + args, + kwargs=None, + *, + dynamic_shapes=None, + package_path=None, + inductor_configs=None, +) -> AOTICompiledModel: ep = torch.export.export( model, - example_inputs, + args, + kwargs, dynamic_shapes=dynamic_shapes, strict=False, ) - gm = ep.module() - package_path = torch._inductor.aot_compile(gm, example_inputs, options=options) # type: ignore[arg-type] - compiled_model = load_package(package_path, device) - return compiled_model + package_path = torch._inductor.aoti_compile_and_package( + ep, args, kwargs, package_path=package_path, inductor_configs=inductor_configs + ) # type: ignore[arg-type] + loaded = load_package(package_path) + return loaded -def check_model( - self: TestCase, - model, - example_inputs, - options=None, - dynamic_shapes=None, - disable_constraint_solver=False, - atol=None, - rtol=None, -): - with torch.no_grad(), config.patch( - { - "aot_inductor.package": True, - # TODO: "aot_inductor.force_mmap_weights": True, - } - ): - torch.manual_seed(0) - model = model.to(self.device) - ref_model = copy.deepcopy(model) - ref_inputs = copy.deepcopy(example_inputs) - expected = ref_model(*ref_inputs) +@unittest.skipIf(sys.platform == "darwin", "No CUDA on MacOS") +@unittest.skipIf(IS_FBCODE, "This is for OSS only") +@parameterized_class( + [ + {"device": "cpu", "package_cpp_only": False}, + {"device": "cpu", "package_cpp_only": True}, + ] + + ( + [ + {"device": "cuda", "package_cpp_only": False}, + {"device": "cuda", "package_cpp_only": True}, + ] + if sys.platform != "darwin" + else [] + ), + class_name_func=lambda cls, _, params: f"{cls.__name__}{'Cpp' if params['package_cpp_only'] else ''}_{params['device']}", +) +class TestAOTInductorPackage(TestCase): + def check_model( + self: TestCase, + model, + example_inputs, + inductor_configs=None, + dynamic_shapes=None, + disable_constraint_solver=False, + atol=None, + rtol=None, + ) -> AOTICompiledModel: + with torch.no_grad(): + torch.manual_seed(0) + model = model.to(self.device) + ref_model = copy.deepcopy(model) + ref_inputs = copy.deepcopy(example_inputs) + expected = ref_model(*ref_inputs) - torch.manual_seed(0) - compiled_model = compile( - model, - example_inputs, - dynamic_shapes, - options, - self.device, - ) + inductor_configs = inductor_configs or {} + inductor_configs["aot_inductor.package_cpp_only"] = self.package_cpp_only - actual = compiled_model(*example_inputs) + torch.manual_seed(0) + with tempfile.NamedTemporaryFile(suffix=".pt2") as f: + compiled_model = compile( + model, + example_inputs, + dynamic_shapes=dynamic_shapes, + inductor_configs=inductor_configs, + package_path=f.name, + ) - self.assertEqual(actual, expected, atol=atol, rtol=rtol) + actual = compiled_model(*example_inputs) + self.assertEqual(actual, expected, atol=atol, rtol=rtol) + return compiled_model -class AOTInductorTestsTemplate: def test_add(self): class Model(torch.nn.Module): def forward(self, x, y): @@ -99,34 +116,84 @@ class AOTInductorTestsTemplate: ) self.check_model(Model(), example_inputs) + def test_metadata(self): + class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(10, 10) -common_utils.instantiate_parametrized_tests(AOTInductorTestsTemplate) + def forward(self, x, y): + return x + self.linear(y) + example_inputs = ( + torch.randn(10, 10, device=self.device), + torch.randn(10, 10, device=self.device), + ) + metadata = {"dummy": "moo"} + compiled_model = self.check_model( + Model(), + example_inputs, + inductor_configs={"aot_inductor.metadata": metadata}, + ) -@unittest.skipIf(sys.platform == "darwin" or IS_FBCODE, "No CUDA on MacOS") -class AOTInductorTestPackagedABICompatibleCuda(TestCase): - device = "cuda" - check_model = check_model + loaded_metadata = compiled_model.get_metadata() # type: ignore[attr-defined] + self.assertEqual(loaded_metadata.get("dummy"), "moo") -copy_tests( - AOTInductorTestsTemplate, - AOTInductorTestPackagedABICompatibleCuda, - "packaged_abi_compatible_cuda", -) + def test_multiple_methods(self): + options = { + "aot_inductor.package": True, + "aot_inductor.package_cpp_only": self.package_cpp_only, + } + class Model1(torch.nn.Module): + def __init__(self) -> None: + super().__init__() -@unittest.skipIf(IS_FBCODE, "This is for OSS only") -class AOTInductorTestPackagedABICompatibleCpu(TestCase): - device = "cpu" - check_model = check_model + def forward(self, a, b): + return torch.cat([a, b], dim=0) + b = torch.randn(3, 4, device=self.device) + dim0_a = Dim("dim0_a", min=1, max=10) + dim0_b = Dim("dim0_b", min=1, max=20) + dynamic_shapes = {"a": {0: dim0_a}, "b": {0: dim0_b}} + example_inputs1 = ( + torch.randn(2, 4, device=self.device), + torch.randn(3, 4, device=self.device), + ) + ep1 = torch.export.export( + Model1(), example_inputs1, dynamic_shapes=dynamic_shapes + ) + aoti_files1 = torch._inductor.aot_compile( + ep1.module(), example_inputs1, options=options + ) + + class Model2(torch.nn.Module): + def __init__(self, device): + super().__init__() + self.device = device + + def forward(self, x): + t = torch.tensor(x.size(-1), device=self.device, dtype=torch.float) + t = torch.sqrt(t * 3) + return x * t + + example_inputs2 = (torch.randn(5, 5, device=self.device),) + ep2 = torch.export.export(Model2(self.device), example_inputs2) + aoti_files2 = torch._inductor.aot_compile( + ep2.module(), example_inputs2, options=options + ) + + with tempfile.NamedTemporaryFile(suffix=".pt2") as f: + package_path = package_aoti( + f.name, {"model1": aoti_files1, "model2": aoti_files2} + ) + loaded1 = load_package(package_path, "model1") + loaded2 = load_package(package_path, "model2") + + self.assertEqual(loaded1(*example_inputs1), ep1.module()(*example_inputs1)) + self.assertEqual(loaded2(*example_inputs2), ep2.module()(*example_inputs2)) -copy_tests( - AOTInductorTestsTemplate, - AOTInductorTestPackagedABICompatibleCpu, - "packaged_abi_compatible_cpu", -) if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/torch/_C/_aoti.pyi b/torch/_C/_aoti.pyi index a5e782fe6212..4e9f5e7c8671 100644 --- a/torch/_C/_aoti.pyi +++ b/torch/_C/_aoti.pyi @@ -18,3 +18,6 @@ def alloc_tensor_by_stealing_from_void_ptr( class AOTIModelContainerRunnerCpu: ... class AOTIModelContainerRunnerCuda: ... + +# Defined in torch/csrc/inductor/aoti_package/pybind.cpp +class AOTIModelPackageLoader: ... diff --git a/torch/_inductor/__init__.py b/torch/_inductor/__init__.py index f95e7caaf71e..404869debf17 100644 --- a/torch/_inductor/__init__.py +++ b/torch/_inductor/__init__.py @@ -30,6 +30,48 @@ def compile( return compile_fx(gm, example_inputs, config_patches=options) +def aoti_compile_and_package( + exported_program, + args: Tuple[Any], + kwargs: Optional[Dict[str, Any]] = None, + *, + package_path: Optional[str] = None, + inductor_configs: Optional[Dict[str, Any]] = None, +) -> str: + """ + Compiles the exported program with AOTInductor, and packages it into a .pt2 + file specified by the input package_path. + """ + from torch._inductor.package import package_aoti + from torch.export import ExportedProgram + + if not isinstance(exported_program, ExportedProgram): + raise ValueError("Only ExportedProgram is supported") + + assert package_path is None or package_path.endswith(".pt2") + + inductor_configs = inductor_configs or {} + + if inductor_configs.get("aot_inductor.output_path"): + raise RuntimeError( + "Please pass in a package path to aot_inductor_compile() instead " + "of setting the aot_inductor.output_path config." + ) + inductor_configs["aot_inductor.package"] = True + + m = exported_program.module() + assert isinstance(m, torch.fx.GraphModule) + + aoti_files = aot_compile(m, args, kwargs, options=inductor_configs) # type: ignore[arg-type] + + if package_path is None: + package_path = aoti_files + ".pt2" + + res = package_aoti(package_path, aoti_files) + assert res == package_path + return package_path + + def aot_compile( gm: torch.fx.GraphModule, args: Tuple[Any], diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 6d77ea959e5f..b7f3bbf1e9dc 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -1719,10 +1719,23 @@ class AotCodeCompiler: # Currently, this only support serializing extern nodes in fbcode # Eventually, we should also have a serializer for OSS. if serialized_extern_kernel_nodes: - output_json = os.path.splitext(input_path)[0] + ".json" - with open(output_json, "w") as f: + extern_kernel_nodes_json = os.path.splitext(input_path)[0] + ".json" + with open(extern_kernel_nodes_json, "w") as f: f.write(serialized_extern_kernel_nodes) + metadata = config.aot_inductor.metadata + metadata["AOTI_DEVICE_KEY"] = device_type + + # Save user provided metadata + meta_json = os.path.splitext(input_path)[0] + "_metadata.json" + for k, v in config.aot_inductor.metadata.items(): + assert isinstance(k, str) and isinstance( + v, (str) + ), "Metadata must only contain strings" + + with open(meta_json, "w") as f: + f.write(json.dumps(config.aot_inductor.metadata)) + output_so = ( config.aot_inductor.output_path if specified_so_name @@ -1755,54 +1768,29 @@ class AotCodeCompiler: if config.aot_inductor.force_mmap_weights: use_mmap_weights = True - if config.aot_inductor.package: - ( - object_output_name, - object_output_dir, - ) = get_name_and_dir_from_output_file_path(input_path) - object_build_options = CppTorchDeviceOptions( - vec_isa=picked_vec_isa, - device_type=device_type, - aot_mode=graph.aot_mode, - compile_only=True, - use_absolute_path=use_absolute_path, - use_mmap_weights=use_mmap_weights, - ) - object_builder = CppBuilder( - name=object_output_name, - sources=input_path, - output_dir=object_output_dir, - BuildOption=object_build_options, - ) - compile_cmd = object_builder.get_command_line() - output_o = object_builder.get_target_file_path() + ( + object_output_name, + object_output_dir, + ) = get_name_and_dir_from_output_file_path(input_path) + object_build_options = CppTorchDeviceOptions( + vec_isa=picked_vec_isa, + device_type=device_type, + aot_mode=graph.aot_mode, + compile_only=True, + use_absolute_path=use_absolute_path, + use_mmap_weights=use_mmap_weights, + ) + object_builder = CppBuilder( + name=object_output_name, + sources=input_path, + output_dir=object_output_dir, + BuildOption=object_build_options, + ) + compile_cmd = object_builder.get_command_line() + output_o = object_builder.get_target_file_path() - compile_flags = os.path.splitext(input_path)[0] + "_compile_flags.json" - object_build_options.save_flags_to_file(compile_flags) - - else: - ( - object_output_name, - object_output_dir, - ) = get_name_and_dir_from_output_file_path(input_path) - object_build_options = CppTorchDeviceOptions( - vec_isa=picked_vec_isa, - device_type=device_type, - aot_mode=graph.aot_mode, - compile_only=True, - use_absolute_path=use_absolute_path, - use_mmap_weights=use_mmap_weights, - ) - object_builder = CppBuilder( - name=object_output_name, - sources=input_path, - output_dir=object_output_dir, - BuildOption=object_build_options, - ) - compile_cmd = object_builder.get_command_line() - output_o = object_builder.get_target_file_path() - - log.debug("aot compilation command: %s", compile_cmd) + log.debug("aot compilation command: %s", compile_cmd) + if not config.aot_inductor.package_cpp_only: if fbcode_aot_cpu_re: output_o = os.path.splitext(input_path)[0] + ".o" compile_file(input_path, output_o, compile_cmd.split()) @@ -1810,6 +1798,10 @@ class AotCodeCompiler: else: run_command_and_check(compile_cmd) + if config.aot_inductor.package: + compile_flags = os.path.splitext(input_path)[0] + "_compile_flags.json" + object_build_options.save_flags_to_file(compile_flags) + def _to_bytes(t: torch.Tensor, all_cuda: bool) -> bytes: def _pad_to_alignment(raw_bytes: bytes) -> bytes: padded_bytes = raw_bytes.ljust( @@ -1859,29 +1851,37 @@ class AotCodeCompiler: "darwin": _compile_consts_darwin, }[sys.platform](aot_constants) - if config.aot_inductor.package: - output_name, output_dir = get_name_and_dir_from_output_file_path( - output_so - ) - so_build_options = CppTorchDeviceOptions( - vec_isa=picked_vec_isa, - device_type=device_type, - aot_mode=graph.aot_mode, - use_absolute_path=use_absolute_path, - ) - so_builder = CppBuilder( - name=output_name, - sources=[output_o, consts_o], - output_dir=output_dir, - BuildOption=so_build_options, - ) - link_cmd = so_builder.get_command_line() - output_so = so_builder.get_target_file_path() + output_name, output_dir = get_name_and_dir_from_output_file_path(output_so) + so_build_options = CppTorchDeviceOptions( + vec_isa=picked_vec_isa, + device_type=device_type, + aot_mode=graph.aot_mode, + use_absolute_path=use_absolute_path, + ) + so_builder = CppBuilder( + name=output_name, + sources=[output_o, consts_o], + output_dir=output_dir, + BuildOption=so_build_options, + ) + link_cmd = so_builder.get_command_line() + output_so = so_builder.get_target_file_path() + log.debug("aot linkage command: %s", link_cmd) + + # Append cmds to the end of codegen-ed wrapper file + with open(input_path, "a") as f: + f.write("\n") + f.write(f"// Compile cmd\n// {compile_cmd}\n") + f.write(f"// Link cmd\n// {link_cmd}\n") + + if config.aot_inductor.package: linker_flags = os.path.splitext(input_path)[0] + "_linker_flags.json" so_build_options.save_flags_to_file(linker_flags) - from torch._inductor.package import package_aoti + if config.aot_inductor.package_cpp_only: + # If we only want to package the cpp, then we need to save the + # weights separately into a bin, and we also need to prevent compiling the so if use_mmap_weights: weight_file = ( @@ -1891,28 +1891,7 @@ class AotCodeCompiler: f_weights.write(serialized_weights) f_weights.write(struct.pack("q", magic_number)) - archive_path = package_aoti(os.path.split(input_path)[0]) - return archive_path else: - output_name, output_dir = get_name_and_dir_from_output_file_path( - output_so - ) - so_build_options = CppTorchDeviceOptions( - vec_isa=picked_vec_isa, - device_type=device_type, - aot_mode=graph.aot_mode, - use_absolute_path=use_absolute_path, - ) - so_builder = CppBuilder( - name=output_name, - sources=[output_o, consts_o], - output_dir=output_dir, - BuildOption=so_build_options, - ) - link_cmd = so_builder.get_command_line() - output_so = so_builder.get_target_file_path() - - log.debug("aot linkage command: %s", link_cmd) if fbcode_aot_cpu_re: output_so = ( config.aot_inductor.output_path @@ -1937,11 +1916,10 @@ class AotCodeCompiler: f_so.write(serialized_weights) f_so.write(struct.pack("q", magic_number)) - # Append cmds to the end of codegen-ed wrapper file - with open(input_path, "a") as f: - f.write("\n") - f.write(f"// Compile cmd\n// {compile_cmd}\n") - f.write(f"// Link cmd\n// {link_cmd}\n") + if config.aot_inductor.package: + # We want to return the directory that contains all the AOTI + # generated files, not just the so + return os.path.split(output_so)[0] return output_so diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 08ff9c281fb9..4c583a4b2a90 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -1132,6 +1132,7 @@ def compile_fx_aot( if config_patches is None else {**config_patches, "cpp_wrapper": True} ) + if ( "aot_inductor.output_path" not in config_patches and not config.aot_inductor.output_path diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 0ced92c3271b..32144d4ef975 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1006,9 +1006,11 @@ class aot_inductor: ) # Serialized tree spec for flattening inputs + # TODO: Move this into metadata serialized_in_spec = "" # Serialized tree spec for flattening outputs + # TODO: Move this into metadata serialized_out_spec = "" # flag to decide whether to create a submodule for constant graph. @@ -1019,6 +1021,11 @@ class aot_inductor: force_mmap_weights: bool = False package: bool = False + package_cpp_only: bool = False + + # Dictionary of metadata users might want to save to pass to the runtime. + # TODO: Move this somewhere else, since it's no longer really a config + metadata: Dict[str, str] = {} class cuda: diff --git a/torch/_inductor/package/__init__.py b/torch/_inductor/package/__init__.py index c088562100cd..15587401b723 100644 --- a/torch/_inductor/package/__init__.py +++ b/torch/_inductor/package/__init__.py @@ -1 +1 @@ -from .package import load_package, package_aoti +from .package import AOTICompiledModel, load_package, package_aoti diff --git a/torch/_inductor/package/package.py b/torch/_inductor/package/package.py index d1304293e3cd..ca62b7172e66 100644 --- a/torch/_inductor/package/package.py +++ b/torch/_inductor/package/package.py @@ -1,24 +1,25 @@ -import glob import json +import logging import os import shlex import subprocess -import tempfile import zipfile from pathlib import Path -from typing import Callable, List, Optional, Union +from typing import Dict, List, Optional, Union import torch import torch._inductor import torch.utils._pytree as pytree -from torch._inductor import config, exc +from torch._inductor import exc from torch._inductor.cpp_builder import BuildOptionsBase, CppBuilder from torch.export._tree_utils import reorder_kwargs -from .build_package import build_package_contents from .pt2_archive_constants import AOTINDUCTOR_DIR, ARCHIVE_VERSION +log = logging.getLogger(__name__) + + class PT2ArchiveWriter: def __init__(self, archive_path: str) -> None: self.archive_path: str = archive_path @@ -154,84 +155,71 @@ def compile_so(aoti_dir: str, aoti_files: List[str], so_path: str) -> str: return output_so -def package_aoti(aoti_output_dir: str) -> str: +def package_aoti(archive_file: str, aoti_files: Union[str, Dict[str, str]]) -> str: """ Saves the AOTInductor generated files to the PT2Archive format. + + Args: + archive_file: The file name to save the package to. + aoti_files: This can either be a singular path to a directory containing + the AOTInductor files, or a dictionary mapping the model name to the + path to its AOTInductor generated files. + """ + if isinstance(aoti_files, str): + aoti_files = {"model": aoti_files} + + assert isinstance(aoti_files, dict) + assert archive_file.endswith(".pt2") + + # Save using the PT2 packaging format + # (https://docs.google.com/document/d/1jLPp8MN8Whs0-VW9PmJ93Yg02W85tpujvHrTa1pc5x8/edit#heading=h.v2y2jgnwc56a) + + with PT2ArchiveWriter(archive_file) as archive_writer: + for model_name, aoti_output_dir in aoti_files.items(): + log.debug( + "Packaging AOTInductor files from %s with model name, %s", + aoti_output_dir, + model_name, + ) + for root, dirs, files in os.walk(aoti_output_dir): + for file in files: + log.debug( + "Saving AOTI generated file %s to archive in %s%s/%s", + os.path.join(root, file), + AOTINDUCTOR_DIR, + model_name, + file, + ) + archive_writer.write_file( + f"{AOTINDUCTOR_DIR}{model_name}/{file}", + os.path.join(root, file), + ) + return archive_file + + +class AOTICompiledModel: + """ + Callable AOT Inductor loaded model from a .pt2 """ - # Add a makefile and python script - build_package_filename = "build_package.py" - with open(os.path.join(aoti_output_dir, build_package_filename), "w") as f: - f.write(build_package_contents) + def __init__(self, loader: torch._C._aoti.AOTIModelPackageLoader) -> None: + self.loader = loader - with open(os.path.join(aoti_output_dir, "Makefile"), "w") as f: - f.write(f"all:\n\tpython3 {build_package_filename}\n") - - if config.aot_inductor.output_path.endswith(".so"): - raise RuntimeError( - "Unable to save package as a .so. It should be a .pt2 format or a directory." - ) - elif config.aot_inductor.output_path.endswith(".pt2"): - # Save using the PT2 packaging format - # (https://docs.google.com/document/d/1jLPp8MN8Whs0-VW9PmJ93Yg02W85tpujvHrTa1pc5x8/edit#heading=h.v2y2jgnwc56a) - archive_path = config.aot_inductor.output_path - - with PT2ArchiveWriter(archive_path) as archive_writer: - package_files = glob.glob(f"{aoti_output_dir}/*") - - for path in package_files: - filename = os.path.basename(path) - archive_writer.write_file(f"{AOTINDUCTOR_DIR}{filename}", path) - - return archive_path - - else: - # Directly put the files into the directory, without any archiving - return aoti_output_dir - - -def load_package(path: str, device: str) -> Callable: # type: ignore[type-arg] - if path.endswith(".so"): - raise RuntimeError( - "Unable to load .so. It should be a .pt2 format or a directory." - ) - - elif path.endswith(".pt2"): - so_path = os.path.splitext(path)[0] - with PT2ArchiveReader(path) as archive_reader: - file_names = archive_reader.get_file_names() - - with tempfile.TemporaryDirectory() as tmp_dir: - archive_reader.extractall(tmp_dir) - file_names = archive_reader.get_file_names() - aoti_files = [ - file for file in file_names if file.startswith(AOTINDUCTOR_DIR) - ] - - so_path = compile_so(tmp_dir, aoti_files, so_path) - - else: - assert os.path.isdir(path), "Must specify a directory or a .pt2 file" - aoti_files = [ - os.path.join(root, file) - for root, dirs, files in os.walk(path) - for file in files - ] - so_path = compile_so(path, aoti_files, path) - - if device == "cpu": - runner = torch._C._aoti.AOTIModelContainerRunnerCpu(so_path, 1) # type: ignore[call-arg] - elif device == "cuda" or device.startswith("cuda:"): - runner = torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1, device) # type: ignore[assignment, call-arg] - else: - raise RuntimeError("Unsupported device " + device) - - def optimized(*args, **kwargs): # type: ignore[no-untyped-def] - call_spec = runner.get_call_spec() # type: ignore[attr-defined] + def __call__(self, *args, **kwargs): # type: ignore[no-untyped-def] + call_spec = self.loader.get_call_spec() # type: ignore[attr-defined] in_spec = pytree.treespec_loads(call_spec[0]) out_spec = pytree.treespec_loads(call_spec[1]) flat_inputs = pytree.tree_flatten((args, reorder_kwargs(kwargs, in_spec)))[0] - flat_outputs = runner.run(flat_inputs) # type: ignore[attr-defined] + flat_outputs = self.loader.run(flat_inputs) # type: ignore[attr-defined] return pytree.tree_unflatten(flat_outputs, out_spec) - return optimized + def get_metadata(self) -> Dict[str, str]: + return self.loader.get_metadata() # type: ignore[attr-defined] + + +def load_package(path: str, model_name: str = "model") -> AOTICompiledModel: # type: ignore[type-arg] + if not path.endswith(".pt2"): + raise RuntimeError("Unable to load package. Path must be a .pt2 file.") + + loader = torch._C._aoti.AOTIModelPackageLoader(path, model_name) # type: ignore[call-arg] + return AOTICompiledModel(loader) diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 3dff5fb65dde..19433d62985f 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -69,6 +69,7 @@ #include #include #include +#include #include #include #include @@ -1687,6 +1688,7 @@ PyObject* initModule() { torch::python::init_bindings(module); torch::lazy::initLazyBindings(module); torch::inductor::initAOTIRunnerBindings(module); + torch::inductor::initAOTIPackageBindings(module); #ifdef USE_ITT torch::profiler::initIttBindings(module); #endif diff --git a/torch/csrc/inductor/aoti_package/model_package_loader.cpp b/torch/csrc/inductor/aoti_package/model_package_loader.cpp new file mode 100644 index 000000000000..1c09d9186d0a --- /dev/null +++ b/torch/csrc/inductor/aoti_package/model_package_loader.cpp @@ -0,0 +1,411 @@ +#if !defined(C10_MOBILE) && !defined(ANDROID) + +#include +#include +#include + +#include +#include +#include +#include +#include + +// TODO: Investigate why this is necessary, but fixes build problems in FRL +#if __has_include("filesystem") +#include +namespace fs = std::filesystem; +#else +#include +namespace fs = std::experimental::filesystem; +#endif + +#ifndef _WIN32 +#include +#endif + +// TODO: C++17 has the filesystem header, which may replace these +#ifdef _WIN32 +// On Windows, the POSIX implementations are considered deprecated. We simply +// map to the newer variant. +#include +#include +#include +#define access _access +#define F_OK 0 +#else +#include +#include +#endif + +namespace { +bool file_exists(std::string& path) { +#ifdef _WIN32 + return fs::exists(path); +#else + struct stat rc; + return lstat(path.c_str(), &rc) == 0; +#endif +} + +std::string create_temp_dir() { +#ifdef _WIN32 + throw std::runtime_error("Not implemented"); +#else + std::string temp_dir = "/tmp/XXXXXX"; + if (mkdtemp(temp_dir.data()) == nullptr) { + throw std::runtime_error( + std::string("Failed to create temporary directory: ") + + strerror(errno)); + } + return temp_dir; +#endif +} +} // namespace + +namespace torch::inductor { + +const nlohmann::json& AOTIModelPackageLoader::load_json_file( + std::string json_path) { + if (!file_exists(json_path)) { + throw std::runtime_error("File found: " + json_path); + } + + std::ifstream json_file(json_path); + TORCH_CHECK(json_file.is_open()); + static nlohmann::json json_obj; + json_file >> json_obj; + + return json_obj; +} + +void AOTIModelPackageLoader::load_metadata(const std::string& cpp_filename) { + // Parse metadata json file (if it exists) into the metadata_ map + size_t lastindex = cpp_filename.find_last_of('.'); + std::string metadata_json_path = + cpp_filename.substr(0, lastindex) + "_metadata.json"; + + const nlohmann::json metadata_json_obj = load_json_file(metadata_json_path); + + for (auto& item : metadata_json_obj.items()) { + metadata_[item.key()] = item.value().get(); + } +} + +std::tuple AOTIModelPackageLoader:: + get_cpp_compile_command( + const std::string& filename, + const std::vector& sources, + const nlohmann::json& compile_options, + const std::string& output_dir = "") { + // Construct the cpp command + + std::string compiler = compile_options["compiler"].get(); + bool compile_only = compile_options["compile_only"].get(); + + std::string source_args = ""; + for (const std::string& source : sources) { + source_args += source + " "; + } + + std::string file_ext = compile_only ? ".o" : ".so"; + std::string target_file = output_dir + filename + file_ext; + + std::string cflags_args = ""; + for (auto& arg : compile_options["cflags"]) { + cflags_args += "-" + arg.get() + " "; + } + + std::string definitions_args = ""; + for (auto& arg : compile_options["definitions"]) { + definitions_args += "-D " + arg.get() + " "; + } + + std::string include_dirs_args = ""; + for (auto& arg : compile_options["include_dirs"]) { + include_dirs_args += "-I" + arg.get() + " "; + } + + std::string ldflags_args = ""; + for (auto& arg : compile_options["ldflags"]) { + ldflags_args += "-" + arg.get() + " "; + } + + std::string libraries_dirs_args = ""; + for (auto& arg : compile_options["libraries_dirs"]) { + libraries_dirs_args += "-L" + arg.get() + " "; + } + + std::string libraries_args = ""; + for (auto& arg : compile_options["libraries"]) { + libraries_args += "-l" + arg.get() + " "; + } + + std::string passthrough_parameters_args = ""; + for (auto& arg : compile_options["passthrough_args"]) { + passthrough_parameters_args += arg.get() + " "; + } + + std::string compile_only_arg = compile_only ? "-c" : ""; + + std::string cmd = fmt::format( + "{} {} {} {} {} {} {} {} {} {} -o {}", + compiler, + source_args, + definitions_args, + cflags_args, + include_dirs_args, + passthrough_parameters_args, + ldflags_args, + libraries_args, + libraries_dirs_args, + compile_only_arg, + target_file); + + return std::make_tuple(cmd, target_file); +} + +bool AOTIModelPackageLoader::recursive_mkdir(const std::string& dir) { + // Creates directories recursively, copied from jit_utils.cpp + // Check if current dir exists + const char* p_dir = dir.c_str(); + const bool dir_exists = (access(p_dir, F_OK) == 0); + if (dir_exists) { + return true; + } + + // Try to create current directory +#ifdef _WIN32 + int ret = _mkdir(dir.c_str()); +#else + int ret = mkdir(dir.c_str(), S_IRWXU | S_IRWXG | S_IRWXO); +#endif + // Success + if (ret == 0) { + return true; + } + + // Find folder separator and check if we are at the top + auto pos = dir.find_last_of("/\\"); + if (pos == std::string::npos) { + return false; + } + + // Try to create parent directory + if (!(recursive_mkdir(dir.substr(0, pos)))) { + return false; + } + + // Try to create complete path again +#ifdef _WIN32 + ret = _mkdir(dir.c_str()); +#else + ret = mkdir(dir.c_str(), S_IRWXU | S_IRWXG | S_IRWXO); +#endif + return ret == 0; +} + +std::string AOTIModelPackageLoader::compile_so( + const std::string& cpp_filename, + const std::string& consts_filename) { + // Compile the cpp file into a .so + + size_t lastindex = cpp_filename.find_last_of('.'); + std::string filename = cpp_filename.substr(0, lastindex); + + std::string compile_flags_path = filename + "_compile_flags.json"; + const nlohmann::json compile_flags = load_json_file(compile_flags_path); + + auto compile_result = + get_cpp_compile_command(filename, {cpp_filename}, compile_flags); + std::string compile_cmd = std::get<0>(compile_result); + std::string output_o = std::get<1>(compile_result); + + std::string linker_flags_path = + cpp_filename.substr(0, lastindex) + "_linker_flags.json"; + const nlohmann::json linker_flags = load_json_file(linker_flags_path); + + auto link_result = get_cpp_compile_command( + filename, {output_o, consts_filename}, linker_flags); + std::string link_cmd = std::get<0>(link_result); + std::string output_so = std::get<1>(link_result); + + // Run the commands to generate a .so file + int status = system(compile_cmd.c_str()); + if (status != 0) { + throw std::runtime_error("Failed to compile cpp file."); + } + status = system(link_cmd.c_str()); + if (status != 0) { + throw std::runtime_error("Failed to link files."); + } + + // Move the mmapped weights onto the .so + std::string serialized_weights_path = filename + "_serialized_weights.bin"; + if (file_exists(serialized_weights_path)) { + std::ifstream serialized_weights_file( + serialized_weights_path, std::ios::binary); + if (!serialized_weights_file.is_open()) { + throw std::runtime_error("Failed to open serialized weights file"); + } + std::vector serialized_weights( + (std::istreambuf_iterator(serialized_weights_file)), + std::istreambuf_iterator()); + serialized_weights_file.close(); + + std::ofstream output_so_file(output_so, std::ios::binary | std::ios::app); + if (!output_so_file.is_open()) { + throw std::runtime_error("Failed to open output .so file"); + } + // Page align the weights + std::streampos so_size = output_so_file.tellp(); + std::vector padding(16384 - so_size % 16384, ' '); + output_so_file.write( + padding.data(), static_cast(padding.size())); + output_so_file.write( + serialized_weights.data(), + static_cast(serialized_weights.size())); + output_so_file.close(); + } + + return output_so; +} + +AOTIModelPackageLoader::AOTIModelPackageLoader( + const std::string& model_package_path) + : AOTIModelPackageLoader(model_package_path, "model") {} + +AOTIModelPackageLoader::AOTIModelPackageLoader( + const std::string& model_package_path, + const std::string& model_name = "model") { + // Extract all files within the zipfile to a temporary directory + mz_zip_archive zip_archive; + memset(&zip_archive, 0, sizeof(zip_archive)); + + if (!mz_zip_reader_init_file(&zip_archive, model_package_path.c_str(), 0)) { + throw std::runtime_error( + std::string("Failed to initialize zip archive: ") + + mz_zip_get_error_string(mz_zip_get_last_error(&zip_archive))); + } + + std::string temp_dir = create_temp_dir(); + std::string so_filename = ""; + std::string cpp_filename = ""; + std::string consts_filename = ""; + std::string found_filenames = ""; // Saving for bookkeeping + + for (uint32_t i = 0; i < zip_archive.m_total_files; i++) { + uint32_t filename_len = + mz_zip_reader_get_filename(&zip_archive, i, nullptr, 0); + if (filename_len == 0) { + throw std::runtime_error("Failed to read filename"); + } + char* filename = new char[filename_len + 1]; + if (!mz_zip_reader_get_filename(&zip_archive, i, filename, filename_len)) { + throw std::runtime_error("Failed to read filename"); + } + + std::string filename_str(filename); + found_filenames += filename_str; + found_filenames += " "; + + // Only compile files in the specified model directory + std::string model_directory = "data/aotinductor/" + model_name; + if (filename_str.length() >= model_directory.length() && + filename_str.substr(0, model_directory.length()) == model_directory) { + std::string output_path_str = temp_dir; + output_path_str += "/"; + output_path_str += filename_str; + + // Create the parent directory if it doesn't exist + size_t parent_path_idx = output_path_str.find_last_of("/\\"); + if (parent_path_idx == std::string::npos) { + throw std::runtime_error( + "Failed to find parent path in " + output_path_str); + } + std::string parent_path = output_path_str.substr(0, parent_path_idx); + if (!recursive_mkdir(parent_path.c_str())) { + throw std::runtime_error(fmt::format( + "Failed to create directory {}: {}", parent_path, strerror(errno))); + } + + // Extracts file to the temp directory + mz_zip_reader_extract_file_to_file( + &zip_archive, filename, output_path_str.c_str(), 0); + + // Save the file for bookkeeping + size_t extension_idx = output_path_str.find_last_of('.'); + if (extension_idx != std::string::npos) { + std::string filename_extension = output_path_str.substr(extension_idx); + if (filename_extension == ".cpp") { + cpp_filename = output_path_str; + } + if (filename_extension == ".o") { + consts_filename = output_path_str; + } + if (filename_extension == ".so") { + so_filename = output_path_str; + } + } + } + } + + // Close the zip archive as we have extracted all files to the temp directory + if (!mz_zip_reader_end(&zip_archive)) { + throw std::runtime_error( + std::string("Failed to close zip archive: {}") + + mz_zip_get_error_string(mz_zip_get_last_error(&zip_archive))); + } + + if (cpp_filename.empty() && so_filename.empty()) { + throw std::runtime_error( + "No AOTInductor generate cpp file or so file found in zip archive. Loaded the following:\n" + + found_filenames); + } + + // Compile the .so + std::string so_path = !so_filename.empty() + ? so_filename + : compile_so(cpp_filename, consts_filename); + + // Load metadata which can be queried by user + load_metadata(cpp_filename); + + // Construct the runner depending on the device information + std::string device = metadata_["AOTI_DEVICE_KEY"]; + + if (device.empty()) { + throw std::runtime_error("No device information found."); + } + + std::unordered_map + registered_aoti_runner = getAOTIModelRunnerRegistry(); + + if (registered_aoti_runner.find(device) == registered_aoti_runner.end()) { + throw std::runtime_error("Unsupported device found: " + device); + } + + runner_ = registered_aoti_runner[device](so_path, 1, device, ""); + + std::remove(temp_dir.c_str()); +} + +AOTIModelContainerRunner* AOTIModelPackageLoader::get_runner() { + return runner_.get(); +} + +std::vector AOTIModelPackageLoader::run( + std::vector& inputs) { + return runner_->run(inputs); +} + +std::unordered_map AOTIModelPackageLoader:: + get_metadata() { + return metadata_; +} + +std::vector AOTIModelPackageLoader::get_call_spec() { + return runner_->get_call_spec(); +} + +} // namespace torch::inductor +#endif diff --git a/torch/csrc/inductor/aoti_package/model_package_loader.h b/torch/csrc/inductor/aoti_package/model_package_loader.h new file mode 100644 index 000000000000..03dc1c64018d --- /dev/null +++ b/torch/csrc/inductor/aoti_package/model_package_loader.h @@ -0,0 +1,40 @@ +#if !defined(C10_MOBILE) && !defined(ANDROID) +#pragma once + +#include +#include + +#include + +namespace torch::inductor { +class TORCH_API AOTIModelPackageLoader { + public: + AOTIModelPackageLoader(const std::string& model_package_path); + AOTIModelPackageLoader( + const std::string& model_package_path, + const std::string& model_name); + + AOTIModelContainerRunner* get_runner(); + std::unordered_map get_metadata(); + std::vector run(std::vector& inputs); + std::vector get_call_spec(); + + private: + std::unique_ptr runner_; + std::unordered_map metadata_; + + void load_metadata(const std::string& cpp_filename); + std::string compile_so( + const std::string& cpp_filename, + const std::string& consts_filename); + const nlohmann::json& load_json_file(std::string json_path); + std::tuple get_cpp_compile_command( + const std::string& filename, + const std::vector& sources, + const nlohmann::json& compile_options, + const std::string& output_dir); + bool recursive_mkdir(const std::string& dir); +}; + +} // namespace torch::inductor +#endif diff --git a/torch/csrc/inductor/aoti_package/pybind.cpp b/torch/csrc/inductor/aoti_package/pybind.cpp new file mode 100644 index 000000000000..3d2154a75493 --- /dev/null +++ b/torch/csrc/inductor/aoti_package/pybind.cpp @@ -0,0 +1,24 @@ +#include +#include +#include +#ifdef USE_CUDA +#include +#endif + +#include +#include + +namespace torch::inductor { + +void initAOTIPackageBindings(PyObject* module) { + auto rootModule = py::handle(module).cast(); + auto m = rootModule.def_submodule("_aoti"); + + py::class_(m, "AOTIModelPackageLoader") + .def(py::init()) + .def(py::init()) + .def("get_metadata", &AOTIModelPackageLoader::get_metadata) + .def("run", &AOTIModelPackageLoader::run) + .def("get_call_spec", &AOTIModelPackageLoader::get_call_spec); +} +} // namespace torch::inductor diff --git a/torch/csrc/inductor/aoti_package/pybind.h b/torch/csrc/inductor/aoti_package/pybind.h new file mode 100644 index 000000000000..1eb7818c00e9 --- /dev/null +++ b/torch/csrc/inductor/aoti_package/pybind.h @@ -0,0 +1,7 @@ +#include + +namespace torch::inductor { + +void initAOTIPackageBindings(PyObject* module); + +} // namespace torch::inductor diff --git a/torch/csrc/inductor/aoti_runner/model_container_runner.h b/torch/csrc/inductor/aoti_runner/model_container_runner.h index 99669aec14ac..6e6339d3dd27 100644 --- a/torch/csrc/inductor/aoti_runner/model_container_runner.h +++ b/torch/csrc/inductor/aoti_runner/model_container_runner.h @@ -82,7 +82,7 @@ class TORCH_API AOTIModelContainerRunner { std::unique_ptr proxy_executor_; }; -using CreateAOTIModelRunnerFunc = std::shared_ptr (*)( +using CreateAOTIModelRunnerFunc = std::unique_ptr (*)( const std::string& model_so_path, size_t num_models, const std::string& device_str, diff --git a/torch/csrc/inductor/aoti_runner/model_container_runner_cpu.cpp b/torch/csrc/inductor/aoti_runner/model_container_runner_cpu.cpp index 25decff00c45..f40545d04c49 100644 --- a/torch/csrc/inductor/aoti_runner/model_container_runner_cpu.cpp +++ b/torch/csrc/inductor/aoti_runner/model_container_runner_cpu.cpp @@ -17,5 +17,21 @@ std::vector AOTIModelContainerRunnerCpu::run( return AOTIModelContainerRunner::run(inputs); } +namespace { +std::unique_ptr create_aoti_runner_cpu( + const std::string& model_so_path, + size_t num_models, + const std::string& device_str, + const std::string& cubin_dir) { + if (device_str != "cpu") { + throw std::runtime_error("Incorrect device passed to aoti_runner_cpu"); + } + return std::make_unique( + model_so_path, num_models); +} +} // namespace + +RegisterAOTIModelRunner register_cpu_runner("cpu", &create_aoti_runner_cpu); + } // namespace torch::inductor #endif diff --git a/torch/csrc/inductor/aoti_runner/model_container_runner_cuda.cpp b/torch/csrc/inductor/aoti_runner/model_container_runner_cuda.cpp index 596b436bb703..3ddad0885aa5 100644 --- a/torch/csrc/inductor/aoti_runner/model_container_runner_cuda.cpp +++ b/torch/csrc/inductor/aoti_runner/model_container_runner_cuda.cpp @@ -30,5 +30,18 @@ std::vector AOTIModelContainerRunnerCuda::run_with_cuda_stream( inputs, reinterpret_cast(cuda_stream.stream())); } +namespace { +std::unique_ptr create_aoti_runner_cuda( + const std::string& model_so_path, + size_t num_models, + const std::string& device_str, + const std::string& cubin_dir) { + return std::make_unique( + model_so_path, num_models, device_str, cubin_dir); +} +} // namespace + +RegisterAOTIModelRunner register_cuda_runner("cuda", &create_aoti_runner_cuda); + } // namespace torch::inductor #endif diff --git a/torch/csrc/inductor/aoti_torch/shim_common.cpp b/torch/csrc/inductor/aoti_torch/shim_common.cpp index 964a976775e6..f49bf23b9ce4 100644 --- a/torch/csrc/inductor/aoti_torch/shim_common.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_common.cpp @@ -55,6 +55,39 @@ namespace fs = std::filesystem; namespace fs = std::experimental::filesystem; #endif +#ifndef _WIN32 +#include +#endif + +// HACK for failed builds in ARVR, where it cannot find these symbols within +// std::experimental::filesystem +namespace { +fs::path get_current_path() { +#if __has_include("filesystem") + return fs::current_path(); +#else + throw std::runtime_error("Not implemented"); +#endif +} + +bool file_exists(std::string& path) { +#ifdef _WIN32 + return fs::exists(path); +#else + struct stat rc; + return lstat(path.c_str(), &rc) == 0; +#endif +} + +bool create_directories(const std::string& path) { +#if __has_include("filesystem") + return fs::create_directories(path); +#else + throw std::runtime_error("Not implemented"); +#endif +} +} // namespace + using namespace torch::aot_inductor; namespace { @@ -1004,14 +1037,14 @@ AOTI_TORCH_EXPORT void aoti_torch_save_tensor_handle( at::Tensor* t = tensor_handle_to_tensor_pointer(self); #ifndef C10_MOBILE // Save tensor to tmp .pt file for tensors and can be torch.load'ed later - std::string cwd = fs::current_path().string(); + std::string cwd = get_current_path().string(); std::string tmp_folder = cwd + "/tmp/aoti_torch/"; - if (!fs::exists(tmp_folder)) { + if (!file_exists(tmp_folder)) { std::cout << "aoti_torch_save_tensor_handle: Path does not exist, creating it..." << tmp_folder << std::endl; - if (!fs::create_directories(tmp_folder)) { + if (!create_directories(tmp_folder)) { std::cout << "aoti_torch_save_tensor_handle: Error creating directory: " << tmp_folder << std::endl; return;