mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 08:24:57 +08:00
Support triton.language.dtype with torch.compile (#121690)
Putting this PR as an RFC since I have resorted to some horrible hacks in order to make this work. ``` (Pdb) p triton.language.float32 triton.language.fp32 (Pdb) p str(triton.language.float32) 'fp32' (Pdb) p repr(triton.language.float32) 'triton.language.fp32' ``` This means that we need to "rewrite" them for fx graph and inductor execution. This PR allows Mamba2 to work with `torch.compile`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/121690 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
22bb24986d
commit
79ee6bbde3
@ -1086,6 +1086,51 @@ def forward(self, x_1, output_1):
|
||||
with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg):
|
||||
f(x, x)
|
||||
|
||||
@requires_cuda
|
||||
@skipIfRocm
|
||||
@common_utils.parametrize("dynamic", [False, True])
|
||||
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
|
||||
def test_triton_kernel_triton_dtype(self, dynamic, backend):
|
||||
@triton.jit
|
||||
def add_kernel_with_dtype(
|
||||
in_ptr0,
|
||||
in_ptr1,
|
||||
out_ptr,
|
||||
dtype: "tl.constexpr",
|
||||
n_elements,
|
||||
BLOCK_SIZE: "tl.constexpr",
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
x = tl.load(in_ptr0 + offsets, mask=mask).to(dtype)
|
||||
y = tl.load(in_ptr1 + offsets, mask=mask).to(dtype)
|
||||
output = x + y
|
||||
tl.store(out_ptr + offsets, output, mask=mask)
|
||||
|
||||
def f(x, y, dtype_torch, dtype_triton):
|
||||
output = torch.zeros_like(x).to(dtype=dtype_torch)
|
||||
n_elements = output.numel()
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
||||
add_kernel_with_dtype[grid](
|
||||
x, y, output, dtype_triton, n_elements, BLOCK_SIZE=4
|
||||
)
|
||||
return output
|
||||
|
||||
x = torch.randn(4, device="cuda")
|
||||
y = torch.randn(4, device="cuda")
|
||||
args_list = (
|
||||
[x, y, torch.float32, tl.float32],
|
||||
[x, y, torch.bfloat16, tl.bfloat16],
|
||||
)
|
||||
for args in args_list:
|
||||
eager_out = f(*args)
|
||||
compiled_out = torch.compile(
|
||||
f, fullgraph=True, backend=backend, dynamic=dynamic
|
||||
)(*args)
|
||||
self.assertEqual(compiled_out, eager_out)
|
||||
|
||||
|
||||
def make_mutation_test(fn):
|
||||
@requires_cuda
|
||||
|
||||
@ -1222,6 +1222,13 @@ class CheckFunctionManager:
|
||||
# guard_fn.__globals__ becomes equal to builder.scope. This causes
|
||||
# guard_fn to hold a referece to f_locals sitting in builder.scope["L"]
|
||||
globals_for_guard_fn = {"G": builder.scope["G"]}
|
||||
from torch.utils._triton import has_triton_package
|
||||
|
||||
if has_triton_package():
|
||||
import triton
|
||||
|
||||
globals_for_guard_fn[triton.__name__] = triton
|
||||
|
||||
try:
|
||||
exec(pycode, globals_for_guard_fn, out)
|
||||
except SyntaxError as ex:
|
||||
|
||||
@ -98,7 +98,7 @@ from torch._utils_internal import log_compilation_event
|
||||
|
||||
from torch.nn.modules.lazy import LazyModuleMixin
|
||||
from torch.utils._pytree import tree_map_only
|
||||
|
||||
from torch.utils._triton import has_triton, has_triton_package
|
||||
|
||||
counters: DefaultDict[str, Counter[str]] = collections.defaultdict(collections.Counter)
|
||||
optimus_scuba_log: Dict[str, Any] = {}
|
||||
@ -977,6 +977,10 @@ common_constant_types = {
|
||||
torch.memory_format,
|
||||
torch.layout,
|
||||
}
|
||||
if has_triton_package():
|
||||
import triton
|
||||
|
||||
common_constant_types.add(triton.language.dtype)
|
||||
|
||||
|
||||
def is_safe_constant(v):
|
||||
@ -2193,8 +2197,6 @@ def is_compile_supported(device_type):
|
||||
if device_type == "cpu":
|
||||
pass
|
||||
elif device_type == "cuda" and compile_supported:
|
||||
from torch.utils._triton import has_triton
|
||||
|
||||
compile_supported = has_triton()
|
||||
else:
|
||||
compile_supported = False
|
||||
|
||||
@ -1013,6 +1013,10 @@ class WrapperCodeGen(CodeGen):
|
||||
self.header.splice(f"\n\n{metadata_comment}{name} = {kernel}")
|
||||
|
||||
def define_user_defined_triton_kernel(self, kernel, configs, kwargs):
|
||||
from torch.utils._triton import patch_triton_dtype_repr
|
||||
|
||||
patch_triton_dtype_repr()
|
||||
|
||||
original_name = kernel.__name__
|
||||
|
||||
from .common import KernelArgType, SizeArg, TensorArg
|
||||
@ -1299,6 +1303,11 @@ class WrapperCodeGen(CodeGen):
|
||||
raise NotImplementedError()
|
||||
|
||||
def val_to_arg_str(self, s):
|
||||
from torch.utils._triton import dtype_to_string, has_triton
|
||||
|
||||
if has_triton():
|
||||
import triton
|
||||
|
||||
if isinstance(s, SymTypes):
|
||||
return pexpr(sympy.expand(repr(s)))
|
||||
elif isinstance(s, sympy.Expr):
|
||||
@ -1317,6 +1326,8 @@ class WrapperCodeGen(CodeGen):
|
||||
return _get_qualified_name(s)
|
||||
elif isinstance(s, (ir.Buffer, ReinterpretView)):
|
||||
return s.codegen_reference()
|
||||
elif has_triton() and isinstance(s, triton.language.dtype): # type: ignore[possibly-undefined]
|
||||
return dtype_to_string(s)
|
||||
else:
|
||||
return repr(s)
|
||||
|
||||
|
||||
@ -380,11 +380,19 @@ class CodeGen:
|
||||
def _gen_python_code(
|
||||
self, nodes, root_module: str, namespace: _Namespace, *, verbose: bool = False,
|
||||
) -> PythonCode:
|
||||
from torch.utils._triton import has_triton
|
||||
|
||||
free_vars: List[str] = []
|
||||
body: List[str] = []
|
||||
globals_: Dict[str, Any] = {}
|
||||
wrapped_fns: Dict[str, None] = {}
|
||||
|
||||
if has_triton():
|
||||
import triton
|
||||
globals_[triton.__name__] = triton
|
||||
from torch.utils._triton import patch_triton_dtype_repr
|
||||
patch_triton_dtype_repr()
|
||||
|
||||
# Wrap string in list to pass by reference
|
||||
maybe_return_annotation : List[str] = ['']
|
||||
|
||||
|
||||
@ -245,6 +245,11 @@ class TracerBase:
|
||||
|
||||
Can be override to support more trace-specific types.
|
||||
"""
|
||||
from torch.utils._triton import has_triton
|
||||
|
||||
if has_triton():
|
||||
import triton
|
||||
|
||||
if not isinstance(a, Proxy) and hasattr(a, '__fx_create_arg__'):
|
||||
return a.__fx_create_arg__(self)
|
||||
# aggregates
|
||||
@ -280,6 +285,8 @@ class TracerBase:
|
||||
|
||||
elif isinstance(a, torch._ops.OpOverload):
|
||||
return a
|
||||
elif has_triton() and isinstance(a, triton.language.dtype):
|
||||
return a
|
||||
|
||||
if isinstance(a, Proxy):
|
||||
# base case: we unwrap the Proxy object
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
import functools
|
||||
import hashlib
|
||||
|
||||
from torch._dynamo.device_interface import get_interface_for_device
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def has_triton_package() -> bool:
|
||||
@ -16,6 +14,8 @@ def has_triton_package() -> bool:
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def has_triton() -> bool:
|
||||
from torch._dynamo.device_interface import get_interface_for_device
|
||||
|
||||
def cuda_extra_check(device_interface):
|
||||
return device_interface.Worker.get_device_properties().major >= 7
|
||||
|
||||
@ -59,3 +59,23 @@ def triton_hash_with_backend():
|
||||
backend = triton_backend()
|
||||
key = f"{triton_key()}-{backend.hash()}"
|
||||
return hashlib.sha256(key.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def dtype_to_string(dtype):
|
||||
if dtype.name.startswith("fp"):
|
||||
suffix = "float" + dtype.name[2:]
|
||||
elif dtype.name.startswith("bf"):
|
||||
suffix = "bfloat" + dtype.name[2:]
|
||||
else:
|
||||
suffix = dtype.name
|
||||
return "triton.language." + suffix
|
||||
|
||||
|
||||
def patch_triton_dtype_repr():
|
||||
import triton
|
||||
|
||||
# Hack to get triton dtype repr to produce an evaluatable expression
|
||||
# triton.language.float32 emits triton.language.fp32 which does not
|
||||
# exist
|
||||
# REMOVE when https://github.com/openai/triton/pull/3342 lands
|
||||
triton.language.dtype.__repr__ = lambda self: dtype_to_string(self)
|
||||
|
||||
Reference in New Issue
Block a user