Revert "Add Triton CPU as an Inductor backend (#133408)"

This reverts commit 31c0467594c7c41c8e8ff1828bf01fa31fc4454f.

Reverted https://github.com/pytorch/pytorch/pull/133408 on behalf of https://github.com/int3 due to internal tests failing ([comment](https://github.com/pytorch/pytorch/pull/133408#issuecomment-2379692517))
This commit is contained in:
PyTorch MergeBot
2024-09-27 16:54:27 +00:00
parent 17f396b0b4
commit 36428f91e9
34 changed files with 258 additions and 455 deletions

View File

@ -379,7 +379,6 @@ case "$image" in
GCC_VERSION=11
CONDA_CMAKE=yes
HALIDE=yes
TRITON=yes
;;
pytorch-linux-focal-linter)
# TODO: Use 3.9 here because of this issue https://github.com/python/mypy/issues/13627.

View File

@ -30,7 +30,7 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
ModelArgs,
Transformer,
)
from torch.testing._internal.inductor_utils import HAS_GPU
from torch.utils._triton import has_triton
log = logging.getLogger(__name__)
@ -48,7 +48,7 @@ orig_F_scaled_dot_product_attention = F.scaled_dot_product_attention
class TestFullyShardCompileCompute(FSDPTest):
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
def test_disable_compiling_hooks(self):
self.run_subtests(
@ -529,14 +529,14 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
return model_init_fn, input_creation_fn
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
def test_simple_mlp_fullgraph_backend_aot_eager(self):
self._test_traceable_fsdp(
*self._create_simple_mlp_factory_fns(), "aot_eager", fullgraph=True
)
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
def test_simple_mlp_fullgraph_backend_aot_eager_decomp_partition(self):
self._test_traceable_fsdp(
*self._create_simple_mlp_factory_fns(),
@ -545,7 +545,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
)
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
def test_simple_mlp_fullgraph_backend_inductor(self):
self._test_traceable_fsdp(
*self._create_simple_mlp_factory_fns(), "inductor", fullgraph=True
@ -613,7 +613,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
return model_init_fn, input_creation_fn
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
def test_nested_fully_shard_backend_aot_eager(self):
for fullgraph in [True, False]:
self._test_traceable_fsdp(
@ -623,7 +623,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
)
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
def test_nested_fully_shard_backend_aot_eager_decomp_partition(self):
for fullgraph in [True, False]:
self._test_traceable_fsdp(
@ -633,7 +633,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
)
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
def test_nested_fully_shard_backend_inductor_fullgraph_True(self):
for fullgraph in [True]:
with self._reinplace_all_gather_with_optional_checks(
@ -729,7 +729,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
file_check.run(bwd_code)
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
def test_nested_fully_shard_backend_inductor_fullgraph_False(self):
_, triton_codes = run_and_get_code(
lambda: self._test_traceable_fsdp(
@ -806,7 +806,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
return contextlib.nullcontext()
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
def test_transformer_backend_aot_eager(self):
for fullgraph, all_requires_grad in itertools.product(
[True, False], [True, False]
@ -823,7 +823,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
)
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
# TODO: native_dropout has worse accuracy after decomp, need to figure out why
@torch._inductor.config.patch(fallback_random=True)
def test_transformer_backend_aot_eager_decomp_partition(self):
@ -840,7 +840,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
)
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
# TODO: native_dropout causes CUDA IMA error, need to figure out why
@torch._inductor.config.patch(fallback_random=True)
def test_transformer_backend_inductor_fullgraph_True(self):
@ -943,7 +943,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
file_check.run(bwd_code)
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
# TODO: native_dropout causes CUDA IMA error, need to figure out why
@torch._inductor.config.patch(fallback_random=True)
def test_transformer_backend_inductor_fullgraph_False(self):

View File

@ -18,7 +18,7 @@ from torch.testing._internal.common_fsdp import (
TransformerWithSharedParams,
)
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
from torch.testing._internal.inductor_utils import HAS_GPU
from torch.utils._triton import has_triton
if not dist.is_available():
@ -38,7 +38,7 @@ class TestCompile(FSDPTest):
def world_size(self) -> int:
return torch.cuda.device_count()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
def test_compile(self):
self.run_subtests(

View File

@ -33,7 +33,7 @@ from torch.testing._internal.common_distributed import (
)
from torch.testing._internal.common_utils import run_tests, skipIfRocm
from torch.testing._internal.distributed.fake_pg import FakeStore
from torch.testing._internal.inductor_utils import HAS_GPU
from torch.utils._triton import has_triton
from torch.utils.checkpoint import checkpoint
@ -216,21 +216,21 @@ class ReplicateTest(MultiProcessInductorTestCase):
]
self._test_compile(use_gpu=False, no_sync=True)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_rocm_multiprocess
@skip_if_lt_x_gpu(2)
@torch._inductor.config.patch(reorder_for_locality=False)
def test_compile_gpu(self):
self._test_compile(use_gpu=True, no_sync=False, checkpoint=False)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_rocm_multiprocess
@skip_if_lt_x_gpu(2)
@torch._inductor.config.patch(reorder_for_locality=False)
def test_compile_gpu_ac(self):
self._test_compile(use_gpu=True, no_sync=False, checkpoint=True)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_rocm_multiprocess
@skip_if_lt_x_gpu(2)
def test_compile_bf16(self):
@ -244,7 +244,7 @@ class ReplicateTest(MultiProcessInductorTestCase):
self._test_compile(use_gpu=True, no_sync=False, setup_func=setup)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_rocm_multiprocess
@skip_if_lt_x_gpu(2)
def test_compile_fp16(self):
@ -261,7 +261,7 @@ class ReplicateTest(MultiProcessInductorTestCase):
use_gpu=True, no_sync=False, setup_func=setup, no_inductor=True
)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_rocm_multiprocess
@skip_if_lt_x_gpu(2)
def test_compile_backward_only(self):
@ -385,7 +385,7 @@ class DDP_TP_Test(InductorTestCase):
def tearDown(self):
dist.destroy_process_group()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skipIfRocm
def test_ddp_tp(self):
ref_model = Net()

View File

@ -46,7 +46,7 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
with_comms,
)
from torch.testing._internal.distributed.fake_pg import FakeStore
from torch.testing._internal.inductor_utils import HAS_GPU
from torch.utils._triton import has_triton
from torch.utils.checkpoint import checkpoint
@ -439,7 +439,7 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase):
tmp_dt._local_tensor.stride(), tmp_dt_fake._local_tensor.stride()
)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
def test_dtensor_contiguous_dtensor_noncontiguous_local_as_tangent(self):
# Partial -> Shard on an unbalanced tensor results in:
# - A contiguous DTensor
@ -515,7 +515,7 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase):
out_test = opt_mod(dt)
self.assertEqual(out_ref, out_test)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
def test_dtensor_different_gradient_placement(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
@ -647,7 +647,7 @@ def forward(self, primals_1):
return (sin_1, primals_1, wait_tensor)""",
)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
def test_dtensor_partial_placement_graph_output(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
@ -665,7 +665,7 @@ def forward(self, primals_1):
out_dt = torch.matmul(tmp_dt, y_dt)
out_dt.sum().backward()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(1)
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@patch.object(torch._inductor.config, "compile_threads", 1)

View File

@ -43,7 +43,7 @@ from torch.testing._internal.common_utils import (
TEST_WITH_DEV_DBG_ASAN,
TestCase,
)
from torch.testing._internal.inductor_utils import HAS_GPU
from torch.utils._triton import has_triton
if not dist.is_available():
@ -218,7 +218,7 @@ class TestFSDPUseOrigParamsMultipleParamGroups(FSDPTest):
raise ValueError(f"Invalid string: {sharding_strategy_str}")
return sharding_strategy
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
def test_fsdp_compile(self):
self.run_subtests(

View File

@ -37,7 +37,7 @@ from torch.testing._internal.common_utils import ( # type: ignore[attr-defined]
)
from torch.testing._internal.distributed._tensor.common_dtensor import MLPModule
from torch.testing._internal.distributed.fake_pg import FakeStore
from torch.testing._internal.inductor_utils import HAS_GPU
from torch.utils._triton import has_triton
def _make_post_grad_fx(f, *inps):
@ -78,7 +78,7 @@ class MicroPipelineTPTest(TestCase):
def tearDown(self):
dist.destroy_process_group()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@fresh_inductor_cache()
def test_find_all_gather_patterns(self):
group = dist.group.WORLD
@ -129,7 +129,7 @@ class MicroPipelineTPTest(TestCase):
torch.ops.aten.view.dtype,
)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@fresh_inductor_cache()
def test_find_reduce_scatter_patterns(self):
group = dist.group.WORLD
@ -168,7 +168,7 @@ class MicroPipelineTPTest(TestCase):
self.assertEqual(reduce_scatters[1].reduce_op, "avg")
self.assertEqual(reduce_scatters[1].scatter_dim, 1)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@fresh_inductor_cache()
def test_get_unexposed_collectives(self):
group = dist.group.WORLD
@ -193,7 +193,7 @@ class MicroPipelineTPTest(TestCase):
["all_gather_into_tensor", "reduce_scatter_tensor"],
)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@parametrize("A_dims", [2, 3])
@parametrize("gather_dim", [0, 1, 2])
@fresh_inductor_cache()
@ -231,7 +231,7 @@ class MicroPipelineTPTest(TestCase):
self.assertNotIn("all_gather_into_tensor", code)
@runOnRocmArch(MI300_ARCH)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@parametrize("A_dims", [2, 3])
@parametrize("gather_dim", [0, 1, 2])
@fresh_inductor_cache()
@ -299,7 +299,7 @@ class MicroPipelineTPTest(TestCase):
self.assertIn("fused_all_gather_scaled_matmul", code)
self.assertNotIn("all_gather_into_tensor", code)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@parametrize("A_dims", [2, 3])
@parametrize("scatter_dim", [0, 1, 2])
@fresh_inductor_cache()
@ -328,7 +328,7 @@ class MicroPipelineTPTest(TestCase):
self.assertNotIn("reduce_scatter_tensor", code)
@runOnRocmArch(MI300_ARCH)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@parametrize("A_dims", [2, 3])
@parametrize("scatter_dim", [0, 1, 2])
@fresh_inductor_cache()
@ -381,7 +381,7 @@ class MicroPipelineTPTest(TestCase):
self.assertIn("fused_scaled_matmul_reduce_scatter", code)
self.assertNotIn("reduce_scatter_tensor", code)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@parametrize("shard_dim", [0, 1])
@fresh_inductor_cache()
def test_dtensor_seq_par(self, shard_dim: int):

View File

@ -28,7 +28,7 @@ from torch.testing._internal.common_utils import ( # type: ignore[attr-defined]
TestCase,
)
from torch.testing._internal.distributed.fake_pg import FakeStore
from torch.testing._internal.inductor_utils import HAS_GPU
from torch.utils._triton import has_triton
def load_test_module(name):
@ -218,7 +218,7 @@ class TestWithNCCL(MultiProcessTestCase):
assert output.eq(expect).all()
assert output.completed
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
# https://github.com/pytorch/pytorch/issues/126338
def test_inductor_dtypeview_memory_leak(self):
@ -434,7 +434,7 @@ class TestWithNCCL(MultiProcessTestCase):
torch.ops._c10d_functional.wait_tensor(tensor)
self.assertTrue(wait_called)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
@fresh_inductor_cache()
def test_threading(self):
@ -498,7 +498,7 @@ class CompileTest(TestCase):
def tearDown(self):
dist.destroy_process_group()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@fresh_inductor_cache()
def test_inductor_all_reduce_single(self):
def func(arg: torch.Tensor) -> torch.Tensor:
@ -535,7 +535,7 @@ class CompileTest(TestCase):
out = AOTIRunnerUtil.run("cuda", func, (arg,))
torch.cuda.synchronize()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@fresh_inductor_cache()
def test_inductor_all_reduce_coalesced(self):
def func(args: List[torch.Tensor]) -> torch.Tensor:
@ -581,7 +581,7 @@ class CompileTest(TestCase):
out = AOTIRunnerUtil.run("cuda", func, (args,))
torch.cuda.synchronize()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@fresh_inductor_cache()
def test_inductor_inplace_op_on_view(self):
def func(arg: torch.Tensor) -> torch.Tensor:
@ -608,7 +608,7 @@ class CompileTest(TestCase):
.run(code)
)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@fresh_inductor_cache()
def test_inductor_reuse_buffer_after_inplace_collective(self):
def func(arg: torch.Tensor) -> torch.Tensor:
@ -643,7 +643,7 @@ class CompileTest(TestCase):
)
assert "= torch.ops._c10d_functional.wait_tensor.default" not in code
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@fresh_inductor_cache()
def test_inductor_all_gather_into_tensor_single(self):
def func(arg: torch.Tensor) -> torch.Tensor:
@ -670,7 +670,7 @@ class CompileTest(TestCase):
out = AOTIRunnerUtil.run("cuda", func, (arg,))
torch.cuda.synchronize()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@fresh_inductor_cache()
def test_inductor_all_gather_into_tensor_coalesced(self):
def func(args: List[torch.Tensor]) -> torch.Tensor:
@ -704,7 +704,7 @@ class CompileTest(TestCase):
out = AOTIRunnerUtil.run("cuda", func, (args,))
torch.cuda.synchronize()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@fresh_inductor_cache()
def test_inductor_reduce_scatter_tensor_single(self):
def func(arg: torch.Tensor) -> torch.Tensor:
@ -730,7 +730,7 @@ class CompileTest(TestCase):
out = AOTIRunnerUtil.run("cuda", func, (arg,))
torch.cuda.synchronize()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@fresh_inductor_cache()
def test_inductor_reduce_scatter_tensor_coalesced(self):
def func(args: List[torch.Tensor]) -> torch.Tensor:
@ -766,7 +766,7 @@ class CompileTest(TestCase):
AOTIRunnerUtil.run("cuda", func, (args,))
torch.cuda.synchronize()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@fresh_inductor_cache()
def test_inductor_all_to_all_single(self):
def _tolist_with_constrain_as_size(tensor):
@ -814,7 +814,7 @@ class CompileTest(TestCase):
.run(code)
)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@fresh_inductor_cache()
def test_inductor_broadcast(self):
def func(arg: torch.Tensor) -> torch.Tensor:
@ -850,7 +850,7 @@ class CompileTest(TestCase):
out = AOTIRunnerUtil.run("cuda", func, (arg,))
torch.cuda.synchronize()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@fresh_inductor_cache()
def test_ranks_and_tag(self):
def func(arg: torch.Tensor) -> torch.Tensor:

View File

@ -28,7 +28,7 @@ from torch.testing._internal.common_distributed import (
DynamoDistributedMultiProcTestCase,
requires_nccl,
)
from torch.testing._internal.inductor_utils import HAS_GPU
from torch.utils._triton import has_triton
def get_snode_runtime_for_reorder_compute_test(snode):
@ -92,7 +92,7 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
# works around issue with skipif<2 and workers with unpredictable #s gpu
return 2
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@patch.object(torch._inductor.config, "compile_threads", 1)
@ -131,7 +131,7 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
correct = func(inputs)
self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@patch.object(torch._inductor.config, "compile_threads", 1)
@ -178,7 +178,7 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
correct = func(inputs)
self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@patch.object(torch._inductor.config, "compile_threads", 1)
@ -231,7 +231,7 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
correct = func(inputs, **self.get_world_trs())
self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@patch.object(torch._inductor.config, "compile_threads", 1)
@ -283,7 +283,7 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
correct = func(inputs, **self.get_world_trs())
self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@patch.object(torch._inductor.config, "compile_threads", 1)
@ -340,7 +340,7 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
correct = func(inputs, **self.get_world_trs())
self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@patch.object(torch._inductor.config, "compile_threads", 1)
@patch.object(

View File

@ -46,7 +46,7 @@ from torch.testing._internal.common_distributed import (
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import requires_cuda
from torch.testing._internal.inductor_utils import HAS_GPU
from torch.utils._triton import has_triton
def reset_rng_state():
@ -325,7 +325,7 @@ def run_hf_bert_ddp(self, model, inputs, backend):
class TestFakeDistributedSingleProc(torch._dynamo.test_case.TestCase):
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@patch.object(config, "optimize_ddp", True)
@patch.object(torch._inductor.config, "fallback_random", True)
def test_hf_bert_ddp_inductor(self):
@ -528,7 +528,7 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
@skip_if_lt_x_gpu(2)
@import_transformers_or_skip()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@config.patch(optimize_ddp=True, enable_compiler_collectives=True)
@patch.object(torch._inductor.config, "fallback_random", True)
def test_hf_bert_ddp_inductor(self):
@ -536,7 +536,7 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
@skip_if_lt_x_gpu(2)
@import_transformers_or_skip()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@config.patch(optimize_ddp=True, enable_compiler_collectives=True)
@patch.object(torch._inductor.config, "fallback_random", True)
def test_hf_bert_ddp_inductor_static_graph(self):
@ -561,7 +561,7 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
self._test_hf_bert_aot_eager(static_graph=True)
@skip_if_lt_x_gpu(2)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@config.patch(optimize_ddp=False, enable_compiler_collectives=True)
def test_ddp_activation_checkpointing(self):
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
@ -676,7 +676,7 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
@config.patch(enable_compiler_collectives=True)
@skip_if_lt_x_gpu(1)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
def test_fsdp_inductor(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
# Test with basic FSDP wrapping (outer wrap around whole model)
@ -701,7 +701,7 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
@config.patch(enable_compiler_collectives=True)
@skip_if_lt_x_gpu(1)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
def test_fsdp_activation_checkpointing(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
model, inputs = get_toy_model_for_activation_checkpointing(
@ -722,7 +722,7 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
)
@import_transformers_or_skip()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
# TODO(whc) Investigate why cudagraphs breaks inductor+fsdp for hf_bert
@patch.object(torch._inductor.config.triton, "cudagraphs", False)
@patch.object(torch._inductor.config, "fallback_random", True)
@ -767,7 +767,7 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
self.assertTrue(same(correct_results, opt_results))
@import_transformers_or_skip()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
# TODO(whc) Investigate why cudagraphs breaks inductor+fsdp for hf_bert
@patch.object(torch._inductor.config.triton, "cudagraphs", False)
@patch.object(torch._inductor.config, "fallback_random", True)
@ -815,7 +815,7 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
)
self.assertTrue(same(correct_results, opt_results))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@config.patch(enable_compiler_collectives=True)
def test_compiler_collectives_automatic_dynamic_tensor(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
@ -860,7 +860,7 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
for r in res[1:]:
self.assertEqual(res[0], r)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@config.patch(enable_compiler_collectives=True)
def test_compiler_collectives_automatic_dynamic_scalar(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
@ -888,7 +888,7 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
for r in res[1:]:
self.assertEqual(res[0], r)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@config.patch(enable_compiler_collectives=True)
def test_compiler_collectives_automatic_dynamic_speculation_divergence(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
@ -921,7 +921,7 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
for r in res[1:]:
self.assertEqual(res[0], r)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@config.patch(enable_compiler_collectives=True)
def test_compiler_collectives_graph_break_empty_graph_still_collective(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
@ -955,7 +955,7 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
for r in res[1:]:
self.assertEqual(res[0], r)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@config.patch(enable_compiler_collectives=True)
def test_compiler_collectives_dim_mismatch(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
@ -984,7 +984,7 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
for r in res[1:]:
self.assertEqual(res[0], r)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@config.patch(enable_compiler_collectives=True)
def test_compiler_collectives_missing_source(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
@ -1006,7 +1006,7 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
for r in res[1:]:
self.assertEqual(res[0], r)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@config.patch(enable_compiler_collectives=True)
def test_compiler_collectives_scalar_missing_source(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
@ -1028,7 +1028,7 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
for r in res[1:]:
self.assertEqual(res[0], r)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@config.patch(enable_compiler_collectives=True)
def test_compiler_collectives_type_mismatch(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
@ -1062,7 +1062,7 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
for r in res[1:]:
self.assertEqual(res[0], r)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@patch.object(torch._inductor.config, "fx_graph_cache", False)
@patch.object(torch._inductor.config, "fx_graph_remote_cache", False)
def test_asymmetric_compilation(self):
@ -1113,7 +1113,7 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
for r in res[1:]:
self.assertEqual(res[0], r)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@patch.object(torch._inductor.config, "fx_graph_cache", True)
@patch.object(torch._inductor.config, "fx_graph_remote_cache", False)
@patch.object(torch._inductor.config, "sleep_sec_TESTING_ONLY", 10)
@ -1203,7 +1203,7 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
outputs = ddp_m(inputs)
self.assertTrue(same(correct_outputs, outputs))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@patch.object(config, "optimize_ddp", False)
def test_ddp_baseline_inductor(self):
from torch.nn.parallel import DistributedDataParallel as DDP
@ -1299,7 +1299,7 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
self.assertTrue(all("DDPOptimizer" in r.reason for r in break_reasons))
@patch.object(config, "optimize_ddp", True)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
def test_graph_split_inductor(self):
assert config.optimize_ddp
"""
@ -1368,18 +1368,18 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
opt_outputs = opt_fn(inputs)
self.assertTrue(same(correct_outputs, opt_outputs))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
def test_graph_split_inductor_layout_optimizations_training(self):
self._test_graph_split_inductor_layout_optimizations_impl(
contextlib.nullcontext
)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
def test_graph_split_inductor_layout_optimizations_inference(self):
self._test_graph_split_inductor_layout_optimizations_impl(torch.no_grad)
@patch.object(config, "optimize_ddp", True)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
def test_graph_split_inductor_transpose(self):
assert config.optimize_ddp
@ -1470,7 +1470,7 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
self.assertTrue(same(correct_outputs, opt_outputs))
self.assertEqual(check_splits_compiler.compiler_called, 3)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
def test_empty_graph_inductor(self):
def fn():
get_world_size = torch.distributed.distributed_c10d.get_world_size()

View File

@ -14,7 +14,7 @@ from functorch import make_fx
from torch._inductor.utils import run_and_get_code
from torch.testing import FileCheck
from torch.testing._internal.distributed.fake_pg import FakeStore
from torch.testing._internal.inductor_utils import HAS_GPU
from torch.utils._triton import has_triton
if not dist.is_available():
@ -564,7 +564,7 @@ class TestCollectivesWithNCCL(MultiProcessTestCase):
expected = torch.cat(expected)
self.assertEqual(y, expected)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@requires_nccl()
@with_comms()
def test_tracing(self):
@ -574,7 +574,7 @@ class TestCollectivesWithNCCL(MultiProcessTestCase):
compiled_allreduce = torch.compile(allreduce, fullgraph=True)
compiled_allreduce(torch.randn(8, device=self.device), self.process_group)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
def test_tracing_with_fakepg(self):
exit_if_lt_x_gpu(self.world_size)
@ -590,7 +590,7 @@ class TestCollectivesWithNCCL(MultiProcessTestCase):
)
allreduce(torch.randn(8, device=self.device), pg=dist.group.WORLD)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@requires_nccl()
@with_comms()
def test_tracing_with_dce_code(self):

