Compare commits

..

1 Commits

Author SHA1 Message Date
41852cdbf3 add old way 2025-09-11 16:43:31 -04:00
11 changed files with 73 additions and 106 deletions

View File

@ -1 +1,2 @@
cc99baf14dacc2497d0c5ed84e076ef2c37f6a4d
f510715882304796a96e33028b4f6de1b026c2c7

View File

@ -481,7 +481,7 @@ class TestPoolingNN(NNTestCase):
def test_max_unpool3d_input_check(self):
x = torch.ones(1, 3, 1, 1, 1)
with self.assertRaises(RuntimeError):
with self.assertRaises(AssertionError):
F.max_unpool3d(x, torch.zeros(x.shape, dtype=int), [1, 1])
def test_quantized_max_pool1d_empty_kernel(self):

View File

@ -15,7 +15,7 @@ import torch
from torch import _VF
import torch.jit
import torch.nn.functional as F
from torch.nn.modules.utils import _single, _pair
from torch.nn.modules.utils import _ntuple, _pair, _single
from hypothesis import settings, HealthCheck
from hypothesis import assume, given, note
@ -5311,10 +5311,11 @@ class TestQuantizedConv(TestCase):
input_channels = input_channels_per_group * groups
output_channels = output_channels_per_group * groups
# Padded input size should be at least as big as dilated kernel
kernels = _single(kernels)
strides = _single(strides)
pads = _single(pads)
dilations = _single(dilations)
input_dimension_function = _ntuple(len(input_feature_map_shape))
kernels = input_dimension_function(kernels)
strides = input_dimension_function(strides)
pads = input_dimension_function(pads)
dilations = input_dimension_function(dilations)
for i in range(len(kernels)):
assume(input_feature_map_shape[i] + 2 * pads[i]
>= dilations[i] * (kernels[i] - 1) + 1)
@ -7846,10 +7847,11 @@ class TestQuantizedConv(TestCase):
input_channels = input_channels_per_group * groups
output_channels = output_channels_per_group * groups
# Padded input size should be at least as big as dilated kernel
kernels = _single(kernels)
strides = _single(strides)
pads = _single(pads)
dilations = _single(dilations)
input_dimension_function = _ntuple(len(input_feature_map_shape))
kernels = input_dimension_function(kernels)
strides = input_dimension_function(strides)
pads = input_dimension_function(pads)
dilations = input_dimension_function(dilations)
for i in range(len(kernels)):
assume(input_feature_map_shape[i] + 2 * pads[i]
>= dilations[i] * (kernels[i] - 1) + 1)

View File

@ -8957,9 +8957,9 @@ class TestPad(TestCaseMPS):
# pad dims == input dims
helper((1, 3), (0, 2, 0, 1), nn.ConstantPad2d)
# input.numel() == 0 but output.numel() > 0
helper((0, 3, 3), (1, 1, 1, 1, 1, 1), nn.ConstantPad2d)
helper((0, 3, 3), 1, nn.ConstantPad2d)
# pad dims < input dims - 2
helper((1, 2, 3, 4), (1, 2), nn.ConstantPad2d)
helper((1, 2, 3, 4, 5), (1, 2, 0, 0), nn.ConstantPad2d)
# 3D Padding
helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ReflectionPad3d)
@ -8972,7 +8972,7 @@ class TestPad(TestCaseMPS):
# input size < pad size
helper((2, 4, 6), (1, 3, 3, 5, 3, 4), nn.ConstantPad3d)
# check the workaround for the right padding bug in Monterey
helper((1, 2, 2, 2, 2), (0, 1), nn.ConstantPad3d)
helper((1, 2, 2, 2, 2), (0, 1, 0, 1, 0, 1), nn.ConstantPad3d)
def test_constant_pad_nd_preserves_memory_format(self):
nchw_tensor = torch.rand((1, 2, 5, 3))

View File

