mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-06 17:24:59 +08:00
Support triton.language.dtype with torch.compile (#121690)
Putting this PR as an RFC since I have resorted to some horrible hacks in order to make this work. ``` (Pdb) p triton.language.float32 triton.language.fp32 (Pdb) p str(triton.language.float32) 'fp32' (Pdb) p repr(triton.language.float32) 'triton.language.fp32' ``` This means that we need to "rewrite" them for fx graph and inductor execution. This PR allows Mamba2 to work with `torch.compile`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/121690 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
22bb24986d
commit
79ee6bbde3
@ -1086,6 +1086,51 @@ def forward(self, x_1, output_1):
|
|||||||
with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg):
|
with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg):
|
||||||
f(x, x)
|
f(x, x)
|
||||||
|
|
||||||
|
@requires_cuda
|
||||||
|
@skipIfRocm
|
||||||
|
@common_utils.parametrize("dynamic", [False, True])
|
||||||
|
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
|
||||||
|
def test_triton_kernel_triton_dtype(self, dynamic, backend):
|
||||||
|
@triton.jit
|
||||||
|
def add_kernel_with_dtype(
|
||||||
|
in_ptr0,
|
||||||
|
in_ptr1,
|
||||||
|
out_ptr,
|
||||||
|
dtype: "tl.constexpr",
|
||||||
|
n_elements,
|
||||||
|
BLOCK_SIZE: "tl.constexpr",
|
||||||
|
):
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
block_start = pid * BLOCK_SIZE
|
||||||
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = offsets < n_elements
|
||||||
|
x = tl.load(in_ptr0 + offsets, mask=mask).to(dtype)
|
||||||
|
y = tl.load(in_ptr1 + offsets, mask=mask).to(dtype)
|
||||||
|
output = x + y
|
||||||
|
tl.store(out_ptr + offsets, output, mask=mask)
|
||||||
|
|
||||||
|
def f(x, y, dtype_torch, dtype_triton):
|
||||||
|
output = torch.zeros_like(x).to(dtype=dtype_torch)
|
||||||
|
n_elements = output.numel()
|
||||||
|
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
||||||
|
add_kernel_with_dtype[grid](
|
||||||
|
x, y, output, dtype_triton, n_elements, BLOCK_SIZE=4
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
x = torch.randn(4, device="cuda")
|
||||||
|
y = torch.randn(4, device="cuda")
|
||||||
|
args_list = (
|
||||||
|
[x, y, torch.float32, tl.float32],
|
||||||
|
[x, y, torch.bfloat16, tl.bfloat16],
|
||||||
|
)
|
||||||
|
for args in args_list:
|
||||||
|
eager_out = f(*args)
|
||||||
|
compiled_out = torch.compile(
|
||||||
|
f, fullgraph=True, backend=backend, dynamic=dynamic
|
||||||
|
)(*args)
|
||||||
|
self.assertEqual(compiled_out, eager_out)
|
||||||
|
|
||||||
|
|
||||||
def make_mutation_test(fn):
|
def make_mutation_test(fn):
|
||||||
@requires_cuda
|
@requires_cuda
|
||||||
|
|||||||
@ -1222,6 +1222,13 @@ class CheckFunctionManager:
|
|||||||
# guard_fn.__globals__ becomes equal to builder.scope. This causes
|
# guard_fn.__globals__ becomes equal to builder.scope. This causes
|
||||||
# guard_fn to hold a referece to f_locals sitting in builder.scope["L"]
|
# guard_fn to hold a referece to f_locals sitting in builder.scope["L"]
|
||||||
globals_for_guard_fn = {"G": builder.scope["G"]}
|
globals_for_guard_fn = {"G": builder.scope["G"]}
|
||||||
|
from torch.utils._triton import has_triton_package
|
||||||
|
|
||||||
|
if has_triton_package():
|
||||||
|
import triton
|
||||||
|
|
||||||
|
globals_for_guard_fn[triton.__name__] = triton
|
||||||
|
|
||||||
try:
|
try:
|
||||||
exec(pycode, globals_for_guard_fn, out)
|
exec(pycode, globals_for_guard_fn, out)
|
||||||
except SyntaxError as ex:
|
except SyntaxError as ex:
|
||||||
|
|||||||
@ -98,7 +98,7 @@ from torch._utils_internal import log_compilation_event
|
|||||||
|
|
||||||
from torch.nn.modules.lazy import LazyModuleMixin
|
from torch.nn.modules.lazy import LazyModuleMixin
|
||||||
from torch.utils._pytree import tree_map_only
|
from torch.utils._pytree import tree_map_only
|
||||||
|
from torch.utils._triton import has_triton, has_triton_package
|
||||||
|
|
||||||
counters: DefaultDict[str, Counter[str]] = collections.defaultdict(collections.Counter)
|
counters: DefaultDict[str, Counter[str]] = collections.defaultdict(collections.Counter)
|
||||||
optimus_scuba_log: Dict[str, Any] = {}
|
optimus_scuba_log: Dict[str, Any] = {}
|
||||||
@ -977,6 +977,10 @@ common_constant_types = {
|
|||||||
torch.memory_format,
|
torch.memory_format,
|
||||||
torch.layout,
|
torch.layout,
|
||||||
}
|
}
|
||||||
|
if has_triton_package():
|
||||||
|
import triton
|
||||||
|
|
||||||
|
common_constant_types.add(triton.language.dtype)
|
||||||
|
|
||||||
|
|
||||||
def is_safe_constant(v):
|
def is_safe_constant(v):
|
||||||
@ -2193,8 +2197,6 @@ def is_compile_supported(device_type):
|
|||||||
if device_type == "cpu":
|
if device_type == "cpu":
|
||||||
pass
|
pass
|
||||||
elif device_type == "cuda" and compile_supported:
|
elif device_type == "cuda" and compile_supported:
|
||||||
from torch.utils._triton import has_triton
|
|
||||||
|
|
||||||
compile_supported = has_triton()
|
compile_supported = has_triton()
|
||||||
else:
|
else:
|
||||||
compile_supported = False
|
compile_supported = False
|
||||||
|
|||||||
@ -1013,6 +1013,10 @@ class WrapperCodeGen(CodeGen):
|
|||||||
self.header.splice(f"\n\n{metadata_comment}{name} = {kernel}")
|
self.header.splice(f"\n\n{metadata_comment}{name} = {kernel}")
|
||||||
|
|
||||||
def define_user_defined_triton_kernel(self, kernel, configs, kwargs):
|
def define_user_defined_triton_kernel(self, kernel, configs, kwargs):
|
||||||
|
from torch.utils._triton import patch_triton_dtype_repr
|
||||||
|
|
||||||
|
patch_triton_dtype_repr()
|
||||||
|
|
||||||
original_name = kernel.__name__
|
original_name = kernel.__name__
|
||||||
|
|
||||||
from .common import KernelArgType, SizeArg, TensorArg
|
from .common import KernelArgType, SizeArg, TensorArg
|
||||||
@ -1299,6 +1303,11 @@ class WrapperCodeGen(CodeGen):
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def val_to_arg_str(self, s):
|
def val_to_arg_str(self, s):
|
||||||
|
from torch.utils._triton import dtype_to_string, has_triton
|
||||||
|
|
||||||
|
if has_triton():
|
||||||
|
import triton
|
||||||
|
|
||||||
if isinstance(s, SymTypes):
|
if isinstance(s, SymTypes):
|
||||||
return pexpr(sympy.expand(repr(s)))
|
return pexpr(sympy.expand(repr(s)))
|
||||||
elif isinstance(s, sympy.Expr):
|
elif isinstance(s, sympy.Expr):
|
||||||
@ -1317,6 +1326,8 @@ class WrapperCodeGen(CodeGen):
|
|||||||
return _get_qualified_name(s)
|
return _get_qualified_name(s)
|
||||||
elif isinstance(s, (ir.Buffer, ReinterpretView)):
|
elif isinstance(s, (ir.Buffer, ReinterpretView)):
|
||||||
return s.codegen_reference()
|
return s.codegen_reference()
|
||||||
|
elif has_triton() and isinstance(s, triton.language.dtype): # type: ignore[possibly-undefined]
|
||||||
|
return dtype_to_string(s)
|
||||||
else:
|
else:
|
||||||
return repr(s)
|
return repr(s)
|
||||||
|
|
||||||
|
|||||||
@ -380,11 +380,19 @@ class CodeGen:
|
|||||||
def _gen_python_code(
|
def _gen_python_code(
|
||||||
self, nodes, root_module: str, namespace: _Namespace, *, verbose: bool = False,
|
self, nodes, root_module: str, namespace: _Namespace, *, verbose: bool = False,
|
||||||
) -> PythonCode:
|
) -> PythonCode:
|
||||||
|
from torch.utils._triton import has_triton
|
||||||
|
|
||||||
free_vars: List[str] = []
|
free_vars: List[str] = []
|
||||||
body: List[str] = []
|
body: List[str] = []
|
||||||
globals_: Dict[str, Any] = {}
|
globals_: Dict[str, Any] = {}
|
||||||
wrapped_fns: Dict[str, None] = {}
|
wrapped_fns: Dict[str, None] = {}
|
||||||
|
|
||||||
|
if has_triton():
|
||||||
|
import triton
|
||||||
|
globals_[triton.__name__] = triton
|
||||||
|
from torch.utils._triton import patch_triton_dtype_repr
|
||||||
|
patch_triton_dtype_repr()
|
||||||
|
|
||||||
# Wrap string in list to pass by reference
|
# Wrap string in list to pass by reference
|
||||||
maybe_return_annotation : List[str] = ['']
|
maybe_return_annotation : List[str] = ['']
|
||||||
|
|
||||||
|
|||||||
@ -245,6 +245,11 @@ class TracerBase:
|
|||||||
|
|
||||||
Can be override to support more trace-specific types.
|
Can be override to support more trace-specific types.
|
||||||
"""
|
"""
|
||||||
|
from torch.utils._triton import has_triton
|
||||||
|
|
||||||
|
if has_triton():
|
||||||
|
import triton
|
||||||
|
|
||||||
if not isinstance(a, Proxy) and hasattr(a, '__fx_create_arg__'):
|
if not isinstance(a, Proxy) and hasattr(a, '__fx_create_arg__'):
|
||||||
return a.__fx_create_arg__(self)
|
return a.__fx_create_arg__(self)
|
||||||
# aggregates
|
# aggregates
|
||||||
@ -280,6 +285,8 @@ class TracerBase:
|
|||||||
|
|
||||||
elif isinstance(a, torch._ops.OpOverload):
|
elif isinstance(a, torch._ops.OpOverload):
|
||||||
return a
|
return a
|
||||||
|
elif has_triton() and isinstance(a, triton.language.dtype):
|
||||||
|
return a
|
||||||
|
|
||||||
if isinstance(a, Proxy):
|
if isinstance(a, Proxy):
|
||||||
# base case: we unwrap the Proxy object
|
# base case: we unwrap the Proxy object
|
||||||
|
|||||||
@ -1,8 +1,6 @@
|
|||||||
import functools
|
import functools
|
||||||
import hashlib
|
import hashlib
|
||||||
|
|
||||||
from torch._dynamo.device_interface import get_interface_for_device
|
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache(None)
|
@functools.lru_cache(None)
|
||||||
def has_triton_package() -> bool:
|
def has_triton_package() -> bool:
|
||||||
@ -16,6 +14,8 @@ def has_triton_package() -> bool:
|
|||||||
|
|
||||||
@functools.lru_cache(None)
|
@functools.lru_cache(None)
|
||||||
def has_triton() -> bool:
|
def has_triton() -> bool:
|
||||||
|
from torch._dynamo.device_interface import get_interface_for_device
|
||||||
|
|
||||||
def cuda_extra_check(device_interface):
|
def cuda_extra_check(device_interface):
|
||||||
return device_interface.Worker.get_device_properties().major >= 7
|
return device_interface.Worker.get_device_properties().major >= 7
|
||||||
|
|
||||||
@ -59,3 +59,23 @@ def triton_hash_with_backend():
|
|||||||
backend = triton_backend()
|
backend = triton_backend()
|
||||||
key = f"{triton_key()}-{backend.hash()}"
|
key = f"{triton_key()}-{backend.hash()}"
|
||||||
return hashlib.sha256(key.encode("utf-8")).hexdigest()
|
return hashlib.sha256(key.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def dtype_to_string(dtype):
|
||||||
|
if dtype.name.startswith("fp"):
|
||||||
|
suffix = "float" + dtype.name[2:]
|
||||||
|
elif dtype.name.startswith("bf"):
|
||||||
|
suffix = "bfloat" + dtype.name[2:]
|
||||||
|
else:
|
||||||
|
suffix = dtype.name
|
||||||
|
return "triton.language." + suffix
|
||||||
|
|
||||||
|
|
||||||
|
def patch_triton_dtype_repr():
|
||||||
|
import triton
|
||||||
|
|
||||||
|
# Hack to get triton dtype repr to produce an evaluatable expression
|
||||||
|
# triton.language.float32 emits triton.language.fp32 which does not
|
||||||
|
# exist
|
||||||
|
# REMOVE when https://github.com/openai/triton/pull/3342 lands
|
||||||
|
triton.language.dtype.__repr__ = lambda self: dtype_to_string(self)
|
||||||
|
|||||||
Reference in New Issue
Block a user