View File

@ -29,7 +29,7 @@ from torch.testing._internal.common_utils import (
parametrize,
requires_cuda,
)
from torch.testing._internal.inductor_utils import HAS_GPU
from torch.utils._triton import has_triton
def _tolist_with_constrain_as_size(tensor):
@ -58,7 +58,7 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
# works around issue with skipif<2 and workers with unpredictable #s gpu
return 2
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
def test_broadcast_inductor(self):
"""
@ -90,7 +90,7 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
compiled_out = compiled_func(*inputs)
self.assertTrue(same(eager_out, compiled_out))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
def test_allreduce_inductor(self):
"""
@ -123,7 +123,7 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
inductor_out = compiled_matmul_cat_col(*inputs)
self.assertTrue(same(eager_out, inductor_out, tol=0.001))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
def test_allreduce_inductor_cudagraph_trees(self):
"""
@ -169,7 +169,7 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
op = torch.ops.c10d_functional.all_reduce.default
self.assertIn(torch.Tag.pt2_compliant_tag, op.tags)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
def test_eager_allreduce_inductor_wait(self):
def eager_func(a, b, c, d, *, tag, ranks, group_size):
@ -208,7 +208,7 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
print(f"inductor_out, {inductor_out}")
self.assertTrue(same(eager_out, inductor_out, tol=0.001))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
def test_inductor_allreduce_eager_wait(self):
def inductor_func(a, b, c, d, *, tag, ranks, group_size):
@ -243,7 +243,7 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
)
self.assertTrue(same(eager_out, inductor_out, tol=0.001))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
def test_allreduce_input_buffer_reuse(self):
@ -261,7 +261,7 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
correct = func(inputs, **self.get_world_trs())
self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
def test_permute_tensor(self):
def func(tensor, src_dst_pairs, *, tag, ranks, group_size):
@ -287,7 +287,7 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
self.assertEqual(out, expected)
self.assertEqual(correct, expected)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
def test_allgather_output_buffer_reuse(self):
@ -311,7 +311,7 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
correct = model(inp, self.world_size, **self.get_world_trs())
self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
def test_allgather_contiguous_input(self):
class Model(torch.nn.Module):
@ -335,7 +335,7 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
correct = model(inp, self.world_size, **self.get_world_trs())
self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
def test_allgather_into_tensor_inductor(self):
"""
@ -366,7 +366,7 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
inductor_out = compiled_matmul_cat_col(*inputs)
self.assertTrue(same(eager_out, inductor_out, tol=0.001))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
def test_reduce_scatter_tensor_inductor(self):
def example(a, b, *, tag, ranks, group_size):
@ -393,7 +393,7 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
inductor_out = compiled_fn(*inputs)
self.assertTrue(same(eager_out, inductor_out, tol=0.001))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
def test_all_to_all_single_inductor(self):
@ -462,7 +462,7 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
inductor_out = compiled_fn(*inputs, **trs)
self.assertTrue(same(eager_out, inductor_out, tol=0.001))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
def test_all_to_all_single_inductor_split_sizes_none(self):
def example(inp, *, tag, ranks, group_size):
@ -518,7 +518,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
"group_size": world_size,
}
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch(debug=True)
def test_inductor_single_op(self):
def func(inp, *, tag, ranks, group_size):
@ -547,7 +547,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
correct = func(inputs, **self.get_world_trs())
self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch(debug=True)
def test_inductor_steal_buffer(self):
"""
@ -582,7 +582,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
correct = func(inputs, **self.get_world_trs())
self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch({"debug": True, "triton.descriptive_names": False})
def test_inductor_doesnt_mutate_shared(self):
"""
@ -1030,7 +1030,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
out = torch.ops.c10d_functional.all_reduce(x, "sum", **self.get_world_trs())
self.assertEqual(x.size(), out.size())
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch({"debug": True, "triton.descriptive_names": False})
def test_inductor_all_gather_coalesced(self):
"""
@ -1076,7 +1076,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
correct = func(inputs, **self.get_world_trs())
assert same(out, correct), f"{out} va {correct}"
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch({"debug": True, "triton.descriptive_names": False})
def test_inductor_reduce_scatter_coalesced(self):
"""