@ -7480,14 +7480,8 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
def test_fractional_max_pool2d_invalid_output_ratio(self):
arg_1 = [2, 1]
arg_2 = [0.5, 0.5, 0.6]
arg_class = torch.nn.FractionalMaxPool2d(kernel_size=arg_1, output_ratio=arg_2,)
arg_3_0_tensor = torch.rand([20, 16, 50, 32], dtype=torch.float32)
arg_3_0 = arg_3_0_tensor.clone()
arg_3 = [arg_3_0,]
with self.assertRaisesRegex(ValueError,
"fractional_max_pool2d requires output_ratio to either be a single Int or tuple of Ints."):
res = arg_class(*arg_3)
with self.assertRaisesRegex(AssertionError, "Expected an iterable of length 2, but got length 3"):
arg_class = torch.nn.FractionalMaxPool2d(kernel_size=arg_1, output_ratio=arg_2,)
def test_max_pool1d_invalid_output_size(self):
arg_1 = 3

View File

@ -3,7 +3,6 @@ import functools
import logging
import os
import pathlib
from typing import Any, Optional, Union
from torch._inductor.ir import MultiTemplateBuffer
from torch._inductor.metrics import get_metric_table, is_metric_table_enabled
@ -15,7 +14,6 @@ from ..runtime.benchmarking import benchmarker
from ..utils import cache_on_self, IndentedBuffer
from ..virtualized import V
from .common import TensorArg, WorkspaceArg
from .triton import TritonKernel
log = logging.getLogger(__name__)
@ -29,11 +27,11 @@ class MultiKernelState:
V.graph.wrapper_code has a reference to MultiKernelState instance.
"""
def __init__(self) -> None:
self.subkernel_to_kernel_name: dict[tuple[str, ...], str] = {}
self.kernel_defs: IndentedBuffer = IndentedBuffer()
def __init__(self):
self.subkernel_to_kernel_name = {}
self.kernel_defs = IndentedBuffer()
def define_kernel(self, kernels: list[TritonKernel]) -> str:
def define_kernel(self, kernels):
"""
Previously we name the multi kernel as "multi_kernel_{kernel_names[0]}".
This has some minor issue.
@ -51,7 +49,7 @@ class MultiKernelState:
# Prevent circular import
from ..select_algorithm import TritonTemplateKernel
kernel_names: tuple[str] = tuple(k.kernel_name for k in kernels)
kernel_names = tuple(k.kernel_name for k in kernels)
if kernel_names in self.subkernel_to_kernel_name:
return self.subkernel_to_kernel_name[kernel_names]
@ -70,7 +68,6 @@ class MultiKernelState:
kernels[0].output_node, MultiTemplateBuffer
):
for i, kernel in enumerate(kernels):
assert isinstance(kernel, TritonTemplateKernel)
additional_call_args, additional_arg_types = (
kernel.additional_call_args_and_types()
)
@ -127,11 +124,11 @@ class MultiKernel:
Here is a concrete example: https://gist.github.com/shunting314/d9f3fb6bc6cee3dbae005825ca196d39
"""
def __init__(self, kernels: list[TritonKernel]) -> None:
def __init__(self, kernels):
assert len(kernels) >= 2
self.kernels: list[TritonKernel] = kernels
self.kernel_name: str = V.graph.wrapper_code.multi_kernel_state.define_kernel(
self.kernels = kernels
self.kernel_name = V.graph.wrapper_code.multi_kernel_state.define_kernel(
kernels
)
@ -154,9 +151,9 @@ class MultiKernel:
return [*result.values()]
@staticmethod
def merge_workspaces_inplace(kernels: list[TritonKernel]) -> Optional[list[WorkspaceArg]]:
def merge_workspaces_inplace(kernels):
if len(kernels) < 2:
return None
return
# All kernels must share the same workspace
workspace_args = functools.reduce(
MultiKernel._merge_workspace_args,
@ -166,7 +163,7 @@ class MultiKernel:
kernel.args.workspace_args = workspace_args
return workspace_args
def call_kernel(self, kernel_name: str) -> None:
def call_kernel(self, kernel_name):
"""
Collect the union of arguments from all subkernels as the arguments
for the multi-kernel.
@ -227,7 +224,7 @@ class MultiKernel:
for ws in reversed(self.kernels[0].args.workspace_args):
V.graph.wrapper_code.generate_workspace_deallocation(ws)
def codegen_nan_check(self) -> None:
def codegen_nan_check(self):
wrapper = V.graph.wrapper_code
seen: OrderedSet[str] = OrderedSet()
for k in self.kernels:
@ -243,16 +240,16 @@ class MultiKernel:
wrapper.writeline(line)
@property
def removed_buffers(self) -> OrderedSet[str]:
def removed_buffers(self):
return OrderedSet.intersection(*[k.removed_buffers for k in self.kernels])
@property
def inplaced_to_remove(self) -> OrderedSet[str]:
def inplaced_to_remove(self):
return OrderedSet.intersection(*[k.inplaced_to_remove for k in self.kernels])
@property
@cache_on_self
def inplace_update_buffers(self) -> dict[str, str]:
def inplace_update_buffers(self):
"""
Make sure all kernels have the same inplace update mappings.
"""
@ -260,7 +257,7 @@ class MultiKernel:
assert k.inplace_update_buffers == self.kernels[0].inplace_update_buffers
return self.kernels[0].inplace_update_buffers
def warn_mix_layout(self, kernel_name: str) -> None:
def warn_mix_layout(self, kernel_name: str):
pass
@ -269,42 +266,35 @@ class MultiKernelCall:
This class is called at run time to actually run the kernel
"""
def __init__(
self,
multi_kernel_name: str,
kernels: list[TritonKernel],
arg_index: dict[int, list[slice]],
shape_specialize: bool = False,
) -> None:
def __init__(self, multi_kernel_name, kernels, arg_index, shape_specialize=False):
assert len(kernels) >= 2
self._kernels: list[TritonKernel] = kernels
self.multi_kernel_name: str = multi_kernel_name
self._kernels = kernels
self.multi_kernel_name = multi_kernel_name
self.disable_cache: bool = os.environ.get(
self.disable_cache = os.environ.get(
"TORCHINDUCTOR_DISABLE_MULTI_KERNEL_CACHE"
) == "1" or is_metric_table_enabled("persistent_red_perf")
self.picked_kernel: Optional[int] = None
self.arg_index: dict[int, list[slice]] = arg_index
self.picked_kernel = None
self.arg_index = arg_index
if config.triton.multi_kernel > 1:
# manually force a subkernel to ease perf testing
picked_by_config: int = config.triton.multi_kernel - 2
picked_by_config = config.triton.multi_kernel - 2
assert picked_by_config < len(self._kernels)
self.picked_kernel = picked_by_config
elif not self.disable_cache:
self.load_cache()
self._recorded: bool = False
self._recorded = False
# This means for each unique shape we will do a separate assessment
# for which kernel is the best. This is particularly useful for matmul
# kernels where the best kernel can vary based on very small differences
# in shape.
self._shape_specialize: bool = shape_specialize
# Maps tuple of inputs shapes -> chosen kernel index
self._shape_cache: dict[tuple[tuple[int, ...], ...], int] = {}
self._shape_specialize = shape_specialize
self._shape_cache = {}
def cache_file_path(self) -> pathlib.Path:
def cache_file_path(self):
key = code_hash(
",".join(
[
@ -316,7 +306,7 @@ class MultiKernelCall:
_, _, path = get_path(key, "picked_kernel")
return pathlib.Path(path)
def load_cache(self) -> None:
def load_cache(self):
assert self.picked_kernel is None
path = self.cache_file_path()
if path.exists():
@ -329,7 +319,7 @@ class MultiKernelCall:
"Load picked kernel %d from cache file %s", self.picked_kernel, path
)
def store_cache(self) -> None:
def store_cache(self):
assert self.picked_kernel is not None
path = self.cache_file_path()
path.parent.mkdir(parents=True, exist_ok=True)
@ -338,7 +328,7 @@ class MultiKernelCall:
log.debug("Store picked kernel %d to cache file %s", self.picked_kernel, path)
@property
def kernels(self) -> list[TritonKernel]:
def kernels(self):
"""
Read results from future.
@ -352,7 +342,7 @@ class MultiKernelCall:
return self._kernels
def benchmark_sub_kernels(self, *args: Any, **kwargs: Any) -> list[float]:
def benchmark_sub_kernels(self, *args, **kwargs):
"""
Benchmark all the sub kernels and return the execution time
(in milliseconds) for each of time.
@ -361,8 +351,8 @@ class MultiKernelCall:
be picked.
"""
def wrap_fn(kernel: Any, index: int) -> Any:
def inner() -> Any:
def wrap_fn(kernel, index):
def inner():
filtered_args = self._get_filtered_args(args, index)
args_clone, kwargs_clone = kernel.clone_args(*filtered_args, **kwargs)
return kernel.run(*args_clone, **kwargs_clone)
@ -374,7 +364,7 @@ class MultiKernelCall:
for index, kernel in enumerate(self.kernels)
]
def _get_filtered_args(self, args: tuple[Any, ...], index: int) -> Union[list[Any], tuple[Any, ...]]:
def _get_filtered_args(self, args, index):
"""
We pass in all arguments to all kernels into the MultiKernelCall
so when invoking a particular kernel we need to filter to only the
@ -398,7 +388,7 @@ class MultiKernelCall:
# path for the cache file. Also reading the cache file need do some IO
# which can be slower.
@staticmethod
def record_choice(multi_kernel_name: str, picked_kernel_name: str) -> None:
def record_choice(multi_kernel_name: str, picked_kernel_name: str):
"""
Record the multi-kernel choice for cpp-wrapper after autotuning
@ -424,7 +414,7 @@ class MultiKernelCall:
# there should be no miss
return V.graph.multi_kernel_to_choice[multi_kernel_name]
def run(self, *args: Any, **kwargs: Any) -> None:
def run(self, *args, **kwargs):
if self._shape_specialize:
cache_key = self._get_shape_cache_key(*args, **kwargs)
cached_choice = self._get_cached_shape_choice(cache_key)

