Support triton.language.dtype with torch.compile (#121690)

Putting this PR as an RFC since I have resorted to some horrible hacks in order to make this work.
```
(Pdb) p triton.language.float32
triton.language.fp32
(Pdb) p str(triton.language.float32)
'fp32'
(Pdb) p repr(triton.language.float32)
'triton.language.fp32'
```
This means that we need to "rewrite" them for fx graph and inductor execution.

This PR allows Mamba2 to work with `torch.compile`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121690
Approved by: https://github.com/Skylion007
This commit is contained in:
Oguz Ulgen
2024-03-12 13:02:22 -07:00
committed by PyTorch MergeBot
parent 22bb24986d
commit 79ee6bbde3
7 changed files with 105 additions and 5 deletions

View File

@ -1086,6 +1086,51 @@ def forward(self, x_1, output_1):
with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg): with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg):
f(x, x) f(x, x)
@requires_cuda
@skipIfRocm
@common_utils.parametrize("dynamic", [False, True])
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
def test_triton_kernel_triton_dtype(self, dynamic, backend):
@triton.jit
def add_kernel_with_dtype(
in_ptr0,
in_ptr1,
out_ptr,
dtype: "tl.constexpr",
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask).to(dtype)
y = tl.load(in_ptr1 + offsets, mask=mask).to(dtype)
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)
def f(x, y, dtype_torch, dtype_triton):
output = torch.zeros_like(x).to(dtype=dtype_torch)
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
add_kernel_with_dtype[grid](
x, y, output, dtype_triton, n_elements, BLOCK_SIZE=4
)
return output
x = torch.randn(4, device="cuda")
y = torch.randn(4, device="cuda")
args_list = (
[x, y, torch.float32, tl.float32],
[x, y, torch.bfloat16, tl.bfloat16],
)
for args in args_list:
eager_out = f(*args)
compiled_out = torch.compile(
f, fullgraph=True, backend=backend, dynamic=dynamic
)(*args)
self.assertEqual(compiled_out, eager_out)
def make_mutation_test(fn): def make_mutation_test(fn):
@requires_cuda @requires_cuda

View File

@ -1222,6 +1222,13 @@ class CheckFunctionManager:
# guard_fn.__globals__ becomes equal to builder.scope. This causes # guard_fn.__globals__ becomes equal to builder.scope. This causes
# guard_fn to hold a referece to f_locals sitting in builder.scope["L"] # guard_fn to hold a referece to f_locals sitting in builder.scope["L"]
globals_for_guard_fn = {"G": builder.scope["G"]} globals_for_guard_fn = {"G": builder.scope["G"]}
from torch.utils._triton import has_triton_package
if has_triton_package():
import triton
globals_for_guard_fn[triton.__name__] = triton
try: try:
exec(pycode, globals_for_guard_fn, out) exec(pycode, globals_for_guard_fn, out)
except SyntaxError as ex: except SyntaxError as ex:

View File

@ -98,7 +98,7 @@ from torch._utils_internal import log_compilation_event
from torch.nn.modules.lazy import LazyModuleMixin from torch.nn.modules.lazy import LazyModuleMixin
from torch.utils._pytree import tree_map_only from torch.utils._pytree import tree_map_only
from torch.utils._triton import has_triton, has_triton_package
counters: DefaultDict[str, Counter[str]] = collections.defaultdict(collections.Counter) counters: DefaultDict[str, Counter[str]] = collections.defaultdict(collections.Counter)
optimus_scuba_log: Dict[str, Any] = {} optimus_scuba_log: Dict[str, Any] = {}
@ -977,6 +977,10 @@ common_constant_types = {
torch.memory_format, torch.memory_format,
torch.layout, torch.layout,
} }
if has_triton_package():
import triton
common_constant_types.add(triton.language.dtype)
def is_safe_constant(v): def is_safe_constant(v):
@ -2193,8 +2197,6 @@ def is_compile_supported(device_type):
if device_type == "cpu": if device_type == "cpu":
pass pass
elif device_type == "cuda" and compile_supported: elif device_type == "cuda" and compile_supported:
from torch.utils._triton import has_triton
compile_supported = has_triton() compile_supported = has_triton()
else: else:
compile_supported = False compile_supported = False

View File

