mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Support triton.language.dtype
with torch.compile
(#121690)
Putting this PR as an RFC since I have resorted to some horrible hacks in order to make this work. ``` (Pdb) p triton.language.float32 triton.language.fp32 (Pdb) p str(triton.language.float32) 'fp32' (Pdb) p repr(triton.language.float32) 'triton.language.fp32' ``` This means that we need to "rewrite" them for fx graph and inductor execution. This PR allows Mamba2 to work with `torch.compile`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/121690 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
22bb24986d
commit
79ee6bbde3
@ -245,6 +245,11 @@ class TracerBase:
|
||||
|
||||
Can be override to support more trace-specific types.
|
||||
"""
|
||||
from torch.utils._triton import has_triton
|
||||
|
||||
if has_triton():
|
||||
import triton
|
||||
|
||||
if not isinstance(a, Proxy) and hasattr(a, '__fx_create_arg__'):
|
||||
return a.__fx_create_arg__(self)
|
||||
# aggregates
|
||||
@ -280,6 +285,8 @@ class TracerBase:
|
||||
|
||||
elif isinstance(a, torch._ops.OpOverload):
|
||||
return a
|
||||
elif has_triton() and isinstance(a, triton.language.dtype):
|
||||
return a
|
||||
|
||||
if isinstance(a, Proxy):
|
||||
# base case: we unwrap the Proxy object
|
||||
|
Reference in New Issue
Block a user