View File

@ -1319,9 +1319,7 @@ class CudaReproTests(TestCase):
self.assertEqual(expect, actual)
# Expect the code iterates in contiguous order, and is not tiled
lines = code[0].split("\n")
start = lines.index("@triton.jit")
kernel_code = "\n".join(lines[start : start + 14])
kernel_code = "\n".join(code[0].split("\n")[60:74])
self.assertExpectedInline(
kernel_code,
"""\

View File

@ -22,9 +22,10 @@ from torch._inductor import config
from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import run_and_get_cpp_code
from torch.export import Dim
from torch.utils._triton import has_triton
@unittest.skipIf(not HAS_CUDA, "Inductor+gpu needs triton and CUDA")
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@config.patch(memory_planning=True)
class TestMemoryPlanning(TestCase):
def _generate(self, *, device):

View File

@ -1,34 +0,0 @@
# Owner(s): ["module: inductor"]
from torch._inductor import config
from torch._inductor.test_case import run_tests
from torch.testing._internal.inductor_utils import HAS_CPU
from torch.utils._triton import has_triton
try:
from . import test_torchinductor
except ImportError:
import test_torchinductor
if has_triton():
import triton
TRITON_HAS_CPU = "cpu" in triton.backends.backends
else:
TRITON_HAS_CPU = False
if HAS_CPU and TRITON_HAS_CPU:
@config.patch(cpu_backend="triton")
class SweepInputsCpuTritonTest(test_torchinductor.SweepInputsCpuTest):
pass
@config.patch(cpu_backend="triton")
class CpuTritonTests(test_torchinductor.CpuTests):
pass
if __name__ == "__main__":
if HAS_CPU and TRITON_HAS_CPU:
run_tests(needs="filelock")

View File

@ -38,10 +38,10 @@ from torch.testing._internal.common_utils import (
IS_WINDOWS,
)
from torch.testing._internal.inductor_utils import HAS_GPU
import pytest
from torch.utils._triton import has_triton
SEMI_STRUCTURED_SUPPORTED_BACKENDS = dict()
_IS_SM8X = False
@ -981,7 +981,7 @@ class TestSparseSemiStructuredCUTLASS(TestCase):
torch.backends.cuda.matmul.allow_tf32 = orig
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Test needs triton and recent GPU arch")
@inference_dtypes
def test_conversions(self, device, dtype):
@ -1009,7 +1009,7 @@ class TestSparseSemiStructuredCUTLASS(TestCase):
for r, c in shapes:
run_test(r, c, device, dtype)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not has_triton(), "Test needs triton and recent GPU arch")
@inference_dtypes
def test_conversions_all_patterns(self, device, dtype):
r, c = 32, 128

View File

@ -1,7 +1,5 @@
# mypy: allow-untyped-defs
import inspect
import time
from dataclasses import dataclass
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, Union
import torch
@ -297,51 +295,6 @@ class XpuInterface(DeviceInterface):
return torch.xpu.is_bf16_supported()
@dataclass
class CpuDeviceProperties:
multi_processor_count: int
class CpuInterface(DeviceInterface):
class Event(_EventBase):
def __init__(self, enable_timing=True):
self.time = 0.0
def elapsed_time(self, end_event) -> float:
return (end_event.time - self.time) * 1000
def record(self):
self.time = time.perf_counter()
@staticmethod
def is_available() -> bool:
return True
@staticmethod
def get_compute_capability(device: _device_t = None) -> str:
return ""
@staticmethod
def get_raw_stream(device_idx) -> int:
return 0
@staticmethod
def current_device():
return 0
@staticmethod
def synchronize(device: _device_t = None):
pass
class Worker:
@staticmethod
def get_device_properties(device: _device_t = None):
import multiprocessing
cpu_count = multiprocessing.cpu_count()
return CpuDeviceProperties(cpu_count)
device_interfaces: Dict[str, Type[DeviceInterface]] = {}
_device_initialized = False
@ -380,6 +333,4 @@ def init_device_reg():
for i in range(torch.xpu.device_count()):
register_interface_for_device(f"xpu:{i}", XpuInterface)
register_interface_for_device("cpu", CpuInterface)
_device_initialized = True

View File

@ -574,7 +574,7 @@ class TestBenchmarkRequest(BenchmarkRequest):
return self.value
class GPUDeviceBenchmarkMixin:
class GPUDeviceBenchmarkRequest(BenchmarkRequest):
def do_bench(
self,
fn,
@ -601,17 +601,7 @@ class GPUDeviceBenchmarkMixin:
return out
class CPUDeviceBenchmarkMixin:
def do_bench(
self,
fn,
*input_tensors: torch.Tensor,
output_tensor: Optional[torch.Tensor] = None,
) -> float:
return benchmarker.benchmark_cpu(fn)
class TritonBenchmarkRequest(BenchmarkRequest):
class TritonBenchmarkRequest(GPUDeviceBenchmarkRequest):
# Important: Instances of this class have to be serializable
# across process boundaries. Do not put CUDA Tensors in here!
def __init__(
@ -656,13 +646,9 @@ class TritonBenchmarkRequest(BenchmarkRequest):
if "warmup" in inspect.signature(run_method).parameters:
warmup_arg["warmup"] = False
if output_tensor.device.type == "cpu":
stream = 0
else:
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
stream = get_raw_stream(self.output_tensor_meta.device.index)
if torch.version.hip and self.matrix_instr_nonkdim != 0:
return functools.partial(
run_method,
*input_tensors,
@ -670,7 +656,17 @@ class TritonBenchmarkRequest(BenchmarkRequest):
*self.extra_args,
grid=self.grid,
**warmup_arg,
stream=stream,
stream=get_raw_stream(self.output_tensor_meta.device.index),
)
else:
return functools.partial(
run_method,
*input_tensors,
output_tensor,
*self.extra_args,
grid=self.grid,
**warmup_arg,
stream=get_raw_stream(self.output_tensor_meta.device.index),
)
def precompile(self):
@ -681,15 +677,7 @@ class TritonBenchmarkRequest(BenchmarkRequest):
return f"{self.kernel_name=}, {self.module_path=}, {self.module_cache_key=}"
class TritonGPUBenchmarkRequest(GPUDeviceBenchmarkMixin, TritonBenchmarkRequest):
pass
class TritonCPUBenchmarkRequest(CPUDeviceBenchmarkMixin, TritonBenchmarkRequest):
pass
class CUDABenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest):
class CUDABenchmarkRequest(GPUDeviceBenchmarkRequest):
# Important: Instances of this class have to be serializable
# across process boundaries. Do not put CUDA Tensors in here!
@ -806,7 +794,17 @@ class CUDABenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest):
return f"{self.kernel_name=}, {self.source_file=}, {self.hash_key=}"
class CppBenchmarkRequest(CPUDeviceBenchmarkMixin, BenchmarkRequest):
class CPUDeviceBenchmarkRequest(BenchmarkRequest):
def do_bench(
self,
fn,
*input_tensors: torch.Tensor,
output_tensor: Optional[torch.Tensor] = None,
) -> float:
return benchmarker.benchmark_cpu(fn)
class CppBenchmarkRequest(CPUDeviceBenchmarkRequest):
# Important: Instances of this class have to be serializable
# across process boundaries. Do not put Tensors in here!

View File

@ -240,11 +240,7 @@ def init_backend_registration():
from .wrapper import PythonWrapperCodegen
if get_scheduling_for_device("cpu") is None:
cpu_backends = {
"cpp": CppScheduling,
"halide": HalideScheduling,
"triton": TritonScheduling,
}
cpu_backends = {"cpp": CppScheduling, "halide": HalideScheduling}
register_backend_for_device(
"cpu",
lambda *args, **kwargs: cpu_backends[config.cpu_backend](*args, **kwargs),
@ -306,7 +302,6 @@ def get_device_op_overrides(device: str):
assert isinstance(device, str)
if not device_op_overrides_dict.keys():
from . import cpu_device_op_overrides # noqa: F401
from .cuda import device_op_overrides # noqa: F401
from .xpu import device_op_overrides as xpu_op_overrides # noqa: F401

View File

@ -4652,7 +4652,7 @@ class KernelGroup:
def call_kernel(self, wrapper, kernel_name):
_, call_args, arg_types = self.args.cpp_argdefs()
wrapper.generate_kernel_call(
kernel_name, call_args, gpu=False, triton=False, arg_types=arg_types
kernel_name, call_args, gpu=False, arg_types=arg_types
)

View File

@ -107,9 +107,7 @@ class CppTemplateKernel(CppKernel):
def call_kernel(self, name: str, node: ir.CppTemplateBuffer):
wrapper = V.graph.wrapper_code
_, call_args, arg_types = self.args.cpp_argdefs()
wrapper.generate_kernel_call(
name, call_args, triton=False, gpu=False, arg_types=arg_types
)
wrapper.generate_kernel_call(name, call_args, gpu=False, arg_types=arg_types)
def dtype(self, node: ir.Buffer) -> str:
return DTYPE_TO_CPP[node.get_dtype()]

View File

@ -1,26 +0,0 @@
# mypy: allow-untyped-defs
from textwrap import dedent
from .common import DeviceOpOverrides, register_device_op_overrides
class CpuDeviceOpOverrides(DeviceOpOverrides):
def import_get_raw_stream_as(self, name):
return dedent(
"""
def get_raw_stream(_):
return 0
"""
)
def set_device(self, device_idx):
return "pass"
def synchronize(self):
return "pass"
def device_guard(self, device_idx):
return "pass"
register_device_op_overrides("cpu", CpuDeviceOpOverrides())

View File

@ -1639,7 +1639,6 @@ class HalideKernel(SIMDKernel):
name,
call_args,
gpu=False, # grid/stream is handled internally in halide
triton=False,
)
def generate_assert(self, check):

View File

@ -7,18 +7,14 @@ from ctypes import byref, c_int, c_size_t, c_void_p
from typing import Any, Callable, Iterable, List, Optional, Union
import torch
from torch._inductor.autotune_process import (
BenchmarkRequest,
GPUDeviceBenchmarkMixin,
TensorMeta,
)
from torch._inductor.autotune_process import GPUDeviceBenchmarkRequest, TensorMeta
from torch._inductor.codecache import DLLWrapper, ROCmCodeCache
log = logging.getLogger(__name__)
class ROCmBenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest):
class ROCmBenchmarkRequest(GPUDeviceBenchmarkRequest):
# Important: Instances of this class have to be serializable
# across process boundaries. Do not put CUDA Tensors in here!

View File

@ -2689,11 +2689,6 @@ class TritonKernel(SIMDKernel):
if name is None:
code.splice(gen_common_triton_imports())
device_type = V.graph.scheduler.get_current_device_or_throw().type
if device_type == "cpu":
code.splice("triton_helpers.set_driver_to_cpu()")
else:
code.splice("triton_helpers.set_driver_to_gpu()")
if config.benchmark_kernel:
code.splice(self.imports_for_benchmark_kernel())
@ -2931,7 +2926,7 @@ class TritonKernel(SIMDKernel):
call_args,
grid,
current_device.index,
gpu=current_device.type != "cpu",
gpu=True,
triton=True,
arg_types=arg_types,
grid_fn=self._get_grid_fn_str(),

View File

@ -1658,26 +1658,17 @@ class PythonWrapperCodegen(CodeGen):
gpu: Defines whether the backend is GPU. Otherwise the backend is CPU.
triton: Defines whether the backend uses Triton for codegen. Otherwise it uses the CUDA language when gpu=True,
and C++ when gpu=False.
triton: Defines whether the GPU backend uses Triton for codegen.
Otherwise it uses the CUDA language for codegen.
Only valid when gpu == True.
"""
if not (triton or gpu):
self.writeline(self.wrap_kernel_call(kernel_name, call_args))
return
if gpu:
device_index, call_args_str = self.prepare_triton_kernel_call(
device_index, call_args
)
call_args_str = ", ".join(call_args_str)
stream_name = self.write_get_raw_stream(device_index, V.graph)
if not triton:
stream_ptr = f"c_void_p({stream_name})"
self.writeline(
f"{kernel_name}.{kernel_name}({call_args_str}, {stream_ptr})"
)
return
if triton:
self.write_triton_header_once()
if grid is None:
grid_str = grid_fn
@ -1688,7 +1679,9 @@ class PythonWrapperCodegen(CodeGen):
grid_str = f"{grid_fn}({grid_str})"
# add debug printer code for triton kernel calls at (jit) inductor level
debug_printer_manager = V.graph.wrapper_code.debug_printer
debug_printer_manager.set_printer_args(call_args, kernel_name, arg_types, None)
debug_printer_manager.set_printer_args(
call_args, kernel_name, arg_types, None
)
with debug_printer_manager:
self.writeline(
f"{kernel_name}.run({call_args_str}, grid={grid_str}, stream={stream_name})"
@ -1734,7 +1727,9 @@ class PythonWrapperCodegen(CodeGen):
else:
arg_str = tensor_args[arg]
else:
arg_str = self.generate_example_arg_value(arg, arg_type, raw_arg, i)
arg_str = self.generate_example_arg_value(
arg, arg_type, raw_arg, i
)
all_args.append(arg_str if key is None else f"{key}={arg_str}")
if grid is None:
@ -1754,6 +1749,13 @@ class PythonWrapperCodegen(CodeGen):
f"del {', '.join(arg for arg in tensor_args.values())}\n",
)
self.kernel_autotune_names.add(kernel_name)
else:
stream_ptr = f"c_void_p({stream_name})"
self.writeline(
f"{kernel_name}.{kernel_name}({call_args_str}, {stream_ptr})"
)
else:
self.writeline(self.wrap_kernel_call(kernel_name, call_args))
def writeline(self, line):
self.lines.append(line)

View File

@ -1151,7 +1151,7 @@ class rocm:
use_preselected_instances: bool = False
# Backend to use for CPU codegen either "cpp" or "triton" (experimental) or "halide" (experimental)
# Backend to use for CPU codegen either "cpp" or "halide" (experimental)
cpu_backend = "cpp"
# Backend to use for CUDA codegen either "triton" or "halide" (experimental)

View File

@ -77,7 +77,7 @@ class Benchmarker:
def benchmark(
self: Self,
fn: Callable[..., Any],
fn_args: Tuple[Any, ...],
fn_args: Tuple[Any],
fn_kwargs: Dict[str, Any],
**kwargs: Any,
) -> float:

View File

@ -1,7 +1,5 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import warnings
import triton
import triton.language as tl
@ -33,32 +31,6 @@ except ImportError:
raise NotImplementedError
def set_driver_to_cpu():
driver = triton.runtime.driver
if backend := triton.backends.backends.get("cpu", None):
if isinstance(driver.active, backend.driver):
# Don't re-initialize backend if it is already active
return
driver.set_active(backend.driver())
return
# This can be a hard error once triton-cpu is merged into fbcode
warnings.warn(
"Could not find an active CPU backend. Generated kernels will not be executable!"
)
def set_driver_to_gpu():
driver = triton.runtime.driver
for name, backend in triton.backends.backends.items():
if backend.driver.is_active() and name != "cpu":
if isinstance(driver.active, backend.driver):
# Don't re-initialize backend if it is already active
return
driver.set_active(backend.driver())
return
raise RuntimeError("Could not find an active GPU backend")
@triton.jit
def promote_to_tensor(x):
# Addition promotes to tensor for us

View File

@ -57,8 +57,6 @@ if triton is not None:
from triton.runtime.autotuner import OutOfResources
from triton.runtime.jit import KernelInterface
from . import triton_helpers
try:
from triton.compiler.compiler import ASTSource
except ImportError:
@ -69,14 +67,11 @@ if triton is not None:
except ImportError:
GPUTarget = None
else:
from types import ModuleType
Config = object
KernelInterface = object
OutOfResources = object
ASTSource = None
GPUTarget = None
triton_helpers = ModuleType("triton_helpers")
try:
autograd_profiler = torch.autograd.profiler
@ -379,11 +374,6 @@ class CachingAutotuner(KernelInterface):
compile_meta["device_type"] = self.device_props.type
compile_meta["cc"] = self.device_props.cc
if self.device_props.type == "cpu":
triton_helpers.set_driver_to_cpu()
else:
triton_helpers.set_driver_to_gpu()
if ASTSource:
compile_args = (
ASTSource(
@ -686,9 +676,6 @@ class CachingAutotuner(KernelInterface):
return do_bench_using_profiling(kernel_call, warmup=10, rep=40)
if self.device_props.type == "cpu":
return benchmarker.benchmark_cpu(kernel_call)
return benchmarker.benchmark_gpu(kernel_call, rep=40)
def clone_args(self, *args, **kwargs) -> Tuple[List[Any], Dict[str, Any]]:

View File

@ -3473,11 +3473,12 @@ class Scheduler:
self.current_device.type
):
V.graph.wrapper_code.codegen_device_guard_exit()
self.current_device = device
if device_need_guard(device.type):
assert device.index is not None, "device should have an index"
V.graph.wrapper_code.codegen_device_guard_enter(device.index)
self.current_device = device
self.buffer_names_to_free.update(node.last_usage)
if node.is_template():

View File

@ -15,7 +15,7 @@ import time
from collections import namedtuple
from concurrent.futures import as_completed, ThreadPoolExecutor
from io import StringIO
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from unittest.mock import patch
import sympy
@ -27,12 +27,7 @@ from torch._dynamo.testing import rand_strided
from torch._dynamo.utils import counters, identity, preserve_rng_state
from . import config, ir
from .autotune_process import (
TensorMeta,
TritonBenchmarkRequest,
TritonCPUBenchmarkRequest,
TritonGPUBenchmarkRequest,
)
from .autotune_process import TensorMeta, TritonBenchmarkRequest
from .codecache import code_hash, PersistentCache, PyCodeCache
from .codegen.common import IndentedBuffer, KernelTemplate
from .codegen.triton import (
@ -582,7 +577,6 @@ class TritonTemplateKernel(TritonKernel):
grid_fn=f"{self.grid_fn.__module__}.{self.grid_fn.__name__}",
arg_types=arg_types,
triton_meta=self.triton_meta,
gpu="cpu" not in V.graph.device_types,
)
@ -752,12 +746,7 @@ class TritonTemplate(KernelTemplate):
),
kwargs,
)
bmreq_cls: Type[TritonBenchmarkRequest]
if layout.device.type == "cpu":
bmreq_cls = TritonCPUBenchmarkRequest
else:
bmreq_cls = TritonGPUBenchmarkRequest
bmreq = bmreq_cls(
bmreq = TritonBenchmarkRequest(
module_path=mod.__file__,
module_cache_key=mod.key,
kernel_name=kernel_name,

View File

@ -1099,13 +1099,7 @@ def use_triton_template(layout, *, enable_int32=False, enable_float8=False):
if enable_float8:
layout_dtypes.extend([torch.float8_e4m3fn, torch.float8_e5m2])
return (
(
(
layout.device.type == "cuda"
and _use_template_for_cuda(layout, layout_dtypes)
)
or (layout.device.type == "cpu" and layout.dtype in layout_dtypes)
)
_use_template_for_cuda(layout, layout_dtypes)
and _use_autotune_backend("TRITON")
and has_backend_feature(layout.device, BackendFeature.TRITON_TEMPLATES)
)

View File

@ -17,27 +17,15 @@ def has_triton_package() -> bool:
@functools.lru_cache(None)
def has_triton() -> bool:
if not has_triton_package():
return False
from torch._dynamo.device_interface import get_interface_for_device
def cuda_extra_check(device_interface):
return device_interface.Worker.get_device_properties().major >= 7
def cpu_extra_check(device_interface):
import triton.backends
return "cpu" in triton.backends.backends
def _return_true(device_interface):
return True
triton_supported_devices = {
"cuda": cuda_extra_check,
"xpu": _return_true,
"cpu": cpu_extra_check,
}
triton_supported_devices = {"cuda": cuda_extra_check, "xpu": _return_true}
def is_device_compatible_with_triton():
for device, extra_check in triton_supported_devices.items():
@ -46,7 +34,7 @@ def has_triton() -> bool:
return True
return False
return is_device_compatible_with_triton()
return is_device_compatible_with_triton() and has_triton_package()
@functools.lru_cache(None)