mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor][triton] Update HAS_WARP_SPEC to check triton.Config params. Update Triton Hash to top of release/3.4.x stack (#158459)
Update triton commit hash to `11ec6354315768a85da41032535e3b7b99c5f706`, which is the new release/3.4.x branch in triton-lang/triton. Also, update HAS_WARP_SPEC handling: In triton 3.4, warp spec will have a different interface: num_consumer_groups will be determined automatically by the compiler. This breaks the current Inductor integration, so for now, update HAS_WARP_SPEC to check whether triton.Config takes num_consumer_groups and num_buffers_warp_spec as parameters. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158459 Approved by: https://github.com/atalman
This commit is contained in:
committed by
PyTorch MergeBot
parent
d5af0eca8d
commit
7892f5a007
@ -1 +1 @@
|
||||
ae848267bebc65c6181e8cc5e64a6357d2679260
|
||||
11ec6354315768a85da41032535e3b7b99c5f706
|
||||
|
@ -2,7 +2,6 @@
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
import unittest
|
||||
from unittest import mock
|
||||
|
||||
import torch
|
||||
@ -13,9 +12,8 @@ from torch._inductor.runtime.static_cuda_launcher import StaticallyLaunchedCudaK
|
||||
from torch._inductor.runtime.triton_compat import CompiledKernel, tl, triton
|
||||
from torch._inductor.runtime.triton_helpers import libdevice
|
||||
from torch._inductor.test_case import TestCase
|
||||
from torch.testing._internal.common_utils import IS_FBCODE, skipIfRocm
|
||||
from torch.testing._internal.common_utils import skipIfRocm
|
||||
from torch.testing._internal.triton_utils import requires_cuda
|
||||
from torch.torch_version import TorchVersion
|
||||
|
||||
|
||||
@requires_cuda
|
||||
@ -141,37 +139,6 @@ class TestStaticCudaLauncher(TestCase):
|
||||
launcher.run(1, 1, 1, stream, new_arg0, 50, 50, 50, 50)
|
||||
self.assertEqual(new_arg0, arg0)
|
||||
|
||||
# TODO: floats don't work properly, triton seems to think they're all tl.float32
|
||||
# despite type annotations.
|
||||
# There's also not really a good way for me to make a float16 in python...
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(IS_FBCODE, "Not working in fbcode")
|
||||
def test_floats(self):
|
||||
@triton.jit
|
||||
def floats(arg0, arg1: tl.float16, arg2: tl.float32, arg3: tl.float64):
|
||||
x = tl.load(arg0)
|
||||
y = arg1 + arg2 + arg3
|
||||
tl.store(arg0, x + y)
|
||||
|
||||
arg0 = torch.zeros(1, dtype=torch.float64, device="cuda")
|
||||
|
||||
args = (arg0, 1.0, 1.0, 1.0)
|
||||
|
||||
compiled_kernel = floats[1,](*args)
|
||||
launcher = self._make_launcher(compiled_kernel)
|
||||
if TorchVersion(triton.__version__) >= TorchVersion("3.4.0"):
|
||||
self.assertEqual(launcher.arg_tys, "Offd")
|
||||
else:
|
||||
self.assertEqual(launcher.arg_tys, "Offf")
|
||||
# TODO this line fails on Triton 3.4.0 (https://github.com/triton-lang/triton/issues/6176)
|
||||
# Add the check back when this is fixed in Triton
|
||||
# self.assertEqual(arg0, torch.tensor([3.0], dtype=torch.float64, device="cuda"))
|
||||
new_arg0 = torch.zeros(1, dtype=torch.float64, device="cuda")
|
||||
device_interface = get_interface_for_device("cuda")
|
||||
stream = device_interface.get_raw_stream(device_interface.current_device())
|
||||
launcher.run(1, 1, 1, stream, new_arg0, 1.0, 1.0, 1.0)
|
||||
self.assertEqual(new_arg0, arg0)
|
||||
|
||||
@skipIfRocm
|
||||
def test_basic_1arg(self):
|
||||
@triton.jit
|
||||
|
@ -69,7 +69,18 @@ if triton is not None:
|
||||
def _log2(x: Any) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
HAS_WARP_SPEC = hasattr(tl, "async_task")
|
||||
def _triton_config_has(param_name: str) -> bool:
|
||||
if not hasattr(triton, "Config"):
|
||||
return False
|
||||
if not hasattr(triton.Config, "__init__"):
|
||||
return False
|
||||
return param_name in inspect.signature(triton.Config.__init__).parameters
|
||||
|
||||
HAS_WARP_SPEC = (
|
||||
hasattr(tl, "async_task")
|
||||
and _triton_config_has("num_consumer_groups")
|
||||
and _triton_config_has("num_buffers_warp_spec")
|
||||
)
|
||||
|
||||
try:
|
||||
from triton import knobs
|
||||
|
Reference in New Issue
Block a user