mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix fx graph triton import bug (#122041)
Summary: Unless we register triton to be a special import, FX graph import mechanism imports it as `from fx-generated._0 import triton as triton` which is obviously broken. Test Plan: I could not figure out how to write a test for this but ``` buck2 run 'fbcode//mode/dev-nosan' fbcode//tgif/lib/tests/gpu_tests:lowering_pass_test -- -r test_default_ait_lowering_multi_hardwares ``` now passes Differential Revision: D54990782 Pull Request resolved: https://github.com/pytorch/pytorch/pull/122041 Approved by: https://github.com/aakhundov
This commit is contained in:
committed by
PyTorch MergeBot
parent
5030913d6a
commit
e39aedfcc5
@ -111,6 +111,10 @@ def _is_from_torch(obj: Any) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _is_from_triton(name) -> bool:
|
||||
return name == "triton"
|
||||
|
||||
|
||||
class _Namespace:
|
||||
"""A context for associating names uniquely with objects.
|
||||
|
||||
|
@ -16,7 +16,7 @@ from torch.nn.modules.module import _addindent
|
||||
from torch.package import Importer, PackageExporter, PackageImporter, sys_importer
|
||||
|
||||
from ._compatibility import compatibility
|
||||
from .graph import _custom_builtins, _is_from_torch, _PyTreeCodeGen, Graph, PythonCode
|
||||
from .graph import _custom_builtins, _is_from_torch, _is_from_triton, _PyTreeCodeGen, Graph, PythonCode
|
||||
|
||||
__all__ = [
|
||||
"reduce_graph_module",
|
||||
@ -107,6 +107,8 @@ def _format_import_statement(name: str, obj: Any, importer: Importer) -> str:
|
||||
return _custom_builtins[name].import_str
|
||||
if _is_from_torch(name):
|
||||
return "import torch"
|
||||
if _is_from_triton(name):
|
||||
return "import triton"
|
||||
module_name, attr_name = importer.get_name(obj)
|
||||
return f"from {module_name} import {attr_name} as {name}"
|
||||
|
||||
|
@ -480,6 +480,15 @@ class PackageExporter:
|
||||
)
|
||||
return
|
||||
|
||||
# Exporting triton is not always possible, work around it
|
||||
if module_name == "triton":
|
||||
self.dependency_graph.add_node(
|
||||
module_name,
|
||||
action=_ModuleProviderAction.SKIP,
|
||||
provided=True,
|
||||
)
|
||||
return
|
||||
|
||||
if module_name == "_mock":
|
||||
self.dependency_graph.add_node(
|
||||
module_name,
|
||||
|
@ -41,6 +41,7 @@ IMPLICIT_IMPORT_ALLOWLIST: Iterable[str] = [
|
||||
# FX GraphModule might depend on builtins module and users usually
|
||||
# don't extern builtins. Here we import it here by default.
|
||||
"builtins",
|
||||
"triton",
|
||||
]
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user