Support triton.language.dtype with torch.compile (#121690)

Putting this PR as an RFC since I have resorted to some horrible hacks in order to make this work.
```
(Pdb) p triton.language.float32
triton.language.fp32
(Pdb) p str(triton.language.float32)
'fp32'
(Pdb) p repr(triton.language.float32)
'triton.language.fp32'
```
This means that we need to "rewrite" them for fx graph and inductor execution.

This PR allows Mamba2 to work with `torch.compile`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121690
Approved by: https://github.com/Skylion007
This commit is contained in:
Oguz Ulgen
2024-03-12 13:02:22 -07:00
committed by PyTorch MergeBot
parent 22bb24986d
commit 79ee6bbde3
7 changed files with 105 additions and 5 deletions

View File

@ -245,6 +245,11 @@ class TracerBase:
Can be override to support more trace-specific types.
"""
from torch.utils._triton import has_triton
if has_triton():
import triton
if not isinstance(a, Proxy) and hasattr(a, '__fx_create_arg__'):
return a.__fx_create_arg__(self)
# aggregates
@ -280,6 +285,8 @@ class TracerBase:
elif isinstance(a, torch._ops.OpOverload):
return a
elif has_triton() and isinstance(a, triton.language.dtype):
return a
if isinstance(a, Proxy):
# base case: we unwrap the Proxy object