@ -1013,6 +1013,10 @@ class WrapperCodeGen(CodeGen):
self.header.splice(f"\n\n{metadata_comment}{name} = {kernel}") self.header.splice(f"\n\n{metadata_comment}{name} = {kernel}")
def define_user_defined_triton_kernel(self, kernel, configs, kwargs): def define_user_defined_triton_kernel(self, kernel, configs, kwargs):
from torch.utils._triton import patch_triton_dtype_repr
patch_triton_dtype_repr()
original_name = kernel.__name__ original_name = kernel.__name__
from .common import KernelArgType, SizeArg, TensorArg from .common import KernelArgType, SizeArg, TensorArg
@ -1299,6 +1303,11 @@ class WrapperCodeGen(CodeGen):
raise NotImplementedError() raise NotImplementedError()
def val_to_arg_str(self, s): def val_to_arg_str(self, s):
from torch.utils._triton import dtype_to_string, has_triton
if has_triton():
import triton
if isinstance(s, SymTypes): if isinstance(s, SymTypes):
return pexpr(sympy.expand(repr(s))) return pexpr(sympy.expand(repr(s)))
elif isinstance(s, sympy.Expr): elif isinstance(s, sympy.Expr):
@ -1317,6 +1326,8 @@ class WrapperCodeGen(CodeGen):
return _get_qualified_name(s) return _get_qualified_name(s)
elif isinstance(s, (ir.Buffer, ReinterpretView)): elif isinstance(s, (ir.Buffer, ReinterpretView)):
return s.codegen_reference() return s.codegen_reference()
elif has_triton() and isinstance(s, triton.language.dtype): # type: ignore[possibly-undefined]
return dtype_to_string(s)
else: else:
return repr(s) return repr(s)

View File

@ -380,11 +380,19 @@ class CodeGen:
def _gen_python_code( def _gen_python_code(
self, nodes, root_module: str, namespace: _Namespace, *, verbose: bool = False, self, nodes, root_module: str, namespace: _Namespace, *, verbose: bool = False,
) -> PythonCode: ) -> PythonCode:
from torch.utils._triton import has_triton
free_vars: List[str] = [] free_vars: List[str] = []
body: List[str] = [] body: List[str] = []
globals_: Dict[str, Any] = {} globals_: Dict[str, Any] = {}
wrapped_fns: Dict[str, None] = {} wrapped_fns: Dict[str, None] = {}
if has_triton():
import triton
globals_[triton.__name__] = triton
from torch.utils._triton import patch_triton_dtype_repr
patch_triton_dtype_repr()
# Wrap string in list to pass by reference # Wrap string in list to pass by reference
maybe_return_annotation : List[str] = [''] maybe_return_annotation : List[str] = ['']

View File

@ -245,6 +245,11 @@ class TracerBase:
Can be override to support more trace-specific types. Can be override to support more trace-specific types.
""" """
from torch.utils._triton import has_triton
if has_triton():
import triton
if not isinstance(a, Proxy) and hasattr(a, '__fx_create_arg__'): if not isinstance(a, Proxy) and hasattr(a, '__fx_create_arg__'):
return a.__fx_create_arg__(self) return a.__fx_create_arg__(self)
# aggregates # aggregates
@ -280,6 +285,8 @@ class TracerBase:
elif isinstance(a, torch._ops.OpOverload): elif isinstance(a, torch._ops.OpOverload):
return a return a
elif has_triton() and isinstance(a, triton.language.dtype):
return a
if isinstance(a, Proxy): if isinstance(a, Proxy):
# base case: we unwrap the Proxy object # base case: we unwrap the Proxy object

View File

@ -1,8 +1,6 @@
import functools import functools
import hashlib import hashlib
from torch._dynamo.device_interface import get_interface_for_device
@functools.lru_cache(None) @functools.lru_cache(None)
def has_triton_package() -> bool: def has_triton_package() -> bool:
@ -16,6 +14,8 @@ def has_triton_package() -> bool:
@functools.lru_cache(None) @functools.lru_cache(None)
def has_triton() -> bool: def has_triton() -> bool:
from torch._dynamo.device_interface import get_interface_for_device
def cuda_extra_check(device_interface): def cuda_extra_check(device_interface):
return device_interface.Worker.get_device_properties().major >= 7 return device_interface.Worker.get_device_properties().major >= 7
@ -59,3 +59,23 @@ def triton_hash_with_backend():
backend = triton_backend() backend = triton_backend()
key = f"{triton_key()}-{backend.hash()}" key = f"{triton_key()}-{backend.hash()}"
return hashlib.sha256(key.encode("utf-8")).hexdigest() return hashlib.sha256(key.encode("utf-8")).hexdigest()
def dtype_to_string(dtype):
if dtype.name.startswith("fp"):
suffix = "float" + dtype.name[2:]
elif dtype.name.startswith("bf"):
suffix = "bfloat" + dtype.name[2:]
else:
suffix = dtype.name
return "triton.language." + suffix
def patch_triton_dtype_repr():
import triton
# Hack to get triton dtype repr to produce an evaluatable expression
# triton.language.float32 emits triton.language.fp32 which does not
# exist
# REMOVE when https://github.com/openai/triton/pull/3342 lands
triton.language.dtype.__repr__ = lambda self: dtype_to_string(self)