Compare commits

..

8 Commits

Author SHA1 Message Date
e3d00beddd Fix triu_/tril_ overlap handling 2025-10-21 07:54:24 -07:00
21131a2444 Revert "[ROCm][CI] Update rocm.yml workflow to use 1 GPU ARC runners (#165481)"
This reverts commit ffa90d46e61650834d5f926008f48f50c6a7e87a.

Reverted https://github.com/pytorch/pytorch/pull/165481 on behalf of https://github.com/jeffdaily due to timeouts after merge ([comment](https://github.com/pytorch/pytorch/pull/165481#issuecomment-3426898171))
2025-10-21 14:15:55 +00:00
1009790ad8 [pytree][dynamo] trace on native optree functions for community pytree support (#165860)
Resolves #164972

- #164972

All `torch.utils._cxx_pytree` functions are based on `optree` functions with hardcoded `none_is_leaf=True` and `namespace="torch"`. This PR changes the polyfills to generic `optree` functions with those arguments unhardcoded. This means `torch.utils._cxx_pytree` functions are still traceable while the community `optree` usages can get dynamo support additionally.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165860
Approved by: https://github.com/Lucaskabela
2025-10-21 14:13:08 +00:00
410e6a4321 Better error handling in torch/csrc/jit/frontend/* (#165213)
Refactor error handling by using TORCH_CHECK for improved clarity in constants and scope management in some files in torch/csrc/jit/frontend/*

Fixes some parts of ISSUE https://github.com/pytorch/pytorch/issues/148114

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165213
Approved by: https://github.com/FFFrog, https://github.com/albanD
2025-10-21 13:54:59 +00:00
23c55c5b66 [Code Clean]Replace assert statements with explicit if/raise patterns (#165735)
Fix part of #164878

Replace 75 assert statements with explicit if/raise patterns in `torch/ao/ns` , include:

- `torch/ao/ns/_numeric_suite_fx.py`  - 5 asserts

- `torch/ao/ns/fx/graph_matcher.py` - 6 asserts

- `torch/ao/ns/fx/graph_passes.py` -12 asserts

- `torch/ao/ns/fx/n_shadows_utils.py` - 20 asserts

- `torch/ao/ns/fx/pattern_utils.py` - 2 asserts

- `torch/ao/ns/fx/utils.py` - 21 asserts

- `torch/ao/ns/fx/weight_utils.py` - 19 asserts

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165735
Approved by: https://github.com/albanD
2025-10-21 11:21:57 +00:00
1290b077f2 [dynamo][misc] Replace UserFunctionVariable with VariableTracker build (#165707)
Audit: To prevent future issues with functools.partial or callable
objects.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165707
Approved by: https://github.com/Lucaskabela
2025-10-21 09:27:41 +00:00
9f9ab881b2 [ROCm][inductor] heuristic improvements for reduction kernels (#161280)
Improvements to reduction kernel heuristics for MI350.

Contributions from several members of the AMD Inductor and Triton teams: @jataylo @iupaikov-amd @AmdSampsa @xiaohuguo2023

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161280
Approved by: https://github.com/jansel, https://github.com/PaulZhang12, https://github.com/eellison, https://github.com/jeffdaily
2025-10-21 07:48:54 +00:00
f2bb22ff84 [Inductor-FX] Support Tensor.item (#165599)
# Feature
This PR supports compiling `Tensor.item` with Inductor's FX backend. This maps to a custom WrapperCodeGen method called `codegen_dynamic_scalar`.

# Implementation
The implementation is fairly mechanical, following the usual flow for these types of PRs.
1. Introduce a new Wrapper IR line for this, called `DynamicScalarLine`.
2. Split `PythonWrapperCodegen.codegen_dynamic_scalar` into 2 parts: a public method which generates the Wrapper IR line, and a private one generating Python from Wrapper IR.
3. Implement an FX codegen method for the wrapper IR line. This one calls `aten.where.Scalar` to handle code like `1 if x.item() else 0`, which is a bit tricky. It also calls `aten.item.default` to convert tensors to scalars.

# Test plan
Added CI tests mirroring the AOTI ones. They test float, int and bool types, the latter taking a distinct codegen path.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165599
Approved by: https://github.com/angelayi, https://github.com/jansel
2025-10-21 07:09:56 +00:00
35 changed files with 646 additions and 1406 deletions

View File

@ -54,17 +54,12 @@ self-hosted-runner:
- windows-11-arm64
- windows-11-arm64-preview
# Organization-wide AMD-hosted runners
# MI2xx non-ARC runners
# MI2xx runners
- linux.rocm.gpu
- linux.rocm.gpu.mi250
- linux.rocm.gpu.2
- linux.rocm.gpu.4
- linux.rocm.gpu.mi250
- linux.rocm.gpu.gfx1100
# MI2xx ARC runners
- linux.rocm.gpu.mi250.1
- linux.rocm.gpu.mi250.2
- linux.rocm.gpu.mi250.4
# gfx942 ARC runners
# gfx942 runners
- linux.rocm.gpu.gfx942.1
- linux.rocm.gpu.gfx942.2
- linux.rocm.gpu.gfx942.4

View File

@ -36,12 +36,12 @@ jobs:
sync-tag: rocm-build
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.mi250.1" },
{ config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.mi250.1" },
{ config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.mi250.1" },
{ config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.mi250.1" },
{ config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.mi250.1" },
{ config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.mi250.1" },
{ config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.2" },
{ config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.2" },
{ config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.2" },
{ config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.2" },
{ config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.2" },
{ config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.2" },
]}
secrets: inherit

View File

@ -141,6 +141,8 @@ void compute_triu_tril(const Tensor& self, int64_t k, const Tensor &result) {
return;
}
checkTrilTriuMemoryOverlap(result, self);
bool inplace_op = self.is_same(result);
bool inplace_update = false;

View File

@ -1,3 +1,4 @@
#include <ATen/MemoryOverlap.h>
#include <ATen/core/Tensor.h>
#include <ATen/native/LinearAlgebraUtils.h>
@ -54,4 +55,13 @@ static inline std::tuple<bool, Tensor> checkTrilTriuBatchContiguous(const Tensor
return std::make_tuple(true, tensor);
}
static inline void checkTrilTriuMemoryOverlap(const Tensor& result, const Tensor& self) {
if (result.is_same(self)) {
at::assert_no_internal_overlap(result);
} else {
at::assert_no_internal_overlap(result);
at::assert_no_overlap(result, self);
}
}
} // namespace at::native

View File

@ -5,6 +5,7 @@
#include <ATen/Dispatch.h>
#include <ATen/MemoryOverlap.h>
#include <ATen/native/Resize.h>
#include <ATen/native/TriangularOpsUtils.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
@ -110,6 +111,8 @@ __global__ void triu_tril_kernel(
template <bool upper>
void triu_tril_cuda_template(const Tensor& result, const Tensor& self, int64_t k, const char* name) {
checkTrilTriuMemoryOverlap(result, self);
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
at::ScalarType::ComplexHalf,
at::ScalarType::Half,

View File

@ -424,7 +424,7 @@ from user code:
@torch.compile(backend="eager")
def fn(x):
d = {"a": 1}
optree.tree_flatten(d)
optree.tree_flatten_with_path(d)
return torch.sin(x)
fn(torch.randn(4))
@ -434,10 +434,10 @@ from user code:
first_graph_break,
"""\
Attempted to call function marked as skipped
Explanation: Dynamo cannot trace optree C/C++ function optree._C.PyCapsule.flatten.
Explanation: Dynamo cannot trace optree C/C++ function optree._C.PyCapsule.flatten_with_path.
Hint: Consider using torch.utils._pytree - https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py
Developer debug context: module: optree._C, qualname: PyCapsule.flatten, skip reason: <missing reason>
Developer debug context: module: optree._C, qualname: PyCapsule.flatten_with_path, skip reason: <missing reason>
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0007.html""",
)

View File

@ -110,6 +110,7 @@ if python_pytree._cxx_pytree_dynamo_traceable:
import torch.utils._cxx_pytree as cxx_pytree
pytree_modules["cxx"] = cxx_pytree
pytree_modules["native_optree"] = cxx_pytree.optree
else:
cxx_pytree = None
@ -12862,6 +12863,9 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
def fn(xs):
flat_xs, spec = pytree.tree_flatten(xs)
res = [x.clone() for x in flat_xs]
if pytree.__name__ == "optree":
# The treespec argument comes first in OpTree / JAX PyTree
return pytree.tree_unflatten(spec, res)
return pytree.tree_unflatten(res, spec)
xs = [torch.tensor(i) for i in range(3)]
@ -12876,6 +12880,9 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
def fn(xs):
flat_xs, spec = pytree.tree_flatten(xs)
res = [x.clone() for x in flat_xs]
if pytree.__name__ == "optree":
# The treespec argument comes first in OpTree / JAX PyTree
return pytree.tree_unflatten(spec, res)
return pytree.tree_unflatten(res, spec)
xs = [torch.tensor(i) for i in range(3)]
@ -12893,6 +12900,9 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
def fn(xs):
flat_xs, spec = pytree.tree_flatten(xs)
res = [x.clone() for x in flat_xs]
if pytree.__name__ == "optree":
# The treespec argument comes first in OpTree / JAX PyTree
return pytree.tree_unflatten(spec, res)
return pytree.tree_unflatten(res, spec)
xs = [torch.tensor(i) for i in range(3)]
@ -12910,6 +12920,9 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
def fn(xs):
flat_xs, spec = pytree.tree_flatten(xs)
res = [x.clone() for x in flat_xs]
if pytree.__name__ == "optree":
# The treespec argument comes first in OpTree / JAX PyTree
return pytree.tree_unflatten(spec, res)
return pytree.tree_unflatten(res, spec)
xs = [torch.tensor(i) for i in range(3)]
@ -12931,6 +12944,9 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
def fn(xs):
flat_xs, spec = pytree.tree_flatten(xs)
res = [x.clone() for x in flat_xs]
if pytree.__name__ == "optree":
# The treespec argument comes first in OpTree / JAX PyTree
return pytree.tree_unflatten(spec, res)
return pytree.tree_unflatten(res, spec)
xs = [torch.tensor(i) for i in range(3)]
@ -13032,7 +13048,13 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
torch.ones(3, 2),
1,
]
new_tree = pytree.tree_unflatten(new_leaves, treespec)
if pytree.__name__ == "optree":
# `None` is a internal node rather than leaf in default OpTree / JAX PyTree
new_leaves.pop()
# The treespec argument comes first in OpTree / JAX PyTree
new_tree = pytree.tree_unflatten(treespec, new_leaves)
else:
new_tree = pytree.tree_unflatten(new_leaves, treespec)
return leaves, new_tree
x = torch.randn(3, 2)
@ -13087,6 +13109,10 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
@parametrize_pytree_module
def test_pytree_tree_map_only(self, pytree):
if not callable(getattr(pytree, "tree_map_only", None)):
# OpTree and JAX PyTree do not have `tree_map_only`
return
def fn(xs):
def mapper(x):
return x.clone()

View File

@ -74,7 +74,14 @@ class TestAsyncCompile(TestCase):
return (a @ b).to(torch.float32).sum(dim=1)
# Fake name to make sure the lookup table is name agnostic
func_def = """
# When codegen/triton.py is changed, func_def must be updated
loop_header = (
"for r0_offset in tl.range(0, r0_numel, R0_BLOCK, num_stages = 2):"
if torch.version.hip
else "for r0_offset in tl.range(0, r0_numel, R0_BLOCK):"
)
func_def = f"""
def triton_fused_fake_name(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 1024
r0_numel = 11776
@ -87,7 +94,7 @@ def triton_fused_fake_name(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.cons
rbase = r0_base
x0 = xindex
_tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
{loop_header}
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset

View File

@ -1,572 +0,0 @@
# Owner(s): ["module: inductor"]
"""
Tests for custom operation autotuning with PyTorch Inductor.
Users can register custom ops with multiple decomposition implementations and let
Inductor automatically select the best performing variant. Key features tested:
- Name-based input generators (use argument names instead of indices)
- Dynamic shape handling across multiple compilations
- Parametric tuning with tuning_knob for combinatorial parameter exploration
- Numerical correctness and performance validation
"""
import torch
from torch._inductor import config
from torch._inductor.kernel.custom_op import (
CustomOpConfig,
register_custom_op_autotuning,
)
from torch._inductor.test_case import run_tests, TestCase
from torch.testing._internal.common_utils import skipIfXpu
from torch.testing._internal.inductor_utils import HAS_GPU
torch.set_float32_matmul_precision("high")
class TestCustomOpAutoTune(TestCase):
"""Test custom operation autotuning functionality."""
def setUp(self) -> None:
"""Set up test environment with appropriate device and dtype."""
super().setUp()
self.device = "cuda" if HAS_GPU else "cpu"
self.dtype = torch.float16 if self.device == "cuda" else torch.float32
def _create_test_configs(self):
"""Create common test configurations for different sizes."""
return [
{"batch_size": 1, "seq_len": 32, "hidden_dim": 128},
{"batch_size": 2, "seq_len": 64, "hidden_dim": 256},
]
def _run_autotune_test(self, op_object, inputs, expected, test_name):
"""Shared test infrastructure for autotuning tests."""
@torch.compile
def test_model(*args):
return op_object(*args)
torch._dynamo.reset()
autotune_backends = "TRITON" if self.device == "cuda" else "ATEN"
with config.patch(
max_autotune=True,
max_autotune_gemm_backends=autotune_backends,
fx_graph_cache=False,
benchmark_kernel=True,
):
compiled_result = test_model(*inputs)
self.assertEqual(
compiled_result.shape, expected.shape, f"{test_name} shape mismatch"
)
torch.testing.assert_close(
compiled_result,
expected,
rtol=2e-1,
atol=5e-1,
msg=f"{test_name} numerical mismatch",
)
def _assert_implementations_equivalent(self, decompositions, inputs, op_name):
"""Utility to assert that all implementations produce equivalent results."""
implementations = [(func.__name__, func) for func in decompositions]
results = {}
for name, impl in implementations:
result = impl(*inputs)
results[name] = result
# Basic sanity checks
self.assertTrue(
torch.isfinite(result).all(),
f"{op_name} {name} produced non-finite values",
)
# Verify numerical equivalence
reference_name, reference_result = next(iter(results.items()))
for name, result in results.items():
if name != reference_name:
rtol = 1e-1 if "Approximated" in name else 1e-2
atol = 1e-1 if "Approximated" in name else 1e-2
torch.testing.assert_close(
result,
reference_result,
rtol=rtol,
atol=atol,
msg=f"{op_name} {name} differs from {reference_name}",
)
def _create_rmsnorm_inputs(self, batch_size=8, seq_len=1024, hidden_dim=512):
"""Create test inputs for RMSNorm operations."""
input_tensor = torch.randn(
batch_size,
seq_len,
hidden_dim,
device=self.device,
dtype=self.dtype,
requires_grad=False,
)
weight = torch.randn(
hidden_dim, device=self.device, dtype=self.dtype, requires_grad=False
)
return input_tensor, weight
def _create_mlp_inputs(
self,
batch_size=2,
seq_len=32,
hidden_dim=512,
intermediate_dim=1024,
output_dim=256,
):
"""Create test inputs for MLP operations."""
input_tensor = torch.randn(
batch_size,
seq_len,
hidden_dim,
device=self.device,
dtype=self.dtype,
requires_grad=False,
)
gate_weight = torch.randn(
hidden_dim,
intermediate_dim,
device=self.device,
dtype=self.dtype,
requires_grad=False,
)
up_weight = torch.randn(
hidden_dim,
intermediate_dim,
device=self.device,
dtype=self.dtype,
requires_grad=False,
)
down_weight = torch.randn(
intermediate_dim,
output_dim,
device=self.device,
dtype=self.dtype,
requires_grad=False,
)
return input_tensor, gate_weight, up_weight, down_weight
@skipIfXpu
def test_rmsnorm_custom_op_autotune_with_dynamic_shape(self):
"""Test RMSNorm autotuning decomposition variants compared to fallback default with dynamic shapes."""
test_op_name = f"test_lib::rmsnorm_{id(self)}"
def rmsnorm_decomposition1(
x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-8
) -> torch.Tensor:
"""Variance-based approach: compute variance then rsqrt."""
variance = x.pow(2).mean(dim=-1, keepdim=True)
rstd = torch.rsqrt(variance + eps)
return x * rstd * weight
def rmsnorm_decomposition2(
x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-8
) -> torch.Tensor:
"""vLLM-style RMSNorm implementation - variance computation first approach."""
x_var = x # In vLLM, this could be sliced for variance_size_override
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + eps)
if weight is not None:
x = x * weight
return x
def rmsnorm_decomposition3(
x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-8
) -> torch.Tensor:
"""vLLM-style RMSNorm with extended variance computation pattern."""
x_squared = x.pow(2)
variance = x_squared.mean(dim=-1, keepdim=True)
rstd = torch.rsqrt(variance + eps)
normalized = x * rstd
# Apply weight scaling
if weight is not None:
normalized = normalized * weight
return normalized
@torch.library.custom_op(test_op_name, mutates_args=())
def test_rmsnorm_op(
input_tensor: torch.Tensor, weight: torch.Tensor, eps: float = 1e-8
) -> torch.Tensor:
return torch.nn.functional.rms_norm(
input_tensor, input_tensor.shape[-1:], weight, eps=eps
)
@test_rmsnorm_op.register_fake
def _(input_tensor: torch.Tensor, weight: torch.Tensor, eps: float = 1e-8):
return torch.empty_like(input_tensor)
lib_name, op_name = test_op_name.split("::")
op_object = getattr(getattr(torch.ops, lib_name), op_name)
decompositions = [
rmsnorm_decomposition1,
rmsnorm_decomposition2,
rmsnorm_decomposition3,
]
register_custom_op_autotuning(
op_object.default,
configs=[
CustomOpConfig(rmsnorm_decomposition1) for decomp in decompositions
],
name="test_rmsnorm_autotuned",
input_gen_fns={
"x": lambda x: torch.randn_like(x, device=self.device) * 0.02,
"weight": lambda weight: torch.ones_like(weight, device=self.device),
},
)
# Test multiple shapes to verify dynamic shape handling
test_shapes = [(2, 16, 128), (8, 32, 256)]
for i, (batch_size, seq_len, hidden_dim) in enumerate(test_shapes):
input_tensor = torch.randn(
batch_size,
seq_len,
hidden_dim,
device=self.device,
dtype=self.dtype,
requires_grad=False,
)
weight = torch.randn(
hidden_dim, device=self.device, dtype=self.dtype, requires_grad=False
)
# Test numerical equivalence for all decompositions
self._assert_implementations_equivalent(
decompositions, (input_tensor, weight), f"RMSNorm_{i}"
)
# Test autotuning
expected = rmsnorm_decomposition1(input_tensor, weight)
self._run_autotune_test(
op_object, (input_tensor, weight), expected, f"RMSNorm_{i}"
)
@skipIfXpu
def test_mlp_custom_op_autotune(self):
"""Test MLP autotuning with method parameter controlling different decomposition variants"""
test_op_name = f"test_lib::mlp_{id(self)}"
def mlp_variants(
input_tensor: torch.Tensor,
gate_weight: torch.Tensor,
up_weight: torch.Tensor,
down_weight: torch.Tensor,
method: int = 0,
) -> torch.Tensor:
"""MLP implementation with different computational approaches controlled by method parameter."""
if method == 0:
# Separate matmuls: standard implementation with torch.matmul
gate_proj = torch.matmul(input_tensor, gate_weight)
up_proj = torch.matmul(input_tensor, up_weight)
gated = torch.relu(gate_proj) * up_proj
return torch.matmul(gated, down_weight)
elif method == 1:
# Batched approach: uses torch.mm with reshaped tensors
batch_shape = input_tensor.shape[:-1]
hidden_dim = input_tensor.shape[-1]
output_dim = down_weight.shape[-1]
input_2d = input_tensor.view(-1, hidden_dim)
gate_proj = torch.mm(input_2d, gate_weight)
up_proj = torch.mm(input_2d, up_weight)
gated = torch.relu(gate_proj) * up_proj
output_2d = torch.mm(gated, down_weight)
return output_2d.view(*batch_shape, output_dim)
elif method == 2:
# Fused weights approach: concatenate then split weights
# Concatenate gate and up weights for one matrix multiply
fused_weight = torch.cat([gate_weight, up_weight], dim=1)
fused_proj = torch.matmul(input_tensor, fused_weight)
intermediate_dim = gate_weight.shape[1]
gate_proj, up_proj = fused_proj.split(
[intermediate_dim, intermediate_dim], dim=-1
)
gated = torch.relu(gate_proj) * up_proj
return torch.matmul(gated, down_weight)
@torch.library.custom_op(test_op_name, mutates_args=())
def test_mlp_op(
input_tensor: torch.Tensor,
gate_weight: torch.Tensor,
up_weight: torch.Tensor,
down_weight: torch.Tensor,
) -> torch.Tensor:
return mlp_variants(
input_tensor, gate_weight, up_weight, down_weight, method=0
)
@test_mlp_op.register_fake
def _(
input_tensor: torch.Tensor,
gate_weight: torch.Tensor,
up_weight: torch.Tensor,
down_weight: torch.Tensor,
method: int = 0,
):
return torch.empty(
input_tensor.shape[:-1] + (down_weight.shape[-1],),
device=input_tensor.device,
dtype=input_tensor.dtype,
)
lib_name, op_name = test_op_name.split("::")
op_object = getattr(getattr(torch.ops, lib_name), op_name)
# Use explicit configs with method parameter as tuning knob
register_custom_op_autotuning(
op_object.default,
configs=[
CustomOpConfig(mlp_variants, method=1), # Batched approach
CustomOpConfig(mlp_variants, method=2), # Fused weights
],
name="test_mlp_autotuned",
input_gen_fns={
"input_tensor": lambda fake_tensor: torch.randn_like(
fake_tensor, device=self.device
)
* 0.1,
"gate_weight": lambda fake_tensor: torch.randn_like(
fake_tensor, device=self.device
)
* 0.05,
"up_weight": lambda fake_tensor: torch.randn_like(
fake_tensor, device=self.device
)
* 0.05,
"down_weight": lambda fake_tensor: torch.randn_like(
fake_tensor, device=self.device
)
* 0.05,
},
)
# Create test inputs using the original helper method
input_tensor, gate_weight, up_weight, down_weight = self._create_mlp_inputs()
# Test that all method variants produce numerically equivalent results
expected = mlp_variants(
input_tensor, gate_weight, up_weight, down_weight, method=0
)
for method in [1, 2]:
result = mlp_variants(
input_tensor, gate_weight, up_weight, down_weight, method=method
)
torch.testing.assert_close(
result,
expected,
rtol=1e-5,
atol=1e-5,
msg=f"Method {method} not equivalent to method 0",
)
# Test autotuning - all should be mathematically equivalent
self._run_autotune_test(
op_object,
(input_tensor, gate_weight, up_weight, down_weight),
expected,
"MLP",
)
def _create_decompose_k_inputs(self, m=256, k=65536, n=1024):
"""Create test inputs for decompose_k matrix multiplication - divisible by all k_splits values."""
# Ensure k is divisible by all k_splits values: [2, 32, 64, 128, 256]
k = ((k + 255) // 256) * 256 # Round up to nearest multiple of 256
a = torch.randn(m, k, device=self.device, dtype=self.dtype, requires_grad=False)
b = torch.randn(k, n, device=self.device, dtype=self.dtype, requires_grad=False)
return a, b
@skipIfXpu
def test_decompose_k_custom_op_autotune(self):
"""Test decompose_k autotuning with parameter tuning for k_splits values."""
test_op_name = f"test_lib::decompose_k_{id(self)}"
def decompose_k_implementation(
a: torch.Tensor, b: torch.Tensor, k_splits: int = 4
) -> torch.Tensor:
"""Matrix multiply with k-way decomposition - parameter-tuned implementation."""
m = a.shape[0]
n = b.shape[1]
k = a.shape[1]
k_parts = k // k_splits
B = k_splits
a_reshaped = torch.permute(
a.reshape(m, B, k_parts), (1, 0, 2)
) # [B, m, k_parts]
b_reshaped = b.reshape(B, k_parts, n) # [B, k_parts, n]
result = torch.bmm(a_reshaped, b_reshaped) # [B, m, n]
return torch.sum(result, dim=0) # [m, n]
@torch.library.custom_op(test_op_name, mutates_args=())
def test_decompose_k_op(
a: torch.Tensor, b: torch.Tensor, k_splits: int = 4
) -> torch.Tensor:
return decompose_k_implementation(a, b, k_splits)
@test_decompose_k_op.register_fake
def _(a: torch.Tensor, b: torch.Tensor, k_splits: int = 4):
return torch.empty(a.shape[0], b.shape[1], device=a.device, dtype=a.dtype)
lib_name, op_name = test_op_name.split("::")
op_object = getattr(getattr(torch.ops, lib_name), op_name)
# Use parameter tuning to test different k_splits values
register_custom_op_autotuning(
op_object.default,
configs=[
CustomOpConfig(decompose_k_implementation, k_splits=2),
CustomOpConfig(decompose_k_implementation, k_splits=32),
CustomOpConfig(decompose_k_implementation, k_splits=64),
CustomOpConfig(decompose_k_implementation, k_splits=128),
CustomOpConfig(decompose_k_implementation, k_splits=256),
],
name="test_decompose_k_autotuned",
input_gen_fns={
"a": lambda fake_tensor: torch.randn_like(
fake_tensor, device=self.device
)
* 0.1, # Matrix A
"b": lambda fake_tensor: torch.randn_like(
fake_tensor, device=self.device
)
* 0.1, # Matrix B
},
)
a, b = self._create_decompose_k_inputs()
expected = a @ b
self._run_autotune_test(op_object, (a, b), expected, "DecomposeK")
@skipIfXpu
def test_multi_parameter_tuning(self):
"""Test autotuning with multiple parameters using scale_mode and chunk_size."""
op_name = f"test_lib::multi_param_{id(self)}"
def multi_param_scaling(
x: torch.Tensor,
factor: torch.Tensor,
scale_mode: int = 1,
chunk_size: int = 16,
) -> torch.Tensor:
"""Different scaling approaches controlled by scale_mode parameter."""
if scale_mode == 1:
# Simple broadcasting
return x * factor
elif scale_mode == 2:
# Process in chunks
batch_size, seq_len = x.shape[:2]
chunks = []
for start in range(0, seq_len, chunk_size):
end = min(start + chunk_size, seq_len)
chunk = x[:, start:end]
chunks.append(chunk * factor)
return torch.cat(chunks, dim=1)
elif scale_mode == 3:
# Using einsum for scaling
return torch.einsum("...i,i->...i", x, factor)
@torch.library.custom_op(op_name, mutates_args=())
def multi_param_op(
x: torch.Tensor,
factor: torch.Tensor,
scale_mode: int = 1,
chunk_size: int = 16,
) -> torch.Tensor:
return multi_param_scaling(x, factor, scale_mode, chunk_size)
@multi_param_op.register_fake
def _(
x: torch.Tensor,
factor: torch.Tensor,
scale_mode: int = 1,
chunk_size: int = 16,
):
return torch.empty_like(x)
lib_name, op_suffix = op_name.split("::")
op_object = getattr(getattr(torch.ops, lib_name), op_suffix)
# Use explicit configs with scale_mode and chunk_size parameters as tuning knobs
register_custom_op_autotuning(
op_object.default,
configs=[
CustomOpConfig(multi_param_scaling, scale_mode=1), # Broadcast
CustomOpConfig(
multi_param_scaling, scale_mode=2, chunk_size=16
), # Chunked 16
CustomOpConfig(
multi_param_scaling, scale_mode=2, chunk_size=32
), # Chunked 32
CustomOpConfig(multi_param_scaling, scale_mode=3), # Einsum
],
name="multi_param_autotuned",
input_gen_fns={
"x": lambda t: torch.randn_like(t, device=self.device) * 0.1,
"factor": lambda t: torch.ones(
t.shape[-1], device=self.device, dtype=t.dtype
),
},
)
# Create test inputs
test_x = torch.randn(4, 64, 128, device=self.device, dtype=self.dtype)
test_factor = torch.ones(128, device=self.device, dtype=self.dtype) * 2.0
# Verify numerical equivalence across all approaches
expected_result = test_x * test_factor
# Test each scale_mode variant
configs = [
(1, 16), # broadcast, chunk_size ignored
(2, 16), # chunked with size 16
(2, 32), # chunked with size 32
(3, 16), # einsum, chunk_size ignored
]
for i, (scale_mode, chunk_size) in enumerate(configs):
result = multi_param_scaling(
test_x, test_factor, scale_mode=scale_mode, chunk_size=chunk_size
)
torch.testing.assert_close(
result,
expected_result,
rtol=1e-5,
atol=1e-5,
msg=f"scale_mode {scale_mode} with chunk_size {chunk_size} not equivalent to expected",
)
# Test autotuning
self._run_autotune_test(
op_object, (test_x, test_factor), expected_result, "MultiParam"
)
if __name__ == "__main__":
run_tests()

View File

@ -1034,6 +1034,22 @@ def forward(self, arg0_1, arg1_1, arg2_1):
x = torch.randn(7, device=self.device)
self.check(M(), (x,), dynamic_shapes=({0: Dim.DYNAMIC},))
@parametrize("dynamic", (False, True))
@parametrize("input_", (1.5, 2, False))
def test_item(self, input_, dynamic: bool):
"""
Test calling Tensor.item.
"""
class M(torch.nn.Module):
def forward(self, x):
return x[1].item()
x = torch.tensor((input_,) * 10)
d = Dim("s0", min=1)
dynamic_shapes = ({0: 2 * d},) if dynamic else None
self.check(M(), (x,), dynamic_shapes=dynamic_shapes)
@parametrize("pred", (False, True))
def test_mismatched_branch_dynamic(self, pred: bool):
"""

View File

@ -14295,8 +14295,12 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
self.assertTrue(torch.all(result < 2560).item())
code_str = "\n".join(code)
if torch.version.hip:
triton_str = "tl.minimum"
else:
triton_str = "triton_helpers.minimum"
self.assertIn(
"triton_helpers.minimum",
triton_str,
code_str,
"Generated Triton code should use triton_helpers.minimum for clamping",
)

View File

@ -9986,6 +9986,20 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
self.assertEqual(result_triu_min, expected_triu_min)
self.assertEqual(result_tril_min, expected_tril_min)
@dtypes(torch.float)
def test_triu_tril_inplace_memory_overlap(self, device, dtype):
base = torch.rand((), dtype=dtype, device=device)
expanded = base.expand(3, 3)
msg = (
"unsupported operation: more than one element of the written-to tensor "
"refers to a single memory location. Please clone() the tensor before "
"performing the operation."
)
with self.assertRaisesRegex(RuntimeError, msg):
expanded.triu_(1)
with self.assertRaisesRegex(RuntimeError, msg):
expanded.tril_(-1)
@dtypes(torch.float, torch.double)
@precisionOverride({torch.float32: 1e-4})
def test_1_sized_with_0_strided(self, device, dtype):

View File

@ -6,7 +6,7 @@ from __future__ import annotations
from collections import deque
from dataclasses import dataclass, field
from typing import Any, Callable, Literal, TYPE_CHECKING
from typing import Any, Callable, TYPE_CHECKING
from typing_extensions import TypeIs
import torch.utils._pytree as python_pytree
@ -28,7 +28,7 @@ if python_pytree._cxx_pytree_dynamo_traceable:
import optree
import optree._C
import torch.utils._cxx_pytree as cxx_pytree
import torch.utils._cxx_pytree as cxx_pytree # noqa: F401
if TYPE_CHECKING:
from torch.utils._cxx_pytree import PyTree
@ -64,45 +64,69 @@ if python_pytree._cxx_pytree_dynamo_traceable:
del __func
del __name
@substitute_in_graph(cxx_pytree.tree_is_leaf, can_constant_fold_through=True)
@substitute_in_graph(optree.tree_is_leaf, can_constant_fold_through=True)
def tree_is_leaf(
tree: PyTree,
/,
is_leaf: Callable[[PyTree], bool] | None = None,
*,
none_is_leaf: bool = False,
namespace: str = "",
) -> bool:
if tree is None or (is_leaf is not None and is_leaf(tree)):
if (tree is None and none_is_leaf) or (is_leaf is not None and is_leaf(tree)):
return True
if optree.register_pytree_node.get(type(tree), namespace="torch") is None: # type: ignore[attr-defined]
if optree.register_pytree_node.get(type(tree), namespace=namespace) is None: # type: ignore[attr-defined]
return True
return False
@substitute_in_graph(cxx_pytree.tree_iter, can_constant_fold_through=False)
@substitute_in_graph(optree.tree_iter, can_constant_fold_through=False)
def tree_iter(
tree: PyTree,
/,
is_leaf: Callable[[PyTree], bool] | None = None,
*,
none_is_leaf: bool = False,
namespace: str = "",
) -> Iterable[Any]:
stack = [tree]
while stack:
node = stack.pop()
if tree_is_leaf(node, is_leaf=is_leaf):
if tree_is_leaf(
node,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
):
yield node
continue
children, *_ = optree.tree_flatten_one_level(
node,
is_leaf=is_leaf,
none_is_leaf=True,
namespace="torch",
none_is_leaf=none_is_leaf,
namespace=namespace,
)
stack.extend(reversed(children))
__all__ += ["tree_iter"]
@substitute_in_graph(cxx_pytree.tree_leaves, can_constant_fold_through=True)
@substitute_in_graph(optree.tree_leaves, can_constant_fold_through=True)
def tree_leaves(
tree: PyTree,
/,
is_leaf: Callable[[PyTree], bool] | None = None,
*,
none_is_leaf: bool = False,
namespace: str = "",
) -> list[Any]:
return list(tree_iter(tree, is_leaf=is_leaf))
return list(
tree_iter(
tree,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
)
__all__ += ["tree_leaves"]
@ -127,12 +151,12 @@ if python_pytree._cxx_pytree_dynamo_traceable:
_metadata: Any
_entries: tuple[Any, ...]
_unflatten_func: Callable[[Any | None, Iterable[PyTree]], PyTree] | None
none_is_leaf: bool
namespace: str
num_nodes: int = field(init=False)
num_leaves: int = field(init=False)
num_children: int = field(init=False)
none_is_leaf: Literal[True] = field(init=False)
namespace: Literal["torch"] = field(init=False)
def __post_init__(self) -> None:
if self._type is None:
@ -152,8 +176,6 @@ if python_pytree._cxx_pytree_dynamo_traceable:
object.__setattr__(self, "num_nodes", num_nodes)
object.__setattr__(self, "num_leaves", num_leaves)
object.__setattr__(self, "num_children", num_children)
object.__setattr__(self, "none_is_leaf", True)
object.__setattr__(self, "namespace", "torch")
def __repr__(self) -> str:
def helper(treespec: PyTreeSpec) -> str:
@ -168,6 +190,7 @@ if python_pytree._cxx_pytree_dynamo_traceable:
]
if (
treespec.type in BUILTIN_TYPES
or (treespec.type is type(None) and not self.none_is_leaf)
or optree.is_namedtuple_class(treespec.type)
or optree.is_structseq_class(treespec.type)
):
@ -181,9 +204,12 @@ if python_pytree._cxx_pytree_dynamo_traceable:
f"[{', '.join(children_representations)}])"
)
return (
f"PyTreeSpec({helper(self)}, NoneIsLeaf, namespace={self.namespace!r})"
)
inner = [
str(helper(self)),
*(["NoneIsLeaf"] if self.none_is_leaf else []),
f"namespace={self.namespace!r}",
]
return f"PyTreeSpec({', '.join(inner)})"
def __len__(self) -> int:
return self.num_leaves
@ -228,8 +254,8 @@ if python_pytree._cxx_pytree_dynamo_traceable:
children, metadata, *_ = optree.tree_flatten_one_level(
node,
none_is_leaf=True,
namespace="torch",
none_is_leaf=self.none_is_leaf,
namespace=self.namespace,
)
if len(children) != treespec.num_children:
raise ValueError(
@ -277,8 +303,8 @@ if python_pytree._cxx_pytree_dynamo_traceable:
# node_type is treespec.type
children, metadata, *_ = optree.tree_flatten_one_level(
node,
none_is_leaf=True,
namespace="torch",
none_is_leaf=self.none_is_leaf,
namespace=self.namespace,
)
if (
node_type
@ -320,25 +346,40 @@ if python_pytree._cxx_pytree_dynamo_traceable:
assert callable(self._unflatten_func)
return self._unflatten_func(self._metadata, subtrees)
_LEAF_SPEC = PyTreeSpec((), None, None, (), None)
def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec]:
return isinstance(obj, PyTreeSpec)
@substitute_in_graph( # type: ignore[arg-type]
cxx_pytree.tree_flatten,
optree.tree_flatten,
# We need to disable constant folding here because we want the function to reference the
# PyTreeSpec class defined above, not the one in the C++ module.
can_constant_fold_through=False,
)
def tree_flatten(
tree: PyTree,
/,
is_leaf: Callable[[PyTree], bool] | None = None,
*,
none_is_leaf: bool = False,
namespace: str = "",
) -> tuple[list[Any], PyTreeSpec]:
def helper(node: PyTree, leaves: list[Any]) -> PyTreeSpec:
if tree_is_leaf(node, is_leaf=is_leaf):
if tree_is_leaf(
node,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
):
leaves.append(node)
return _LEAF_SPEC
return PyTreeSpec(
(),
None,
None,
(),
None,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
(
children,
@ -348,13 +389,21 @@ if python_pytree._cxx_pytree_dynamo_traceable:
) = optree.tree_flatten_one_level(
node,
is_leaf=is_leaf,
none_is_leaf=True,
namespace="torch",
none_is_leaf=none_is_leaf,
namespace=namespace,
)
# Recursively flatten the children
subspecs = tuple(helper(child, leaves) for child in children)
return PyTreeSpec(subspecs, type(node), metadata, entries, unflatten_func) # type: ignore[arg-type]
return PyTreeSpec(
subspecs,
type(node),
metadata,
entries,
unflatten_func,
none_is_leaf=none_is_leaf,
namespace=namespace,
) # type: ignore[arg-type]
leaves: list[Any] = []
treespec = helper(tree, leaves)
@ -363,26 +412,35 @@ if python_pytree._cxx_pytree_dynamo_traceable:
__all__ += ["tree_flatten"]
@substitute_in_graph( # type: ignore[arg-type]
cxx_pytree.tree_structure,
optree.tree_structure,
# We need to disable constant folding here because we want the function to reference the
# PyTreeSpec class defined above, not the one in the C++ module.
can_constant_fold_through=False,
)
def tree_structure(
tree: PyTree,
/,
is_leaf: Callable[[PyTree], bool] | None = None,
*,
none_is_leaf: bool = False,
namespace: str = "",
) -> PyTreeSpec:
return tree_flatten(tree, is_leaf=is_leaf)[1] # type: ignore[return-value]
return tree_flatten( # type: ignore[return-value]
tree,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
)[1]
__all__ += ["tree_structure"]
@substitute_in_graph( # type: ignore[arg-type]
cxx_pytree.tree_unflatten,
optree.tree_unflatten,
# We need to disable constant folding here because we want the function to reference the
# PyTreeSpec class defined above, not the one in the C++ module.
can_constant_fold_through=False,
)
def tree_unflatten(leaves: Iterable[Any], treespec: PyTreeSpec) -> PyTree:
def tree_unflatten(treespec: PyTreeSpec, leaves: Iterable[Any]) -> PyTree:
if not _is_pytreespec_instance(treespec):
raise TypeError(
f"tree_unflatten(leaves, treespec): Expected `treespec` to be instance of "
@ -392,29 +450,57 @@ if python_pytree._cxx_pytree_dynamo_traceable:
__all__ += ["tree_unflatten"]
@substitute_in_graph(cxx_pytree.tree_map, can_constant_fold_through=True)
@substitute_in_graph(optree.tree_map, can_constant_fold_through=True)
def tree_map(
func: Callable[..., Any],
tree: PyTree,
/,
*rests: PyTree,
is_leaf: Callable[[PyTree], bool] | None = None,
none_is_leaf: bool = False,
namespace: str = "",
) -> PyTree:
leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
leaves, treespec = tree_flatten(
tree,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
return treespec.unflatten(map(func, *flat_args))
__all__ += ["tree_map"]
@substitute_in_graph(cxx_pytree.tree_map_, can_constant_fold_through=True)
@substitute_in_graph(optree.tree_map_, can_constant_fold_through=True)
def tree_map_(
func: Callable[..., Any],
tree: PyTree,
/,
*rests: PyTree,
is_leaf: Callable[[PyTree], bool] | None = None,
none_is_leaf: bool = False,
namespace: str = "",
) -> PyTree:
leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
leaves, treespec = tree_flatten(
tree,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
deque(map(func, *flat_args), maxlen=0) # consume and exhaust the iterable
return tree
__all__ += ["tree_map_"]
_none_unflatten = optree.register_pytree_node.get(type(None)).unflatten_func # type: ignore[union-attr]
@substitute_in_graph( # type: ignore[arg-type]
_none_unflatten,
can_constant_fold_through=True,
skip_signature_check=True,
)
def none_unflatten(_: None, children: Iterable[Any], /) -> None:
if len(list(children)) != 0:
raise ValueError("Expected no children.")
return None

View File

@ -200,9 +200,10 @@ class SuperVariable(VariableTracker):
and not (args or kwargs)
):
with do_not_convert_to_tracable_parameter():
return variables.UserFunctionVariable(
unpatched_nn_module_init, source=source
).call_function(tx, [self.objvar] + args, kwargs)
fn_vt = VariableTracker.build(
tx, unpatched_nn_module_init, source=source
)
return fn_vt.call_function(tx, [self.objvar] + args, kwargs)
else:
unimplemented_v2(
gb_type="Unsupported super().__init__() call",
@ -230,9 +231,8 @@ class SuperVariable(VariableTracker):
elif isinstance(inner_fn, staticmethod) and isinstance(
inner_fn.__func__, types.FunctionType
):
return variables.UserFunctionVariable(
inner_fn.__func__, source=source
).call_function(tx, args, kwargs)
fn_vt = VariableTracker.build(tx, inner_fn.__func__, source=source)
return fn_vt.call_function(tx, args, kwargs)
elif isinstance(inner_fn, classmethod) and isinstance(
inner_fn.__func__, types.FunctionType
):
@ -255,13 +255,13 @@ class SuperVariable(VariableTracker):
tx, self.objvar.value_type, cls_source
)
return variables.UserFunctionVariable(
inner_fn.__func__, source=AttrSource(source, "__func__")
).call_function(tx, [cls_variable, *args], kwargs)
fn_vt = VariableTracker.build(
tx, inner_fn.__func__, source=AttrSource(source, "__func__")
)
return fn_vt.call_function(tx, [cls_variable, *args], kwargs)
elif isinstance(inner_fn, types.FunctionType):
return variables.UserFunctionVariable(
inner_fn, source=source
).call_function(tx, [self.objvar] + args, kwargs)
fn_vt = VariableTracker.build(tx, inner_fn, source=source)
return fn_vt.call_function(tx, [self.objvar] + args, kwargs)
elif isinstance(inner_fn, types.MethodType):
return variables.UserMethodVariable(
inner_fn.__func__, self.objvar, source=source
@ -574,10 +574,8 @@ class ComptimeVariable(VariableTracker):
from ..comptime import comptime
# To support the comptime.print_graph convenience accessors
from .functions import UserFunctionVariable
return UserFunctionVariable(
getattr(comptime, name), source=AttrSource(self.source, name)
return VariableTracker.build(
tx, getattr(comptime, name), source=AttrSource(self.source, name)
)
def call_function(
@ -771,9 +769,8 @@ class AutogradFunctionVariable(VariableTracker):
sig = inspect.signature(fn)
if len(args) - 1 == len(sig._parameters):
args = args[1:] # Don't use context
return variables.UserFunctionVariable(fn, source=source).call_function(
tx, args, kwargs
)
fn_vt = VariableTracker.build(tx, fn, source=source)
return fn_vt.call_function(tx, args, kwargs)
elif isinstance(fn, types.MethodType):
return variables.UserMethodVariable(
fn.__func__,
@ -799,9 +796,8 @@ class AutogradFunctionVariable(VariableTracker):
assert isinstance(fn, types.FunctionType)
fn_source = AttrSource(self.source, "backward")
return variables.UserFunctionVariable(fn, source=fn_source).call_function(
tx, args, kwargs
)
fn_vt = VariableTracker.build(tx, fn, source=fn_source)
return fn_vt.call_function(tx, args, kwargs)
def call_function(self, tx: "InstructionTranslator", args, kwargs):
return AutogradFunctionVariable(self.fn_cls)
@ -1026,10 +1022,12 @@ class AutogradEngineVariable(UserDefinedObjectVariable):
assert tx.one_graph or tx.error_on_graph_break, (
"queue_callback() is only supported when Compiled Autograd is enabled with fullgraph=True"
)
return variables.UserFunctionVariable(
fn_vt = VariableTracker.build(
tx,
torch._dynamo.external_utils.FakeCompiledAutogradEngine.queue_callback,
source=self.source,
).call_function(
)
return fn_vt.call_function(
tx,
(tx.output.side_effects.get_ca_final_callbacks_var(), *args),
kwargs,

View File

@ -1485,7 +1485,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
else:
self.writeline(f"{arg.inner} = {cexpr(arg.inner_expr)};")
def codegen_dynamic_scalar(self, node):
def _codegen_dynamic_scalar(self, node):
(data,) = (t.codegen_reference() for t in node.inputs)
self.codegen_tensor_item(node.inputs[0].get_dtype(), data, f"{node.sym}_raw")

View File

@ -1,6 +1,6 @@
import itertools
import logging
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Union
import torch
import torch._inductor.config as config
@ -8,7 +8,6 @@ from torch._inductor import ir
from torch._inductor.codegen.common import KernelTemplate
from torch._inductor.ir import (
Buffer,
FixedLayout,
get_free_symbols,
get_symbolic_inputs,
gm_original_output_strides,
@ -111,12 +110,7 @@ class SubgraphChoiceCaller(ir.ChoiceCaller):
bm_func([*sym_inputs, *args])
if config.profile_bandwidth_with_do_bench_using_profiling:
return do_bench_using_profiling(lambda: bm_func([*sym_inputs, *args]))
# Use appropriate benchmarker based on layout device type
if self.layout.device.type == "cpu":
return benchmarker.benchmark_cpu(lambda: bm_func([*sym_inputs, *args]))
else:
return benchmarker.benchmark_gpu(lambda: bm_func([*sym_inputs, *args]))
return benchmarker.benchmark_gpu(lambda: bm_func([*sym_inputs, *args]))
def hash_key(self) -> str:
return "-".join(
@ -203,152 +197,3 @@ class SubgraphTemplate(KernelTemplate):
description=description,
make_fx_graph=make_fx_graph,
)
def generate_custom_op_choices(
self,
name: str,
decompositions: list[Callable[..., Any]],
input_nodes: list[Buffer],
kwargs: Optional[dict[str, Any]] = None,
default_impl: Optional[Callable[..., Any]] = None,
) -> list[SubgraphChoiceCaller]:
"""
Generate multiple SubgraphChoiceCaller instances for custom op autotuning.
This method extends SubgraphTemplate to support custom op decompositions,
allowing multiple implementations to compete in autotuning.
Args:
name: Base name for the choices
decompositions: List of decomposition functions to compare
input_nodes: Input nodes for the operation
kwargs: Additional arguments for decomposition functions
default_impl: Default implementation for layout inference
Returns:
List of SubgraphChoiceCaller instances for autotuning
"""
if not decompositions:
return []
kwargs = kwargs or {}
# Infer layouts and ensure stride consistency for fair autotuning comparison
layouts = [
self._infer_custom_op_layout(input_nodes, [decomp], kwargs, default_impl)
for decomp in decompositions
]
self._validate_stride_consistency(name, decompositions, layouts)
# Assert single output layout - assumes custom ops have one output tensor
assert len(layouts) > 0, f"No layouts inferred for custom op '{name}'"
assert all(
layout.device == layouts[0].device
and layout.dtype == layouts[0].dtype
and layout.size == layouts[0].size
for layout in layouts
), f"All decompositions for '{name}' must produce equivalent output layouts"
layout = layouts[0] # All layouts have equivalent stride/shape/dtype now
choices = []
for decomp in decompositions:
# Create make_fx_graph function for this decomposition
def make_fx_graph(*args: Any, decomp: Callable[..., Any] = decomp) -> Any:
import functools
from torch.fx.experimental.proxy_tensor import make_fx
# Ensure kwargs is not None for unpacking
decomp_kwargs = kwargs if kwargs is not None else {}
return make_fx(functools.partial(decomp, **decomp_kwargs))(*args)
choice = self.generate(
name=f"{name}_{decomp.__name__}",
input_nodes=input_nodes,
layout=layout,
make_fx_graph=make_fx_graph,
description=f"CustomOp {decomp.__name__}",
)
choices.append(choice)
return choices
def _validate_stride_consistency(
self,
op_name: str,
decompositions: list[Callable[..., Any]],
layouts: list[Layout],
) -> None:
"""Ensure all decompositions produce compatible strides for fair autotuning."""
if not layouts:
return
strides = [layout.stride for layout in layouts]
reference = strides[0]
for i, stride in enumerate(strides[1:]):
if stride != reference:
raise AssertionError(
f"Stride mismatch in custom op '{op_name}' autotuning: "
f"'{decompositions[i].__name__}' produces stride {stride}, "
f"but '{decompositions[0].__name__}' produces {reference}. "
f"All decompositions must have identical output strides."
)
def _infer_custom_op_layout(
self,
input_nodes: list[Buffer],
decompositions: list[Callable[..., Any]],
kwargs: dict[str, Any],
default_impl: Optional[Callable[..., Any]] = None,
) -> Layout:
"""Infer output layout for custom ops using the default implementation when available.
Note that the Subgraph assumes custom ops return exactly one tensor so far.
TODO: Add support for multiple output custom ops.
"""
import functools
from torch._inductor.virtualized import V
# Assert kwargs contain only non-tensor arguments for functools.partial
for key, value in kwargs.items():
assert not isinstance(value, (torch.Tensor, Buffer)), (
f"kwargs['{key}'] contains tensor {type(value)}. "
f"Tensor arguments should be in input_nodes, not kwargs. "
f"Only scalar/non-tensor parameters should be in kwargs."
)
# Use default_impl if available, otherwise use first decomposition
impl = default_impl if default_impl is not None else decompositions[0]
with V.fake_mode:
example_inputs = []
for inp in input_nodes:
raw_shape = inp.get_size()
concrete_shape = V.graph.sizevars.size_hints(
raw_shape, fallback=config.unbacked_symint_fallback
)
fake_tensor = torch.empty(
concrete_shape, dtype=inp.get_dtype(), device=inp.get_device()
)
example_inputs.append(fake_tensor)
fn = functools.partial(
impl, **kwargs
) # kwargs must be non-tensor for partial
output = fn(*example_inputs)
# Assert single output
assert isinstance(output, torch.Tensor), (
f"Expected single tensor output, got {type(output)}. "
f"Multi-output custom ops not yet supported in autotuning."
)
return FixedLayout(
device=output.device,
dtype=output.dtype,
size=output.shape,
stride=output.stride(),
)

View File

@ -1224,11 +1224,17 @@ class TritonOverrides(OpOverrides):
@staticmethod
def minimum(a, b):
return f"triton_helpers.minimum({a}, {b})"
if torch.version.hip:
return f"tl.minimum({a}, {b}, tl.PropagateNan.ALL)"
else:
return f"triton_helpers.minimum({a}, {b})"
@staticmethod
def maximum(a, b):
return f"triton_helpers.maximum({a}, {b})"
if torch.version.hip:
return f"tl.maximum({a}, {b}, tl.PropagateNan.ALL)"
else:
return f"triton_helpers.maximum({a}, {b})"
@staticmethod
def where(a, b, c):
@ -1601,7 +1607,10 @@ class TritonOverrides(OpOverrides):
@staticmethod
@maybe_upcast_float32()
def rsqrt(x):
return f"libdevice.rsqrt({x})"
if torch.version.hip:
return f"tl.rsqrt({x})"
else:
return f"libdevice.rsqrt({x})"
@staticmethod
@maybe_upcast_float32()
@ -4504,8 +4513,9 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
loop_end = (
"rsplit_end" if self.cooperative_reduction else f"{prefix}numel"
)
num_stages = ", num_stages = 2" if torch.version.hip else ""
self.body.writeline(
f"for {prefix}offset in range({loop_start}, {loop_end}, {prefix.upper()}BLOCK):"
f"for {prefix}offset in tl.range({loop_start}, {loop_end}, {prefix.upper()}BLOCK{num_stages}):"
)
with self.body.indent(offset=level + 1):
self.iteration_ranges_codegen_header(tree, self.body)

View File

@ -415,6 +415,19 @@ class CommentLine(WrapperLine):
return converter._generate_comment
@dataclasses.dataclass
class DynamicScalarLine(WrapperLine):
wrapper: PythonWrapperCodegen
node: ir.DynamicScalar
def codegen(self, code: IndentedBuffer) -> None:
self.wrapper._codegen_dynamic_scalar(self.node)
@staticmethod
def codegen_fx(converter: FxConverter) -> FxConversionFunc:
return converter._generate_dynamic_scalar
@dataclasses.dataclass
class ExitSubgraphLine(WrapperLine):
wrapper: PythonWrapperCodegen
@ -2060,6 +2073,9 @@ class PythonWrapperCodegen(CodeGen):
self.unbacked_symbol_decls.add(str(node.unbacked_size_symbol))
def codegen_dynamic_scalar(self, node):
self.writeline(DynamicScalarLine(self, node))
def _codegen_dynamic_scalar(self, node):
(data,) = (t.codegen_reference() for t in node.inputs)
if len(node.keypath) == 0:
self.writeline(f"{node.sym} = {data}.item()")

View File

@ -29,6 +29,7 @@ from torch._library.triton import wrap_triton
from torch.fx import GraphModule
from torch.fx.experimental.symbolic_shapes import (
CallMethodKey,
ConvertIntKey,
DivideByKey,
free_unbacked_symbols,
)
@ -54,6 +55,7 @@ from .wrapper import (
CommBufferFreeLine,
CommentLine,
ConditionalLine,
DynamicScalarLine,
EnterDeviceContextManagerLine,
EnterSubgraphLine,
ExitDeviceContextManagerLine,
@ -738,6 +740,39 @@ class FxConverter:
assert isinstance(line, CommentLine)
# We ignore comments in FX IR.
def _generate_dynamic_scalar(self, line: WrapperLine) -> None:
assert isinstance(line, DynamicScalarLine)
ir_node = line.node
(input_ir_node,) = ir_node.inputs
assert isinstance(input_ir_node, ir.IRNode)
input_fx_node = self._generate_buffer(input_ir_node)
keypath = ir_node.keypath
graph = self.gm.graph
def generate_item(x: Optional[torch.fx.Node]) -> torch.fx.Node:
assert x is not None
return graph.call_function(
aten.item.default,
args=(x,),
)
if len(keypath) == 0:
result_fx_node = generate_item(input_fx_node)
elif len(keypath) == 1 and isinstance(keypath[0], ConvertIntKey):
where_fx_node = graph.call_function(
aten.where.Scalar,
args=(input_fx_node, 1, 0),
)
result_fx_node = generate_item(where_fx_node)
else:
raise NotImplementedError(f"Unsupported keypath: {keypath}")
result_symbol = ir_node.sym
result_buffer = SymbolBuffer(result_symbol)
self._record_allocation(result_buffer, result_fx_node)
self._generate_size_proxy(result_fx_node, result_symbol)
def _generate_enter_device_context_manager(self, line: WrapperLine) -> None:
assert isinstance(line, EnterDeviceContextManagerLine)
# We ignore the device context in FX IR.

View File

@ -1444,7 +1444,7 @@ class triton:
# So far we see a fixed 8 spilled registers for kernels using sin/cos.
# Raise the threshold to 16 to be safe.
# We should revisit this once we understand more of the source of register spills.
spill_threshold: int = 16
spill_threshold: int = 32 if torch.version.hip else 16
# Generate code containing the newer tl.make_block_ptr() API for loads/store
use_block_ptr = False

View File

@ -1,397 +0,0 @@
# Owner(s): ["module: inductor"]
import functools
from typing import Any, Callable, Optional, Union
import torch
from torch._inductor.codegen.subgraph import SubgraphTemplate
from torch._inductor.ir import Buffer, FixedLayout, ir_node_to_tensor, TensorBox
from torch._inductor.lowering import lowerings, validate_ir
from torch._inductor.select_algorithm import (
autotune_select_algorithm,
ExternKernelChoice,
)
from torch._inductor.virtualized import V
class CustomOpConfig:
"""Config for custom op autotuning - similar to triton.Config.
Specifies decomposition function with parameter values.
Each config creates exactly one variant (no Cartesian product).
Args:
decomposition: Function to autotune
**params: Parameters passed to the function
Examples:
CustomOpConfig(attention_impl, head_dim=32, method='chunked')
CustomOpConfig(fallback_impl)
"""
def __init__(self, decomposition: Callable[..., Any], **params: Any):
if not callable(decomposition):
raise TypeError(
f"decomposition must be callable, got {type(decomposition)}"
)
self.decomposition = decomposition
self.params = params
# Generate descriptive name
if self.params:
param_suffix = "_".join(f"{k}_{v}" for k, v in sorted(self.params.items()))
self.name = f"{decomposition.__name__}_{param_suffix}"
else:
self.name = decomposition.__name__
def create_variant(self) -> Callable[..., Any]:
"""Create callable with parameters pre-applied using functools.partial."""
if self.params:
variant = functools.partial(self.decomposition, **self.params)
variant.__name__ = self.name # type: ignore[attr-defined]
return variant
return self.decomposition
def __repr__(self) -> str:
if self.params:
params_str = ", ".join(f"{k}={v}" for k, v in self.params.items())
return f"CustomOpConfig({self.decomposition.__name__}, {params_str})"
return f"CustomOpConfig({self.decomposition.__name__})"
__all__ = [
"autotune_custom_op",
"register_custom_op_autotuning",
"CustomOpConfig",
]
def _extract_tensor_inputs(
args: tuple[Any, ...], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
"""Extract tensor inputs from mixed args/kwargs.
Separates tensors (for autotuning input_nodes) from non-tensor parameters.
Non-tensor kwargs are later functools.partial'd into decomposition functions.
Args:
args: Positional arguments (mix of tensors and scalars)
kwargs: Keyword arguments (mix of tensors and scalars)
Returns:
Tuple of (tensor_inputs_list, non_tensor_kwargs)
"""
tensor_inputs = []
non_tensor_kwargs = {}
# Process args and kwargs: separate tensor inputs and non tensor args
for i, arg in enumerate(args):
if isinstance(arg, (TensorBox, Buffer)):
tensor_inputs.append(arg)
else:
# Add non-tensor positional args to kwargs with generated names
non_tensor_kwargs[f"arg_{i}"] = arg
for key, value in kwargs.items():
if isinstance(value, (TensorBox, Buffer)):
tensor_inputs.append(value)
else:
non_tensor_kwargs[key] = value
return tensor_inputs, non_tensor_kwargs
def _create_user_input_gen_fns(
inputs: list[Any],
arg_names: list[str],
user_input_gen_fns: dict[str, Callable[[torch.Tensor], torch.Tensor]],
) -> dict[int, Callable[[Any], torch.Tensor]]:
"""Convert user input generators from name-based to index-based format.
Inductor autotune's input_gen_fns expects index of arg_names as key.
Uses V.graph.sizevars.size_hints() to guess best for dynamic shapes.
"""
from torch._inductor import config
name_to_index = {name: i for i, name in enumerate(arg_names)}
index_based_fns = {}
for name, gen_fn in user_input_gen_fns.items():
if name in name_to_index:
index_based_fns[name_to_index[name]] = gen_fn
else:
print(f"Warning: Unknown argument name '{name}' in input_gen_fns")
def create_internal_input_gen_fn(
user_function: Callable[[torch.Tensor], torch.Tensor], arg_name: str
) -> Callable[[Any], torch.Tensor]:
"""Create internal input generator that converts IR buffer to user's fake tensor."""
def internal_input_gen_fn(ir_buffer: Any) -> torch.Tensor:
raw_shape = ir_buffer.get_size()
concrete_shape = V.graph.sizevars.size_hints(
raw_shape, fallback=config.unbacked_symint_fallback
)
fake_tensor = torch.empty(
concrete_shape, dtype=ir_buffer.get_dtype(), device="meta"
)
return user_function(fake_tensor)
return internal_input_gen_fn
return {
i: create_internal_input_gen_fn(
user_gen_fn, arg_names[i] if i < len(arg_names) else f"arg_{i}"
)
for i, user_gen_fn in index_based_fns.items()
if i < len(inputs)
}
def _create_fallback_choice(
name: str,
default_impl: Callable[..., Any],
fake_output: torch.Tensor,
kwargs: dict[str, Any],
) -> ExternKernelChoice:
"""Create fallback choice for default implementation."""
def fallback_wrapper(*args: Any) -> Any:
return default_impl(*args, **kwargs)
return ExternKernelChoice(
kernel=fallback_wrapper,
name=f"{name}_fallback_default",
has_out_variant=False,
op_overload=default_impl,
use_fallback_kernel=True,
)
def _create_parameter_variants(
decompositions: list[Callable[..., Any]],
tuning_knob: dict[str, list[Any]],
) -> list[Any]: # Returns partial objects which are callable
"""Create parameter variants for decompositions using tuning knob.
Args:
decompositions: Base implementation functions
tuning_knob: Parameter tuning dict with parameter names and value lists
Returns:
List of variant functions with all parameter combinations
"""
# Validate parameter values
for param_name, param_values in tuning_knob.items():
if not param_values or not isinstance(param_values, (list, tuple)):
raise TypeError(
f"Parameter values for '{param_name}' must be a list or tuple, got {type(param_values)}"
)
# Generate all combinations of parameter values using Cartesian product
import itertools
param_names = list(tuning_knob.keys())
param_values_lists = list(tuning_knob.values())
param_combinations = list(itertools.product(*param_values_lists))
# Create variants for each decomposition with each parameter combination
variants = []
for decomp_fn in decompositions:
for param_combo in param_combinations:
# Create kwargs dict for this combination
param_kwargs = dict(zip(param_names, param_combo))
# Create partial function with all parameters
variant = functools.partial(decomp_fn, **param_kwargs)
param_suffix = "_".join(
f"{name}_{value}" for name, value in param_kwargs.items()
)
variant.__name__ = f"{decomp_fn.__name__}_{param_suffix}" # type: ignore[attr-defined]
variants.append(variant)
return variants
def autotune_custom_op(
name: str,
decompositions: list[Callable[..., Any]],
inputs: list[Any],
kwargs: Optional[dict[str, Any]] = None,
default_impl: Optional[Callable[..., Any]] = None,
user_input_gen_fns: Optional[
dict[str, Callable[[torch.Tensor], torch.Tensor]]
] = None,
) -> Union[TensorBox, Any]:
"""Autotune custom operations by comparing multiple decomposition implementations.
Currently supports SINGLE OUTPUT custom ops only.
TODO: Add support for multiple output custom ops (tuple/list returns).
This function generates multiple implementation choices for a custom operation and
uses Inductor's autotuning system to select the best performing variant at runtime.
Args:
name: Unique identifier for the autotuning operation
decompositions: List of alternative implementation functions to benchmark
inputs: Input tensor IR nodes from compilation (TensorBox/Buffer objects)
kwargs: Non-tensor parameters to pass to decomposition functions
default_impl: Original custom op implementation used as fallback
user_input_gen_fns: Optional custom input generators for benchmarking.
Maps input indices to functions that take fake tensors
and return real tensors for performance measurement.
Returns:
IR node representing the optimized operation result
Raises:
TypeError: If decompositions is not a list/tuple
RuntimeError: If no inputs or no valid choices generated
"""
if kwargs is None:
kwargs = {}
if not isinstance(decompositions, (list, tuple)):
raise TypeError(
f"decompositions must be a list or tuple of callables, got {type(decompositions)}"
)
if not inputs:
raise RuntimeError(f"Custom op '{name}' requires tensor inputs for autotuning")
template = SubgraphTemplate(name=name)
choices = template.generate_custom_op_choices(
name=name,
decompositions=list(decompositions),
input_nodes=list(inputs),
kwargs=kwargs,
)
# Add default implementation as fallback
if default_impl and hasattr(default_impl, "_op"):
fallback_name = f"{name}_fallback_default"
from torch._inductor.select_algorithm import extern_kernels
# Skip if extern_kernel already registered to avoid duplicate registration error
if not hasattr(extern_kernels, fallback_name):
with V.fake_mode:
fake_inputs = [ir_node_to_tensor(inp) for inp in inputs]
fake_output = default_impl(*fake_inputs, **kwargs)
fallback_choice = _create_fallback_choice(
name, default_impl, fake_output, kwargs
)
fallback_choice.maybe_append_choice(
choices=choices,
input_nodes=list(inputs),
layout=FixedLayout(
device=fake_output.device,
dtype=fake_output.dtype,
size=fake_output.shape,
stride=fake_output.stride(),
),
)
if not choices:
raise RuntimeError(f"No valid choices generated for {name}")
# Convert user input generation functions to internal format
input_gen_fns = {}
if user_input_gen_fns:
import inspect
arg_names = (
list(inspect.signature(decompositions[0]).parameters.keys())
if decompositions
else []
)
input_gen_fns = _create_user_input_gen_fns(
inputs, arg_names, user_input_gen_fns
)
return autotune_select_algorithm(
name=name,
choices=choices,
input_nodes=list(inputs),
layout=choices[0].layout,
input_gen_fns=input_gen_fns,
)
def register_custom_op_autotuning(
custom_op: torch._ops.OpOverload,
configs: Union[list[CustomOpConfig], list[Callable[..., Any]]],
name: Optional[str] = None,
input_gen_fns: Optional[dict[str, Callable[[torch.Tensor], torch.Tensor]]] = None,
) -> None:
"""Register custom op for autotuning with explicit configs.
Uses config-based API where each config specifies a decomposition function
with its parameter values.
Args:
custom_op: Custom operation to register
configs: List of CustomOpConfig objects or callable functions
name: Operation name (default: "{op_name}_autotuned")
input_gen_fns: Custom input generators for benchmarking
Examples:
register_custom_op_autotuning(
torch.ops.mylib.attention.default,
configs=[
CustomOpConfig(attention_impl, head_dim=32, method='chunked'),
CustomOpConfig(attention_impl, head_dim=64, method='tiled'),
CustomOpConfig(fallback_impl), # No params
],
input_gen_fns={
"query": lambda fake: torch.randn_like(fake, device='cuda'),
"key": lambda fake: torch.randn_like(fake, device='cuda'),
"value": lambda fake: torch.randn_like(fake, device='cuda'),
}
)
"""
if not isinstance(configs, (list, tuple)):
raise TypeError(f"configs must be a list or tuple, got {type(configs)}")
if not configs:
raise ValueError("At least one config must be provided")
# Convert configs to decomposition functions
final_decompositions = []
for config in configs:
if isinstance(config, CustomOpConfig):
# CustomOpConfig object
final_decompositions.append(config.create_variant())
elif callable(config):
# Direct callable function
final_decompositions.append(config)
else:
raise TypeError(
f"Each config must be a CustomOpConfig object or callable function, "
f"got {type(config)}"
)
if name is None:
name = f"{custom_op._name}_autotuned"
@functools.wraps(custom_op)
def autotuning_lowering(*args: Any, **kwargs: Any) -> Any:
"""Inductor lowering function that replaces custom op calls with autotuned versions."""
# Extract tensor inputs and non-tensor parameters
tensor_inputs, non_tensor_kwargs = _extract_tensor_inputs(args, kwargs)
result = autotune_custom_op(
name=name,
decompositions=final_decompositions,
inputs=tensor_inputs,
kwargs=non_tensor_kwargs,
default_impl=custom_op,
user_input_gen_fns=input_gen_fns,
)
validate_ir(result)
return result
lowerings[custom_op] = autotuning_lowering

View File

@ -860,7 +860,7 @@ class CachingAutotuner(KernelInterface):
# for some (complicated) custom Triton kernels, a register-spilling
# config may yield the best latency.
if not self.custom_kernel and launcher.n_spills > self.inductor_meta.get(
"spill_threshold", 16
"spill_threshold", 32 if torch.version.hip else 16
):
log.debug(
"Skip config %s because of register spilling: %d",
@ -2393,6 +2393,7 @@ def triton_config_reduction(
num_stages=1,
num_warps=None,
register_intensive=False,
waves_per_eu=None,
dynamic_scale_rblock=True,
reduction_hint=None,
) -> Config:
@ -2446,13 +2447,19 @@ def triton_config_reduction(
cfg = _get_config({"x": x, **rnumels})
check_max_block(cfg)
check_config(cfg, xnumel=size_hints["x"])
return InductorConfig(
config = InductorConfig(
cfg,
num_warps=num_warps,
num_stages=num_stages,
dynamic_scale_rblock=dynamic_scale_rblock,
)
if torch.version.hip:
if waves_per_eu is not None:
config.kwargs["waves_per_eu"] = waves_per_eu
return config
def _get_config(numels: dict[str, int]) -> dict[str, int]:
"""
@ -2463,7 +2470,7 @@ def _get_config(numels: dict[str, int]) -> dict[str, int]:
def triton_config_tiled_reduction(
size_hints, x, y, r, num_stages=1, register_intensive=False
size_hints, x, y, r, num_stages=1, register_intensive=False, waves_per_eu=None
):
"""
Construct a tile reduction triton config with some adjustment
@ -2500,7 +2507,11 @@ def triton_config_tiled_reduction(
)
check_config(cfg, xnumel=size_hints["x"], ynumel=size_hints["y"])
check_max_block(cfg)
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
config = Config(cfg, num_warps=num_warps, num_stages=num_stages)
if torch.version.hip:
if waves_per_eu is not None:
config.kwargs["waves_per_eu"] = waves_per_eu
return config
def _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs: list[Config]):
@ -2748,6 +2759,11 @@ def _reduction_configs(
# Convert reductions to 1D, to simplify heuristics.
rnumel = get_total_reduction_numel(size_hints)
# Is max autotune enabled
max_autotune_enabled = inductor_meta.get("max_autotune") or inductor_meta.get(
"max_autotune_pointwise"
)
register_intensive = False
MAX_R0_BLOCK = 2048
loads_and_red = inductor_meta.get("num_load", 0) + inductor_meta.get(
@ -2790,6 +2806,7 @@ def _reduction_configs(
num_stages=1,
register_intensive=False,
dynamic_scale_rblock=True,
waves_per_eu=None,
):
# For 3D case with tiling scores, create an adapted version
if "y" in size_hints:
@ -2802,6 +2819,7 @@ def _reduction_configs(
num_warps=num_warps,
num_stages=num_stages,
register_intensive=register_intensive,
waves_per_eu=waves_per_eu,
)
else:
# For other cases, use the original function
@ -2812,6 +2830,7 @@ def _reduction_configs(
num_warps=num_warps,
num_stages=num_stages,
register_intensive=register_intensive,
waves_per_eu=waves_per_eu,
dynamic_scale_rblock=dynamic_scale_rblock,
reduction_hint=reduction_hint,
)
@ -2893,12 +2912,12 @@ def _reduction_configs(
)
configs.append(c)
result_configs = []
# For 3d tiling, default to more autotuning initially
if "y" in size_hints:
pass
elif inductor_meta.get("max_autotune") or inductor_meta.get(
"max_autotune_pointwise"
):
elif max_autotune_enabled:
pass # skip all these cases
elif reduction_hint == ReductionHint.INNER:
return configs + [contiguous_config]
@ -2907,7 +2926,10 @@ def _reduction_configs(
elif reduction_hint == ReductionHint.OUTER_TINY:
return configs + [tiny_config]
return configs + [
# We continue here under the following conditions:
# - max_autotune_enabled is True
# - max_autotune_enabled is False and reduction_hint is NOT one of the above cases
result_configs = configs + [
contiguous_config,
outer_config,
tiny_config,
@ -2919,6 +2941,16 @@ def _reduction_configs(
make_config(64, 4, num_warps=8),
]
if torch.version.hip:
result_configs.extend(
[
make_config(1024, 8, num_warps=4, num_stages=1, waves_per_eu=2),
make_config(512, 8, num_warps=4, num_stages=1, waves_per_eu=1),
]
)
return result_configs
def match_target_block_product(
size_hints, tiling_scores, target_block_product, min_block_size=1
@ -2975,6 +3007,7 @@ def adapt_config_for_tiling(
num_stages=1,
register_intensive=False,
persistent_reduction=False,
waves_per_eu=None,
) -> Config:
"""
Create an adapted configuration based on tiling scores,
@ -2993,6 +3026,7 @@ def adapt_config_for_tiling(
block_sizes["r0_"],
num_stages=num_stages,
register_intensive=register_intensive,
waves_per_eu=waves_per_eu,
)

View File

@ -2919,6 +2919,7 @@ class AlgorithmSelectorCache(PersistentCache):
)
timings = do_autotuning(choices, precompile_fn)
# if timings is empty, we really have no choice but to return a semi-random
# choice. returning the first `ExternKernelCaller` is probably the safest bet
# in this case, since it will generally be the ATen kernel. if there are no
@ -3523,7 +3524,6 @@ class AlgorithmSelectorCache(PersistentCache):
dtypes = ", ".join([str(n.get_dtype()) for n in input_nodes])
if config.autotune_num_choices_displayed == 0:
return
# when autotune_num_choices_displayed is None, [:None] means all
n = config.autotune_num_choices_displayed
top_k = sorted(timings, key=timings.__getitem__)[:n]

View File

@ -264,7 +264,8 @@ class OutputComparisonLogger(OutputLogger):
# fmt: on
if not self.enabled:
return x
assert isinstance(x, torch.Tensor), "non-tensor inputs not yet supported"
if not isinstance(x, torch.Tensor):
raise AssertionError("non-tensor inputs not yet supported")
if self.save_activations:
# save the activation, for debugging
self.stats.append(x.detach())
@ -595,9 +596,8 @@ def _extract_logger_info_one_model(
key = mod.ref_name
if key not in results:
results[key] = {}
assert mod.model_name not in results[key], (
f"{mod.model_name} is already present in results"
)
if mod.model_name in results[key]:
raise AssertionError(f"{mod.model_name} is already present in results")
if mod.results_type not in results[key]:
results[key][mod.results_type] = {}
if mod.model_name not in results[key][mod.results_type]:
@ -809,12 +809,10 @@ def extend_logger_results_with_comparison(
"""
for results_type_to_results in results.values():
for model_name_to_results in results_type_to_results.values():
assert model_name_1 in model_name_to_results, (
f"{model_name_1} not found in results"
)
assert model_name_2 in model_name_to_results, (
f"{model_name_2} not found in results"
)
if model_name_1 not in model_name_to_results:
raise AssertionError(f"{model_name_1} not found in results")
if model_name_2 not in model_name_to_results:
raise AssertionError(f"{model_name_2} not found in results")
results_1 = model_name_to_results[model_name_1]
results_2 = model_name_to_results[model_name_2]
@ -832,7 +830,8 @@ def extend_logger_results_with_comparison(
):
result_1 = cur_result_1
break
assert result_1 is not None
if result_1 is None:
raise AssertionError("Expected result_1 to be not None")
values_1 = result_1["values"]
values_2 = result_2["values"]

View File

@ -150,7 +150,8 @@ class _NSGraphMatchableSubgraphsIterator:
if node.op == "call_function":
return node.target not in self.non_matchable_functions
elif node.op == "call_module":
assert isinstance(node.target, str)
if not isinstance(node.target, str):
raise AssertionError(f"Expected str, got {type(node.target)}")
target_mod = getattr_from_fqn(self.gm, node.target)
return not any(
isinstance(target_mod, t) # type: ignore[arg-type]
@ -228,16 +229,19 @@ def _get_subgraph_relationship_type(
else:
return SubgraphTypeRelationship.NOT_RELATED
elif node_a.op == "call_module":
assert (
subgraph_a.base_op_node == subgraph_a.start_node
and subgraph_b.base_op_node == subgraph_b.start_node
), (
"Matching call_module patterns where base_op_node != start_node is not supported yet"
)
if (
subgraph_a.base_op_node != subgraph_a.start_node
or subgraph_b.base_op_node != subgraph_b.start_node
):
raise AssertionError(
"Matching call_module patterns where base_op_node != start_node is not supported yet"
)
# for call_module, we need to look up the modules to do the type check
assert isinstance(node_a.target, str)
if not isinstance(node_a.target, str):
raise AssertionError(f"Expected str, got {type(node_a.target)}")
mod_a = getattr_from_fqn(gm_a, node_a.target)
assert isinstance(node_b.target, str)
if not isinstance(node_b.target, str):
raise AssertionError(f"Expected str, got {type(node_b.target)}")
mod_b = getattr_from_fqn(gm_b, node_b.target)
key = (type(mod_a), type(mod_b))
@ -312,7 +316,8 @@ def _get_node_target_type(node: Node, gm: GraphModule) -> Optional[NSNodeTargetT
if node.op in ("call_function", "call_method"):
return node.target
elif node.op == "call_module":
assert isinstance(node.target, str)
if not isinstance(node.target, str):
raise AssertionError(f"Expected str, got {type(node.target)}")
mod = getattr_from_fqn(gm, node.target)
return type(mod)
return None
@ -452,9 +457,10 @@ of subgraphs, and each pair of subgraphs is related to each other."""
key_name_b = _get_name_for_subgraph(
cur_subgraph_b, gm_b, base_name_to_sets_of_related_ops, existing_names_b
)
assert key_name_a == key_name_b, (
f"Subgraph names {key_name_a} and {key_name_b} do not match"
)
if key_name_a != key_name_b:
raise AssertionError(
f"Subgraph names {key_name_a} and {key_name_b} do not match"
)
results[key_name_a] = (cur_subgraph_a, cur_subgraph_b)
continue
elif cur_subgraph_a is None and cur_subgraph_b is None:

View File

@ -32,7 +32,8 @@ def _maybe_get_fqn(node: Node, gm: GraphModule) -> Optional[str]:
# an observer, get the fqn of the node being observed.
node_to_use_for_fqn = node
if node.op == "call_module":
assert isinstance(node.target, str)
if not isinstance(node.target, str):
raise AssertionError(f"Expected str, got {type(node.target)}")
module = getattr_from_fqn(gm, node.target)
if _is_activation_post_process(module):
node_to_use_for_fqn = get_normalized_nth_input(node, gm, 0)
@ -348,7 +349,8 @@ def _insert_dtype_cast_after_node(
new_dtype_cast_name,
)
else:
assert dtype_cast_mod_cls
if not dtype_cast_mod_cls:
raise AssertionError("Expected dtype_cast_mod_cls to be not None")
dtype_cast_mod = dtype_cast_mod_cls()
setattr(gm_b, new_dtype_cast_name, dtype_cast_mod)
return graph_c.create_node(
@ -373,7 +375,8 @@ def _insert_dtype_cast_after_node(
)
results.append(new_dtype_cast_node)
else:
assert dtype_cast_mod_cls
if not dtype_cast_mod_cls:
raise AssertionError("Expected dtype_cast_mod_cls to be not None")
dtype_cast_mod = dtype_cast_mod_cls()
setattr(gm_b, new_dtype_cast_name, dtype_cast_mod)
new_dtype_cast_node = graph_c.create_node(
@ -412,10 +415,8 @@ def _copy_node_from_a_to_c(
)
return node_a_copy
elif node_a.op == "call_method":
assert node_a.target in (
"dequantize",
"to",
), f"target {node_a.target} is not implemented"
if node_a.target not in ("dequantize", "to"):
raise AssertionError(f"target {node_a.target} is not implemented")
if node_a.target == "dequantize":
arg_copy = _copy_node_from_a_to_c(
get_normalized_nth_input(node_a, gm_a, 0), gm_a, gm_b, graph_c
@ -535,7 +536,8 @@ def _insert_copy_of_subgraph_a_after_input_node_c(
"""
TODO(before land): real docblock
"""
assert isinstance(input_node_c, (Node, list))
if not isinstance(input_node_c, (Node, list)):
raise AssertionError(f"Expected Node or list, got {type(input_node_c)}")
# create a sequential list of the subgraphs' nodes from start to end,
# because we need to add the nodes to graph C in non-reverse order
@ -621,7 +623,8 @@ def _insert_copy_of_node_a_after_input_node_c(
if isinstance(input_node_c, Node):
graph_c = input_node_c.graph
else:
assert isinstance(input_node_c, list)
if not isinstance(input_node_c, list):
raise AssertionError(f"Expected list, got {type(input_node_c)}")
graph_c = input_node_c[0].graph
norm_args_kwargs = node_a.normalized_arguments(
@ -645,9 +648,10 @@ def _insert_copy_of_node_a_after_input_node_c(
return arg
elif isinstance(kwarg_val, (list, tuple)):
for el in kwarg_val:
assert not isinstance(el, Node), (
"handling of Node inside list is not implemented"
)
if isinstance(el, Node):
raise AssertionError(
"handling of Node inside list is not implemented"
)
return arg
else:
raise AssertionError(
@ -684,7 +688,8 @@ def _insert_copy_of_node_a_after_input_node_c(
# if target is a module, we point to the module from gm_b
new_mod_copy_name = get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
# fetch the corresponding module from gm_a
assert isinstance(node_a.target, str)
if not isinstance(node_a.target, str):
raise AssertionError(f"Expected str, got {type(node_a.target)}")
mod_a = getattr_from_fqn(gm_a, node_a.target)
setattr(gm_b, new_mod_copy_name, mod_a)
node_a_shadows_c = graph_c.create_node(
@ -696,7 +701,8 @@ def _insert_copy_of_node_a_after_input_node_c(
)
return node_a_shadows_c
else:
assert node_a.op in ("call_function", "call_method")
if node_a.op not in ("call_function", "call_method"):
raise AssertionError(f"Unexpected op: {node_a.op}")
node_a_shadows_c = graph_c.create_node(
node_a.op,
node_a.target,
@ -791,7 +797,8 @@ def create_a_shadows_b(
ref_node_type_b,
) = start_node_b_to_matched_subgraph_a_and_name[node_b]
else:
assert node_b_is_end_node
if not node_b_is_end_node:
raise AssertionError("Expected node_b_is_end_node to be not false")
(
subgraph_a,
ref_name,
@ -1001,7 +1008,10 @@ def create_a_shadows_b(
)
input_logger: Union[Node, list[Node]] = dtype_cast_node
else:
assert isinstance(dtype_cast_node, list)
if not isinstance(dtype_cast_node, list):
raise AssertionError(
f"Expected list, got {type(dtype_cast_node)}"
)
new_loggers = []
for dtype_cast_idx, dtype_cast_node_inner in enumerate(
dtype_cast_node
@ -1083,7 +1093,10 @@ def create_a_shadows_b(
input_logger_mod.ref_node_name = cur_node.name
else:
# pyrefly: ignore # unbound-name
assert isinstance(input_logger, list)
if not isinstance(input_logger, list):
raise AssertionError(
f"Expected list, got {type(input_logger)}"
)
# pyrefly: ignore # unbound-name
for input_logger_inner in input_logger:
input_logger_mod = getattr(gm_b, input_logger_inner.name)

View File

@ -144,9 +144,11 @@ def _get_dedup_subgraphs(matches: dict[str, _MatchResult]) -> dict[str, list[Nod
seen_nodes.add(node_or_tuple)
else:
assert isinstance(node_or_tuple, tuple)
if not isinstance(node_or_tuple, tuple):
raise AssertionError(f"Expected tuple, got {type(node_or_tuple)}")
for node in node_or_tuple:
assert isinstance(node, Node)
if not isinstance(node, Node):
raise AssertionError(f"Expected Node, got {type(node)}")
if node in seen_nodes:
was_seen = True
seen_nodes.add(node)
@ -160,7 +162,10 @@ def _get_dedup_subgraphs(matches: dict[str, _MatchResult]) -> dict[str, list[Nod
if len(cur_match[1]) == 1:
list_of_nodes = cur_match[1]
else:
assert len(cur_match[1]) == 2
if len(cur_match[1]) != 2:
raise ValueError(
f"Expected cur_match[1] to have length 2, got {len(cur_match[1])}"
)
# either (a, b), or ((a, b), c) or (c, (a, b))
# cannot make any assumptions on order, not clear what the
# _find_matches function is doing to populate this
@ -181,13 +186,12 @@ def _get_dedup_subgraphs(matches: dict[str, _MatchResult]) -> dict[str, list[Nod
last_node = n
else:
mid_node = n
assert (
first_node is not None
and mid_node is not None
and last_node is not None
)
assert mid_node.args[0] is first_node
assert last_node.args[0] is mid_node
if first_node is None or mid_node is None or last_node is None:
raise AssertionError("Expected all nodes to be non-None")
if mid_node.args[0] is not first_node:
raise AssertionError("Expected mid_node.args[0] to be first_node")
if last_node.args[0] is not mid_node:
raise AssertionError("Expected last_node.args[0] to be mid_node")
return [last_node, mid_node, first_node]
if isinstance(cur_match[1][0], Node) and isinstance(cur_match[1][1], Node):
@ -377,7 +381,10 @@ def create_submodule_from_subgraph(
# the current implementation is simplistic and cannot handle
# ops with two or more arguments which need to be passed from
# the previous op, so we assert them out
assert cur_node_orig.target not in BINARY_FUNCTIONS
if cur_node_orig.target in BINARY_FUNCTIONS:
raise AssertionError(
f"Unexpected binary function target: {cur_node_orig.target}"
)
# at this point in the code, cur_node_copy is pointing to the copy
# of the previous node
@ -435,9 +442,10 @@ def create_submodule_from_subgraph(
break
# go to next node
assert len(cur_node_orig.users.keys()) == 1, (
f"{cur_node_orig} has more than 1 users, not supported yet"
)
if len(cur_node_orig.users.keys()) != 1:
raise AssertionError(
f"{cur_node_orig} has more than 1 users, not supported yet"
)
cur_node_orig = next(iter(cur_node_orig.users.keys()))
cur_iteration += 1
if cur_iteration > iteration_limit:
@ -494,7 +502,8 @@ def create_one_transformed_and_logged_copy_of_subgraph(
)
attr_name = _get_attr_name(subgraph_idx, subgraph_candidate_idx)
assert not hasattr(mt, attr_name)
if hasattr(mt, attr_name):
raise AssertionError(f"Unexpected attribute '{attr_name}' found in {mt}")
setattr(mt, attr_name, logger_mod_orig)
with mt.graph.inserting_after(last_node):
new_node = mt.graph.call_module(attr_name, args=(last_node,), kwargs={})
@ -537,9 +546,10 @@ def create_one_transformed_and_logged_copy_of_subgraph(
"prepare_custom_config",
"qconfig_mapping",
]:
assert kwarg_name not in custom_prepare_kwargs, (
f"cannot specify {kwarg_name} in custom_prepare_kwargs"
)
if kwarg_name in custom_prepare_kwargs:
raise AssertionError(
f"cannot specify {kwarg_name} in custom_prepare_kwargs"
)
prepare_kwargs: dict[str, Any] = {
"example_inputs": example_inputs,
"qconfig_mapping": qconfig_mapping,
@ -551,7 +561,8 @@ def create_one_transformed_and_logged_copy_of_subgraph(
# attach the wrapper to the model
attr_name = _get_attr_wrapper_name(subgraph_idx, subgraph_candidate_idx)
assert not hasattr(mt, attr_name)
if hasattr(mt, attr_name):
raise AssertionError(f"Unexpected attribute '{attr_name}' found in {mt}")
setattr(mt, attr_name, orig_mod_copy_wrapped)
# add a call to the wrapper module from the parent graph
@ -600,7 +611,8 @@ def create_one_transformed_and_logged_copy_of_subgraph(
)
attr_name = _get_attr_name(subgraph_idx, subgraph_candidate_idx)
assert not hasattr(mt, attr_name)
if hasattr(mt, attr_name):
raise AssertionError(f"Unexpected attribute '{attr_name}' found in {mt}")
setattr(mt, attr_name, logger_mod_orig)
with mt.graph.inserting_after(new_node):
logger = mt.graph.call_module(
@ -824,7 +836,8 @@ def create_add_loggers_graph(
):
new_shadow_mod = maybe_shadow_mod
break
assert new_shadow_mod is not None
if new_shadow_mod is None:
raise AssertionError("Expected new_shadow_mod to be non-None")
orig_first_node_to_shadow_in_node[first_node] = new_shadow_mod
orig_first_node_to_shadow_out_node[first_node] = new_shadow_mod
@ -850,7 +863,10 @@ def create_add_loggers_graph(
fqn,
)
attr_name = _get_attr_name(cur_subgraph_idx, subgraph_candidate_idx)
assert not hasattr(model, attr_name)
if hasattr(model, attr_name):
raise AssertionError(
f"Unexpected attribute '{attr_name}' found in {model}"
)
setattr(model, attr_name, logger_mod_orig)
insertion_point = last_node
with model.graph.inserting_after(insertion_point):
@ -887,9 +903,15 @@ def create_add_loggers_graph(
# since now only linear subgraphs are supported, all nodes
# except the last one must have only one user
if cur_node_orig != last_node:
assert len(cur_node_orig.users.keys()) == 1
if len(cur_node_orig.users.keys()) != 1:
raise AssertionError(
f"Expected exactly 1, but got {len(cur_node_orig.users)}"
)
cur_node_orig = next(iter(cur_node_orig.users.keys()))
assert not cur_node_orig.name.startswith(SHADOW_NODE_NAME_PREFIX)
if cur_node_orig.name.startswith(SHADOW_NODE_NAME_PREFIX):
raise AssertionError(
"cur_node_orig should not start with SHADOW_NODE_NAME_PREFIX"
)
insertion_point = cur_node_copy
# add a comparison logger after last_node's copy
@ -905,7 +927,10 @@ def create_add_loggers_graph(
fqn,
)
attr_name = _get_attr_name(cur_subgraph_idx, subgraph_candidate_idx)
assert not hasattr(model, attr_name)
if hasattr(model, attr_name):
raise AssertionError(
f"Unexpected attribute '{attr_name}' found in {model}"
)
setattr(model, attr_name, logger_mod_orig)
with model.graph.inserting_after(insertion_point):
logger = model.graph.call_module(
@ -979,7 +1004,8 @@ def create_add_loggers_graph(
return prev_shadow_output
cur_shadow_input = orig_first_node_to_shadow_in_node[first_node]
assert cur_shadow_input is not None
if cur_shadow_input is None:
raise AssertionError("Expected cur_shadow_input to be non-None")
cur_shadow_input.args = tree_map(
maybe_remap_node_to_shadow, cur_shadow_input.args
)
@ -1019,7 +1045,8 @@ def _get_weight_info_from_shadow_wrapper(shadow_wrapper: torch.nn.Module):
# we have `w2_0`, and are navigating this subgraph
# to get `_input_scale_1` and `_input_zero_point_1`
assert len(shadow_n.users) == 1
if len(shadow_n.users) != 1:
raise AssertionError(f"Expected exactly 1, got {len(shadow_n.users)}")
quant_node = next(iter(shadow_n.users.keys()))
new_args: Any = None
if quant_node.target == torch.quantize_per_channel:
@ -1028,7 +1055,10 @@ def _get_weight_info_from_shadow_wrapper(shadow_wrapper: torch.nn.Module):
zp_val = getattr_from_fqn(shadow_wrapper, zp_node.target)
new_args = (scale_val, zp_val, axis, dtype)
else:
assert quant_node.target == torch.quantize_per_tensor
if quant_node.target != torch.quantize_per_tensor:
raise AssertionError(
f"Expected torch.quantize_per_tensor, but got {quant_node.target}"
)
_weight, scale_node, zp_node, dtype = quant_node.args
scale_val = getattr_from_fqn(shadow_wrapper, scale_node.target)
zp_val = getattr_from_fqn(shadow_wrapper, zp_node.target)

View File

@ -167,7 +167,8 @@ def end_node_matches_reversed_fusion(
elif cur_node.op == "call_module":
fusion_el_is_mod = isinstance(cur_fusion_el, type)
if fusion_el_is_mod:
assert isinstance(cur_node.target, str)
if not isinstance(cur_node.target, str):
raise AssertionError(f"Expected str, got {type(cur_node.target)}")
target_mod = getattr_from_fqn(gm, cur_node.target)
if not isinstance(cur_fusion_el, type):
return False
@ -190,7 +191,10 @@ def end_node_matches_reversed_fusion(
if cur_node.target != cur_fusion_el:
return False
else:
assert isinstance(cur_fusion_el, tuple)
if not isinstance(cur_fusion_el, tuple):
raise AssertionError(
f"Expected tuple, got {type(cur_fusion_el)}"
)
if cur_node.target != cur_fusion_el[0]:
return False
elif len(cur_node.args) < 2:

View File

@ -61,7 +61,8 @@ def get_node_first_input_and_output_type(
return (NodeInputOrOutputType.INT8, NodeInputOrOutputType.INT8)
elif node.target in FUNS_IO_TYPE_FP32_OR_INT8:
first_arg = get_normalized_nth_input(node, gm, 0)
assert isinstance(first_arg, Node)
if not isinstance(first_arg, Node):
raise AssertionError(f"Expected Node, got {type(first_arg)}")
(
_prev_node_input_type,
prev_node_output_type,
@ -73,8 +74,11 @@ def get_node_first_input_and_output_type(
return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN)
elif node.op == "call_module":
assert node.op == "call_module"
assert isinstance(node.target, str)
if node.op != "call_module":
raise AssertionError(f"Expected call_module, got '{node.op}'")
if not isinstance(node.target, str):
raise AssertionError(f"Expected str, but got {type(node.target)}")
mod = getattr_from_fqn(gm, node.target)
is_known_fp32_or_int8_input_module = any(
isinstance(mod, target_type) # type: ignore[arg-type]
@ -87,7 +91,8 @@ def get_node_first_input_and_output_type(
# A logger or observer's input and output type is the output
# type of the preceding node.
first_arg = get_normalized_nth_input(node, gm, 0)
assert isinstance(first_arg, Node)
if not isinstance(first_arg, Node):
raise AssertionError(f"Expected Node, got {type(first_arg)}")
(
_prev_node_input_type,
prev_node_output_type,
@ -116,7 +121,8 @@ def get_node_first_input_and_output_type(
# So, we look up the output type of the previous node and return that
# as the input type of this node instance.
prev_node = get_normalized_nth_input(node, gm, 0)
assert isinstance(prev_node, Node)
if not isinstance(prev_node, Node):
raise AssertionError(f"Expected Node, got {type(prev_node)}")
(
_prev_node_input_type,
prev_node_output_type,
@ -131,7 +137,8 @@ def get_node_first_input_and_output_type(
# as the input type of this node instance. We also look up the target
# of to and return the correct output type.
prev_node = get_normalized_nth_input(node, gm, 0)
assert isinstance(prev_node, Node)
if not isinstance(prev_node, Node):
raise AssertionError(f"Expected Node, got {type(prev_node)}")
(
_prev_node_input_type,
prev_node_output_type,
@ -140,15 +147,17 @@ def get_node_first_input_and_output_type(
)
cur_node_dtype_target = get_normalized_nth_input(node, gm, 1)
assert cur_node_dtype_target is torch.float16, (
f"{cur_node_dtype_target} handling needs to be added"
)
if cur_node_dtype_target is not torch.float16:
raise AssertionError(
f"{cur_node_dtype_target} handling needs to be added"
)
return (prev_node_output_type, NodeInputOrOutputType.FP16)
elif node.target in METHS_IO_TYPE_FP32_OR_INT8:
first_arg = get_normalized_nth_input(node, gm, 0)
assert isinstance(first_arg, Node)
if not isinstance(first_arg, Node):
raise AssertionError(f"Expected Node, got {type(first_arg)}")
(
_prev_node_input_type,
prev_node_output_type,
@ -181,8 +190,14 @@ def get_node_input_qparams(
def _get_scale_zp_from_function_args(node, gm, scale_arg_idx, zp_arg_idx):
scale_node = get_normalized_nth_input(node, gm, scale_arg_idx)
zp_node = get_normalized_nth_input(node, gm, zp_arg_idx)
assert isinstance(scale_node, Node) and isinstance(scale_node.target, str)
assert isinstance(zp_node, Node) and isinstance(zp_node.target, str)
if not isinstance(scale_node, Node):
raise AssertionError(f"Expected Node, got {type(scale_node)}")
if not isinstance(scale_node.target, str):
raise AssertionError(f"Expected str, got {type(scale_node.target)}")
if not isinstance(zp_node, Node):
raise AssertionError(f"Expected Node, got {type(zp_node)}")
if not isinstance(zp_node.target, str):
raise AssertionError(f"Expected str, got {type(zp_node.target)}")
scale_obj = getattr_from_fqn(gm, scale_node.target)
zp_obj = getattr_from_fqn(gm, zp_node.target)
return (scale_obj, zp_obj)
@ -200,7 +215,8 @@ def get_node_input_qparams(
elif prev_node.op == "call_module":
# get type of the module
assert isinstance(prev_node.target, str)
if not isinstance(prev_node.target, str):
raise AssertionError(f"Expected str, got {type(prev_node.target)}")
module_obj = getattr_from_fqn(gm, prev_node.target)
if isinstance(
module_obj,
@ -259,15 +275,24 @@ def return_first_non_observer_node(
if node.op == "call_module":
node_obj = getattr_from_fqn(gm, node.target) # type: ignore[arg-type]
if _is_activation_post_process(node_obj):
assert len(node.args) == 1
assert isinstance(node.args[0], Node)
if len(node.args) != 1:
raise AssertionError(
f"Expected node.args to have length 1, got {len(node.args)}"
)
if not isinstance(node.args[0], Node):
raise AssertionError(f"Expected Node, got {type(node.args[0])}")
node = node.args[0]
# code duplication intended, not worth refactoring
assert isinstance(node.target, str)
if not isinstance(node.target, str):
raise AssertionError(f"Expected str, got {type(node.target)}")
node_obj = getattr_from_fqn(gm, node.target)
if _is_activation_post_process(node_obj):
assert len(node.args) == 1
assert isinstance(node.args[0], Node)
if len(node.args) != 1:
raise AssertionError(
f"Expected node.args to have length 1, got {len(node.args)}"
)
if not isinstance(node.args[0], Node):
raise AssertionError(f"Expected Node, got {type(node.args[0])}")
node = node.args[0]
return node
@ -331,7 +356,8 @@ def get_target_type_str(node: Node, gm: GraphModule) -> str:
if node.op in ("call_function", "call_method"):
target_type = torch.typename(node.target)
elif node.op == "call_module":
assert isinstance(node.target, str)
if not isinstance(node.target, str):
raise AssertionError(f"Expected str, got {type(node.target)}")
target_mod = getattr_from_fqn(gm, node.target)
target_type = torch.typename(target_mod)
return target_type
@ -365,7 +391,8 @@ def rekey_logger_info_on_node_name_of_model(
for model_name_to_results in result_type_to_results.values():
for cur_model_name, list_of_results in model_name_to_results.items():
if cur_model_name == model_name:
assert len(list_of_results)
if len(list_of_results) == 0:
raise AssertionError("Expected list_of_results to be not empty")
new_layer_name = list_of_results[0]["ref_node_name"]
else:
continue
@ -519,14 +546,20 @@ def get_normalized_nth_input(node: Node, gm: GraphModule, idx: int) -> Node:
)
if norm_args_and_kwargs is not None:
norm_args, norm_kwargs = norm_args_and_kwargs
assert len(norm_args) + len(norm_kwargs) > idx
if len(norm_args) + len(norm_kwargs) <= idx:
raise AssertionError(
f"Index {idx} out of range: total = {len(norm_args) + len(norm_kwargs)}"
)
if idx < len(norm_args):
return norm_args[idx]
else:
# note: in Python 3.7+ dicts are ordered
return list(norm_kwargs.values())[idx]
else:
assert len(node.args) + len(node.kwargs) > idx
if len(node.args) + len(node.kwargs) <= idx:
raise AssertionError(
f"Index {idx} out of range: total = {len(node.args) + len(node.kwargs)}"
)
if idx < len(node.args):
return node.args[idx] # type: ignore[return-value]
else:
@ -536,7 +569,10 @@ def get_normalized_nth_input(node: Node, gm: GraphModule, idx: int) -> Node:
# this RuntimeError happens when node argument normalization
# requires typehints to proceed, such as for torch.add where
# either the first, second or both arguments could be tensors
assert len(node.args) + len(node.kwargs) > idx
if len(node.args) + len(node.kwargs) <= idx:
raise AssertionError(
f"Index {idx} out of range: total = {len(node.args) + len(node.kwargs)}"
) from None
if idx < len(node.args):
return node.args[idx] # type: ignore[return-value]
else:

View File

@ -77,7 +77,8 @@ def get_lstm_mod_weights(mod: nn.Module) -> list[torch.Tensor]:
res.append(param_value)
return res
else:
assert isinstance(mod, nnqd.LSTM), f"type {type(mod)} not handled yet"
if not isinstance(mod, nnqd.LSTM):
raise AssertionError(f"type {type(mod)} not handled yet")
res = []
for weight_value in mod._all_weight_values:
res.append(
@ -92,10 +93,13 @@ def get_lstm_mod_weights(mod: nn.Module) -> list[torch.Tensor]:
def get_conv_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
# traverse backwards from the weight arg, accounting for any observers
weight_arg_node = node.args[1]
assert isinstance(weight_arg_node, Node)
if not isinstance(weight_arg_node, Node):
raise AssertionError(f"Expected Node, got {type(weight_arg_node)}")
weight_node = return_first_non_observer_node(weight_arg_node, gm)
assert isinstance(weight_node, Node)
assert weight_node.op == "get_attr"
if not isinstance(weight_node, Node):
raise AssertionError(f"Expected Node, got {type(weight_node)}")
if weight_node.op != "get_attr":
raise AssertionError(f"Expected get_attr, got {weight_node.op}")
weight = getattr_from_fqn(gm, weight_node.target) # type: ignore[arg-type]
return weight.detach()
@ -103,8 +107,10 @@ def get_conv_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
def get_qconv_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
# qconv state is arg 1
qconv_state_node = node.args[1]
assert isinstance(qconv_state_node, Node)
assert qconv_state_node.op == "get_attr"
if not isinstance(qconv_state_node, Node):
raise AssertionError(f"Expected Node, got {type(qconv_state_node)}")
if qconv_state_node.op != "get_attr":
raise AssertionError(f"Expected get_attr, got {qconv_state_node.op}")
qconv_state_obj = getattr_from_fqn(gm, qconv_state_node.target) # type: ignore[arg-type]
return qconv_state_obj.weight()
@ -115,34 +121,44 @@ def get_linear_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
# weight -> obs -> linear
# weight -> to(torch.float16) -> dequantize -> linear
linear_second_arg = node.args[1]
assert isinstance(linear_second_arg, Node)
if not isinstance(linear_second_arg, Node):
raise AssertionError(f"Expected Node, got {type(linear_second_arg)}")
if linear_second_arg.op == "call_module":
# weight -> obs -> linear
weight_arg_node = node.args[1]
assert isinstance(weight_arg_node, Node)
if not isinstance(weight_arg_node, Node):
raise AssertionError(f"Expected Node, got {type(weight_arg_node)}")
weight_node = weight_arg_node.args[0]
assert isinstance(weight_node, Node)
assert weight_node.op == "get_attr"
if not isinstance(weight_node, Node):
raise AssertionError(f"Expected Node, got {type(weight_node)}")
if weight_node.op != "get_attr":
raise AssertionError(f"Expected get_attr, got {weight_node.op}")
weight = getattr_from_fqn(gm, weight_node.target) # type: ignore[arg-type]
return weight.detach()
elif linear_second_arg.op == "call_method":
# weight -> to(torch.float16) -> dequantize -> linear
assert linear_second_arg.op == "call_method"
if linear_second_arg.op != "call_method":
raise AssertionError(f"Expected call_method, got {linear_second_arg.op}")
dequant_node = node.args[1]
assert isinstance(dequant_node, Node)
if not isinstance(dequant_node, Node):
raise AssertionError(f"Expected Node, got {type(dequant_node)}")
to_fp16_node = dequant_node.args[0]
assert isinstance(to_fp16_node, Node)
if not isinstance(to_fp16_node, Node):
raise AssertionError(f"Expected Node, got {type(to_fp16_node)}")
# extract the dtype, so we can cast to it before returning
target_dtype = to_fp16_node.args[1]
weight_node = to_fp16_node.args[0]
assert isinstance(weight_node, Node)
assert weight_node.op == "get_attr"
if not isinstance(weight_node, Node):
raise AssertionError(f"Expected Node, got {type(weight_node)}")
if weight_node.op != "get_attr":
raise AssertionError(f"Expected get_attr, got {weight_node.op}")
weight = getattr_from_fqn(gm, weight_node.target) # type: ignore[arg-type]
# return the weight with fp16 cast
return weight.detach().to(target_dtype)
else:
assert linear_second_arg.op == "get_attr"
if linear_second_arg.op != "get_attr":
raise AssertionError(f"Expected get_attr, got {linear_second_arg.op}")
weight = getattr_from_fqn(gm, linear_second_arg.target) # type: ignore[arg-type]
return weight.detach()
@ -150,8 +166,10 @@ def get_linear_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
def get_qlinear_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
# packed weight is arg 1
packed_weight_node = node.args[1]
assert isinstance(packed_weight_node, Node)
assert packed_weight_node.op == "get_attr"
if not isinstance(packed_weight_node, Node):
raise AssertionError(f"Expected Node, got {type(packed_weight_node)}")
if packed_weight_node.op != "get_attr":
raise AssertionError(f"Expected get_attr, got {packed_weight_node.op}")
packed_weight = getattr_from_fqn(gm, packed_weight_node.target) # type: ignore[arg-type]
# TODO(future PR): why does packed_weight.unpack() not work?
(weight, _bias), _name = packed_weight.__getstate__()
@ -264,7 +282,8 @@ def extract_weight_from_node(
elif node.op == "call_module":
# for call_module, we need to look up the modules to do the type check
assert isinstance(node.target, str)
if not isinstance(node.target, str):
raise AssertionError(f"Expected str, got {type(node.target)}")
mod = getattr_from_fqn(gm, node.target)
module_mapping = op_to_type_to_weight_extraction_fn["call_module"]
for target_mod_type, weight_extraction_fn in module_mapping.items():

View File

@ -3259,7 +3259,7 @@ struct to_ir {
case TK_IN:
return aten::__contains__;
default:
throw std::runtime_error("unknown kind " + std::to_string(kind));
TORCH_CHECK(false, "unknown kind ", kind);
}
}
@ -3306,7 +3306,7 @@ struct to_ir {
case TK_RSHIFT:
return "__rshift__";
default:
throw std::runtime_error("unknown kind " + std::to_string(kind));
TORCH_CHECK(false, "unknown kind ", kind);
}
}
@ -4120,8 +4120,7 @@ struct to_ir {
} else if (kind == aten::ge) {
return aten::le;
}
throw std::runtime_error(
"reverseComparision: unsupported NodeKind. File a bug");
TORCH_CHECK(false, "reverseComparision: unsupported NodeKind. File a bug");
}
// any expression that can produce a SugaredValue is handled here

View File

@ -94,7 +94,7 @@ C10_EXPORT std::string kindToString(int kind) {
TC_FORALL_TOKEN_KINDS(DEFINE_CASE)
#undef DEFINE_CASE
default:
throw std::runtime_error("Unknown kind: " + std::to_string(kind));
TORCH_CHECK(false, "Unknown kind: ", kind);
}
}

View File

@ -167,12 +167,12 @@ Value* TracingState::getValue(const IValue& var) {
// Didn't find it. Bake in a constant
if (ten.requires_grad()) {
pauseTracing();
std::ostringstream oss;
oss << "Cannot insert a Tensor that requires grad as a constant. "
<< "Consider making it a parameter or input, or detaching the gradient\n"
<< "Tensor:\n"
<< ten;
throw std::runtime_error(oss.str());
TORCH_CHECK(
false,
"Cannot insert a Tensor that requires grad as a constant. ",
"Consider making it a parameter or input, or detaching the gradient\n",
"Tensor:\n",
ten);
}
Value* constant = graph->insertConstant(ten);
@ -208,15 +208,19 @@ Value* TracingState::getValue(const IValue& var) {
}
}
std::ostringstream oss;
if (var.isFuture()) {
oss << "Tried to trace Future or Object that the tracer was not aware of.";
TORCH_CHECK(
false,
"Tried to trace Future or Object that the tracer was not aware of.");
} else {
oss << "Tried to trace " << var
<< " but it is not part of the active trace. Modules that are called during a trace"
<< " must be registered as submodules of the thing being traced.";
TORCH_CHECK(
false,
"Tried to trace ",
var,
" but it is not part of the active trace. Modules that are called during a trace",
" must be registered as submodules of the thing being traced.");
}
throw std::runtime_error(oss.str());
} else {
// If the values are non-tensors, we try to create constants
// and bake those constants into the traced graph
@ -225,11 +229,12 @@ Value* TracingState::getValue(const IValue& var) {
recordSourceLocation(constant.value()->node());
return *constant;
}
std::ostringstream os;
os << "Tracer cannot get value trace for type " << var.tagKind() << ". "
<< "The below value could not be materialized as a constant:\n"
<< var;
throw std::runtime_error(os.str());
TORCH_CHECK(
false,
"Tracer cannot get value trace for type ",
var.tagKind(),
". The below value could not be materialized as a constant:\n",
var);
}
}
bool TracingState::hasValue(const IValue& var) const {
@ -252,15 +257,14 @@ Value* TracingState::getOutput(const IValue& iv, size_t i) {
auto& value_map = getTracingState()->env_stack.back();
auto it = value_map.find(iv);
if (it == value_map.end()) {
std::ostringstream os;
os << "output " << i << " (" << var
<< ") of traced region did not have observable "
<< "data dependence with trace inputs; this probably indicates your "
"program "
<< "cannot be understood by the tracer.";
throw std::runtime_error(os.str());
}
TORCH_CHECK(
it != value_map.end(),
"output ",
i,
" (",
var,
") of traced region did not have observable data dependence with trace inputs; ",
"this probably indicates your program cannot be understood by the tracer.");
return it->second;
} else if (iv.isTensorList()) {
if (tracing_mode_strict) {
@ -281,11 +285,10 @@ Value* TracingState::getOutput(const IValue& iv, size_t i) {
graph->insertNode(tuple_node);
return tuple_node->output();
} else if (iv.isGenericDict()) {
if (tracing_mode_strict) {
throw std::runtime_error(
"Encountering a dict at the output of the tracer" +
std::string(STRICT_TRACER_MSG));
}
TORCH_CHECK(
!tracing_mode_strict,
"Encountering a dict at the output of the tracer",
STRICT_TRACER_MSG);
auto dict = iv.toGenericDict();
TypePtr key_type = dict.keyType();
TypePtr value_type = dict.valueType();
@ -304,15 +307,15 @@ Value* TracingState::getOutput(const IValue& iv, size_t i) {
}
}
}
if (!key_type_valid || !value_type_valid) {
std::ostringstream os;
os << "output " << i << " (" << dict << ") of traced region "
<< "cannot be understood by the tracer, only outputs matching"
<< "dict[Union[str, Tensor], Union[Tensor, Tuple[Tensor, ...]]] "
<< "can be a dictionary output of a traced function";
throw std::runtime_error(os.str());
}
TORCH_CHECK(
key_type_valid && value_type_valid,
"output ",
i,
" (",
dict,
") of traced region cannot be understood by the tracer, only outputs matching ",
"dict[Union[str, Tensor], Union[Tensor, Tuple[Tensor, ...]]] ",
"can be a dictionary output of a traced function");
std::vector<Value*> keys;
std::vector<Value*> values;
for (const auto& entry : dict) {
@ -598,10 +601,11 @@ void TracingState::setValue(const IValue& v, Value* value) {
setValue(entry.value(), static_value);
}
} else {
std::ostringstream os;
os << "Tracer cannot set value trace for type " << v.tagKind() << ". "
<< "Supported types are tensor, tensor list, and tuple of tensors.";
throw std::runtime_error(os.str());
TORCH_CHECK(
false,
"Tracer cannot set value trace for type ",
v.tagKind(),
". Supported types are tensor, tensor list, and tuple of tensors.");
}
}
@ -801,11 +805,10 @@ void addInputs(Node* n, const char* name, at::IntArrayRef value) {
recordSourceLocation(info[i]->node());
}
for (jit::Value* v : info) {
if (*v->type() != *jit::IntType::get()) {
throw std::runtime_error(
"Type mismatch in setposattr for IntArrayRef. Check that your program "
"is valid without tracing, and please file a bug report if it is.");
}
TORCH_CHECK(
*v->type() == *jit::IntType::get(),
"Type mismatch in setposattr for IntArrayRef. Check that your program "
"is valid without tracing, and please file a bug report if it is.");
}
n->addInput(
g->insertNode(g->createList(jit::IntType::get(), info))->output());

View File

@ -5,6 +5,7 @@
#include <unordered_map>
#include <vector>
#include <c10/util/Exception.h>
#include <c10/util/SmallVector.h>
#include <c10/util/intrusive_ptr.h>
#include <torch/csrc/jit/frontend/lexer.h>
@ -37,10 +38,10 @@ struct Tree : c10::intrusive_ptr_target {
return true;
}
virtual const SourceRange& range() const {
throw std::runtime_error("is an Atom");
TORCH_CHECK(false, "is an Atom");
}
virtual const std::string& stringValue() const {
throw std::runtime_error("stringValue can only be called on TK_STRING");
TORCH_CHECK(false, "stringValue can only be called on TK_STRING");
}
virtual const TreeList& trees() const {
static const TreeList empty_trees = {};
@ -79,13 +80,16 @@ struct Tree : c10::intrusive_ptr_target {
int lineno,
size_t expected_subtrees,
bool allow_more) const {
if (kind() != k) {
std::stringstream ss;
ss << filename << ":" << lineno << ": expecting kind '" << kindToString(k)
<< "' but found '" << kindToString(kind()) << "'\n";
range().highlight(ss);
throw std::runtime_error(ss.str());
}
TORCH_CHECK(
kind() == k,
filename,
":",
lineno,
": expecting kind '",
kindToString(k),
"' but found '",
kindToString(kind()),
"'\n");
if (trees().size() < expected_subtrees ||
(!allow_more && trees().size() != expected_subtrees)) {
std::stringstream ss;
@ -93,7 +97,7 @@ struct Tree : c10::intrusive_ptr_target {
<< expected_subtrees << " subtrees, but found only " << trees().size()
<< "\n";
range().highlight(ss);
throw std::runtime_error(ss.str());
TORCH_CHECK(false, ss.str());
}
}
~Tree() override = default;

View File

@ -367,11 +367,6 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
The reconstructed pytree, containing the ``leaves`` placed in the structure described by
``treespec``.
"""
if not _is_pytreespec_instance(treespec):
raise TypeError(
f"tree_unflatten(leaves, treespec): Expected `treespec` to be instance of "
f"PyTreeSpec but got item of type {type(treespec)}."
)
return optree.tree_unflatten(treespec, leaves) # type: ignore[arg-type]