Fix fx graph triton import bug (#122041)

Summary: Unless we register triton to be a special import, FX graph import mechanism imports it as `from fx-generated._0 import triton as triton` which is obviously broken.

Test Plan:
I could not figure out how to write a test for this but
```
buck2 run 'fbcode//mode/dev-nosan' fbcode//tgif/lib/tests/gpu_tests:lowering_pass_test -- -r test_default_ait_lowering_multi_hardwares
```
now passes

Differential Revision: D54990782

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122041
Approved by: https://github.com/aakhundov
This commit is contained in:
Oguz Ulgen
2024-03-17 22:48:51 +00:00
committed by PyTorch MergeBot
parent 5030913d6a
commit e39aedfcc5
4 changed files with 17 additions and 1 deletions

View File

@ -111,6 +111,10 @@ def _is_from_torch(obj: Any) -> bool:
return False
def _is_from_triton(name) -> bool:
return name == "triton"
class _Namespace:
"""A context for associating names uniquely with objects.

View File

@ -16,7 +16,7 @@ from torch.nn.modules.module import _addindent
from torch.package import Importer, PackageExporter, PackageImporter, sys_importer
from ._compatibility import compatibility
from .graph import _custom_builtins, _is_from_torch, _PyTreeCodeGen, Graph, PythonCode
from .graph import _custom_builtins, _is_from_torch, _is_from_triton, _PyTreeCodeGen, Graph, PythonCode
__all__ = [
"reduce_graph_module",
@ -107,6 +107,8 @@ def _format_import_statement(name: str, obj: Any, importer: Importer) -> str:
return _custom_builtins[name].import_str
if _is_from_torch(name):
return "import torch"
if _is_from_triton(name):
return "import triton"
module_name, attr_name = importer.get_name(obj)
return f"from {module_name} import {attr_name} as {name}"

View File

@ -480,6 +480,15 @@ class PackageExporter:
)
return
# Exporting triton is not always possible, work around it
if module_name == "triton":
self.dependency_graph.add_node(
module_name,
action=_ModuleProviderAction.SKIP,
provided=True,
)
return
if module_name == "_mock":
self.dependency_graph.add_node(
module_name,

View File

@ -41,6 +41,7 @@ IMPLICIT_IMPORT_ALLOWLIST: Iterable[str] = [
# FX GraphModule might depend on builtins module and users usually
# don't extern builtins. Here we import it here by default.
"builtins",
"triton",
]