diff --git a/test/package/test_dependency_api.py b/test/package/test_dependency_api.py index b8350ddf8824..eb1c48c427ba 100644 --- a/test/package/test_dependency_api.py +++ b/test/package/test_dependency_api.py @@ -247,6 +247,8 @@ class TestDependencyAPI(PackageTestCase): * Module did not match against any action pattern. Extern, mock, or intern it. package_a package_a.subpackage + + Set debug=True when invoking PackageExporter for a visualization of where broken modules are coming from! """ ), ) @@ -294,6 +296,8 @@ class TestDependencyAPI(PackageTestCase): * Module is a C extension module. torch.package supports Python modules only. foo bar + + Set debug=True when invoking PackageExporter for a visualization of where broken modules are coming from! """ ), ) @@ -313,6 +317,8 @@ class TestDependencyAPI(PackageTestCase): * Dependency resolution failed. foo Context: attempted relative import beyond top-level package + + Set debug=True when invoking PackageExporter for a visualization of where broken modules are coming from! """ ), ) diff --git a/torch/package/package_exporter.py b/torch/package/package_exporter.py index 347641e46431..f83a79efced6 100644 --- a/torch/package/package_exporter.py +++ b/torch/package/package_exporter.py @@ -131,7 +131,7 @@ class PackagingError(Exception): them to you at once. """ - def __init__(self, dependency_graph: DiGraph): + def __init__(self, dependency_graph: DiGraph, debug=False): # Group errors by reason. broken: Dict[PackagingErrorReason, List[str]] = defaultdict(list) for module_name, attrs in dependency_graph.nodes.items(): @@ -154,7 +154,30 @@ class PackagingError(Exception): error_context = dependency_graph.nodes[module_name].get("error_context") if error_context is not None: message.write(f" Context: {error_context}\n") - + if module_name in _DISALLOWED_MODULES: + message.write( + ( + " Note: While we usually use modules in the python standard library " + f"from the local environment, `{module_name}` has a lot of system " + "level access and therefore can pose a security risk. We heavily " + f"recommend removing `{module_name}` from your packaged code. However, if that " + "is not possible, add it to the extern list by calling " + f'PackageExporter.extern("`{module_name}`")\n' + ) + ) + if debug: + module_path = dependency_graph.first_path(module_name) + message.write( + f" A path to {module_name}: {' -> '.join(module_path)}" + ) + if not debug: + message.write("\n") + message.write( + ( + "Set debug=True when invoking PackageExporter for a visualization of where " + "broken modules are coming from!\n" + ) + ) # Save the dependency graph so that tooling can get at it. self.dependency_graph = dependency_graph super().__init__(message.getvalue()) @@ -195,6 +218,7 @@ class PackageExporter: self, f: Union[str, Path, BinaryIO], importer: Union[Importer, Sequence[Importer]] = sys_importer, + debug: bool = False, ): """ Create an exporter. @@ -204,9 +228,10 @@ class PackageExporter: or a binary I/O object. importer: If a single Importer is passed, use that to search for modules. If a sequence of importers are passed, an ``OrderedImporter`` will be constructed out of them. + debug: If set to True, add path of broken modules to PackagingErrors. """ torch._C._log_api_usage_once("torch.package.PackageExporter") - + self.debug = debug if isinstance(f, (Path, str)): f = str(f) self.buffer: Optional[BinaryIO] = None @@ -979,7 +1004,7 @@ class PackageExporter: # 1. Check the graph for any errors inserted during dependency analysis. for module_name, attrs in self.dependency_graph.nodes.items(): if "error" in attrs: - raise PackagingError(self.dependency_graph) + raise PackagingError(self.dependency_graph, debug=self.debug) # 2. Check that all patterns for which allow_empty=False have been matched at least once. for pattern, pattern_info in self.patterns.items():