Files
pytorch/test/inductor/test_cutlass_backend.py
2024-12-08 18:51:36 +00:00

942 lines
37 KiB
Python

# Owner(s): ["module: inductor"]
import logging
import math
import os
import unittest
from typing import Callable, List, Optional
from unittest import mock
try:
from test_aot_inductor_utils import AOTIRunnerUtil
except ImportError:
from .test_aot_inductor_utils import AOTIRunnerUtil
import torch
from torch._dynamo.utils import counters
from torch._inductor import config
from torch._inductor.codegen.cuda.cuda_kernel import CUDATemplateCaller
from torch._inductor.codegen.cuda.cutlass_utils import get_max_alignment
from torch._inductor.ir import ChoiceCaller, FixedLayout
from torch._inductor.select_algorithm import NoValidChoicesError
from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import fresh_inductor_cache
from torch.sparse import SparseSemiStructuredTensor, to_sparse_semi_structured
from torch.testing import FileCheck
from torch.testing._internal.common_cuda import SM75OrLater, SM80OrLater, SM90OrLater
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
)
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
torch.set_float32_matmul_precision("high")
if HAS_CUDA:
torch.cuda.memory._set_allocator_settings("expandable_segments:False")
_CUTLASS_DIR = os.path.join(os.path.dirname(__file__), "../../third_party/cutlass/")
log = logging.getLogger(__name__)
HAS_CUDA = HAS_CUDA and not torch.version.hip
SM75OrLater = SM75OrLater and not torch.version.hip
SM80OrLater = SM80OrLater and not torch.version.hip
SM90OrLater = SM90OrLater and not torch.version.hip
SM80 = SM80OrLater and torch.cuda.get_device_capability() == (8, 0)
def _get_path_without_sccache() -> str:
"""
Get the PATH environment variable without sccache.
"""
path_envs = os.environ.get("PATH", "").split(":")
path_envs = [env for env in path_envs if "/opt/cache/bin" not in env]
return ":".join(path_envs)
@instantiate_parametrized_tests
class TestCutlassBackend(TestCase):
def setUp(self):
# The new inductor cache refresh mechanism
# introduced with https://github.com/pytorch/pytorch/pull/122661
# interacts badly with persistent subprocesses during
# autotuning. So we need to disable automatic cache refresh
# before calling setUp() on the parent class.
old_disable_fresh_cache_envvar = os.environ.get(
"INDUCTOR_TEST_DISABLE_FRESH_CACHE", ""
)
try:
os.environ["INDUCTOR_TEST_DISABLE_FRESH_CACHE"] = "1"
super().setUp()
finally:
os.environ[
"INDUCTOR_TEST_DISABLE_FRESH_CACHE"
] = old_disable_fresh_cache_envvar
torch.random.manual_seed(1234)
@unittest.skipIf(not SM75OrLater, "need sm_75")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
def test_max_autotune_cutlass_threshold(self):
"""
Make sure Cutlass GEMM threshold works as intended.
"""
if torch.version.hip:
return
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
def mm(a, b):
return a @ b
a = torch.randn(100, 10).cuda().half()
b = torch.randn(10, 100).cuda().half()
with config.patch(
{
"max_autotune": True,
"autotune_in_subproc": True,
"max_autotune_gemm_backends": "CUTLASS,ATen",
"compile_threads": 4,
"cuda.cutlass_backend_min_gemm_size": 100000,
"cuda.cutlass_dir": _CUTLASS_DIR,
"cuda.cutlass_max_profiling_configs": 2,
}
):
from torch._inductor.codegen.cuda.cuda_kernel import CUDATemplateCaller
with mock.patch(
"torch._inductor.select_algorithm.autotune_select_algorithm"
) as mocked_select_algorithm:
Y_compiled = torch.compile(mm, dynamic=False)(a, b)
Y = mm(a, b)
passed_choice_callers: List[ChoiceCaller] = mocked_select_algorithm[0][
1
]
assert all(
isinstance(cc, ChoiceCaller) for cc in passed_choice_callers
), "Argument 1 to autotune_select_algorithm should be a list of ChoiceCaller instances"
# We expect that no Cutlass Kernels are considered, due to the threshold
assert all(
not isinstance(cc, CUDATemplateCaller)
for cc in passed_choice_callers
), "Cutlass Kernels should have been filtered, GEMM size is too small"
torch.testing.assert_close(Y_compiled, Y)
@unittest.skipIf(not SM75OrLater, "need sm_75")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
def test_max_autotune_precompile(self):
"""
Make sure autotuning mm in sub processes work without crashes.
"""
if torch.version.hip:
return
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
def mm(a, b):
return a @ b
a = torch.randn(100, 10).cuda().half()
b = torch.randn(10, 100).cuda().half()
with config.patch(
{
"max_autotune": True,
"autotune_in_subproc": True,
"max_autotune_gemm_backends": "CUTLASS,Triton,ATen",
"compile_threads": 4,
"cuda.cutlass_dir": _CUTLASS_DIR,
"cuda.cutlass_max_profiling_configs": 2,
}
):
Y_compiled = torch.compile(mm, dynamic=False)(a, b)
Y = mm(a, b)
torch.testing.assert_close(Y_compiled, Y)
@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
def test_aoti_rerun_with_different_shapes(self):
"""
Compile with one shape, then re-run with different input shapes
"""
max_autotune_gemm_backends = "CUTLASS"
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
class MyModel(torch.nn.Module):
def forward(self, a, b):
return a @ b
model = MyModel()
a = torch.randn(128, 16).cuda().half()
b = torch.randn(16, 512).cuda().half()
x = torch.randn(256, 32).cuda().half()
y = torch.randn(32, 256).cuda().half()
with config.patch(
{
"max_autotune": True,
"autotune_in_subproc": True,
"max_autotune_gemm_backends": max_autotune_gemm_backends,
"cuda.cutlass_dir": _CUTLASS_DIR,
"cuda.cutlass_max_profiling_configs": 3,
}
):
from torch.export import Dim
M = Dim("M", min=1, max=1024)
N = Dim("N", min=1, max=1024)
K = Dim("K", min=1, max=1024)
dynamic_shapes = {
"a": {0: M, 1: K},
"b": {0: K, 1: N},
}
actual = AOTIRunnerUtil.run_multiple(
"cuda",
model,
[(a, b), (x, y)],
dynamic_shapes=dynamic_shapes,
)
expected = [model(a, b), model(x, y)]
torch.testing.assert_close(actual[0], expected[0])
torch.testing.assert_close(actual[1], expected[1])
@unittest.skipIf(not SM90OrLater, "need sm_90")
@parametrize("dynamic", (False, True))
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
def test_diff_matmul_share_same_kernel(self, dynamic):
max_autotune_gemm_backends = "CUTLASS"
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, a, b, c):
ab = a @ b
ac = a @ c
return ab, ac
model = MyModel()
a = torch.randn(128, 16).cuda().half()
b = torch.randn(16, 128).cuda().half()
c = torch.randn(16, 512).cuda().half()
with config.patch(
{
"max_autotune": True,
"autotune_in_subproc": True,
"max_autotune_gemm_backends": max_autotune_gemm_backends,
"cuda.cutlass_dir": _CUTLASS_DIR,
"cuda.cutlass_max_profiling_configs": 1,
}
):
from torch._inductor.utils import run_and_get_code
compiled = torch.compile(model, dynamic=dynamic)
expected = model(a, b, c)
actual, codes = run_and_get_code(compiled, a, b, c)
torch.testing.assert_close(actual, expected)
FileCheck().check_count(
"cuda_fused_0.cuda_fused_0",
2,
).run(codes[0])
# TODO: Enable dynamic test cases when dynamic support is added.
@unittest.skipIf(not SM75OrLater, "need sm_75")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
@parametrize("dynamic", (False, True))
@parametrize("max_autotune_gemm_backends", ("CUTLASS", "ATen,Triton,CUTLASS"))
@parametrize("use_aoti", (False, True))
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
def test_max_autotune_cutlass_backend_regular_mm(
self, dynamic: bool, max_autotune_gemm_backends: str, use_aoti: bool
):
"""
Make sure autotuning mm in sub processes work without crashes.
"""
if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip:
return
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, a, b):
return a @ b
model = MyModel()
a = torch.randn(128, 16).cuda().half()
b = torch.randn(16, 128).cuda().half()
with config.patch(
{
"max_autotune": True,
"autotune_in_subproc": False,
"max_autotune_gemm_backends": max_autotune_gemm_backends,
"cuda.cutlass_dir": _CUTLASS_DIR,
"cuda.cutlass_max_profiling_configs": 2,
}
):
Y = model(a, b)
if use_aoti:
Y_compiled = AOTIRunnerUtil.run(
"cuda",
model,
(a, b),
)
else:
Y_compiled = torch.compile(model, dynamic=dynamic)(a, b)
torch.testing.assert_close(Y_compiled, Y)
@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
def test_max_autotune_cutlass_backend_regular_mm_streamk(
self, dynamic: bool = False, max_autotune_gemm_backends: str = "CUTLASS"
):
"""
Make sure autotuning mm in sub processes work without crashes.
"""
if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip:
return
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
def mm(a, b):
return a @ b
a = torch.randn(128, 16).cuda().half()
b = torch.randn(16, 128).cuda().half()
with config.patch(
{
"max_autotune": True,
"autotune_in_subproc": True,
"max_autotune_gemm_backends": max_autotune_gemm_backends,
"cuda.cutlass_dir": _CUTLASS_DIR,
"cuda.cutlass_max_profiling_configs": 2,
"cuda.cutlass_op_allowlist_regex": "stream_k", # only stream-k GEMM Kernels
}
):
for M, K, N in (
(128, 16, 128),
(1024, 256, 1024),
(
16384,
1024,
16384,
),
(
16384,
1408,
16384,
),
):
a = torch.randn(M, K).cuda().half()
b = torch.randn(K, N).cuda().half()
Y_compiled = torch.compile(mm, dynamic=dynamic)(a, b)
Y = mm(a, b)
# we need relaxed numerical limits due to the sheer size of the
# matmuls involved. Many small addition differences add up.
torch.testing.assert_close(Y_compiled, Y, atol=0.01, rtol=0.01)
def _test_max_autotune_cutlass_backend_epilogue_fusion(
self,
dynamic: bool = False,
max_autotune_gemm_backends: str = "CUTLASS",
mixed_precision=False,
fp16=True,
expected_fuse_count=0,
mm: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
batch_size: Optional[int] = None,
):
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = (
mixed_precision
)
# Note: The ops that are available
# also depend on the alignment of the shapes
# so if these shapes don't all align to at least 8 elements
# it can happen that no Cutlass 3.x op is available
# that allows fusions
if batch_size is None:
a = torch.randn(256, 32).cuda()
b = torch.randn(32, 256).cuda()
else:
a = torch.randn(batch_size, 256, 32).cuda()
b = torch.randn(batch_size, 32, 256).cuda()
if fp16:
a = a.half()
b = b.half()
with config.patch(
{
"max_autotune": True,
"autotune_in_subproc": True,
"max_autotune_gemm_backends": max_autotune_gemm_backends,
"cuda.cutlass_dir": _CUTLASS_DIR,
"cuda.cutlass_max_profiling_configs": 4,
"cuda.version": "12.2", # required to enable the Kernels we need
}
):
counters["inductor"]["cuda_epilogue_fusion_counter"] = 0
Y_compiled = torch.compile(mm, dynamic=dynamic)(a, b)
Y = mm(a, b)
actual_count = counters["inductor"]["cuda_epilogue_fusion_counter"]
assert (
actual_count == expected_fuse_count
), f"Expected fuse count of {expected_fuse_count} but got {actual_count}"
torch.testing.assert_close(Y_compiled, Y, atol=1e-2, rtol=1e-2)
@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(torch.version.hip, "HIP not supported")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
def test_max_autotune_cutlass_backend_simple_fusion_fp16(self):
def mm(a, b):
return (a @ b) * 3.0
# The pointwise ops seem to be pre-fused into a single Pointwise
self._test_max_autotune_cutlass_backend_epilogue_fusion(
mixed_precision=False, fp16=True, expected_fuse_count=0, mm=mm
)
@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(torch.version.hip, "HIP not supported")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
def test_max_autotune_cutlass_backend_simple_fusion_fp16_fp32acc(self):
def mm(a, b):
return (a @ b) * 3.0
self._test_max_autotune_cutlass_backend_epilogue_fusion(
mixed_precision=True, fp16=True, expected_fuse_count=0, mm=mm
)
@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(torch.version.hip, "HIP not supported")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
def test_max_autotune_cutlass_backend_chained_fusion_fp16(self):
def mm(a, b):
return (a @ b) * 3.3 - 1.234
# The pointwise ops seem to be pre-fused into a single Pointwise
self._test_max_autotune_cutlass_backend_epilogue_fusion(
mixed_precision=False, fp16=True, expected_fuse_count=0, mm=mm
)
@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(torch.version.hip, "HIP not supported")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
def test_max_autotune_cutlass_backend_chained_fusion_fp16_fp32acc(self):
def mm(a, b):
return (a @ b) * 3.3 - 1.234
self._test_max_autotune_cutlass_backend_epilogue_fusion(
mixed_precision=True, fp16=True, expected_fuse_count=0, mm=mm
)
@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(torch.version.hip, "HIP not supported")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
def test_max_autotune_cutlass_backend_relu_fusion_fp16(self):
def mm(a, b):
return torch.nn.functional.relu((a @ b) * 3.3 - 1.234)
self._test_max_autotune_cutlass_backend_epilogue_fusion(
mixed_precision=False, fp16=True, expected_fuse_count=0, mm=mm
)
@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(torch.version.hip, "HIP not supported")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
def test_max_autotune_cutlass_backend_relu_fusion_fp16_fp32acc(self):
def mm(a, b):
return torch.nn.functional.relu((a @ b) * 3.3 - 1.234)
# The pointwise ops seem to be pre-fused into a single Pointwise
self._test_max_autotune_cutlass_backend_epilogue_fusion(
mixed_precision=True, fp16=True, expected_fuse_count=0, mm=mm
)
@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(torch.version.hip, "HIP not supported")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
def test_max_autotune_cutlass_backend_relu6_fusion_fp16_fp32acc(self):
def mm(a, b):
return torch.clamp(torch.nn.functional.relu(a @ b), max=6.0)
# The pointwise ops seem to be pre-fused into a single Pointwise
self._test_max_autotune_cutlass_backend_epilogue_fusion(
mixed_precision=True, fp16=True, expected_fuse_count=0, mm=mm
)
@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(torch.version.hip, "HIP not supported")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
def test_max_autotune_cutlass_backend_no_fusion_dtype_mismatch(self):
def mm(a, b):
# this should not be fused, since the output dtype is different from the matmul dtype
return (a @ b).to(torch.float32) * 0.00001
self._test_max_autotune_cutlass_backend_epilogue_fusion(
mixed_precision=True, fp16=True, expected_fuse_count=0, mm=mm
)
def test_max_autotune_cutlass_backend_simple_bmm(self):
def bmm(a, b):
return torch.bmm(a, b)
self._test_max_autotune_cutlass_backend_epilogue_fusion( # test bmm
mixed_precision=False,
fp16=True,
expected_fuse_count=0,
mm=bmm,
batch_size=10,
)
@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(torch.version.hip, "HIP not supported")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
def test_max_autotune_cutlass_backend_shape_dependent_normalization_fusion(self):
def mm(a, b):
return (a @ b) / b.size(1)
self._test_max_autotune_cutlass_backend_epilogue_fusion(
mixed_precision=True, fp16=True, expected_fuse_count=0, mm=mm
)
# TODO: Enable dynamic test cases when dynamic support is added.
@unittest.skipIf(not SM75OrLater, "need sm_75")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
@parametrize("dynamic", (False,))
@parametrize("max_autotune_gemm_backends", ("CUTLASS", "ATen,Triton,CUTLASS"))
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
def test_max_autotune_cutlass_backend_mm_bias(
self, dynamic: bool = False, max_autotune_gemm_backends: str = "CUTLASS"
):
"""
Make sure autotuning mm in sub processes work without crashes.
"""
if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip:
return
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
def mm(a, b, bias):
return torch.nn.functional.linear(a, b, bias)
a = torch.randn(2048, 4096).cuda().half()
bias = torch.randn(2048).cuda().half()
with config.patch(
{
"max_autotune": True,
"autotune_in_subproc": True,
"max_autotune_gemm_backends": max_autotune_gemm_backends,
"cuda.cutlass_dir": _CUTLASS_DIR,
"cuda.cutlass_max_profiling_configs": 2,
}
):
Y = mm(a, a, bias)
Y_compiled = torch.compile(mm, dynamic=dynamic)(a, a, bias)
torch.testing.assert_close(Y_compiled, Y, atol=1e-1, rtol=1e-1)
@unittest.skipIf(not SM75OrLater, "need sm_75")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
@parametrize("dynamic", (False,))
@parametrize("max_autotune_gemm_backends", ("CUTLASS", "ATen,Triton,CUTLASS"))
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
def test_max_autotune_cutlass_backend_addmm(
self, dynamic, max_autotune_gemm_backends
):
"""
Make sure autotuning addmm in sub processes work without crashes.
"""
if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip:
return
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
def addmm(x, a, b, alpha, beta):
return torch.addmm(x, a, b, alpha=alpha, beta=beta)
def compare_results(
m: int, k: int, n: int, alpha: float, beta: float, x_shape: List[int]
) -> None:
x = torch.randn(x_shape).cuda().half()
a = torch.randn(m, k).cuda().half()
b = torch.randn(k, n).cuda().half()
y_expected = addmm(x, a, b, alpha, beta)
compiled_fn = torch.compile(addmm, dynamic=dynamic)
y = compiled_fn(x, a, b, alpha, beta)
torch.testing.assert_close(y, y_expected)
with config.patch(
{
"max_autotune": True,
# Some Cutlass Kernels fail with IMA on this example, which leads to unrecoverable CUDA errors
# unless we tune in a subproc here.
"autotune_in_subproc": True,
"max_autotune_gemm_backends": max_autotune_gemm_backends,
"cuda.cutlass_dir": _CUTLASS_DIR,
"cuda.cutlass_max_profiling_configs": 4,
"cuda.cutlass_op_allowlist_regex": "",
"cuda.cutlass_op_denylist_regex": "pingpong", # Pingpong Kernels can lead to numerical issues
}
):
# No broadcast
compare_results(4096, 25728, 2048, 2.0, 0.4, [4096, 2048])
# Broadcast first dim.
compare_results(4096, 25728, 2048, 2.0, 0.4, [2048])
# Broadcast last dim.
compare_results(4096, 25728, 2048, 2.0, 0.4, [4096, 1])
# TODO: Enable dynamic test cases when dynamic support is added.
@unittest.skipIf(not SM80OrLater, "need sm_80")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
@parametrize("dynamic", (False,))
@parametrize("max_autotune_gemm_backends", ("CUTLASS", "CUTLASS,ATen"))
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
def test_max_autotune_cutlass_backend_int_mm(
self, dynamic: bool, max_autotune_gemm_backends: str
):
"""
Make sure autotuning mm in sub processes work without crashes.
"""
if "CUTLASS" in max_autotune_gemm_backends.upper() and torch.version.hip:
return
def mm(a, b):
return torch._int_mm(a, b)
# CUTLASS only supports row-major/column-major combination of
# layouts for this operation, thus the transpose of tensor b
# (on the other side, Triton at the moment doesn't support
# this combination, so it's excluded from the test). Also,
# for CUTLASS alignment requirements, number of columns in
# both tensors has to be divisible by 16.
a = torch.randint(0, 5, (100, 16), dtype=torch.int8).cuda()
b = torch.randint(0, 5, (32, 16), dtype=torch.int8).cuda().T
with config.patch(
{
"max_autotune": True,
"autotune_in_subproc": True,
"max_autotune_gemm_backends": max_autotune_gemm_backends,
"cuda.cutlass_dir": _CUTLASS_DIR,
"cuda.cutlass_max_profiling_configs": 2,
}
):
Y_compiled = torch.compile(mm, dynamic=dynamic)(a, b)
Y = mm(a, b)
torch.testing.assert_close(Y_compiled, Y)
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
@unittest.skipIf(not SM90OrLater, "need sm_90")
def test_force_cutlass_backend_aoti_dynamic(self):
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
class MyModel(torch.nn.Module):
def forward(self, x, w):
return x @ w
with config.patch(
{
"max_autotune": True,
"autotune_in_subproc": False,
"max_autotune_gemm_backends": "CUTLASS",
"autotune_fallback_to_aten": False,
"cuda.cutlass_dir": _CUTLASS_DIR,
}
):
model = MyModel()
M, N, K = 16, 32, 64
dynamic_shapes = {
"x": {0: M, 1: K},
"w": {0: K, 1: N},
}
x = torch.randn(M, K).cuda().half()
w = torch.randn(K, N).cuda().half()
actual = AOTIRunnerUtil.run(
"cuda",
model,
(x, w),
dynamic_shapes=dynamic_shapes,
)
expected = model(x, w)
torch.testing.assert_close(expected, actual)
# TODO: Enable dynamic test cases when dynamic support is added.
@unittest.skipIf(not SM80, "need sm_80 exactly")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
@parametrize("dynamic", (False,))
@parametrize("max_autotune_gemm_backends", ("CUTLASS", "CUTLASS,Triton,ATen"))
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
def test_max_autotune_cutlass_backend_mixed_mm(
self, dynamic: bool, max_autotune_gemm_backends: str
):
"""
Make sure autotuning mm in sub processes work without crashes.
"""
if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip:
return
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
def mm(a, b):
return torch.mm(a, b.to(torch.half))
# CUTLASS only supports row-major/column-major combination of
# layouts for this operation, thus the transpose of tensor b.
# Also, for CUTLASS alignment requirements, number of columns
# of the first tensor has to be divisible by 16.
m, n, k = 100, 16, 100
a = torch.randn(m, k).cuda().half()
b = torch.randint(0, 5, (n, k), dtype=torch.int8).cuda().T
with config.patch(
{
"max_autotune": True,
"autotune_in_subproc": True,
"max_autotune_gemm_backends": max_autotune_gemm_backends,
"cuda.cutlass_dir": _CUTLASS_DIR,
"cuda.cutlass_max_profiling_configs": 2,
"use_mixed_mm": True,
"autotune_local_cache": True,
}
):
Y_compiled = torch.compile(mm, dynamic=dynamic)(a, b)
Y = mm(a, b)
torch.testing.assert_close(Y_compiled, Y)
cache = torch._inductor.codecache.LocalCache().lookup("mixed_mm")
high = cache[
f"[('cuda', 'torch.float16', {m}, {k}, {k}, 1, 0), "
f"('cuda', 'torch.int8', {k}, {n}, 1, {k}, 0)]"
]["high"]
cutlass_kernels_count = 0
for kernel, time in high.items():
if kernel.startswith("cutlass_gemm") and not math.isinf(time):
cutlass_kernels_count += 1
assert cutlass_kernels_count > 0
# TODO: Enable dynamic test cases when dynamic support is added.
@unittest.skipIf(not SM80, "need sm_80 exactly")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
@parametrize("dynamic", (False,))
@parametrize("max_autotune_gemm_backends", ("CUTLASS", "CUTLASS,Triton,ATen"))
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
def test_max_autotune_cutlass_backend_sparse_semi_structured_mm(
self, dynamic: bool, max_autotune_gemm_backends: str
):
"""
Make sure autotuning mm in sub processes work without crashes.
"""
if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip:
return
SparseSemiStructuredTensor._FORCE_CUTLASS = True
def mm(a, b):
return torch.mm(a, b)
m, n, k = 32, 8, 64
mask = torch.tensor([0, 0, 1, 1]).tile(m, k // 4).cuda().half()
a = torch.rand(m, k).cuda().half() * mask
a_sparse = to_sparse_semi_structured(a)
b = torch.rand(k, n).cuda().half()
with config.patch(
{
"max_autotune": True,
"autotune_in_subproc": True,
"max_autotune_gemm_backends": max_autotune_gemm_backends,
"cuda.cutlass_dir": _CUTLASS_DIR,
"cuda.cutlass_max_profiling_configs": 2,
"autotune_local_cache": True,
}
):
Y_compiled = torch.compile(mm, dynamic=dynamic)(a_sparse, b)
Y = mm(a, b)
torch.testing.assert_close(Y_compiled, Y)
cache = torch._inductor.codecache.LocalCache().lookup(
"sparse_semi_structured_mm"
)
high = cache[
f"[('cuda', 'torch.float16', {m}, {k // 2}, {k // 2}, 1, 0), "
f"('cuda', 'torch.int16', {m}, {k // 16}, {k // 16}, 1, 0), "
f"('cuda', 'torch.float16', {k}, {n}, {n}, 1, 0)]"
]["high"]
cutlass_kernels_count = 0
for kernel, time in high.items():
if kernel.startswith("cutlass_gemm") and not math.isinf(time):
cutlass_kernels_count += 1
assert cutlass_kernels_count > 0
@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
def test_cutlass_backend_op_denylist(
self,
):
def my_addmm(x, a, b, alpha, beta):
return torch.addmm(x, a, b, alpha=beta, beta=alpha)
x = torch.randn((128, 128)).cuda().half()
a = torch.randn(128, 128).cuda().half()
b = torch.randn(128, 128).cuda().half()
def select_no_algorithm(*args, **kwargs):
raise NoValidChoicesError
with fresh_inductor_cache():
with config.patch(
{
"max_autotune": True,
# Some Cutlass Kernels fail with IMA on this example, which leads to unrecoverable CUDA errors
# unless we tune in a subproc here.
"autotune_in_subproc": False,
"max_autotune_gemm_backends": "CUTLASS,ATen",
"cuda.cutlass_dir": _CUTLASS_DIR,
"cuda.cutlass_max_profiling_configs": 2,
"cuda.cutlass_op_allowlist_regex": "",
"cuda.cutlass_op_denylist_regex": "pingpong", # Pingpong Kernels can lead to numerical issues
}
):
with mock.patch(
"torch._inductor.kernel.mm.autotune_select_algorithm",
wraps=select_no_algorithm,
) as sa:
torch.compile(my_addmm, dynamic=False)(x, a, b, 1.0, 2.0)
args, kwargs = sa.call_args
op_name, choices, _, __ = args
assert op_name == "addmm"
cuda_template_count = 0
for choice in choices:
if isinstance(choice, CUDATemplateCaller):
choice_info = choice.info_dict()
assert (
"pingpong" not in choice_info["op_conf_name"]
), "All pingpong Kernels should have been filtered"
cuda_template_count += 1
assert cuda_template_count > 0, "No CUDATemplateCaller choices"
@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
def test_cutlass_backend_op_allowlist(
self,
):
def addmm(x, a, b, alpha, beta):
return torch.addmm(x, a, b, alpha=alpha, beta=beta)
x = torch.randn((128, 128)).cuda().half()
a = torch.randn(128, 128).cuda().half()
b = torch.randn(128, 128).cuda().half()
def select_no_algorithm(*args, **kwargs):
raise NoValidChoicesError
with fresh_inductor_cache():
with config.patch(
{
"max_autotune": True,
# Some Cutlass Kernels fail with IMA on this example, which leads to unrecoverable CUDA errors
# unless we tune in a subproc here.
"autotune_in_subproc": False,
"max_autotune_gemm_backends": "CUTLASS,ATen",
"cuda.cutlass_dir": _CUTLASS_DIR,
"cuda.cutlass_max_profiling_configs": 2,
"cuda.cutlass_op_allowlist_regex": "pingpong",
"cuda.cutlass_op_denylist_regex": None, # Pingpong Kernels can lead to numerical issues
}
):
with mock.patch(
"torch._inductor.kernel.mm.autotune_select_algorithm",
wraps=select_no_algorithm,
) as sa:
torch.compile(addmm, dynamic=False)(x, a, b, 1.0, 1.0)
args, kwargs = sa.call_args
op_name, choices, _, __ = args
assert op_name == "addmm"
cuda_template_count = 0
for choice in choices:
if isinstance(choice, CUDATemplateCaller):
choice_info = choice.info_dict()
assert (
"pingpong" in choice_info["op_conf_name"]
), "Only pingpong Kernels should have been allowed"
cuda_template_count += 1
assert cuda_template_count > 0, "No CUDATemplateCaller choices"
@unittest.skipIf(not SM80OrLater, "need sm_80")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
def test_get_max_alignment(self):
l4 = FixedLayout("cpu", torch.half, size=(1, 2, 4), stride=(0, 4, 1))
m4 = get_max_alignment(l4)
self.assertEqual(
m4, 4, "Wrong max alignment. Should have been 4. (simple, contiguous case)"
)
l4_2 = FixedLayout("cpu", torch.half, size=(1, 4, 2), stride=(0, 1, 4))
m4_2 = get_max_alignment(l4_2)
self.assertEqual(
m4_2,
4,
"Wrong max alignment. Should have been 4. Did not deal with strides correctly",
)
l1 = FixedLayout("cpu", torch.half, size=(2, 4, 2), stride=(23, 1, 4))
m1 = get_max_alignment(l1)
self.assertEqual(
m1,
1,
"Wrong max alignment. Should have been 1. Did not take stride into account correctly",
)
l2 = FixedLayout("cpu", torch.half, size=(1, 2, 4), stride=(0, 4, 1), offset=6)
m2 = get_max_alignment(l2)
self.assertEqual(
m2, 2, "Wrong max alignment. Should have been 2. (due to choice of offset)"
)
l8 = FixedLayout(
"cpu", torch.half, size=(2, 2, 8), stride=(32, 8, 1), offset=24
)
m8 = get_max_alignment(l8)
self.assertEqual(m8, 8, "Wrong max alignment. Should have been 8.")
l4 = FixedLayout(
"cpu", torch.float32, size=(2, 2, 8), stride=(32, 8, 1), offset=24
)
m4 = get_max_alignment(l4)
self.assertEqual(
m4, 4, "Wrong max alignment. Should have been 4 (due to float32 dtype )."
)
if __name__ == "__main__":
from torch._inductor.utils import is_big_gpu
# Set env to make it work in CI.
if HAS_CUDA and HAS_CPU and is_big_gpu():
run_tests()