View File

@ -4972,7 +4972,6 @@ class ChoiceCaller:
# An additional description used to describe the choice (useful for
# knowing what autotuning is choosing)
self.description = description
self.failed: bool = False
def benchmark(self, *args: Any, out: torch.Tensor) -> float:
algo = self.to_callable()
@ -5010,14 +5009,6 @@ class ChoiceCaller:
def autoheuristic_id(self) -> str:
return "unsupported_choice"
def mark_failed(self) -> None:
"""
Mark the choice as failed so that it can be
removed later. Useful for when we decouple
compilation and tuning.
"""
self.failed = True
class TritonTemplateCallerBase(ChoiceCaller):
def get_make_kernel_render(self) -> Any:

View File

@ -2421,19 +2421,16 @@ class AlgorithmSelectorCache(PersistentCache):
N = input_nodes[-1].get_size()[-1]
append_to_log(mm_file_name, {"invoke": str((M, K, N))})
def create_no_valid_choices() -> NoValidChoicesError:
if len(choices) == 0:
backend_config = (
"max_autotune_gemm_backends"
if name != "convolution"
else "max_autotune_conv_backends"
)
return NoValidChoicesError(
raise NoValidChoicesError(
f"No choices to select, please consider adding ATEN into {backend_config} "
"config (defined in torch/_inductor/config.py) to allow at least one choice. "
)
if len(choices) == 0:
raise create_no_valid_choices()
log.debug("Max autotune selects from %s choices.", str(len(choices)))
if len(choices) == 1:
@ -2490,10 +2487,6 @@ class AlgorithmSelectorCache(PersistentCache):
precompile_fn()
precompile_elapse = time.time() - precompile_start_ts
log.debug("Precompilation elapsed time: %.02fs", precompile_elapse)
# Prune anything that failed to compile
choices = [c for c in choices if not c.failed]
if len(choices) == 0:
raise create_no_valid_choices()
candidates = self.prescreen_choices(
choices, name, inputs_key, self.prescreening_cache
@ -2830,7 +2823,6 @@ class AlgorithmSelectorCache(PersistentCache):
futures[future],
exc_info=e,
)
futures[future].mark_failed()
else:
log.exception( # noqa: G202
"Exception %s for benchmark choice %s",
@ -2838,7 +2830,6 @@ class AlgorithmSelectorCache(PersistentCache):
futures[future],
exc_info=e,
)
futures[future].mark_failed()
else:
counters["inductor"]["select_algorithm_num_precompiles"] += 1
log.info(
@ -3032,13 +3023,8 @@ class AlgorithmSelectorCache(PersistentCache):
# only benchmark triton kernel in sub process for now.
# ATen/Extern kernel are still benchmarked in the current process.
extern = []
triton = []
for c in choices:
if isinstance(c, TritonTemplateCaller):
triton.append(c)
else:
extern.append(c)
extern = [c for c in choices if isinstance(c, ExternKernelCaller)]
triton = [c for c in choices if not isinstance(c, ExternKernelCaller)]
timings = cls.benchmark_in_current_process(
extern, input_nodes, layout, input_gen_fns, hint_override=hint_override

View File

@ -357,10 +357,6 @@ class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton):
GemmConfig(32, 64, 64, 6, 2),
GemmConfig(32, 128, 64, 6, 4),
GemmConfig(32, 256, 64, 6, 4),
GemmConfig(64, 16, 256, 5, 4),
GemmConfig(64, 32, 256, 5, 4),
GemmConfig(64, 128, 128, 3, 4),
GemmConfig(128, 256, 128, 4, 8),
]
self.scaled_persistent_mm_configs: list[BaseConfig] = [
@ -373,10 +369,6 @@ class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton):
GemmConfig(128, 128, 128, 5, 8),
GemmConfig(128, 128, 128, 6, 8),
GemmConfig(128, 128, 64, 4, 8),
GemmConfig(64, 32, 256, 5, 4),
GemmConfig(128, 256, 128, 3, 8),
GemmConfig(64, 128, 256, 4, 4),
GemmConfig(64, 256, 128, 4, 4),
]
# TODO: Unify with other gemm patterns, mm_plus_mm currently follows

View File

@ -768,7 +768,7 @@ class _ConvTransposeNd(_ConvNd):
dilation: Optional[list[int]] = None,
) -> list[int]:
if output_size is None:
ret = _single(self.output_padding) # converting to list if was not already
ret = list(self.output_padding) # converting to list if was not already
else:
has_batch_dim = input.dim() == num_spatial_dims + 2
num_non_spatial_dims = 2 if has_batch_dim else 1

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
import collections
import collections.abc
from itertools import repeat
from typing import Any
@ -10,7 +10,18 @@ __all__ = ["consume_prefix_in_state_dict_if_present"]
def _ntuple(n, name="parse"):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return tuple(x)
ret = tuple(x)
# If the iterable is length 1, automatically expand to fill. This
# matches the behavior of expand_param_if_needed.
if len(ret) == 1:
return tuple(repeat(ret[0], n))
# Otherwise assert the correct length.
assert len(ret) == n, (
f"Expected an iterable of length {n}, but got length {len(ret)}"
)
return ret
return tuple(repeat(x, n))
parse.__name__ = name