mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-24 07:27:32 +08:00
Compare commits
8 Commits
tianren/cu
...
codex/add-
| Author | SHA1 | Date | |
|---|---|---|---|
| e3d00beddd | |||
| 21131a2444 | |||
| 1009790ad8 | |||
| 410e6a4321 | |||
| 23c55c5b66 | |||
| 1290b077f2 | |||
| 9f9ab881b2 | |||
| f2bb22ff84 |
11
.github/actionlint.yaml
vendored
11
.github/actionlint.yaml
vendored
@ -54,17 +54,12 @@ self-hosted-runner:
|
||||
- windows-11-arm64
|
||||
- windows-11-arm64-preview
|
||||
# Organization-wide AMD-hosted runners
|
||||
# MI2xx non-ARC runners
|
||||
# MI2xx runners
|
||||
- linux.rocm.gpu
|
||||
- linux.rocm.gpu.mi250
|
||||
- linux.rocm.gpu.2
|
||||
- linux.rocm.gpu.4
|
||||
- linux.rocm.gpu.mi250
|
||||
- linux.rocm.gpu.gfx1100
|
||||
# MI2xx ARC runners
|
||||
- linux.rocm.gpu.mi250.1
|
||||
- linux.rocm.gpu.mi250.2
|
||||
- linux.rocm.gpu.mi250.4
|
||||
# gfx942 ARC runners
|
||||
# gfx942 runners
|
||||
- linux.rocm.gpu.gfx942.1
|
||||
- linux.rocm.gpu.gfx942.2
|
||||
- linux.rocm.gpu.gfx942.4
|
||||
|
||||
12
.github/workflows/rocm.yml
vendored
12
.github/workflows/rocm.yml
vendored
@ -36,12 +36,12 @@ jobs:
|
||||
sync-tag: rocm-build
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.mi250.1" },
|
||||
{ config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.mi250.1" },
|
||||
{ config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.mi250.1" },
|
||||
{ config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.mi250.1" },
|
||||
{ config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.mi250.1" },
|
||||
{ config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.mi250.1" },
|
||||
{ config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.2" },
|
||||
{ config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.2" },
|
||||
{ config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.2" },
|
||||
{ config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.2" },
|
||||
{ config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.2" },
|
||||
{ config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.2" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
|
||||
@ -141,6 +141,8 @@ void compute_triu_tril(const Tensor& self, int64_t k, const Tensor &result) {
|
||||
return;
|
||||
}
|
||||
|
||||
checkTrilTriuMemoryOverlap(result, self);
|
||||
|
||||
bool inplace_op = self.is_same(result);
|
||||
|
||||
bool inplace_update = false;
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
#include <ATen/MemoryOverlap.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/native/LinearAlgebraUtils.h>
|
||||
|
||||
@ -54,4 +55,13 @@ static inline std::tuple<bool, Tensor> checkTrilTriuBatchContiguous(const Tensor
|
||||
return std::make_tuple(true, tensor);
|
||||
}
|
||||
|
||||
static inline void checkTrilTriuMemoryOverlap(const Tensor& result, const Tensor& self) {
|
||||
if (result.is_same(self)) {
|
||||
at::assert_no_internal_overlap(result);
|
||||
} else {
|
||||
at::assert_no_internal_overlap(result);
|
||||
at::assert_no_overlap(result, self);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/MemoryOverlap.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <ATen/native/TriangularOpsUtils.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
@ -110,6 +111,8 @@ __global__ void triu_tril_kernel(
|
||||
|
||||
template <bool upper>
|
||||
void triu_tril_cuda_template(const Tensor& result, const Tensor& self, int64_t k, const char* name) {
|
||||
checkTrilTriuMemoryOverlap(result, self);
|
||||
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
|
||||
at::ScalarType::ComplexHalf,
|
||||
at::ScalarType::Half,
|
||||
|
||||
@ -424,7 +424,7 @@ from user code:
|
||||
@torch.compile(backend="eager")
|
||||
def fn(x):
|
||||
d = {"a": 1}
|
||||
optree.tree_flatten(d)
|
||||
optree.tree_flatten_with_path(d)
|
||||
return torch.sin(x)
|
||||
|
||||
fn(torch.randn(4))
|
||||
@ -434,10 +434,10 @@ from user code:
|
||||
first_graph_break,
|
||||
"""\
|
||||
Attempted to call function marked as skipped
|
||||
Explanation: Dynamo cannot trace optree C/C++ function optree._C.PyCapsule.flatten.
|
||||
Explanation: Dynamo cannot trace optree C/C++ function optree._C.PyCapsule.flatten_with_path.
|
||||
Hint: Consider using torch.utils._pytree - https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py
|
||||
|
||||
Developer debug context: module: optree._C, qualname: PyCapsule.flatten, skip reason: <missing reason>
|
||||
Developer debug context: module: optree._C, qualname: PyCapsule.flatten_with_path, skip reason: <missing reason>
|
||||
|
||||
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0007.html""",
|
||||
)
|
||||
|
||||
@ -110,6 +110,7 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
||||
import torch.utils._cxx_pytree as cxx_pytree
|
||||
|
||||
pytree_modules["cxx"] = cxx_pytree
|
||||
pytree_modules["native_optree"] = cxx_pytree.optree
|
||||
else:
|
||||
cxx_pytree = None
|
||||
|
||||
@ -12862,6 +12863,9 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
|
||||
def fn(xs):
|
||||
flat_xs, spec = pytree.tree_flatten(xs)
|
||||
res = [x.clone() for x in flat_xs]
|
||||
if pytree.__name__ == "optree":
|
||||
# The treespec argument comes first in OpTree / JAX PyTree
|
||||
return pytree.tree_unflatten(spec, res)
|
||||
return pytree.tree_unflatten(res, spec)
|
||||
|
||||
xs = [torch.tensor(i) for i in range(3)]
|
||||
@ -12876,6 +12880,9 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
|
||||
def fn(xs):
|
||||
flat_xs, spec = pytree.tree_flatten(xs)
|
||||
res = [x.clone() for x in flat_xs]
|
||||
if pytree.__name__ == "optree":
|
||||
# The treespec argument comes first in OpTree / JAX PyTree
|
||||
return pytree.tree_unflatten(spec, res)
|
||||
return pytree.tree_unflatten(res, spec)
|
||||
|
||||
xs = [torch.tensor(i) for i in range(3)]
|
||||
@ -12893,6 +12900,9 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
|
||||
def fn(xs):
|
||||
flat_xs, spec = pytree.tree_flatten(xs)
|
||||
res = [x.clone() for x in flat_xs]
|
||||
if pytree.__name__ == "optree":
|
||||
# The treespec argument comes first in OpTree / JAX PyTree
|
||||
return pytree.tree_unflatten(spec, res)
|
||||
return pytree.tree_unflatten(res, spec)
|
||||
|
||||
xs = [torch.tensor(i) for i in range(3)]
|
||||
@ -12910,6 +12920,9 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
|
||||
def fn(xs):
|
||||
flat_xs, spec = pytree.tree_flatten(xs)
|
||||
res = [x.clone() for x in flat_xs]
|
||||
if pytree.__name__ == "optree":
|
||||
# The treespec argument comes first in OpTree / JAX PyTree
|
||||
return pytree.tree_unflatten(spec, res)
|
||||
return pytree.tree_unflatten(res, spec)
|
||||
|
||||
xs = [torch.tensor(i) for i in range(3)]
|
||||
@ -12931,6 +12944,9 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
|
||||
def fn(xs):
|
||||
flat_xs, spec = pytree.tree_flatten(xs)
|
||||
res = [x.clone() for x in flat_xs]
|
||||
if pytree.__name__ == "optree":
|
||||
# The treespec argument comes first in OpTree / JAX PyTree
|
||||
return pytree.tree_unflatten(spec, res)
|
||||
return pytree.tree_unflatten(res, spec)
|
||||
|
||||
xs = [torch.tensor(i) for i in range(3)]
|
||||
@ -13032,7 +13048,13 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
|
||||
torch.ones(3, 2),
|
||||
1,
|
||||
]
|
||||
new_tree = pytree.tree_unflatten(new_leaves, treespec)
|
||||
if pytree.__name__ == "optree":
|
||||
# `None` is a internal node rather than leaf in default OpTree / JAX PyTree
|
||||
new_leaves.pop()
|
||||
# The treespec argument comes first in OpTree / JAX PyTree
|
||||
new_tree = pytree.tree_unflatten(treespec, new_leaves)
|
||||
else:
|
||||
new_tree = pytree.tree_unflatten(new_leaves, treespec)
|
||||
return leaves, new_tree
|
||||
|
||||
x = torch.randn(3, 2)
|
||||
@ -13087,6 +13109,10 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
|
||||
|
||||
@parametrize_pytree_module
|
||||
def test_pytree_tree_map_only(self, pytree):
|
||||
if not callable(getattr(pytree, "tree_map_only", None)):
|
||||
# OpTree and JAX PyTree do not have `tree_map_only`
|
||||
return
|
||||
|
||||
def fn(xs):
|
||||
def mapper(x):
|
||||
return x.clone()
|
||||
|
||||
@ -74,7 +74,14 @@ class TestAsyncCompile(TestCase):
|
||||
return (a @ b).to(torch.float32).sum(dim=1)
|
||||
|
||||
# Fake name to make sure the lookup table is name agnostic
|
||||
func_def = """
|
||||
# When codegen/triton.py is changed, func_def must be updated
|
||||
loop_header = (
|
||||
"for r0_offset in tl.range(0, r0_numel, R0_BLOCK, num_stages = 2):"
|
||||
if torch.version.hip
|
||||
else "for r0_offset in tl.range(0, r0_numel, R0_BLOCK):"
|
||||
)
|
||||
|
||||
func_def = f"""
|
||||
def triton_fused_fake_name(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
|
||||
xnumel = 1024
|
||||
r0_numel = 11776
|
||||
@ -87,7 +94,7 @@ def triton_fused_fake_name(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.cons
|
||||
rbase = r0_base
|
||||
x0 = xindex
|
||||
_tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
|
||||
for r0_offset in range(0, r0_numel, R0_BLOCK):
|
||||
{loop_header}
|
||||
r0_index = r0_offset + r0_base
|
||||
r0_mask = r0_index < r0_numel
|
||||
roffset = r0_offset
|
||||
|
||||
@ -1,572 +0,0 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
"""
|
||||
Tests for custom operation autotuning with PyTorch Inductor.
|
||||
|
||||
Users can register custom ops with multiple decomposition implementations and let
|
||||
Inductor automatically select the best performing variant. Key features tested:
|
||||
|
||||
- Name-based input generators (use argument names instead of indices)
|
||||
- Dynamic shape handling across multiple compilations
|
||||
- Parametric tuning with tuning_knob for combinatorial parameter exploration
|
||||
- Numerical correctness and performance validation
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch._inductor import config
|
||||
from torch._inductor.kernel.custom_op import (
|
||||
CustomOpConfig,
|
||||
register_custom_op_autotuning,
|
||||
)
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
from torch.testing._internal.common_utils import skipIfXpu
|
||||
from torch.testing._internal.inductor_utils import HAS_GPU
|
||||
|
||||
|
||||
torch.set_float32_matmul_precision("high")
|
||||
|
||||
|
||||
class TestCustomOpAutoTune(TestCase):
|
||||
"""Test custom operation autotuning functionality."""
|
||||
|
||||
def setUp(self) -> None:
|
||||
"""Set up test environment with appropriate device and dtype."""
|
||||
super().setUp()
|
||||
self.device = "cuda" if HAS_GPU else "cpu"
|
||||
self.dtype = torch.float16 if self.device == "cuda" else torch.float32
|
||||
|
||||
def _create_test_configs(self):
|
||||
"""Create common test configurations for different sizes."""
|
||||
return [
|
||||
{"batch_size": 1, "seq_len": 32, "hidden_dim": 128},
|
||||
{"batch_size": 2, "seq_len": 64, "hidden_dim": 256},
|
||||
]
|
||||
|
||||
def _run_autotune_test(self, op_object, inputs, expected, test_name):
|
||||
"""Shared test infrastructure for autotuning tests."""
|
||||
|
||||
@torch.compile
|
||||
def test_model(*args):
|
||||
return op_object(*args)
|
||||
|
||||
torch._dynamo.reset()
|
||||
autotune_backends = "TRITON" if self.device == "cuda" else "ATEN"
|
||||
|
||||
with config.patch(
|
||||
max_autotune=True,
|
||||
max_autotune_gemm_backends=autotune_backends,
|
||||
fx_graph_cache=False,
|
||||
benchmark_kernel=True,
|
||||
):
|
||||
compiled_result = test_model(*inputs)
|
||||
|
||||
self.assertEqual(
|
||||
compiled_result.shape, expected.shape, f"{test_name} shape mismatch"
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
compiled_result,
|
||||
expected,
|
||||
rtol=2e-1,
|
||||
atol=5e-1,
|
||||
msg=f"{test_name} numerical mismatch",
|
||||
)
|
||||
|
||||
def _assert_implementations_equivalent(self, decompositions, inputs, op_name):
|
||||
"""Utility to assert that all implementations produce equivalent results."""
|
||||
implementations = [(func.__name__, func) for func in decompositions]
|
||||
results = {}
|
||||
for name, impl in implementations:
|
||||
result = impl(*inputs)
|
||||
results[name] = result
|
||||
|
||||
# Basic sanity checks
|
||||
self.assertTrue(
|
||||
torch.isfinite(result).all(),
|
||||
f"{op_name} {name} produced non-finite values",
|
||||
)
|
||||
|
||||
# Verify numerical equivalence
|
||||
reference_name, reference_result = next(iter(results.items()))
|
||||
for name, result in results.items():
|
||||
if name != reference_name:
|
||||
rtol = 1e-1 if "Approximated" in name else 1e-2
|
||||
atol = 1e-1 if "Approximated" in name else 1e-2
|
||||
torch.testing.assert_close(
|
||||
result,
|
||||
reference_result,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
msg=f"{op_name} {name} differs from {reference_name}",
|
||||
)
|
||||
|
||||
def _create_rmsnorm_inputs(self, batch_size=8, seq_len=1024, hidden_dim=512):
|
||||
"""Create test inputs for RMSNorm operations."""
|
||||
input_tensor = torch.randn(
|
||||
batch_size,
|
||||
seq_len,
|
||||
hidden_dim,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
requires_grad=False,
|
||||
)
|
||||
weight = torch.randn(
|
||||
hidden_dim, device=self.device, dtype=self.dtype, requires_grad=False
|
||||
)
|
||||
return input_tensor, weight
|
||||
|
||||
def _create_mlp_inputs(
|
||||
self,
|
||||
batch_size=2,
|
||||
seq_len=32,
|
||||
hidden_dim=512,
|
||||
intermediate_dim=1024,
|
||||
output_dim=256,
|
||||
):
|
||||
"""Create test inputs for MLP operations."""
|
||||
input_tensor = torch.randn(
|
||||
batch_size,
|
||||
seq_len,
|
||||
hidden_dim,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
requires_grad=False,
|
||||
)
|
||||
gate_weight = torch.randn(
|
||||
hidden_dim,
|
||||
intermediate_dim,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
requires_grad=False,
|
||||
)
|
||||
up_weight = torch.randn(
|
||||
hidden_dim,
|
||||
intermediate_dim,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
requires_grad=False,
|
||||
)
|
||||
down_weight = torch.randn(
|
||||
intermediate_dim,
|
||||
output_dim,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
requires_grad=False,
|
||||
)
|
||||
return input_tensor, gate_weight, up_weight, down_weight
|
||||
|
||||
@skipIfXpu
|
||||
def test_rmsnorm_custom_op_autotune_with_dynamic_shape(self):
|
||||
"""Test RMSNorm autotuning decomposition variants compared to fallback default with dynamic shapes."""
|
||||
test_op_name = f"test_lib::rmsnorm_{id(self)}"
|
||||
|
||||
def rmsnorm_decomposition1(
|
||||
x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-8
|
||||
) -> torch.Tensor:
|
||||
"""Variance-based approach: compute variance then rsqrt."""
|
||||
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
rstd = torch.rsqrt(variance + eps)
|
||||
return x * rstd * weight
|
||||
|
||||
def rmsnorm_decomposition2(
|
||||
x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-8
|
||||
) -> torch.Tensor:
|
||||
"""vLLM-style RMSNorm implementation - variance computation first approach."""
|
||||
x_var = x # In vLLM, this could be sliced for variance_size_override
|
||||
|
||||
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
|
||||
|
||||
x = x * torch.rsqrt(variance + eps)
|
||||
|
||||
if weight is not None:
|
||||
x = x * weight
|
||||
return x
|
||||
|
||||
def rmsnorm_decomposition3(
|
||||
x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-8
|
||||
) -> torch.Tensor:
|
||||
"""vLLM-style RMSNorm with extended variance computation pattern."""
|
||||
x_squared = x.pow(2)
|
||||
variance = x_squared.mean(dim=-1, keepdim=True)
|
||||
|
||||
rstd = torch.rsqrt(variance + eps)
|
||||
normalized = x * rstd
|
||||
|
||||
# Apply weight scaling
|
||||
if weight is not None:
|
||||
normalized = normalized * weight
|
||||
return normalized
|
||||
|
||||
@torch.library.custom_op(test_op_name, mutates_args=())
|
||||
def test_rmsnorm_op(
|
||||
input_tensor: torch.Tensor, weight: torch.Tensor, eps: float = 1e-8
|
||||
) -> torch.Tensor:
|
||||
return torch.nn.functional.rms_norm(
|
||||
input_tensor, input_tensor.shape[-1:], weight, eps=eps
|
||||
)
|
||||
|
||||
@test_rmsnorm_op.register_fake
|
||||
def _(input_tensor: torch.Tensor, weight: torch.Tensor, eps: float = 1e-8):
|
||||
return torch.empty_like(input_tensor)
|
||||
|
||||
lib_name, op_name = test_op_name.split("::")
|
||||
op_object = getattr(getattr(torch.ops, lib_name), op_name)
|
||||
|
||||
decompositions = [
|
||||
rmsnorm_decomposition1,
|
||||
rmsnorm_decomposition2,
|
||||
rmsnorm_decomposition3,
|
||||
]
|
||||
|
||||
register_custom_op_autotuning(
|
||||
op_object.default,
|
||||
configs=[
|
||||
CustomOpConfig(rmsnorm_decomposition1) for decomp in decompositions
|
||||
],
|
||||
name="test_rmsnorm_autotuned",
|
||||
input_gen_fns={
|
||||
"x": lambda x: torch.randn_like(x, device=self.device) * 0.02,
|
||||
"weight": lambda weight: torch.ones_like(weight, device=self.device),
|
||||
},
|
||||
)
|
||||
|
||||
# Test multiple shapes to verify dynamic shape handling
|
||||
test_shapes = [(2, 16, 128), (8, 32, 256)]
|
||||
|
||||
for i, (batch_size, seq_len, hidden_dim) in enumerate(test_shapes):
|
||||
input_tensor = torch.randn(
|
||||
batch_size,
|
||||
seq_len,
|
||||
hidden_dim,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
requires_grad=False,
|
||||
)
|
||||
weight = torch.randn(
|
||||
hidden_dim, device=self.device, dtype=self.dtype, requires_grad=False
|
||||
)
|
||||
|
||||
# Test numerical equivalence for all decompositions
|
||||
self._assert_implementations_equivalent(
|
||||
decompositions, (input_tensor, weight), f"RMSNorm_{i}"
|
||||
)
|
||||
|
||||
# Test autotuning
|
||||
expected = rmsnorm_decomposition1(input_tensor, weight)
|
||||
self._run_autotune_test(
|
||||
op_object, (input_tensor, weight), expected, f"RMSNorm_{i}"
|
||||
)
|
||||
|
||||
@skipIfXpu
|
||||
def test_mlp_custom_op_autotune(self):
|
||||
"""Test MLP autotuning with method parameter controlling different decomposition variants"""
|
||||
test_op_name = f"test_lib::mlp_{id(self)}"
|
||||
|
||||
def mlp_variants(
|
||||
input_tensor: torch.Tensor,
|
||||
gate_weight: torch.Tensor,
|
||||
up_weight: torch.Tensor,
|
||||
down_weight: torch.Tensor,
|
||||
method: int = 0,
|
||||
) -> torch.Tensor:
|
||||
"""MLP implementation with different computational approaches controlled by method parameter."""
|
||||
|
||||
if method == 0:
|
||||
# Separate matmuls: standard implementation with torch.matmul
|
||||
gate_proj = torch.matmul(input_tensor, gate_weight)
|
||||
up_proj = torch.matmul(input_tensor, up_weight)
|
||||
gated = torch.relu(gate_proj) * up_proj
|
||||
return torch.matmul(gated, down_weight)
|
||||
|
||||
elif method == 1:
|
||||
# Batched approach: uses torch.mm with reshaped tensors
|
||||
batch_shape = input_tensor.shape[:-1]
|
||||
hidden_dim = input_tensor.shape[-1]
|
||||
output_dim = down_weight.shape[-1]
|
||||
|
||||
input_2d = input_tensor.view(-1, hidden_dim)
|
||||
|
||||
gate_proj = torch.mm(input_2d, gate_weight)
|
||||
up_proj = torch.mm(input_2d, up_weight)
|
||||
|
||||
gated = torch.relu(gate_proj) * up_proj
|
||||
output_2d = torch.mm(gated, down_weight)
|
||||
|
||||
return output_2d.view(*batch_shape, output_dim)
|
||||
|
||||
elif method == 2:
|
||||
# Fused weights approach: concatenate then split weights
|
||||
# Concatenate gate and up weights for one matrix multiply
|
||||
fused_weight = torch.cat([gate_weight, up_weight], dim=1)
|
||||
fused_proj = torch.matmul(input_tensor, fused_weight)
|
||||
|
||||
intermediate_dim = gate_weight.shape[1]
|
||||
gate_proj, up_proj = fused_proj.split(
|
||||
[intermediate_dim, intermediate_dim], dim=-1
|
||||
)
|
||||
|
||||
gated = torch.relu(gate_proj) * up_proj
|
||||
|
||||
return torch.matmul(gated, down_weight)
|
||||
|
||||
@torch.library.custom_op(test_op_name, mutates_args=())
|
||||
def test_mlp_op(
|
||||
input_tensor: torch.Tensor,
|
||||
gate_weight: torch.Tensor,
|
||||
up_weight: torch.Tensor,
|
||||
down_weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return mlp_variants(
|
||||
input_tensor, gate_weight, up_weight, down_weight, method=0
|
||||
)
|
||||
|
||||
@test_mlp_op.register_fake
|
||||
def _(
|
||||
input_tensor: torch.Tensor,
|
||||
gate_weight: torch.Tensor,
|
||||
up_weight: torch.Tensor,
|
||||
down_weight: torch.Tensor,
|
||||
method: int = 0,
|
||||
):
|
||||
return torch.empty(
|
||||
input_tensor.shape[:-1] + (down_weight.shape[-1],),
|
||||
device=input_tensor.device,
|
||||
dtype=input_tensor.dtype,
|
||||
)
|
||||
|
||||
lib_name, op_name = test_op_name.split("::")
|
||||
op_object = getattr(getattr(torch.ops, lib_name), op_name)
|
||||
|
||||
# Use explicit configs with method parameter as tuning knob
|
||||
register_custom_op_autotuning(
|
||||
op_object.default,
|
||||
configs=[
|
||||
CustomOpConfig(mlp_variants, method=1), # Batched approach
|
||||
CustomOpConfig(mlp_variants, method=2), # Fused weights
|
||||
],
|
||||
name="test_mlp_autotuned",
|
||||
input_gen_fns={
|
||||
"input_tensor": lambda fake_tensor: torch.randn_like(
|
||||
fake_tensor, device=self.device
|
||||
)
|
||||
* 0.1,
|
||||
"gate_weight": lambda fake_tensor: torch.randn_like(
|
||||
fake_tensor, device=self.device
|
||||
)
|
||||
* 0.05,
|
||||
"up_weight": lambda fake_tensor: torch.randn_like(
|
||||
fake_tensor, device=self.device
|
||||
)
|
||||
* 0.05,
|
||||
"down_weight": lambda fake_tensor: torch.randn_like(
|
||||
fake_tensor, device=self.device
|
||||
)
|
||||
* 0.05,
|
||||
},
|
||||
)
|
||||
|
||||
# Create test inputs using the original helper method
|
||||
input_tensor, gate_weight, up_weight, down_weight = self._create_mlp_inputs()
|
||||
|
||||
# Test that all method variants produce numerically equivalent results
|
||||
expected = mlp_variants(
|
||||
input_tensor, gate_weight, up_weight, down_weight, method=0
|
||||
)
|
||||
|
||||
for method in [1, 2]:
|
||||
result = mlp_variants(
|
||||
input_tensor, gate_weight, up_weight, down_weight, method=method
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
result,
|
||||
expected,
|
||||
rtol=1e-5,
|
||||
atol=1e-5,
|
||||
msg=f"Method {method} not equivalent to method 0",
|
||||
)
|
||||
|
||||
# Test autotuning - all should be mathematically equivalent
|
||||
self._run_autotune_test(
|
||||
op_object,
|
||||
(input_tensor, gate_weight, up_weight, down_weight),
|
||||
expected,
|
||||
"MLP",
|
||||
)
|
||||
|
||||
def _create_decompose_k_inputs(self, m=256, k=65536, n=1024):
|
||||
"""Create test inputs for decompose_k matrix multiplication - divisible by all k_splits values."""
|
||||
# Ensure k is divisible by all k_splits values: [2, 32, 64, 128, 256]
|
||||
k = ((k + 255) // 256) * 256 # Round up to nearest multiple of 256
|
||||
a = torch.randn(m, k, device=self.device, dtype=self.dtype, requires_grad=False)
|
||||
b = torch.randn(k, n, device=self.device, dtype=self.dtype, requires_grad=False)
|
||||
return a, b
|
||||
|
||||
@skipIfXpu
|
||||
def test_decompose_k_custom_op_autotune(self):
|
||||
"""Test decompose_k autotuning with parameter tuning for k_splits values."""
|
||||
test_op_name = f"test_lib::decompose_k_{id(self)}"
|
||||
|
||||
def decompose_k_implementation(
|
||||
a: torch.Tensor, b: torch.Tensor, k_splits: int = 4
|
||||
) -> torch.Tensor:
|
||||
"""Matrix multiply with k-way decomposition - parameter-tuned implementation."""
|
||||
m = a.shape[0]
|
||||
n = b.shape[1]
|
||||
k = a.shape[1]
|
||||
|
||||
k_parts = k // k_splits
|
||||
B = k_splits
|
||||
|
||||
a_reshaped = torch.permute(
|
||||
a.reshape(m, B, k_parts), (1, 0, 2)
|
||||
) # [B, m, k_parts]
|
||||
b_reshaped = b.reshape(B, k_parts, n) # [B, k_parts, n]
|
||||
|
||||
result = torch.bmm(a_reshaped, b_reshaped) # [B, m, n]
|
||||
|
||||
return torch.sum(result, dim=0) # [m, n]
|
||||
|
||||
@torch.library.custom_op(test_op_name, mutates_args=())
|
||||
def test_decompose_k_op(
|
||||
a: torch.Tensor, b: torch.Tensor, k_splits: int = 4
|
||||
) -> torch.Tensor:
|
||||
return decompose_k_implementation(a, b, k_splits)
|
||||
|
||||
@test_decompose_k_op.register_fake
|
||||
def _(a: torch.Tensor, b: torch.Tensor, k_splits: int = 4):
|
||||
return torch.empty(a.shape[0], b.shape[1], device=a.device, dtype=a.dtype)
|
||||
|
||||
lib_name, op_name = test_op_name.split("::")
|
||||
op_object = getattr(getattr(torch.ops, lib_name), op_name)
|
||||
|
||||
# Use parameter tuning to test different k_splits values
|
||||
register_custom_op_autotuning(
|
||||
op_object.default,
|
||||
configs=[
|
||||
CustomOpConfig(decompose_k_implementation, k_splits=2),
|
||||
CustomOpConfig(decompose_k_implementation, k_splits=32),
|
||||
CustomOpConfig(decompose_k_implementation, k_splits=64),
|
||||
CustomOpConfig(decompose_k_implementation, k_splits=128),
|
||||
CustomOpConfig(decompose_k_implementation, k_splits=256),
|
||||
],
|
||||
name="test_decompose_k_autotuned",
|
||||
input_gen_fns={
|
||||
"a": lambda fake_tensor: torch.randn_like(
|
||||
fake_tensor, device=self.device
|
||||
)
|
||||
* 0.1, # Matrix A
|
||||
"b": lambda fake_tensor: torch.randn_like(
|
||||
fake_tensor, device=self.device
|
||||
)
|
||||
* 0.1, # Matrix B
|
||||
},
|
||||
)
|
||||
|
||||
a, b = self._create_decompose_k_inputs()
|
||||
expected = a @ b
|
||||
self._run_autotune_test(op_object, (a, b), expected, "DecomposeK")
|
||||
|
||||
@skipIfXpu
|
||||
def test_multi_parameter_tuning(self):
|
||||
"""Test autotuning with multiple parameters using scale_mode and chunk_size."""
|
||||
op_name = f"test_lib::multi_param_{id(self)}"
|
||||
|
||||
def multi_param_scaling(
|
||||
x: torch.Tensor,
|
||||
factor: torch.Tensor,
|
||||
scale_mode: int = 1,
|
||||
chunk_size: int = 16,
|
||||
) -> torch.Tensor:
|
||||
"""Different scaling approaches controlled by scale_mode parameter."""
|
||||
if scale_mode == 1:
|
||||
# Simple broadcasting
|
||||
return x * factor
|
||||
elif scale_mode == 2:
|
||||
# Process in chunks
|
||||
batch_size, seq_len = x.shape[:2]
|
||||
chunks = []
|
||||
for start in range(0, seq_len, chunk_size):
|
||||
end = min(start + chunk_size, seq_len)
|
||||
chunk = x[:, start:end]
|
||||
chunks.append(chunk * factor)
|
||||
return torch.cat(chunks, dim=1)
|
||||
elif scale_mode == 3:
|
||||
# Using einsum for scaling
|
||||
return torch.einsum("...i,i->...i", x, factor)
|
||||
|
||||
@torch.library.custom_op(op_name, mutates_args=())
|
||||
def multi_param_op(
|
||||
x: torch.Tensor,
|
||||
factor: torch.Tensor,
|
||||
scale_mode: int = 1,
|
||||
chunk_size: int = 16,
|
||||
) -> torch.Tensor:
|
||||
return multi_param_scaling(x, factor, scale_mode, chunk_size)
|
||||
|
||||
@multi_param_op.register_fake
|
||||
def _(
|
||||
x: torch.Tensor,
|
||||
factor: torch.Tensor,
|
||||
scale_mode: int = 1,
|
||||
chunk_size: int = 16,
|
||||
):
|
||||
return torch.empty_like(x)
|
||||
|
||||
lib_name, op_suffix = op_name.split("::")
|
||||
op_object = getattr(getattr(torch.ops, lib_name), op_suffix)
|
||||
|
||||
# Use explicit configs with scale_mode and chunk_size parameters as tuning knobs
|
||||
register_custom_op_autotuning(
|
||||
op_object.default,
|
||||
configs=[
|
||||
CustomOpConfig(multi_param_scaling, scale_mode=1), # Broadcast
|
||||
CustomOpConfig(
|
||||
multi_param_scaling, scale_mode=2, chunk_size=16
|
||||
), # Chunked 16
|
||||
CustomOpConfig(
|
||||
multi_param_scaling, scale_mode=2, chunk_size=32
|
||||
), # Chunked 32
|
||||
CustomOpConfig(multi_param_scaling, scale_mode=3), # Einsum
|
||||
],
|
||||
name="multi_param_autotuned",
|
||||
input_gen_fns={
|
||||
"x": lambda t: torch.randn_like(t, device=self.device) * 0.1,
|
||||
"factor": lambda t: torch.ones(
|
||||
t.shape[-1], device=self.device, dtype=t.dtype
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
# Create test inputs
|
||||
test_x = torch.randn(4, 64, 128, device=self.device, dtype=self.dtype)
|
||||
test_factor = torch.ones(128, device=self.device, dtype=self.dtype) * 2.0
|
||||
|
||||
# Verify numerical equivalence across all approaches
|
||||
expected_result = test_x * test_factor
|
||||
|
||||
# Test each scale_mode variant
|
||||
configs = [
|
||||
(1, 16), # broadcast, chunk_size ignored
|
||||
(2, 16), # chunked with size 16
|
||||
(2, 32), # chunked with size 32
|
||||
(3, 16), # einsum, chunk_size ignored
|
||||
]
|
||||
|
||||
for i, (scale_mode, chunk_size) in enumerate(configs):
|
||||
result = multi_param_scaling(
|
||||
test_x, test_factor, scale_mode=scale_mode, chunk_size=chunk_size
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
result,
|
||||
expected_result,
|
||||
rtol=1e-5,
|
||||
atol=1e-5,
|
||||
msg=f"scale_mode {scale_mode} with chunk_size {chunk_size} not equivalent to expected",
|
||||
)
|
||||
|
||||
# Test autotuning
|
||||
self._run_autotune_test(
|
||||
op_object, (test_x, test_factor), expected_result, "MultiParam"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
@ -1034,6 +1034,22 @@ def forward(self, arg0_1, arg1_1, arg2_1):
|
||||
x = torch.randn(7, device=self.device)
|
||||
self.check(M(), (x,), dynamic_shapes=({0: Dim.DYNAMIC},))
|
||||
|
||||
@parametrize("dynamic", (False, True))
|
||||
@parametrize("input_", (1.5, 2, False))
|
||||
def test_item(self, input_, dynamic: bool):
|
||||
"""
|
||||
Test calling Tensor.item.
|
||||
"""
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x[1].item()
|
||||
|
||||
x = torch.tensor((input_,) * 10)
|
||||
d = Dim("s0", min=1)
|
||||
dynamic_shapes = ({0: 2 * d},) if dynamic else None
|
||||
self.check(M(), (x,), dynamic_shapes=dynamic_shapes)
|
||||
|
||||
@parametrize("pred", (False, True))
|
||||
def test_mismatched_branch_dynamic(self, pred: bool):
|
||||
"""
|
||||
|
||||
@ -14295,8 +14295,12 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
||||
self.assertTrue(torch.all(result < 2560).item())
|
||||
|
||||
code_str = "\n".join(code)
|
||||
if torch.version.hip:
|
||||
triton_str = "tl.minimum"
|
||||
else:
|
||||
triton_str = "triton_helpers.minimum"
|
||||
self.assertIn(
|
||||
"triton_helpers.minimum",
|
||||
triton_str,
|
||||
code_str,
|
||||
"Generated Triton code should use triton_helpers.minimum for clamping",
|
||||
)
|
||||
|
||||
@ -9986,6 +9986,20 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
|
||||
self.assertEqual(result_triu_min, expected_triu_min)
|
||||
self.assertEqual(result_tril_min, expected_tril_min)
|
||||
|
||||
@dtypes(torch.float)
|
||||
def test_triu_tril_inplace_memory_overlap(self, device, dtype):
|
||||
base = torch.rand((), dtype=dtype, device=device)
|
||||
expanded = base.expand(3, 3)
|
||||
msg = (
|
||||
"unsupported operation: more than one element of the written-to tensor "
|
||||
"refers to a single memory location. Please clone() the tensor before "
|
||||
"performing the operation."
|
||||
)
|
||||
with self.assertRaisesRegex(RuntimeError, msg):
|
||||
expanded.triu_(1)
|
||||
with self.assertRaisesRegex(RuntimeError, msg):
|
||||
expanded.tril_(-1)
|
||||
|
||||
@dtypes(torch.float, torch.double)
|
||||
@precisionOverride({torch.float32: 1e-4})
|
||||
def test_1_sized_with_0_strided(self, device, dtype):
|
||||
|
||||
@ -6,7 +6,7 @@ from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Literal, TYPE_CHECKING
|
||||
from typing import Any, Callable, TYPE_CHECKING
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
import torch.utils._pytree as python_pytree
|
||||
@ -28,7 +28,7 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
||||
import optree
|
||||
import optree._C
|
||||
|
||||
import torch.utils._cxx_pytree as cxx_pytree
|
||||
import torch.utils._cxx_pytree as cxx_pytree # noqa: F401
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.utils._cxx_pytree import PyTree
|
||||
@ -64,45 +64,69 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
||||
del __func
|
||||
del __name
|
||||
|
||||
@substitute_in_graph(cxx_pytree.tree_is_leaf, can_constant_fold_through=True)
|
||||
@substitute_in_graph(optree.tree_is_leaf, can_constant_fold_through=True)
|
||||
def tree_is_leaf(
|
||||
tree: PyTree,
|
||||
/,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
*,
|
||||
none_is_leaf: bool = False,
|
||||
namespace: str = "",
|
||||
) -> bool:
|
||||
if tree is None or (is_leaf is not None and is_leaf(tree)):
|
||||
if (tree is None and none_is_leaf) or (is_leaf is not None and is_leaf(tree)):
|
||||
return True
|
||||
if optree.register_pytree_node.get(type(tree), namespace="torch") is None: # type: ignore[attr-defined]
|
||||
if optree.register_pytree_node.get(type(tree), namespace=namespace) is None: # type: ignore[attr-defined]
|
||||
return True
|
||||
return False
|
||||
|
||||
@substitute_in_graph(cxx_pytree.tree_iter, can_constant_fold_through=False)
|
||||
@substitute_in_graph(optree.tree_iter, can_constant_fold_through=False)
|
||||
def tree_iter(
|
||||
tree: PyTree,
|
||||
/,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
*,
|
||||
none_is_leaf: bool = False,
|
||||
namespace: str = "",
|
||||
) -> Iterable[Any]:
|
||||
stack = [tree]
|
||||
while stack:
|
||||
node = stack.pop()
|
||||
if tree_is_leaf(node, is_leaf=is_leaf):
|
||||
if tree_is_leaf(
|
||||
node,
|
||||
is_leaf=is_leaf,
|
||||
none_is_leaf=none_is_leaf,
|
||||
namespace=namespace,
|
||||
):
|
||||
yield node
|
||||
continue
|
||||
|
||||
children, *_ = optree.tree_flatten_one_level(
|
||||
node,
|
||||
is_leaf=is_leaf,
|
||||
none_is_leaf=True,
|
||||
namespace="torch",
|
||||
none_is_leaf=none_is_leaf,
|
||||
namespace=namespace,
|
||||
)
|
||||
stack.extend(reversed(children))
|
||||
|
||||
__all__ += ["tree_iter"]
|
||||
|
||||
@substitute_in_graph(cxx_pytree.tree_leaves, can_constant_fold_through=True)
|
||||
@substitute_in_graph(optree.tree_leaves, can_constant_fold_through=True)
|
||||
def tree_leaves(
|
||||
tree: PyTree,
|
||||
/,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
*,
|
||||
none_is_leaf: bool = False,
|
||||
namespace: str = "",
|
||||
) -> list[Any]:
|
||||
return list(tree_iter(tree, is_leaf=is_leaf))
|
||||
return list(
|
||||
tree_iter(
|
||||
tree,
|
||||
is_leaf=is_leaf,
|
||||
none_is_leaf=none_is_leaf,
|
||||
namespace=namespace,
|
||||
)
|
||||
)
|
||||
|
||||
__all__ += ["tree_leaves"]
|
||||
|
||||
@ -127,12 +151,12 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
||||
_metadata: Any
|
||||
_entries: tuple[Any, ...]
|
||||
_unflatten_func: Callable[[Any | None, Iterable[PyTree]], PyTree] | None
|
||||
none_is_leaf: bool
|
||||
namespace: str
|
||||
|
||||
num_nodes: int = field(init=False)
|
||||
num_leaves: int = field(init=False)
|
||||
num_children: int = field(init=False)
|
||||
none_is_leaf: Literal[True] = field(init=False)
|
||||
namespace: Literal["torch"] = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self._type is None:
|
||||
@ -152,8 +176,6 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
||||
object.__setattr__(self, "num_nodes", num_nodes)
|
||||
object.__setattr__(self, "num_leaves", num_leaves)
|
||||
object.__setattr__(self, "num_children", num_children)
|
||||
object.__setattr__(self, "none_is_leaf", True)
|
||||
object.__setattr__(self, "namespace", "torch")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
def helper(treespec: PyTreeSpec) -> str:
|
||||
@ -168,6 +190,7 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
||||
]
|
||||
if (
|
||||
treespec.type in BUILTIN_TYPES
|
||||
or (treespec.type is type(None) and not self.none_is_leaf)
|
||||
or optree.is_namedtuple_class(treespec.type)
|
||||
or optree.is_structseq_class(treespec.type)
|
||||
):
|
||||
@ -181,9 +204,12 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
||||
f"[{', '.join(children_representations)}])"
|
||||
)
|
||||
|
||||
return (
|
||||
f"PyTreeSpec({helper(self)}, NoneIsLeaf, namespace={self.namespace!r})"
|
||||
)
|
||||
inner = [
|
||||
str(helper(self)),
|
||||
*(["NoneIsLeaf"] if self.none_is_leaf else []),
|
||||
f"namespace={self.namespace!r}",
|
||||
]
|
||||
return f"PyTreeSpec({', '.join(inner)})"
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.num_leaves
|
||||
@ -228,8 +254,8 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
||||
|
||||
children, metadata, *_ = optree.tree_flatten_one_level(
|
||||
node,
|
||||
none_is_leaf=True,
|
||||
namespace="torch",
|
||||
none_is_leaf=self.none_is_leaf,
|
||||
namespace=self.namespace,
|
||||
)
|
||||
if len(children) != treespec.num_children:
|
||||
raise ValueError(
|
||||
@ -277,8 +303,8 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
||||
# node_type is treespec.type
|
||||
children, metadata, *_ = optree.tree_flatten_one_level(
|
||||
node,
|
||||
none_is_leaf=True,
|
||||
namespace="torch",
|
||||
none_is_leaf=self.none_is_leaf,
|
||||
namespace=self.namespace,
|
||||
)
|
||||
if (
|
||||
node_type
|
||||
@ -320,25 +346,40 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
||||
assert callable(self._unflatten_func)
|
||||
return self._unflatten_func(self._metadata, subtrees)
|
||||
|
||||
_LEAF_SPEC = PyTreeSpec((), None, None, (), None)
|
||||
|
||||
def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec]:
|
||||
return isinstance(obj, PyTreeSpec)
|
||||
|
||||
@substitute_in_graph( # type: ignore[arg-type]
|
||||
cxx_pytree.tree_flatten,
|
||||
optree.tree_flatten,
|
||||
# We need to disable constant folding here because we want the function to reference the
|
||||
# PyTreeSpec class defined above, not the one in the C++ module.
|
||||
can_constant_fold_through=False,
|
||||
)
|
||||
def tree_flatten(
|
||||
tree: PyTree,
|
||||
/,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
*,
|
||||
none_is_leaf: bool = False,
|
||||
namespace: str = "",
|
||||
) -> tuple[list[Any], PyTreeSpec]:
|
||||
def helper(node: PyTree, leaves: list[Any]) -> PyTreeSpec:
|
||||
if tree_is_leaf(node, is_leaf=is_leaf):
|
||||
if tree_is_leaf(
|
||||
node,
|
||||
is_leaf=is_leaf,
|
||||
none_is_leaf=none_is_leaf,
|
||||
namespace=namespace,
|
||||
):
|
||||
leaves.append(node)
|
||||
return _LEAF_SPEC
|
||||
return PyTreeSpec(
|
||||
(),
|
||||
None,
|
||||
None,
|
||||
(),
|
||||
None,
|
||||
none_is_leaf=none_is_leaf,
|
||||
namespace=namespace,
|
||||
)
|
||||
|
||||
(
|
||||
children,
|
||||
@ -348,13 +389,21 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
||||
) = optree.tree_flatten_one_level(
|
||||
node,
|
||||
is_leaf=is_leaf,
|
||||
none_is_leaf=True,
|
||||
namespace="torch",
|
||||
none_is_leaf=none_is_leaf,
|
||||
namespace=namespace,
|
||||
)
|
||||
|
||||
# Recursively flatten the children
|
||||
subspecs = tuple(helper(child, leaves) for child in children)
|
||||
return PyTreeSpec(subspecs, type(node), metadata, entries, unflatten_func) # type: ignore[arg-type]
|
||||
return PyTreeSpec(
|
||||
subspecs,
|
||||
type(node),
|
||||
metadata,
|
||||
entries,
|
||||
unflatten_func,
|
||||
none_is_leaf=none_is_leaf,
|
||||
namespace=namespace,
|
||||
) # type: ignore[arg-type]
|
||||
|
||||
leaves: list[Any] = []
|
||||
treespec = helper(tree, leaves)
|
||||
@ -363,26 +412,35 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
||||
__all__ += ["tree_flatten"]
|
||||
|
||||
@substitute_in_graph( # type: ignore[arg-type]
|
||||
cxx_pytree.tree_structure,
|
||||
optree.tree_structure,
|
||||
# We need to disable constant folding here because we want the function to reference the
|
||||
# PyTreeSpec class defined above, not the one in the C++ module.
|
||||
can_constant_fold_through=False,
|
||||
)
|
||||
def tree_structure(
|
||||
tree: PyTree,
|
||||
/,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
*,
|
||||
none_is_leaf: bool = False,
|
||||
namespace: str = "",
|
||||
) -> PyTreeSpec:
|
||||
return tree_flatten(tree, is_leaf=is_leaf)[1] # type: ignore[return-value]
|
||||
return tree_flatten( # type: ignore[return-value]
|
||||
tree,
|
||||
is_leaf=is_leaf,
|
||||
none_is_leaf=none_is_leaf,
|
||||
namespace=namespace,
|
||||
)[1]
|
||||
|
||||
__all__ += ["tree_structure"]
|
||||
|
||||
@substitute_in_graph( # type: ignore[arg-type]
|
||||
cxx_pytree.tree_unflatten,
|
||||
optree.tree_unflatten,
|
||||
# We need to disable constant folding here because we want the function to reference the
|
||||
# PyTreeSpec class defined above, not the one in the C++ module.
|
||||
can_constant_fold_through=False,
|
||||
)
|
||||
def tree_unflatten(leaves: Iterable[Any], treespec: PyTreeSpec) -> PyTree:
|
||||
def tree_unflatten(treespec: PyTreeSpec, leaves: Iterable[Any]) -> PyTree:
|
||||
if not _is_pytreespec_instance(treespec):
|
||||
raise TypeError(
|
||||
f"tree_unflatten(leaves, treespec): Expected `treespec` to be instance of "
|
||||
@ -392,29 +450,57 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
||||
|
||||
__all__ += ["tree_unflatten"]
|
||||
|
||||
@substitute_in_graph(cxx_pytree.tree_map, can_constant_fold_through=True)
|
||||
@substitute_in_graph(optree.tree_map, can_constant_fold_through=True)
|
||||
def tree_map(
|
||||
func: Callable[..., Any],
|
||||
tree: PyTree,
|
||||
/,
|
||||
*rests: PyTree,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
none_is_leaf: bool = False,
|
||||
namespace: str = "",
|
||||
) -> PyTree:
|
||||
leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
|
||||
leaves, treespec = tree_flatten(
|
||||
tree,
|
||||
is_leaf=is_leaf,
|
||||
none_is_leaf=none_is_leaf,
|
||||
namespace=namespace,
|
||||
)
|
||||
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
|
||||
return treespec.unflatten(map(func, *flat_args))
|
||||
|
||||
__all__ += ["tree_map"]
|
||||
|
||||
@substitute_in_graph(cxx_pytree.tree_map_, can_constant_fold_through=True)
|
||||
@substitute_in_graph(optree.tree_map_, can_constant_fold_through=True)
|
||||
def tree_map_(
|
||||
func: Callable[..., Any],
|
||||
tree: PyTree,
|
||||
/,
|
||||
*rests: PyTree,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
none_is_leaf: bool = False,
|
||||
namespace: str = "",
|
||||
) -> PyTree:
|
||||
leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
|
||||
leaves, treespec = tree_flatten(
|
||||
tree,
|
||||
is_leaf=is_leaf,
|
||||
none_is_leaf=none_is_leaf,
|
||||
namespace=namespace,
|
||||
)
|
||||
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
|
||||
deque(map(func, *flat_args), maxlen=0) # consume and exhaust the iterable
|
||||
return tree
|
||||
|
||||
__all__ += ["tree_map_"]
|
||||
|
||||
_none_unflatten = optree.register_pytree_node.get(type(None)).unflatten_func # type: ignore[union-attr]
|
||||
|
||||
@substitute_in_graph( # type: ignore[arg-type]
|
||||
_none_unflatten,
|
||||
can_constant_fold_through=True,
|
||||
skip_signature_check=True,
|
||||
)
|
||||
def none_unflatten(_: None, children: Iterable[Any], /) -> None:
|
||||
if len(list(children)) != 0:
|
||||
raise ValueError("Expected no children.")
|
||||
return None
|
||||
|
||||
@ -200,9 +200,10 @@ class SuperVariable(VariableTracker):
|
||||
and not (args or kwargs)
|
||||
):
|
||||
with do_not_convert_to_tracable_parameter():
|
||||
return variables.UserFunctionVariable(
|
||||
unpatched_nn_module_init, source=source
|
||||
).call_function(tx, [self.objvar] + args, kwargs)
|
||||
fn_vt = VariableTracker.build(
|
||||
tx, unpatched_nn_module_init, source=source
|
||||
)
|
||||
return fn_vt.call_function(tx, [self.objvar] + args, kwargs)
|
||||
else:
|
||||
unimplemented_v2(
|
||||
gb_type="Unsupported super().__init__() call",
|
||||
@ -230,9 +231,8 @@ class SuperVariable(VariableTracker):
|
||||
elif isinstance(inner_fn, staticmethod) and isinstance(
|
||||
inner_fn.__func__, types.FunctionType
|
||||
):
|
||||
return variables.UserFunctionVariable(
|
||||
inner_fn.__func__, source=source
|
||||
).call_function(tx, args, kwargs)
|
||||
fn_vt = VariableTracker.build(tx, inner_fn.__func__, source=source)
|
||||
return fn_vt.call_function(tx, args, kwargs)
|
||||
elif isinstance(inner_fn, classmethod) and isinstance(
|
||||
inner_fn.__func__, types.FunctionType
|
||||
):
|
||||
@ -255,13 +255,13 @@ class SuperVariable(VariableTracker):
|
||||
tx, self.objvar.value_type, cls_source
|
||||
)
|
||||
|
||||
return variables.UserFunctionVariable(
|
||||
inner_fn.__func__, source=AttrSource(source, "__func__")
|
||||
).call_function(tx, [cls_variable, *args], kwargs)
|
||||
fn_vt = VariableTracker.build(
|
||||
tx, inner_fn.__func__, source=AttrSource(source, "__func__")
|
||||
)
|
||||
return fn_vt.call_function(tx, [cls_variable, *args], kwargs)
|
||||
elif isinstance(inner_fn, types.FunctionType):
|
||||
return variables.UserFunctionVariable(
|
||||
inner_fn, source=source
|
||||
).call_function(tx, [self.objvar] + args, kwargs)
|
||||
fn_vt = VariableTracker.build(tx, inner_fn, source=source)
|
||||
return fn_vt.call_function(tx, [self.objvar] + args, kwargs)
|
||||
elif isinstance(inner_fn, types.MethodType):
|
||||
return variables.UserMethodVariable(
|
||||
inner_fn.__func__, self.objvar, source=source
|
||||
@ -574,10 +574,8 @@ class ComptimeVariable(VariableTracker):
|
||||
from ..comptime import comptime
|
||||
|
||||
# To support the comptime.print_graph convenience accessors
|
||||
from .functions import UserFunctionVariable
|
||||
|
||||
return UserFunctionVariable(
|
||||
getattr(comptime, name), source=AttrSource(self.source, name)
|
||||
return VariableTracker.build(
|
||||
tx, getattr(comptime, name), source=AttrSource(self.source, name)
|
||||
)
|
||||
|
||||
def call_function(
|
||||
@ -771,9 +769,8 @@ class AutogradFunctionVariable(VariableTracker):
|
||||
sig = inspect.signature(fn)
|
||||
if len(args) - 1 == len(sig._parameters):
|
||||
args = args[1:] # Don't use context
|
||||
return variables.UserFunctionVariable(fn, source=source).call_function(
|
||||
tx, args, kwargs
|
||||
)
|
||||
fn_vt = VariableTracker.build(tx, fn, source=source)
|
||||
return fn_vt.call_function(tx, args, kwargs)
|
||||
elif isinstance(fn, types.MethodType):
|
||||
return variables.UserMethodVariable(
|
||||
fn.__func__,
|
||||
@ -799,9 +796,8 @@ class AutogradFunctionVariable(VariableTracker):
|
||||
assert isinstance(fn, types.FunctionType)
|
||||
|
||||
fn_source = AttrSource(self.source, "backward")
|
||||
return variables.UserFunctionVariable(fn, source=fn_source).call_function(
|
||||
tx, args, kwargs
|
||||
)
|
||||
fn_vt = VariableTracker.build(tx, fn, source=fn_source)
|
||||
return fn_vt.call_function(tx, args, kwargs)
|
||||
|
||||
def call_function(self, tx: "InstructionTranslator", args, kwargs):
|
||||
return AutogradFunctionVariable(self.fn_cls)
|
||||
@ -1026,10 +1022,12 @@ class AutogradEngineVariable(UserDefinedObjectVariable):
|
||||
assert tx.one_graph or tx.error_on_graph_break, (
|
||||
"queue_callback() is only supported when Compiled Autograd is enabled with fullgraph=True"
|
||||
)
|
||||
return variables.UserFunctionVariable(
|
||||
fn_vt = VariableTracker.build(
|
||||
tx,
|
||||
torch._dynamo.external_utils.FakeCompiledAutogradEngine.queue_callback,
|
||||
source=self.source,
|
||||
).call_function(
|
||||
)
|
||||
return fn_vt.call_function(
|
||||
tx,
|
||||
(tx.output.side_effects.get_ca_final_callbacks_var(), *args),
|
||||
kwargs,
|
||||
|
||||
@ -1485,7 +1485,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||
else:
|
||||
self.writeline(f"{arg.inner} = {cexpr(arg.inner_expr)};")
|
||||
|
||||
def codegen_dynamic_scalar(self, node):
|
||||
def _codegen_dynamic_scalar(self, node):
|
||||
(data,) = (t.codegen_reference() for t in node.inputs)
|
||||
self.codegen_tensor_item(node.inputs[0].get_dtype(), data, f"{node.sym}_raw")
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import itertools
|
||||
import logging
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
import torch
|
||||
import torch._inductor.config as config
|
||||
@ -8,7 +8,6 @@ from torch._inductor import ir
|
||||
from torch._inductor.codegen.common import KernelTemplate
|
||||
from torch._inductor.ir import (
|
||||
Buffer,
|
||||
FixedLayout,
|
||||
get_free_symbols,
|
||||
get_symbolic_inputs,
|
||||
gm_original_output_strides,
|
||||
@ -111,12 +110,7 @@ class SubgraphChoiceCaller(ir.ChoiceCaller):
|
||||
bm_func([*sym_inputs, *args])
|
||||
if config.profile_bandwidth_with_do_bench_using_profiling:
|
||||
return do_bench_using_profiling(lambda: bm_func([*sym_inputs, *args]))
|
||||
|
||||
# Use appropriate benchmarker based on layout device type
|
||||
if self.layout.device.type == "cpu":
|
||||
return benchmarker.benchmark_cpu(lambda: bm_func([*sym_inputs, *args]))
|
||||
else:
|
||||
return benchmarker.benchmark_gpu(lambda: bm_func([*sym_inputs, *args]))
|
||||
return benchmarker.benchmark_gpu(lambda: bm_func([*sym_inputs, *args]))
|
||||
|
||||
def hash_key(self) -> str:
|
||||
return "-".join(
|
||||
@ -203,152 +197,3 @@ class SubgraphTemplate(KernelTemplate):
|
||||
description=description,
|
||||
make_fx_graph=make_fx_graph,
|
||||
)
|
||||
|
||||
def generate_custom_op_choices(
|
||||
self,
|
||||
name: str,
|
||||
decompositions: list[Callable[..., Any]],
|
||||
input_nodes: list[Buffer],
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
default_impl: Optional[Callable[..., Any]] = None,
|
||||
) -> list[SubgraphChoiceCaller]:
|
||||
"""
|
||||
Generate multiple SubgraphChoiceCaller instances for custom op autotuning.
|
||||
|
||||
This method extends SubgraphTemplate to support custom op decompositions,
|
||||
allowing multiple implementations to compete in autotuning.
|
||||
|
||||
Args:
|
||||
name: Base name for the choices
|
||||
decompositions: List of decomposition functions to compare
|
||||
input_nodes: Input nodes for the operation
|
||||
kwargs: Additional arguments for decomposition functions
|
||||
default_impl: Default implementation for layout inference
|
||||
|
||||
Returns:
|
||||
List of SubgraphChoiceCaller instances for autotuning
|
||||
"""
|
||||
if not decompositions:
|
||||
return []
|
||||
|
||||
kwargs = kwargs or {}
|
||||
|
||||
# Infer layouts and ensure stride consistency for fair autotuning comparison
|
||||
layouts = [
|
||||
self._infer_custom_op_layout(input_nodes, [decomp], kwargs, default_impl)
|
||||
for decomp in decompositions
|
||||
]
|
||||
|
||||
self._validate_stride_consistency(name, decompositions, layouts)
|
||||
|
||||
# Assert single output layout - assumes custom ops have one output tensor
|
||||
assert len(layouts) > 0, f"No layouts inferred for custom op '{name}'"
|
||||
assert all(
|
||||
layout.device == layouts[0].device
|
||||
and layout.dtype == layouts[0].dtype
|
||||
and layout.size == layouts[0].size
|
||||
for layout in layouts
|
||||
), f"All decompositions for '{name}' must produce equivalent output layouts"
|
||||
|
||||
layout = layouts[0] # All layouts have equivalent stride/shape/dtype now
|
||||
|
||||
choices = []
|
||||
for decomp in decompositions:
|
||||
# Create make_fx_graph function for this decomposition
|
||||
def make_fx_graph(*args: Any, decomp: Callable[..., Any] = decomp) -> Any:
|
||||
import functools
|
||||
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
|
||||
# Ensure kwargs is not None for unpacking
|
||||
decomp_kwargs = kwargs if kwargs is not None else {}
|
||||
return make_fx(functools.partial(decomp, **decomp_kwargs))(*args)
|
||||
|
||||
choice = self.generate(
|
||||
name=f"{name}_{decomp.__name__}",
|
||||
input_nodes=input_nodes,
|
||||
layout=layout,
|
||||
make_fx_graph=make_fx_graph,
|
||||
description=f"CustomOp {decomp.__name__}",
|
||||
)
|
||||
choices.append(choice)
|
||||
|
||||
return choices
|
||||
|
||||
def _validate_stride_consistency(
|
||||
self,
|
||||
op_name: str,
|
||||
decompositions: list[Callable[..., Any]],
|
||||
layouts: list[Layout],
|
||||
) -> None:
|
||||
"""Ensure all decompositions produce compatible strides for fair autotuning."""
|
||||
if not layouts:
|
||||
return
|
||||
|
||||
strides = [layout.stride for layout in layouts]
|
||||
reference = strides[0]
|
||||
for i, stride in enumerate(strides[1:]):
|
||||
if stride != reference:
|
||||
raise AssertionError(
|
||||
f"Stride mismatch in custom op '{op_name}' autotuning: "
|
||||
f"'{decompositions[i].__name__}' produces stride {stride}, "
|
||||
f"but '{decompositions[0].__name__}' produces {reference}. "
|
||||
f"All decompositions must have identical output strides."
|
||||
)
|
||||
|
||||
def _infer_custom_op_layout(
|
||||
self,
|
||||
input_nodes: list[Buffer],
|
||||
decompositions: list[Callable[..., Any]],
|
||||
kwargs: dict[str, Any],
|
||||
default_impl: Optional[Callable[..., Any]] = None,
|
||||
) -> Layout:
|
||||
"""Infer output layout for custom ops using the default implementation when available.
|
||||
|
||||
Note that the Subgraph assumes custom ops return exactly one tensor so far.
|
||||
TODO: Add support for multiple output custom ops.
|
||||
"""
|
||||
import functools
|
||||
|
||||
from torch._inductor.virtualized import V
|
||||
|
||||
# Assert kwargs contain only non-tensor arguments for functools.partial
|
||||
for key, value in kwargs.items():
|
||||
assert not isinstance(value, (torch.Tensor, Buffer)), (
|
||||
f"kwargs['{key}'] contains tensor {type(value)}. "
|
||||
f"Tensor arguments should be in input_nodes, not kwargs. "
|
||||
f"Only scalar/non-tensor parameters should be in kwargs."
|
||||
)
|
||||
|
||||
# Use default_impl if available, otherwise use first decomposition
|
||||
impl = default_impl if default_impl is not None else decompositions[0]
|
||||
|
||||
with V.fake_mode:
|
||||
example_inputs = []
|
||||
for inp in input_nodes:
|
||||
raw_shape = inp.get_size()
|
||||
concrete_shape = V.graph.sizevars.size_hints(
|
||||
raw_shape, fallback=config.unbacked_symint_fallback
|
||||
)
|
||||
fake_tensor = torch.empty(
|
||||
concrete_shape, dtype=inp.get_dtype(), device=inp.get_device()
|
||||
)
|
||||
example_inputs.append(fake_tensor)
|
||||
|
||||
fn = functools.partial(
|
||||
impl, **kwargs
|
||||
) # kwargs must be non-tensor for partial
|
||||
output = fn(*example_inputs)
|
||||
|
||||
# Assert single output
|
||||
assert isinstance(output, torch.Tensor), (
|
||||
f"Expected single tensor output, got {type(output)}. "
|
||||
f"Multi-output custom ops not yet supported in autotuning."
|
||||
)
|
||||
|
||||
return FixedLayout(
|
||||
device=output.device,
|
||||
dtype=output.dtype,
|
||||
size=output.shape,
|
||||
stride=output.stride(),
|
||||
)
|
||||
|
||||
@ -1224,11 +1224,17 @@ class TritonOverrides(OpOverrides):
|
||||
|
||||
@staticmethod
|
||||
def minimum(a, b):
|
||||
return f"triton_helpers.minimum({a}, {b})"
|
||||
if torch.version.hip:
|
||||
return f"tl.minimum({a}, {b}, tl.PropagateNan.ALL)"
|
||||
else:
|
||||
return f"triton_helpers.minimum({a}, {b})"
|
||||
|
||||
@staticmethod
|
||||
def maximum(a, b):
|
||||
return f"triton_helpers.maximum({a}, {b})"
|
||||
if torch.version.hip:
|
||||
return f"tl.maximum({a}, {b}, tl.PropagateNan.ALL)"
|
||||
else:
|
||||
return f"triton_helpers.maximum({a}, {b})"
|
||||
|
||||
@staticmethod
|
||||
def where(a, b, c):
|
||||
@ -1601,7 +1607,10 @@ class TritonOverrides(OpOverrides):
|
||||
@staticmethod
|
||||
@maybe_upcast_float32()
|
||||
def rsqrt(x):
|
||||
return f"libdevice.rsqrt({x})"
|
||||
if torch.version.hip:
|
||||
return f"tl.rsqrt({x})"
|
||||
else:
|
||||
return f"libdevice.rsqrt({x})"
|
||||
|
||||
@staticmethod
|
||||
@maybe_upcast_float32()
|
||||
@ -4504,8 +4513,9 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||
loop_end = (
|
||||
"rsplit_end" if self.cooperative_reduction else f"{prefix}numel"
|
||||
)
|
||||
num_stages = ", num_stages = 2" if torch.version.hip else ""
|
||||
self.body.writeline(
|
||||
f"for {prefix}offset in range({loop_start}, {loop_end}, {prefix.upper()}BLOCK):"
|
||||
f"for {prefix}offset in tl.range({loop_start}, {loop_end}, {prefix.upper()}BLOCK{num_stages}):"
|
||||
)
|
||||
with self.body.indent(offset=level + 1):
|
||||
self.iteration_ranges_codegen_header(tree, self.body)
|
||||
|
||||
@ -415,6 +415,19 @@ class CommentLine(WrapperLine):
|
||||
return converter._generate_comment
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DynamicScalarLine(WrapperLine):
|
||||
wrapper: PythonWrapperCodegen
|
||||
node: ir.DynamicScalar
|
||||
|
||||
def codegen(self, code: IndentedBuffer) -> None:
|
||||
self.wrapper._codegen_dynamic_scalar(self.node)
|
||||
|
||||
@staticmethod
|
||||
def codegen_fx(converter: FxConverter) -> FxConversionFunc:
|
||||
return converter._generate_dynamic_scalar
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ExitSubgraphLine(WrapperLine):
|
||||
wrapper: PythonWrapperCodegen
|
||||
@ -2060,6 +2073,9 @@ class PythonWrapperCodegen(CodeGen):
|
||||
self.unbacked_symbol_decls.add(str(node.unbacked_size_symbol))
|
||||
|
||||
def codegen_dynamic_scalar(self, node):
|
||||
self.writeline(DynamicScalarLine(self, node))
|
||||
|
||||
def _codegen_dynamic_scalar(self, node):
|
||||
(data,) = (t.codegen_reference() for t in node.inputs)
|
||||
if len(node.keypath) == 0:
|
||||
self.writeline(f"{node.sym} = {data}.item()")
|
||||
|
||||
@ -29,6 +29,7 @@ from torch._library.triton import wrap_triton
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
CallMethodKey,
|
||||
ConvertIntKey,
|
||||
DivideByKey,
|
||||
free_unbacked_symbols,
|
||||
)
|
||||
@ -54,6 +55,7 @@ from .wrapper import (
|
||||
CommBufferFreeLine,
|
||||
CommentLine,
|
||||
ConditionalLine,
|
||||
DynamicScalarLine,
|
||||
EnterDeviceContextManagerLine,
|
||||
EnterSubgraphLine,
|
||||
ExitDeviceContextManagerLine,
|
||||
@ -738,6 +740,39 @@ class FxConverter:
|
||||
assert isinstance(line, CommentLine)
|
||||
# We ignore comments in FX IR.
|
||||
|
||||
def _generate_dynamic_scalar(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, DynamicScalarLine)
|
||||
|
||||
ir_node = line.node
|
||||
(input_ir_node,) = ir_node.inputs
|
||||
assert isinstance(input_ir_node, ir.IRNode)
|
||||
input_fx_node = self._generate_buffer(input_ir_node)
|
||||
keypath = ir_node.keypath
|
||||
graph = self.gm.graph
|
||||
|
||||
def generate_item(x: Optional[torch.fx.Node]) -> torch.fx.Node:
|
||||
assert x is not None
|
||||
return graph.call_function(
|
||||
aten.item.default,
|
||||
args=(x,),
|
||||
)
|
||||
|
||||
if len(keypath) == 0:
|
||||
result_fx_node = generate_item(input_fx_node)
|
||||
elif len(keypath) == 1 and isinstance(keypath[0], ConvertIntKey):
|
||||
where_fx_node = graph.call_function(
|
||||
aten.where.Scalar,
|
||||
args=(input_fx_node, 1, 0),
|
||||
)
|
||||
result_fx_node = generate_item(where_fx_node)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported keypath: {keypath}")
|
||||
|
||||
result_symbol = ir_node.sym
|
||||
result_buffer = SymbolBuffer(result_symbol)
|
||||
self._record_allocation(result_buffer, result_fx_node)
|
||||
self._generate_size_proxy(result_fx_node, result_symbol)
|
||||
|
||||
def _generate_enter_device_context_manager(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, EnterDeviceContextManagerLine)
|
||||
# We ignore the device context in FX IR.
|
||||
|
||||
@ -1444,7 +1444,7 @@ class triton:
|
||||
# So far we see a fixed 8 spilled registers for kernels using sin/cos.
|
||||
# Raise the threshold to 16 to be safe.
|
||||
# We should revisit this once we understand more of the source of register spills.
|
||||
spill_threshold: int = 16
|
||||
spill_threshold: int = 32 if torch.version.hip else 16
|
||||
|
||||
# Generate code containing the newer tl.make_block_ptr() API for loads/store
|
||||
use_block_ptr = False
|
||||
|
||||
@ -1,397 +0,0 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
|
||||
import functools
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch._inductor.codegen.subgraph import SubgraphTemplate
|
||||
from torch._inductor.ir import Buffer, FixedLayout, ir_node_to_tensor, TensorBox
|
||||
from torch._inductor.lowering import lowerings, validate_ir
|
||||
from torch._inductor.select_algorithm import (
|
||||
autotune_select_algorithm,
|
||||
ExternKernelChoice,
|
||||
)
|
||||
from torch._inductor.virtualized import V
|
||||
|
||||
|
||||
class CustomOpConfig:
|
||||
"""Config for custom op autotuning - similar to triton.Config.
|
||||
|
||||
Specifies decomposition function with parameter values.
|
||||
Each config creates exactly one variant (no Cartesian product).
|
||||
|
||||
Args:
|
||||
decomposition: Function to autotune
|
||||
**params: Parameters passed to the function
|
||||
|
||||
Examples:
|
||||
CustomOpConfig(attention_impl, head_dim=32, method='chunked')
|
||||
CustomOpConfig(fallback_impl)
|
||||
"""
|
||||
|
||||
def __init__(self, decomposition: Callable[..., Any], **params: Any):
|
||||
if not callable(decomposition):
|
||||
raise TypeError(
|
||||
f"decomposition must be callable, got {type(decomposition)}"
|
||||
)
|
||||
|
||||
self.decomposition = decomposition
|
||||
self.params = params
|
||||
|
||||
# Generate descriptive name
|
||||
if self.params:
|
||||
param_suffix = "_".join(f"{k}_{v}" for k, v in sorted(self.params.items()))
|
||||
self.name = f"{decomposition.__name__}_{param_suffix}"
|
||||
else:
|
||||
self.name = decomposition.__name__
|
||||
|
||||
def create_variant(self) -> Callable[..., Any]:
|
||||
"""Create callable with parameters pre-applied using functools.partial."""
|
||||
if self.params:
|
||||
variant = functools.partial(self.decomposition, **self.params)
|
||||
variant.__name__ = self.name # type: ignore[attr-defined]
|
||||
return variant
|
||||
|
||||
return self.decomposition
|
||||
|
||||
def __repr__(self) -> str:
|
||||
if self.params:
|
||||
params_str = ", ".join(f"{k}={v}" for k, v in self.params.items())
|
||||
return f"CustomOpConfig({self.decomposition.__name__}, {params_str})"
|
||||
return f"CustomOpConfig({self.decomposition.__name__})"
|
||||
|
||||
|
||||
__all__ = [
|
||||
"autotune_custom_op",
|
||||
"register_custom_op_autotuning",
|
||||
"CustomOpConfig",
|
||||
]
|
||||
|
||||
|
||||
def _extract_tensor_inputs(
|
||||
args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
) -> tuple[list[Any], dict[str, Any]]:
|
||||
"""Extract tensor inputs from mixed args/kwargs.
|
||||
Separates tensors (for autotuning input_nodes) from non-tensor parameters.
|
||||
Non-tensor kwargs are later functools.partial'd into decomposition functions.
|
||||
|
||||
Args:
|
||||
args: Positional arguments (mix of tensors and scalars)
|
||||
kwargs: Keyword arguments (mix of tensors and scalars)
|
||||
|
||||
Returns:
|
||||
Tuple of (tensor_inputs_list, non_tensor_kwargs)
|
||||
"""
|
||||
tensor_inputs = []
|
||||
non_tensor_kwargs = {}
|
||||
|
||||
# Process args and kwargs: separate tensor inputs and non tensor args
|
||||
for i, arg in enumerate(args):
|
||||
if isinstance(arg, (TensorBox, Buffer)):
|
||||
tensor_inputs.append(arg)
|
||||
else:
|
||||
# Add non-tensor positional args to kwargs with generated names
|
||||
non_tensor_kwargs[f"arg_{i}"] = arg
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if isinstance(value, (TensorBox, Buffer)):
|
||||
tensor_inputs.append(value)
|
||||
else:
|
||||
non_tensor_kwargs[key] = value
|
||||
|
||||
return tensor_inputs, non_tensor_kwargs
|
||||
|
||||
|
||||
def _create_user_input_gen_fns(
|
||||
inputs: list[Any],
|
||||
arg_names: list[str],
|
||||
user_input_gen_fns: dict[str, Callable[[torch.Tensor], torch.Tensor]],
|
||||
) -> dict[int, Callable[[Any], torch.Tensor]]:
|
||||
"""Convert user input generators from name-based to index-based format.
|
||||
Inductor autotune's input_gen_fns expects index of arg_names as key.
|
||||
|
||||
Uses V.graph.sizevars.size_hints() to guess best for dynamic shapes.
|
||||
"""
|
||||
from torch._inductor import config
|
||||
|
||||
name_to_index = {name: i for i, name in enumerate(arg_names)}
|
||||
index_based_fns = {}
|
||||
|
||||
for name, gen_fn in user_input_gen_fns.items():
|
||||
if name in name_to_index:
|
||||
index_based_fns[name_to_index[name]] = gen_fn
|
||||
else:
|
||||
print(f"Warning: Unknown argument name '{name}' in input_gen_fns")
|
||||
|
||||
def create_internal_input_gen_fn(
|
||||
user_function: Callable[[torch.Tensor], torch.Tensor], arg_name: str
|
||||
) -> Callable[[Any], torch.Tensor]:
|
||||
"""Create internal input generator that converts IR buffer to user's fake tensor."""
|
||||
|
||||
def internal_input_gen_fn(ir_buffer: Any) -> torch.Tensor:
|
||||
raw_shape = ir_buffer.get_size()
|
||||
concrete_shape = V.graph.sizevars.size_hints(
|
||||
raw_shape, fallback=config.unbacked_symint_fallback
|
||||
)
|
||||
|
||||
fake_tensor = torch.empty(
|
||||
concrete_shape, dtype=ir_buffer.get_dtype(), device="meta"
|
||||
)
|
||||
return user_function(fake_tensor)
|
||||
|
||||
return internal_input_gen_fn
|
||||
|
||||
return {
|
||||
i: create_internal_input_gen_fn(
|
||||
user_gen_fn, arg_names[i] if i < len(arg_names) else f"arg_{i}"
|
||||
)
|
||||
for i, user_gen_fn in index_based_fns.items()
|
||||
if i < len(inputs)
|
||||
}
|
||||
|
||||
|
||||
def _create_fallback_choice(
|
||||
name: str,
|
||||
default_impl: Callable[..., Any],
|
||||
fake_output: torch.Tensor,
|
||||
kwargs: dict[str, Any],
|
||||
) -> ExternKernelChoice:
|
||||
"""Create fallback choice for default implementation."""
|
||||
|
||||
def fallback_wrapper(*args: Any) -> Any:
|
||||
return default_impl(*args, **kwargs)
|
||||
|
||||
return ExternKernelChoice(
|
||||
kernel=fallback_wrapper,
|
||||
name=f"{name}_fallback_default",
|
||||
has_out_variant=False,
|
||||
op_overload=default_impl,
|
||||
use_fallback_kernel=True,
|
||||
)
|
||||
|
||||
|
||||
def _create_parameter_variants(
|
||||
decompositions: list[Callable[..., Any]],
|
||||
tuning_knob: dict[str, list[Any]],
|
||||
) -> list[Any]: # Returns partial objects which are callable
|
||||
"""Create parameter variants for decompositions using tuning knob.
|
||||
|
||||
Args:
|
||||
decompositions: Base implementation functions
|
||||
tuning_knob: Parameter tuning dict with parameter names and value lists
|
||||
|
||||
Returns:
|
||||
List of variant functions with all parameter combinations
|
||||
"""
|
||||
# Validate parameter values
|
||||
for param_name, param_values in tuning_knob.items():
|
||||
if not param_values or not isinstance(param_values, (list, tuple)):
|
||||
raise TypeError(
|
||||
f"Parameter values for '{param_name}' must be a list or tuple, got {type(param_values)}"
|
||||
)
|
||||
|
||||
# Generate all combinations of parameter values using Cartesian product
|
||||
import itertools
|
||||
|
||||
param_names = list(tuning_knob.keys())
|
||||
param_values_lists = list(tuning_knob.values())
|
||||
param_combinations = list(itertools.product(*param_values_lists))
|
||||
|
||||
# Create variants for each decomposition with each parameter combination
|
||||
variants = []
|
||||
for decomp_fn in decompositions:
|
||||
for param_combo in param_combinations:
|
||||
# Create kwargs dict for this combination
|
||||
param_kwargs = dict(zip(param_names, param_combo))
|
||||
|
||||
# Create partial function with all parameters
|
||||
variant = functools.partial(decomp_fn, **param_kwargs)
|
||||
param_suffix = "_".join(
|
||||
f"{name}_{value}" for name, value in param_kwargs.items()
|
||||
)
|
||||
variant.__name__ = f"{decomp_fn.__name__}_{param_suffix}" # type: ignore[attr-defined]
|
||||
variants.append(variant)
|
||||
|
||||
return variants
|
||||
|
||||
|
||||
def autotune_custom_op(
|
||||
name: str,
|
||||
decompositions: list[Callable[..., Any]],
|
||||
inputs: list[Any],
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
default_impl: Optional[Callable[..., Any]] = None,
|
||||
user_input_gen_fns: Optional[
|
||||
dict[str, Callable[[torch.Tensor], torch.Tensor]]
|
||||
] = None,
|
||||
) -> Union[TensorBox, Any]:
|
||||
"""Autotune custom operations by comparing multiple decomposition implementations.
|
||||
|
||||
Currently supports SINGLE OUTPUT custom ops only.
|
||||
TODO: Add support for multiple output custom ops (tuple/list returns).
|
||||
|
||||
This function generates multiple implementation choices for a custom operation and
|
||||
uses Inductor's autotuning system to select the best performing variant at runtime.
|
||||
|
||||
Args:
|
||||
name: Unique identifier for the autotuning operation
|
||||
decompositions: List of alternative implementation functions to benchmark
|
||||
inputs: Input tensor IR nodes from compilation (TensorBox/Buffer objects)
|
||||
kwargs: Non-tensor parameters to pass to decomposition functions
|
||||
default_impl: Original custom op implementation used as fallback
|
||||
user_input_gen_fns: Optional custom input generators for benchmarking.
|
||||
Maps input indices to functions that take fake tensors
|
||||
and return real tensors for performance measurement.
|
||||
|
||||
Returns:
|
||||
IR node representing the optimized operation result
|
||||
|
||||
Raises:
|
||||
TypeError: If decompositions is not a list/tuple
|
||||
RuntimeError: If no inputs or no valid choices generated
|
||||
"""
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
if not isinstance(decompositions, (list, tuple)):
|
||||
raise TypeError(
|
||||
f"decompositions must be a list or tuple of callables, got {type(decompositions)}"
|
||||
)
|
||||
|
||||
if not inputs:
|
||||
raise RuntimeError(f"Custom op '{name}' requires tensor inputs for autotuning")
|
||||
|
||||
template = SubgraphTemplate(name=name)
|
||||
choices = template.generate_custom_op_choices(
|
||||
name=name,
|
||||
decompositions=list(decompositions),
|
||||
input_nodes=list(inputs),
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
# Add default implementation as fallback
|
||||
if default_impl and hasattr(default_impl, "_op"):
|
||||
fallback_name = f"{name}_fallback_default"
|
||||
from torch._inductor.select_algorithm import extern_kernels
|
||||
|
||||
# Skip if extern_kernel already registered to avoid duplicate registration error
|
||||
if not hasattr(extern_kernels, fallback_name):
|
||||
with V.fake_mode:
|
||||
fake_inputs = [ir_node_to_tensor(inp) for inp in inputs]
|
||||
fake_output = default_impl(*fake_inputs, **kwargs)
|
||||
|
||||
fallback_choice = _create_fallback_choice(
|
||||
name, default_impl, fake_output, kwargs
|
||||
)
|
||||
fallback_choice.maybe_append_choice(
|
||||
choices=choices,
|
||||
input_nodes=list(inputs),
|
||||
layout=FixedLayout(
|
||||
device=fake_output.device,
|
||||
dtype=fake_output.dtype,
|
||||
size=fake_output.shape,
|
||||
stride=fake_output.stride(),
|
||||
),
|
||||
)
|
||||
|
||||
if not choices:
|
||||
raise RuntimeError(f"No valid choices generated for {name}")
|
||||
|
||||
# Convert user input generation functions to internal format
|
||||
input_gen_fns = {}
|
||||
if user_input_gen_fns:
|
||||
import inspect
|
||||
|
||||
arg_names = (
|
||||
list(inspect.signature(decompositions[0]).parameters.keys())
|
||||
if decompositions
|
||||
else []
|
||||
)
|
||||
input_gen_fns = _create_user_input_gen_fns(
|
||||
inputs, arg_names, user_input_gen_fns
|
||||
)
|
||||
|
||||
return autotune_select_algorithm(
|
||||
name=name,
|
||||
choices=choices,
|
||||
input_nodes=list(inputs),
|
||||
layout=choices[0].layout,
|
||||
input_gen_fns=input_gen_fns,
|
||||
)
|
||||
|
||||
|
||||
def register_custom_op_autotuning(
|
||||
custom_op: torch._ops.OpOverload,
|
||||
configs: Union[list[CustomOpConfig], list[Callable[..., Any]]],
|
||||
name: Optional[str] = None,
|
||||
input_gen_fns: Optional[dict[str, Callable[[torch.Tensor], torch.Tensor]]] = None,
|
||||
) -> None:
|
||||
"""Register custom op for autotuning with explicit configs.
|
||||
|
||||
Uses config-based API where each config specifies a decomposition function
|
||||
with its parameter values.
|
||||
|
||||
Args:
|
||||
custom_op: Custom operation to register
|
||||
configs: List of CustomOpConfig objects or callable functions
|
||||
name: Operation name (default: "{op_name}_autotuned")
|
||||
input_gen_fns: Custom input generators for benchmarking
|
||||
|
||||
Examples:
|
||||
register_custom_op_autotuning(
|
||||
torch.ops.mylib.attention.default,
|
||||
configs=[
|
||||
CustomOpConfig(attention_impl, head_dim=32, method='chunked'),
|
||||
CustomOpConfig(attention_impl, head_dim=64, method='tiled'),
|
||||
CustomOpConfig(fallback_impl), # No params
|
||||
],
|
||||
input_gen_fns={
|
||||
"query": lambda fake: torch.randn_like(fake, device='cuda'),
|
||||
"key": lambda fake: torch.randn_like(fake, device='cuda'),
|
||||
"value": lambda fake: torch.randn_like(fake, device='cuda'),
|
||||
}
|
||||
)
|
||||
"""
|
||||
if not isinstance(configs, (list, tuple)):
|
||||
raise TypeError(f"configs must be a list or tuple, got {type(configs)}")
|
||||
|
||||
if not configs:
|
||||
raise ValueError("At least one config must be provided")
|
||||
|
||||
# Convert configs to decomposition functions
|
||||
final_decompositions = []
|
||||
for config in configs:
|
||||
if isinstance(config, CustomOpConfig):
|
||||
# CustomOpConfig object
|
||||
final_decompositions.append(config.create_variant())
|
||||
elif callable(config):
|
||||
# Direct callable function
|
||||
final_decompositions.append(config)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Each config must be a CustomOpConfig object or callable function, "
|
||||
f"got {type(config)}"
|
||||
)
|
||||
|
||||
if name is None:
|
||||
name = f"{custom_op._name}_autotuned"
|
||||
|
||||
@functools.wraps(custom_op)
|
||||
def autotuning_lowering(*args: Any, **kwargs: Any) -> Any:
|
||||
"""Inductor lowering function that replaces custom op calls with autotuned versions."""
|
||||
# Extract tensor inputs and non-tensor parameters
|
||||
tensor_inputs, non_tensor_kwargs = _extract_tensor_inputs(args, kwargs)
|
||||
|
||||
result = autotune_custom_op(
|
||||
name=name,
|
||||
decompositions=final_decompositions,
|
||||
inputs=tensor_inputs,
|
||||
kwargs=non_tensor_kwargs,
|
||||
default_impl=custom_op,
|
||||
user_input_gen_fns=input_gen_fns,
|
||||
)
|
||||
|
||||
validate_ir(result)
|
||||
return result
|
||||
|
||||
lowerings[custom_op] = autotuning_lowering
|
||||
@ -860,7 +860,7 @@ class CachingAutotuner(KernelInterface):
|
||||
# for some (complicated) custom Triton kernels, a register-spilling
|
||||
# config may yield the best latency.
|
||||
if not self.custom_kernel and launcher.n_spills > self.inductor_meta.get(
|
||||
"spill_threshold", 16
|
||||
"spill_threshold", 32 if torch.version.hip else 16
|
||||
):
|
||||
log.debug(
|
||||
"Skip config %s because of register spilling: %d",
|
||||
@ -2393,6 +2393,7 @@ def triton_config_reduction(
|
||||
num_stages=1,
|
||||
num_warps=None,
|
||||
register_intensive=False,
|
||||
waves_per_eu=None,
|
||||
dynamic_scale_rblock=True,
|
||||
reduction_hint=None,
|
||||
) -> Config:
|
||||
@ -2446,13 +2447,19 @@ def triton_config_reduction(
|
||||
cfg = _get_config({"x": x, **rnumels})
|
||||
check_max_block(cfg)
|
||||
check_config(cfg, xnumel=size_hints["x"])
|
||||
return InductorConfig(
|
||||
config = InductorConfig(
|
||||
cfg,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
dynamic_scale_rblock=dynamic_scale_rblock,
|
||||
)
|
||||
|
||||
if torch.version.hip:
|
||||
if waves_per_eu is not None:
|
||||
config.kwargs["waves_per_eu"] = waves_per_eu
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def _get_config(numels: dict[str, int]) -> dict[str, int]:
|
||||
"""
|
||||
@ -2463,7 +2470,7 @@ def _get_config(numels: dict[str, int]) -> dict[str, int]:
|
||||
|
||||
|
||||
def triton_config_tiled_reduction(
|
||||
size_hints, x, y, r, num_stages=1, register_intensive=False
|
||||
size_hints, x, y, r, num_stages=1, register_intensive=False, waves_per_eu=None
|
||||
):
|
||||
"""
|
||||
Construct a tile reduction triton config with some adjustment
|
||||
@ -2500,7 +2507,11 @@ def triton_config_tiled_reduction(
|
||||
)
|
||||
check_config(cfg, xnumel=size_hints["x"], ynumel=size_hints["y"])
|
||||
check_max_block(cfg)
|
||||
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
|
||||
config = Config(cfg, num_warps=num_warps, num_stages=num_stages)
|
||||
if torch.version.hip:
|
||||
if waves_per_eu is not None:
|
||||
config.kwargs["waves_per_eu"] = waves_per_eu
|
||||
return config
|
||||
|
||||
|
||||
def _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs: list[Config]):
|
||||
@ -2748,6 +2759,11 @@ def _reduction_configs(
|
||||
# Convert reductions to 1D, to simplify heuristics.
|
||||
rnumel = get_total_reduction_numel(size_hints)
|
||||
|
||||
# Is max autotune enabled
|
||||
max_autotune_enabled = inductor_meta.get("max_autotune") or inductor_meta.get(
|
||||
"max_autotune_pointwise"
|
||||
)
|
||||
|
||||
register_intensive = False
|
||||
MAX_R0_BLOCK = 2048
|
||||
loads_and_red = inductor_meta.get("num_load", 0) + inductor_meta.get(
|
||||
@ -2790,6 +2806,7 @@ def _reduction_configs(
|
||||
num_stages=1,
|
||||
register_intensive=False,
|
||||
dynamic_scale_rblock=True,
|
||||
waves_per_eu=None,
|
||||
):
|
||||
# For 3D case with tiling scores, create an adapted version
|
||||
if "y" in size_hints:
|
||||
@ -2802,6 +2819,7 @@ def _reduction_configs(
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
register_intensive=register_intensive,
|
||||
waves_per_eu=waves_per_eu,
|
||||
)
|
||||
else:
|
||||
# For other cases, use the original function
|
||||
@ -2812,6 +2830,7 @@ def _reduction_configs(
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
register_intensive=register_intensive,
|
||||
waves_per_eu=waves_per_eu,
|
||||
dynamic_scale_rblock=dynamic_scale_rblock,
|
||||
reduction_hint=reduction_hint,
|
||||
)
|
||||
@ -2893,12 +2912,12 @@ def _reduction_configs(
|
||||
)
|
||||
configs.append(c)
|
||||
|
||||
result_configs = []
|
||||
|
||||
# For 3d tiling, default to more autotuning initially
|
||||
if "y" in size_hints:
|
||||
pass
|
||||
elif inductor_meta.get("max_autotune") or inductor_meta.get(
|
||||
"max_autotune_pointwise"
|
||||
):
|
||||
elif max_autotune_enabled:
|
||||
pass # skip all these cases
|
||||
elif reduction_hint == ReductionHint.INNER:
|
||||
return configs + [contiguous_config]
|
||||
@ -2907,7 +2926,10 @@ def _reduction_configs(
|
||||
elif reduction_hint == ReductionHint.OUTER_TINY:
|
||||
return configs + [tiny_config]
|
||||
|
||||
return configs + [
|
||||
# We continue here under the following conditions:
|
||||
# - max_autotune_enabled is True
|
||||
# - max_autotune_enabled is False and reduction_hint is NOT one of the above cases
|
||||
result_configs = configs + [
|
||||
contiguous_config,
|
||||
outer_config,
|
||||
tiny_config,
|
||||
@ -2919,6 +2941,16 @@ def _reduction_configs(
|
||||
make_config(64, 4, num_warps=8),
|
||||
]
|
||||
|
||||
if torch.version.hip:
|
||||
result_configs.extend(
|
||||
[
|
||||
make_config(1024, 8, num_warps=4, num_stages=1, waves_per_eu=2),
|
||||
make_config(512, 8, num_warps=4, num_stages=1, waves_per_eu=1),
|
||||
]
|
||||
)
|
||||
|
||||
return result_configs
|
||||
|
||||
|
||||
def match_target_block_product(
|
||||
size_hints, tiling_scores, target_block_product, min_block_size=1
|
||||
@ -2975,6 +3007,7 @@ def adapt_config_for_tiling(
|
||||
num_stages=1,
|
||||
register_intensive=False,
|
||||
persistent_reduction=False,
|
||||
waves_per_eu=None,
|
||||
) -> Config:
|
||||
"""
|
||||
Create an adapted configuration based on tiling scores,
|
||||
@ -2993,6 +3026,7 @@ def adapt_config_for_tiling(
|
||||
block_sizes["r0_"],
|
||||
num_stages=num_stages,
|
||||
register_intensive=register_intensive,
|
||||
waves_per_eu=waves_per_eu,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -2919,6 +2919,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
)
|
||||
|
||||
timings = do_autotuning(choices, precompile_fn)
|
||||
|
||||
# if timings is empty, we really have no choice but to return a semi-random
|
||||
# choice. returning the first `ExternKernelCaller` is probably the safest bet
|
||||
# in this case, since it will generally be the ATen kernel. if there are no
|
||||
@ -3523,7 +3524,6 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
dtypes = ", ".join([str(n.get_dtype()) for n in input_nodes])
|
||||
if config.autotune_num_choices_displayed == 0:
|
||||
return
|
||||
|
||||
# when autotune_num_choices_displayed is None, [:None] means all
|
||||
n = config.autotune_num_choices_displayed
|
||||
top_k = sorted(timings, key=timings.__getitem__)[:n]
|
||||
|
||||
@ -264,7 +264,8 @@ class OutputComparisonLogger(OutputLogger):
|
||||
# fmt: on
|
||||
if not self.enabled:
|
||||
return x
|
||||
assert isinstance(x, torch.Tensor), "non-tensor inputs not yet supported"
|
||||
if not isinstance(x, torch.Tensor):
|
||||
raise AssertionError("non-tensor inputs not yet supported")
|
||||
if self.save_activations:
|
||||
# save the activation, for debugging
|
||||
self.stats.append(x.detach())
|
||||
@ -595,9 +596,8 @@ def _extract_logger_info_one_model(
|
||||
key = mod.ref_name
|
||||
if key not in results:
|
||||
results[key] = {}
|
||||
assert mod.model_name not in results[key], (
|
||||
f"{mod.model_name} is already present in results"
|
||||
)
|
||||
if mod.model_name in results[key]:
|
||||
raise AssertionError(f"{mod.model_name} is already present in results")
|
||||
if mod.results_type not in results[key]:
|
||||
results[key][mod.results_type] = {}
|
||||
if mod.model_name not in results[key][mod.results_type]:
|
||||
@ -809,12 +809,10 @@ def extend_logger_results_with_comparison(
|
||||
"""
|
||||
for results_type_to_results in results.values():
|
||||
for model_name_to_results in results_type_to_results.values():
|
||||
assert model_name_1 in model_name_to_results, (
|
||||
f"{model_name_1} not found in results"
|
||||
)
|
||||
assert model_name_2 in model_name_to_results, (
|
||||
f"{model_name_2} not found in results"
|
||||
)
|
||||
if model_name_1 not in model_name_to_results:
|
||||
raise AssertionError(f"{model_name_1} not found in results")
|
||||
if model_name_2 not in model_name_to_results:
|
||||
raise AssertionError(f"{model_name_2} not found in results")
|
||||
|
||||
results_1 = model_name_to_results[model_name_1]
|
||||
results_2 = model_name_to_results[model_name_2]
|
||||
@ -832,7 +830,8 @@ def extend_logger_results_with_comparison(
|
||||
):
|
||||
result_1 = cur_result_1
|
||||
break
|
||||
assert result_1 is not None
|
||||
if result_1 is None:
|
||||
raise AssertionError("Expected result_1 to be not None")
|
||||
|
||||
values_1 = result_1["values"]
|
||||
values_2 = result_2["values"]
|
||||
|
||||
@ -150,7 +150,8 @@ class _NSGraphMatchableSubgraphsIterator:
|
||||
if node.op == "call_function":
|
||||
return node.target not in self.non_matchable_functions
|
||||
elif node.op == "call_module":
|
||||
assert isinstance(node.target, str)
|
||||
if not isinstance(node.target, str):
|
||||
raise AssertionError(f"Expected str, got {type(node.target)}")
|
||||
target_mod = getattr_from_fqn(self.gm, node.target)
|
||||
return not any(
|
||||
isinstance(target_mod, t) # type: ignore[arg-type]
|
||||
@ -228,16 +229,19 @@ def _get_subgraph_relationship_type(
|
||||
else:
|
||||
return SubgraphTypeRelationship.NOT_RELATED
|
||||
elif node_a.op == "call_module":
|
||||
assert (
|
||||
subgraph_a.base_op_node == subgraph_a.start_node
|
||||
and subgraph_b.base_op_node == subgraph_b.start_node
|
||||
), (
|
||||
"Matching call_module patterns where base_op_node != start_node is not supported yet"
|
||||
)
|
||||
if (
|
||||
subgraph_a.base_op_node != subgraph_a.start_node
|
||||
or subgraph_b.base_op_node != subgraph_b.start_node
|
||||
):
|
||||
raise AssertionError(
|
||||
"Matching call_module patterns where base_op_node != start_node is not supported yet"
|
||||
)
|
||||
# for call_module, we need to look up the modules to do the type check
|
||||
assert isinstance(node_a.target, str)
|
||||
if not isinstance(node_a.target, str):
|
||||
raise AssertionError(f"Expected str, got {type(node_a.target)}")
|
||||
mod_a = getattr_from_fqn(gm_a, node_a.target)
|
||||
assert isinstance(node_b.target, str)
|
||||
if not isinstance(node_b.target, str):
|
||||
raise AssertionError(f"Expected str, got {type(node_b.target)}")
|
||||
mod_b = getattr_from_fqn(gm_b, node_b.target)
|
||||
|
||||
key = (type(mod_a), type(mod_b))
|
||||
@ -312,7 +316,8 @@ def _get_node_target_type(node: Node, gm: GraphModule) -> Optional[NSNodeTargetT
|
||||
if node.op in ("call_function", "call_method"):
|
||||
return node.target
|
||||
elif node.op == "call_module":
|
||||
assert isinstance(node.target, str)
|
||||
if not isinstance(node.target, str):
|
||||
raise AssertionError(f"Expected str, got {type(node.target)}")
|
||||
mod = getattr_from_fqn(gm, node.target)
|
||||
return type(mod)
|
||||
return None
|
||||
@ -452,9 +457,10 @@ of subgraphs, and each pair of subgraphs is related to each other."""
|
||||
key_name_b = _get_name_for_subgraph(
|
||||
cur_subgraph_b, gm_b, base_name_to_sets_of_related_ops, existing_names_b
|
||||
)
|
||||
assert key_name_a == key_name_b, (
|
||||
f"Subgraph names {key_name_a} and {key_name_b} do not match"
|
||||
)
|
||||
if key_name_a != key_name_b:
|
||||
raise AssertionError(
|
||||
f"Subgraph names {key_name_a} and {key_name_b} do not match"
|
||||
)
|
||||
results[key_name_a] = (cur_subgraph_a, cur_subgraph_b)
|
||||
continue
|
||||
elif cur_subgraph_a is None and cur_subgraph_b is None:
|
||||
|
||||
@ -32,7 +32,8 @@ def _maybe_get_fqn(node: Node, gm: GraphModule) -> Optional[str]:
|
||||
# an observer, get the fqn of the node being observed.
|
||||
node_to_use_for_fqn = node
|
||||
if node.op == "call_module":
|
||||
assert isinstance(node.target, str)
|
||||
if not isinstance(node.target, str):
|
||||
raise AssertionError(f"Expected str, got {type(node.target)}")
|
||||
module = getattr_from_fqn(gm, node.target)
|
||||
if _is_activation_post_process(module):
|
||||
node_to_use_for_fqn = get_normalized_nth_input(node, gm, 0)
|
||||
@ -348,7 +349,8 @@ def _insert_dtype_cast_after_node(
|
||||
new_dtype_cast_name,
|
||||
)
|
||||
else:
|
||||
assert dtype_cast_mod_cls
|
||||
if not dtype_cast_mod_cls:
|
||||
raise AssertionError("Expected dtype_cast_mod_cls to be not None")
|
||||
dtype_cast_mod = dtype_cast_mod_cls()
|
||||
setattr(gm_b, new_dtype_cast_name, dtype_cast_mod)
|
||||
return graph_c.create_node(
|
||||
@ -373,7 +375,8 @@ def _insert_dtype_cast_after_node(
|
||||
)
|
||||
results.append(new_dtype_cast_node)
|
||||
else:
|
||||
assert dtype_cast_mod_cls
|
||||
if not dtype_cast_mod_cls:
|
||||
raise AssertionError("Expected dtype_cast_mod_cls to be not None")
|
||||
dtype_cast_mod = dtype_cast_mod_cls()
|
||||
setattr(gm_b, new_dtype_cast_name, dtype_cast_mod)
|
||||
new_dtype_cast_node = graph_c.create_node(
|
||||
@ -412,10 +415,8 @@ def _copy_node_from_a_to_c(
|
||||
)
|
||||
return node_a_copy
|
||||
elif node_a.op == "call_method":
|
||||
assert node_a.target in (
|
||||
"dequantize",
|
||||
"to",
|
||||
), f"target {node_a.target} is not implemented"
|
||||
if node_a.target not in ("dequantize", "to"):
|
||||
raise AssertionError(f"target {node_a.target} is not implemented")
|
||||
if node_a.target == "dequantize":
|
||||
arg_copy = _copy_node_from_a_to_c(
|
||||
get_normalized_nth_input(node_a, gm_a, 0), gm_a, gm_b, graph_c
|
||||
@ -535,7 +536,8 @@ def _insert_copy_of_subgraph_a_after_input_node_c(
|
||||
"""
|
||||
TODO(before land): real docblock
|
||||
"""
|
||||
assert isinstance(input_node_c, (Node, list))
|
||||
if not isinstance(input_node_c, (Node, list)):
|
||||
raise AssertionError(f"Expected Node or list, got {type(input_node_c)}")
|
||||
|
||||
# create a sequential list of the subgraphs' nodes from start to end,
|
||||
# because we need to add the nodes to graph C in non-reverse order
|
||||
@ -621,7 +623,8 @@ def _insert_copy_of_node_a_after_input_node_c(
|
||||
if isinstance(input_node_c, Node):
|
||||
graph_c = input_node_c.graph
|
||||
else:
|
||||
assert isinstance(input_node_c, list)
|
||||
if not isinstance(input_node_c, list):
|
||||
raise AssertionError(f"Expected list, got {type(input_node_c)}")
|
||||
graph_c = input_node_c[0].graph
|
||||
|
||||
norm_args_kwargs = node_a.normalized_arguments(
|
||||
@ -645,9 +648,10 @@ def _insert_copy_of_node_a_after_input_node_c(
|
||||
return arg
|
||||
elif isinstance(kwarg_val, (list, tuple)):
|
||||
for el in kwarg_val:
|
||||
assert not isinstance(el, Node), (
|
||||
"handling of Node inside list is not implemented"
|
||||
)
|
||||
if isinstance(el, Node):
|
||||
raise AssertionError(
|
||||
"handling of Node inside list is not implemented"
|
||||
)
|
||||
return arg
|
||||
else:
|
||||
raise AssertionError(
|
||||
@ -684,7 +688,8 @@ def _insert_copy_of_node_a_after_input_node_c(
|
||||
# if target is a module, we point to the module from gm_b
|
||||
new_mod_copy_name = get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
|
||||
# fetch the corresponding module from gm_a
|
||||
assert isinstance(node_a.target, str)
|
||||
if not isinstance(node_a.target, str):
|
||||
raise AssertionError(f"Expected str, got {type(node_a.target)}")
|
||||
mod_a = getattr_from_fqn(gm_a, node_a.target)
|
||||
setattr(gm_b, new_mod_copy_name, mod_a)
|
||||
node_a_shadows_c = graph_c.create_node(
|
||||
@ -696,7 +701,8 @@ def _insert_copy_of_node_a_after_input_node_c(
|
||||
)
|
||||
return node_a_shadows_c
|
||||
else:
|
||||
assert node_a.op in ("call_function", "call_method")
|
||||
if node_a.op not in ("call_function", "call_method"):
|
||||
raise AssertionError(f"Unexpected op: {node_a.op}")
|
||||
node_a_shadows_c = graph_c.create_node(
|
||||
node_a.op,
|
||||
node_a.target,
|
||||
@ -791,7 +797,8 @@ def create_a_shadows_b(
|
||||
ref_node_type_b,
|
||||
) = start_node_b_to_matched_subgraph_a_and_name[node_b]
|
||||
else:
|
||||
assert node_b_is_end_node
|
||||
if not node_b_is_end_node:
|
||||
raise AssertionError("Expected node_b_is_end_node to be not false")
|
||||
(
|
||||
subgraph_a,
|
||||
ref_name,
|
||||
@ -1001,7 +1008,10 @@ def create_a_shadows_b(
|
||||
)
|
||||
input_logger: Union[Node, list[Node]] = dtype_cast_node
|
||||
else:
|
||||
assert isinstance(dtype_cast_node, list)
|
||||
if not isinstance(dtype_cast_node, list):
|
||||
raise AssertionError(
|
||||
f"Expected list, got {type(dtype_cast_node)}"
|
||||
)
|
||||
new_loggers = []
|
||||
for dtype_cast_idx, dtype_cast_node_inner in enumerate(
|
||||
dtype_cast_node
|
||||
@ -1083,7 +1093,10 @@ def create_a_shadows_b(
|
||||
input_logger_mod.ref_node_name = cur_node.name
|
||||
else:
|
||||
# pyrefly: ignore # unbound-name
|
||||
assert isinstance(input_logger, list)
|
||||
if not isinstance(input_logger, list):
|
||||
raise AssertionError(
|
||||
f"Expected list, got {type(input_logger)}"
|
||||
)
|
||||
# pyrefly: ignore # unbound-name
|
||||
for input_logger_inner in input_logger:
|
||||
input_logger_mod = getattr(gm_b, input_logger_inner.name)
|
||||
|
||||
@ -144,9 +144,11 @@ def _get_dedup_subgraphs(matches: dict[str, _MatchResult]) -> dict[str, list[Nod
|
||||
seen_nodes.add(node_or_tuple)
|
||||
|
||||
else:
|
||||
assert isinstance(node_or_tuple, tuple)
|
||||
if not isinstance(node_or_tuple, tuple):
|
||||
raise AssertionError(f"Expected tuple, got {type(node_or_tuple)}")
|
||||
for node in node_or_tuple:
|
||||
assert isinstance(node, Node)
|
||||
if not isinstance(node, Node):
|
||||
raise AssertionError(f"Expected Node, got {type(node)}")
|
||||
if node in seen_nodes:
|
||||
was_seen = True
|
||||
seen_nodes.add(node)
|
||||
@ -160,7 +162,10 @@ def _get_dedup_subgraphs(matches: dict[str, _MatchResult]) -> dict[str, list[Nod
|
||||
if len(cur_match[1]) == 1:
|
||||
list_of_nodes = cur_match[1]
|
||||
else:
|
||||
assert len(cur_match[1]) == 2
|
||||
if len(cur_match[1]) != 2:
|
||||
raise ValueError(
|
||||
f"Expected cur_match[1] to have length 2, got {len(cur_match[1])}"
|
||||
)
|
||||
# either (a, b), or ((a, b), c) or (c, (a, b))
|
||||
# cannot make any assumptions on order, not clear what the
|
||||
# _find_matches function is doing to populate this
|
||||
@ -181,13 +186,12 @@ def _get_dedup_subgraphs(matches: dict[str, _MatchResult]) -> dict[str, list[Nod
|
||||
last_node = n
|
||||
else:
|
||||
mid_node = n
|
||||
assert (
|
||||
first_node is not None
|
||||
and mid_node is not None
|
||||
and last_node is not None
|
||||
)
|
||||
assert mid_node.args[0] is first_node
|
||||
assert last_node.args[0] is mid_node
|
||||
if first_node is None or mid_node is None or last_node is None:
|
||||
raise AssertionError("Expected all nodes to be non-None")
|
||||
if mid_node.args[0] is not first_node:
|
||||
raise AssertionError("Expected mid_node.args[0] to be first_node")
|
||||
if last_node.args[0] is not mid_node:
|
||||
raise AssertionError("Expected last_node.args[0] to be mid_node")
|
||||
return [last_node, mid_node, first_node]
|
||||
|
||||
if isinstance(cur_match[1][0], Node) and isinstance(cur_match[1][1], Node):
|
||||
@ -377,7 +381,10 @@ def create_submodule_from_subgraph(
|
||||
# the current implementation is simplistic and cannot handle
|
||||
# ops with two or more arguments which need to be passed from
|
||||
# the previous op, so we assert them out
|
||||
assert cur_node_orig.target not in BINARY_FUNCTIONS
|
||||
if cur_node_orig.target in BINARY_FUNCTIONS:
|
||||
raise AssertionError(
|
||||
f"Unexpected binary function target: {cur_node_orig.target}"
|
||||
)
|
||||
|
||||
# at this point in the code, cur_node_copy is pointing to the copy
|
||||
# of the previous node
|
||||
@ -435,9 +442,10 @@ def create_submodule_from_subgraph(
|
||||
break
|
||||
|
||||
# go to next node
|
||||
assert len(cur_node_orig.users.keys()) == 1, (
|
||||
f"{cur_node_orig} has more than 1 users, not supported yet"
|
||||
)
|
||||
if len(cur_node_orig.users.keys()) != 1:
|
||||
raise AssertionError(
|
||||
f"{cur_node_orig} has more than 1 users, not supported yet"
|
||||
)
|
||||
cur_node_orig = next(iter(cur_node_orig.users.keys()))
|
||||
cur_iteration += 1
|
||||
if cur_iteration > iteration_limit:
|
||||
@ -494,7 +502,8 @@ def create_one_transformed_and_logged_copy_of_subgraph(
|
||||
)
|
||||
|
||||
attr_name = _get_attr_name(subgraph_idx, subgraph_candidate_idx)
|
||||
assert not hasattr(mt, attr_name)
|
||||
if hasattr(mt, attr_name):
|
||||
raise AssertionError(f"Unexpected attribute '{attr_name}' found in {mt}")
|
||||
setattr(mt, attr_name, logger_mod_orig)
|
||||
with mt.graph.inserting_after(last_node):
|
||||
new_node = mt.graph.call_module(attr_name, args=(last_node,), kwargs={})
|
||||
@ -537,9 +546,10 @@ def create_one_transformed_and_logged_copy_of_subgraph(
|
||||
"prepare_custom_config",
|
||||
"qconfig_mapping",
|
||||
]:
|
||||
assert kwarg_name not in custom_prepare_kwargs, (
|
||||
f"cannot specify {kwarg_name} in custom_prepare_kwargs"
|
||||
)
|
||||
if kwarg_name in custom_prepare_kwargs:
|
||||
raise AssertionError(
|
||||
f"cannot specify {kwarg_name} in custom_prepare_kwargs"
|
||||
)
|
||||
prepare_kwargs: dict[str, Any] = {
|
||||
"example_inputs": example_inputs,
|
||||
"qconfig_mapping": qconfig_mapping,
|
||||
@ -551,7 +561,8 @@ def create_one_transformed_and_logged_copy_of_subgraph(
|
||||
|
||||
# attach the wrapper to the model
|
||||
attr_name = _get_attr_wrapper_name(subgraph_idx, subgraph_candidate_idx)
|
||||
assert not hasattr(mt, attr_name)
|
||||
if hasattr(mt, attr_name):
|
||||
raise AssertionError(f"Unexpected attribute '{attr_name}' found in {mt}")
|
||||
setattr(mt, attr_name, orig_mod_copy_wrapped)
|
||||
|
||||
# add a call to the wrapper module from the parent graph
|
||||
@ -600,7 +611,8 @@ def create_one_transformed_and_logged_copy_of_subgraph(
|
||||
)
|
||||
|
||||
attr_name = _get_attr_name(subgraph_idx, subgraph_candidate_idx)
|
||||
assert not hasattr(mt, attr_name)
|
||||
if hasattr(mt, attr_name):
|
||||
raise AssertionError(f"Unexpected attribute '{attr_name}' found in {mt}")
|
||||
setattr(mt, attr_name, logger_mod_orig)
|
||||
with mt.graph.inserting_after(new_node):
|
||||
logger = mt.graph.call_module(
|
||||
@ -824,7 +836,8 @@ def create_add_loggers_graph(
|
||||
):
|
||||
new_shadow_mod = maybe_shadow_mod
|
||||
break
|
||||
assert new_shadow_mod is not None
|
||||
if new_shadow_mod is None:
|
||||
raise AssertionError("Expected new_shadow_mod to be non-None")
|
||||
orig_first_node_to_shadow_in_node[first_node] = new_shadow_mod
|
||||
orig_first_node_to_shadow_out_node[first_node] = new_shadow_mod
|
||||
|
||||
@ -850,7 +863,10 @@ def create_add_loggers_graph(
|
||||
fqn,
|
||||
)
|
||||
attr_name = _get_attr_name(cur_subgraph_idx, subgraph_candidate_idx)
|
||||
assert not hasattr(model, attr_name)
|
||||
if hasattr(model, attr_name):
|
||||
raise AssertionError(
|
||||
f"Unexpected attribute '{attr_name}' found in {model}"
|
||||
)
|
||||
setattr(model, attr_name, logger_mod_orig)
|
||||
insertion_point = last_node
|
||||
with model.graph.inserting_after(insertion_point):
|
||||
@ -887,9 +903,15 @@ def create_add_loggers_graph(
|
||||
# since now only linear subgraphs are supported, all nodes
|
||||
# except the last one must have only one user
|
||||
if cur_node_orig != last_node:
|
||||
assert len(cur_node_orig.users.keys()) == 1
|
||||
if len(cur_node_orig.users.keys()) != 1:
|
||||
raise AssertionError(
|
||||
f"Expected exactly 1, but got {len(cur_node_orig.users)}"
|
||||
)
|
||||
cur_node_orig = next(iter(cur_node_orig.users.keys()))
|
||||
assert not cur_node_orig.name.startswith(SHADOW_NODE_NAME_PREFIX)
|
||||
if cur_node_orig.name.startswith(SHADOW_NODE_NAME_PREFIX):
|
||||
raise AssertionError(
|
||||
"cur_node_orig should not start with SHADOW_NODE_NAME_PREFIX"
|
||||
)
|
||||
insertion_point = cur_node_copy
|
||||
|
||||
# add a comparison logger after last_node's copy
|
||||
@ -905,7 +927,10 @@ def create_add_loggers_graph(
|
||||
fqn,
|
||||
)
|
||||
attr_name = _get_attr_name(cur_subgraph_idx, subgraph_candidate_idx)
|
||||
assert not hasattr(model, attr_name)
|
||||
if hasattr(model, attr_name):
|
||||
raise AssertionError(
|
||||
f"Unexpected attribute '{attr_name}' found in {model}"
|
||||
)
|
||||
setattr(model, attr_name, logger_mod_orig)
|
||||
with model.graph.inserting_after(insertion_point):
|
||||
logger = model.graph.call_module(
|
||||
@ -979,7 +1004,8 @@ def create_add_loggers_graph(
|
||||
return prev_shadow_output
|
||||
|
||||
cur_shadow_input = orig_first_node_to_shadow_in_node[first_node]
|
||||
assert cur_shadow_input is not None
|
||||
if cur_shadow_input is None:
|
||||
raise AssertionError("Expected cur_shadow_input to be non-None")
|
||||
cur_shadow_input.args = tree_map(
|
||||
maybe_remap_node_to_shadow, cur_shadow_input.args
|
||||
)
|
||||
@ -1019,7 +1045,8 @@ def _get_weight_info_from_shadow_wrapper(shadow_wrapper: torch.nn.Module):
|
||||
# we have `w2_0`, and are navigating this subgraph
|
||||
# to get `_input_scale_1` and `_input_zero_point_1`
|
||||
|
||||
assert len(shadow_n.users) == 1
|
||||
if len(shadow_n.users) != 1:
|
||||
raise AssertionError(f"Expected exactly 1, got {len(shadow_n.users)}")
|
||||
quant_node = next(iter(shadow_n.users.keys()))
|
||||
new_args: Any = None
|
||||
if quant_node.target == torch.quantize_per_channel:
|
||||
@ -1028,7 +1055,10 @@ def _get_weight_info_from_shadow_wrapper(shadow_wrapper: torch.nn.Module):
|
||||
zp_val = getattr_from_fqn(shadow_wrapper, zp_node.target)
|
||||
new_args = (scale_val, zp_val, axis, dtype)
|
||||
else:
|
||||
assert quant_node.target == torch.quantize_per_tensor
|
||||
if quant_node.target != torch.quantize_per_tensor:
|
||||
raise AssertionError(
|
||||
f"Expected torch.quantize_per_tensor, but got {quant_node.target}"
|
||||
)
|
||||
_weight, scale_node, zp_node, dtype = quant_node.args
|
||||
scale_val = getattr_from_fqn(shadow_wrapper, scale_node.target)
|
||||
zp_val = getattr_from_fqn(shadow_wrapper, zp_node.target)
|
||||
|
||||
@ -167,7 +167,8 @@ def end_node_matches_reversed_fusion(
|
||||
elif cur_node.op == "call_module":
|
||||
fusion_el_is_mod = isinstance(cur_fusion_el, type)
|
||||
if fusion_el_is_mod:
|
||||
assert isinstance(cur_node.target, str)
|
||||
if not isinstance(cur_node.target, str):
|
||||
raise AssertionError(f"Expected str, got {type(cur_node.target)}")
|
||||
target_mod = getattr_from_fqn(gm, cur_node.target)
|
||||
if not isinstance(cur_fusion_el, type):
|
||||
return False
|
||||
@ -190,7 +191,10 @@ def end_node_matches_reversed_fusion(
|
||||
if cur_node.target != cur_fusion_el:
|
||||
return False
|
||||
else:
|
||||
assert isinstance(cur_fusion_el, tuple)
|
||||
if not isinstance(cur_fusion_el, tuple):
|
||||
raise AssertionError(
|
||||
f"Expected tuple, got {type(cur_fusion_el)}"
|
||||
)
|
||||
if cur_node.target != cur_fusion_el[0]:
|
||||
return False
|
||||
elif len(cur_node.args) < 2:
|
||||
|
||||
@ -61,7 +61,8 @@ def get_node_first_input_and_output_type(
|
||||
return (NodeInputOrOutputType.INT8, NodeInputOrOutputType.INT8)
|
||||
elif node.target in FUNS_IO_TYPE_FP32_OR_INT8:
|
||||
first_arg = get_normalized_nth_input(node, gm, 0)
|
||||
assert isinstance(first_arg, Node)
|
||||
if not isinstance(first_arg, Node):
|
||||
raise AssertionError(f"Expected Node, got {type(first_arg)}")
|
||||
(
|
||||
_prev_node_input_type,
|
||||
prev_node_output_type,
|
||||
@ -73,8 +74,11 @@ def get_node_first_input_and_output_type(
|
||||
return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN)
|
||||
|
||||
elif node.op == "call_module":
|
||||
assert node.op == "call_module"
|
||||
assert isinstance(node.target, str)
|
||||
if node.op != "call_module":
|
||||
raise AssertionError(f"Expected call_module, got '{node.op}'")
|
||||
if not isinstance(node.target, str):
|
||||
raise AssertionError(f"Expected str, but got {type(node.target)}")
|
||||
|
||||
mod = getattr_from_fqn(gm, node.target)
|
||||
is_known_fp32_or_int8_input_module = any(
|
||||
isinstance(mod, target_type) # type: ignore[arg-type]
|
||||
@ -87,7 +91,8 @@ def get_node_first_input_and_output_type(
|
||||
# A logger or observer's input and output type is the output
|
||||
# type of the preceding node.
|
||||
first_arg = get_normalized_nth_input(node, gm, 0)
|
||||
assert isinstance(first_arg, Node)
|
||||
if not isinstance(first_arg, Node):
|
||||
raise AssertionError(f"Expected Node, got {type(first_arg)}")
|
||||
(
|
||||
_prev_node_input_type,
|
||||
prev_node_output_type,
|
||||
@ -116,7 +121,8 @@ def get_node_first_input_and_output_type(
|
||||
# So, we look up the output type of the previous node and return that
|
||||
# as the input type of this node instance.
|
||||
prev_node = get_normalized_nth_input(node, gm, 0)
|
||||
assert isinstance(prev_node, Node)
|
||||
if not isinstance(prev_node, Node):
|
||||
raise AssertionError(f"Expected Node, got {type(prev_node)}")
|
||||
(
|
||||
_prev_node_input_type,
|
||||
prev_node_output_type,
|
||||
@ -131,7 +137,8 @@ def get_node_first_input_and_output_type(
|
||||
# as the input type of this node instance. We also look up the target
|
||||
# of to and return the correct output type.
|
||||
prev_node = get_normalized_nth_input(node, gm, 0)
|
||||
assert isinstance(prev_node, Node)
|
||||
if not isinstance(prev_node, Node):
|
||||
raise AssertionError(f"Expected Node, got {type(prev_node)}")
|
||||
(
|
||||
_prev_node_input_type,
|
||||
prev_node_output_type,
|
||||
@ -140,15 +147,17 @@ def get_node_first_input_and_output_type(
|
||||
)
|
||||
|
||||
cur_node_dtype_target = get_normalized_nth_input(node, gm, 1)
|
||||
assert cur_node_dtype_target is torch.float16, (
|
||||
f"{cur_node_dtype_target} handling needs to be added"
|
||||
)
|
||||
if cur_node_dtype_target is not torch.float16:
|
||||
raise AssertionError(
|
||||
f"{cur_node_dtype_target} handling needs to be added"
|
||||
)
|
||||
|
||||
return (prev_node_output_type, NodeInputOrOutputType.FP16)
|
||||
|
||||
elif node.target in METHS_IO_TYPE_FP32_OR_INT8:
|
||||
first_arg = get_normalized_nth_input(node, gm, 0)
|
||||
assert isinstance(first_arg, Node)
|
||||
if not isinstance(first_arg, Node):
|
||||
raise AssertionError(f"Expected Node, got {type(first_arg)}")
|
||||
(
|
||||
_prev_node_input_type,
|
||||
prev_node_output_type,
|
||||
@ -181,8 +190,14 @@ def get_node_input_qparams(
|
||||
def _get_scale_zp_from_function_args(node, gm, scale_arg_idx, zp_arg_idx):
|
||||
scale_node = get_normalized_nth_input(node, gm, scale_arg_idx)
|
||||
zp_node = get_normalized_nth_input(node, gm, zp_arg_idx)
|
||||
assert isinstance(scale_node, Node) and isinstance(scale_node.target, str)
|
||||
assert isinstance(zp_node, Node) and isinstance(zp_node.target, str)
|
||||
if not isinstance(scale_node, Node):
|
||||
raise AssertionError(f"Expected Node, got {type(scale_node)}")
|
||||
if not isinstance(scale_node.target, str):
|
||||
raise AssertionError(f"Expected str, got {type(scale_node.target)}")
|
||||
if not isinstance(zp_node, Node):
|
||||
raise AssertionError(f"Expected Node, got {type(zp_node)}")
|
||||
if not isinstance(zp_node.target, str):
|
||||
raise AssertionError(f"Expected str, got {type(zp_node.target)}")
|
||||
scale_obj = getattr_from_fqn(gm, scale_node.target)
|
||||
zp_obj = getattr_from_fqn(gm, zp_node.target)
|
||||
return (scale_obj, zp_obj)
|
||||
@ -200,7 +215,8 @@ def get_node_input_qparams(
|
||||
|
||||
elif prev_node.op == "call_module":
|
||||
# get type of the module
|
||||
assert isinstance(prev_node.target, str)
|
||||
if not isinstance(prev_node.target, str):
|
||||
raise AssertionError(f"Expected str, got {type(prev_node.target)}")
|
||||
module_obj = getattr_from_fqn(gm, prev_node.target)
|
||||
if isinstance(
|
||||
module_obj,
|
||||
@ -259,15 +275,24 @@ def return_first_non_observer_node(
|
||||
if node.op == "call_module":
|
||||
node_obj = getattr_from_fqn(gm, node.target) # type: ignore[arg-type]
|
||||
if _is_activation_post_process(node_obj):
|
||||
assert len(node.args) == 1
|
||||
assert isinstance(node.args[0], Node)
|
||||
if len(node.args) != 1:
|
||||
raise AssertionError(
|
||||
f"Expected node.args to have length 1, got {len(node.args)}"
|
||||
)
|
||||
if not isinstance(node.args[0], Node):
|
||||
raise AssertionError(f"Expected Node, got {type(node.args[0])}")
|
||||
node = node.args[0]
|
||||
# code duplication intended, not worth refactoring
|
||||
assert isinstance(node.target, str)
|
||||
if not isinstance(node.target, str):
|
||||
raise AssertionError(f"Expected str, got {type(node.target)}")
|
||||
node_obj = getattr_from_fqn(gm, node.target)
|
||||
if _is_activation_post_process(node_obj):
|
||||
assert len(node.args) == 1
|
||||
assert isinstance(node.args[0], Node)
|
||||
if len(node.args) != 1:
|
||||
raise AssertionError(
|
||||
f"Expected node.args to have length 1, got {len(node.args)}"
|
||||
)
|
||||
if not isinstance(node.args[0], Node):
|
||||
raise AssertionError(f"Expected Node, got {type(node.args[0])}")
|
||||
node = node.args[0]
|
||||
return node
|
||||
|
||||
@ -331,7 +356,8 @@ def get_target_type_str(node: Node, gm: GraphModule) -> str:
|
||||
if node.op in ("call_function", "call_method"):
|
||||
target_type = torch.typename(node.target)
|
||||
elif node.op == "call_module":
|
||||
assert isinstance(node.target, str)
|
||||
if not isinstance(node.target, str):
|
||||
raise AssertionError(f"Expected str, got {type(node.target)}")
|
||||
target_mod = getattr_from_fqn(gm, node.target)
|
||||
target_type = torch.typename(target_mod)
|
||||
return target_type
|
||||
@ -365,7 +391,8 @@ def rekey_logger_info_on_node_name_of_model(
|
||||
for model_name_to_results in result_type_to_results.values():
|
||||
for cur_model_name, list_of_results in model_name_to_results.items():
|
||||
if cur_model_name == model_name:
|
||||
assert len(list_of_results)
|
||||
if len(list_of_results) == 0:
|
||||
raise AssertionError("Expected list_of_results to be not empty")
|
||||
new_layer_name = list_of_results[0]["ref_node_name"]
|
||||
else:
|
||||
continue
|
||||
@ -519,14 +546,20 @@ def get_normalized_nth_input(node: Node, gm: GraphModule, idx: int) -> Node:
|
||||
)
|
||||
if norm_args_and_kwargs is not None:
|
||||
norm_args, norm_kwargs = norm_args_and_kwargs
|
||||
assert len(norm_args) + len(norm_kwargs) > idx
|
||||
if len(norm_args) + len(norm_kwargs) <= idx:
|
||||
raise AssertionError(
|
||||
f"Index {idx} out of range: total = {len(norm_args) + len(norm_kwargs)}"
|
||||
)
|
||||
if idx < len(norm_args):
|
||||
return norm_args[idx]
|
||||
else:
|
||||
# note: in Python 3.7+ dicts are ordered
|
||||
return list(norm_kwargs.values())[idx]
|
||||
else:
|
||||
assert len(node.args) + len(node.kwargs) > idx
|
||||
if len(node.args) + len(node.kwargs) <= idx:
|
||||
raise AssertionError(
|
||||
f"Index {idx} out of range: total = {len(node.args) + len(node.kwargs)}"
|
||||
)
|
||||
if idx < len(node.args):
|
||||
return node.args[idx] # type: ignore[return-value]
|
||||
else:
|
||||
@ -536,7 +569,10 @@ def get_normalized_nth_input(node: Node, gm: GraphModule, idx: int) -> Node:
|
||||
# this RuntimeError happens when node argument normalization
|
||||
# requires typehints to proceed, such as for torch.add where
|
||||
# either the first, second or both arguments could be tensors
|
||||
assert len(node.args) + len(node.kwargs) > idx
|
||||
if len(node.args) + len(node.kwargs) <= idx:
|
||||
raise AssertionError(
|
||||
f"Index {idx} out of range: total = {len(node.args) + len(node.kwargs)}"
|
||||
) from None
|
||||
if idx < len(node.args):
|
||||
return node.args[idx] # type: ignore[return-value]
|
||||
else:
|
||||
|
||||
@ -77,7 +77,8 @@ def get_lstm_mod_weights(mod: nn.Module) -> list[torch.Tensor]:
|
||||
res.append(param_value)
|
||||
return res
|
||||
else:
|
||||
assert isinstance(mod, nnqd.LSTM), f"type {type(mod)} not handled yet"
|
||||
if not isinstance(mod, nnqd.LSTM):
|
||||
raise AssertionError(f"type {type(mod)} not handled yet")
|
||||
res = []
|
||||
for weight_value in mod._all_weight_values:
|
||||
res.append(
|
||||
@ -92,10 +93,13 @@ def get_lstm_mod_weights(mod: nn.Module) -> list[torch.Tensor]:
|
||||
def get_conv_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
|
||||
# traverse backwards from the weight arg, accounting for any observers
|
||||
weight_arg_node = node.args[1]
|
||||
assert isinstance(weight_arg_node, Node)
|
||||
if not isinstance(weight_arg_node, Node):
|
||||
raise AssertionError(f"Expected Node, got {type(weight_arg_node)}")
|
||||
weight_node = return_first_non_observer_node(weight_arg_node, gm)
|
||||
assert isinstance(weight_node, Node)
|
||||
assert weight_node.op == "get_attr"
|
||||
if not isinstance(weight_node, Node):
|
||||
raise AssertionError(f"Expected Node, got {type(weight_node)}")
|
||||
if weight_node.op != "get_attr":
|
||||
raise AssertionError(f"Expected get_attr, got {weight_node.op}")
|
||||
weight = getattr_from_fqn(gm, weight_node.target) # type: ignore[arg-type]
|
||||
return weight.detach()
|
||||
|
||||
@ -103,8 +107,10 @@ def get_conv_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
|
||||
def get_qconv_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
|
||||
# qconv state is arg 1
|
||||
qconv_state_node = node.args[1]
|
||||
assert isinstance(qconv_state_node, Node)
|
||||
assert qconv_state_node.op == "get_attr"
|
||||
if not isinstance(qconv_state_node, Node):
|
||||
raise AssertionError(f"Expected Node, got {type(qconv_state_node)}")
|
||||
if qconv_state_node.op != "get_attr":
|
||||
raise AssertionError(f"Expected get_attr, got {qconv_state_node.op}")
|
||||
qconv_state_obj = getattr_from_fqn(gm, qconv_state_node.target) # type: ignore[arg-type]
|
||||
return qconv_state_obj.weight()
|
||||
|
||||
@ -115,34 +121,44 @@ def get_linear_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
|
||||
# weight -> obs -> linear
|
||||
# weight -> to(torch.float16) -> dequantize -> linear
|
||||
linear_second_arg = node.args[1]
|
||||
assert isinstance(linear_second_arg, Node)
|
||||
if not isinstance(linear_second_arg, Node):
|
||||
raise AssertionError(f"Expected Node, got {type(linear_second_arg)}")
|
||||
|
||||
if linear_second_arg.op == "call_module":
|
||||
# weight -> obs -> linear
|
||||
weight_arg_node = node.args[1]
|
||||
assert isinstance(weight_arg_node, Node)
|
||||
if not isinstance(weight_arg_node, Node):
|
||||
raise AssertionError(f"Expected Node, got {type(weight_arg_node)}")
|
||||
weight_node = weight_arg_node.args[0]
|
||||
assert isinstance(weight_node, Node)
|
||||
assert weight_node.op == "get_attr"
|
||||
if not isinstance(weight_node, Node):
|
||||
raise AssertionError(f"Expected Node, got {type(weight_node)}")
|
||||
if weight_node.op != "get_attr":
|
||||
raise AssertionError(f"Expected get_attr, got {weight_node.op}")
|
||||
weight = getattr_from_fqn(gm, weight_node.target) # type: ignore[arg-type]
|
||||
return weight.detach()
|
||||
elif linear_second_arg.op == "call_method":
|
||||
# weight -> to(torch.float16) -> dequantize -> linear
|
||||
assert linear_second_arg.op == "call_method"
|
||||
if linear_second_arg.op != "call_method":
|
||||
raise AssertionError(f"Expected call_method, got {linear_second_arg.op}")
|
||||
dequant_node = node.args[1]
|
||||
assert isinstance(dequant_node, Node)
|
||||
if not isinstance(dequant_node, Node):
|
||||
raise AssertionError(f"Expected Node, got {type(dequant_node)}")
|
||||
to_fp16_node = dequant_node.args[0]
|
||||
assert isinstance(to_fp16_node, Node)
|
||||
if not isinstance(to_fp16_node, Node):
|
||||
raise AssertionError(f"Expected Node, got {type(to_fp16_node)}")
|
||||
# extract the dtype, so we can cast to it before returning
|
||||
target_dtype = to_fp16_node.args[1]
|
||||
weight_node = to_fp16_node.args[0]
|
||||
assert isinstance(weight_node, Node)
|
||||
assert weight_node.op == "get_attr"
|
||||
if not isinstance(weight_node, Node):
|
||||
raise AssertionError(f"Expected Node, got {type(weight_node)}")
|
||||
if weight_node.op != "get_attr":
|
||||
raise AssertionError(f"Expected get_attr, got {weight_node.op}")
|
||||
weight = getattr_from_fqn(gm, weight_node.target) # type: ignore[arg-type]
|
||||
# return the weight with fp16 cast
|
||||
return weight.detach().to(target_dtype)
|
||||
else:
|
||||
assert linear_second_arg.op == "get_attr"
|
||||
if linear_second_arg.op != "get_attr":
|
||||
raise AssertionError(f"Expected get_attr, got {linear_second_arg.op}")
|
||||
weight = getattr_from_fqn(gm, linear_second_arg.target) # type: ignore[arg-type]
|
||||
return weight.detach()
|
||||
|
||||
@ -150,8 +166,10 @@ def get_linear_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
|
||||
def get_qlinear_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
|
||||
# packed weight is arg 1
|
||||
packed_weight_node = node.args[1]
|
||||
assert isinstance(packed_weight_node, Node)
|
||||
assert packed_weight_node.op == "get_attr"
|
||||
if not isinstance(packed_weight_node, Node):
|
||||
raise AssertionError(f"Expected Node, got {type(packed_weight_node)}")
|
||||
if packed_weight_node.op != "get_attr":
|
||||
raise AssertionError(f"Expected get_attr, got {packed_weight_node.op}")
|
||||
packed_weight = getattr_from_fqn(gm, packed_weight_node.target) # type: ignore[arg-type]
|
||||
# TODO(future PR): why does packed_weight.unpack() not work?
|
||||
(weight, _bias), _name = packed_weight.__getstate__()
|
||||
@ -264,7 +282,8 @@ def extract_weight_from_node(
|
||||
|
||||
elif node.op == "call_module":
|
||||
# for call_module, we need to look up the modules to do the type check
|
||||
assert isinstance(node.target, str)
|
||||
if not isinstance(node.target, str):
|
||||
raise AssertionError(f"Expected str, got {type(node.target)}")
|
||||
mod = getattr_from_fqn(gm, node.target)
|
||||
module_mapping = op_to_type_to_weight_extraction_fn["call_module"]
|
||||
for target_mod_type, weight_extraction_fn in module_mapping.items():
|
||||
|
||||
@ -3259,7 +3259,7 @@ struct to_ir {
|
||||
case TK_IN:
|
||||
return aten::__contains__;
|
||||
default:
|
||||
throw std::runtime_error("unknown kind " + std::to_string(kind));
|
||||
TORCH_CHECK(false, "unknown kind ", kind);
|
||||
}
|
||||
}
|
||||
|
||||
@ -3306,7 +3306,7 @@ struct to_ir {
|
||||
case TK_RSHIFT:
|
||||
return "__rshift__";
|
||||
default:
|
||||
throw std::runtime_error("unknown kind " + std::to_string(kind));
|
||||
TORCH_CHECK(false, "unknown kind ", kind);
|
||||
}
|
||||
}
|
||||
|
||||
@ -4120,8 +4120,7 @@ struct to_ir {
|
||||
} else if (kind == aten::ge) {
|
||||
return aten::le;
|
||||
}
|
||||
throw std::runtime_error(
|
||||
"reverseComparision: unsupported NodeKind. File a bug");
|
||||
TORCH_CHECK(false, "reverseComparision: unsupported NodeKind. File a bug");
|
||||
}
|
||||
|
||||
// any expression that can produce a SugaredValue is handled here
|
||||
|
||||
@ -94,7 +94,7 @@ C10_EXPORT std::string kindToString(int kind) {
|
||||
TC_FORALL_TOKEN_KINDS(DEFINE_CASE)
|
||||
#undef DEFINE_CASE
|
||||
default:
|
||||
throw std::runtime_error("Unknown kind: " + std::to_string(kind));
|
||||
TORCH_CHECK(false, "Unknown kind: ", kind);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -167,12 +167,12 @@ Value* TracingState::getValue(const IValue& var) {
|
||||
// Didn't find it. Bake in a constant
|
||||
if (ten.requires_grad()) {
|
||||
pauseTracing();
|
||||
std::ostringstream oss;
|
||||
oss << "Cannot insert a Tensor that requires grad as a constant. "
|
||||
<< "Consider making it a parameter or input, or detaching the gradient\n"
|
||||
<< "Tensor:\n"
|
||||
<< ten;
|
||||
throw std::runtime_error(oss.str());
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Cannot insert a Tensor that requires grad as a constant. ",
|
||||
"Consider making it a parameter or input, or detaching the gradient\n",
|
||||
"Tensor:\n",
|
||||
ten);
|
||||
}
|
||||
|
||||
Value* constant = graph->insertConstant(ten);
|
||||
@ -208,15 +208,19 @@ Value* TracingState::getValue(const IValue& var) {
|
||||
}
|
||||
}
|
||||
|
||||
std::ostringstream oss;
|
||||
if (var.isFuture()) {
|
||||
oss << "Tried to trace Future or Object that the tracer was not aware of.";
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Tried to trace Future or Object that the tracer was not aware of.");
|
||||
} else {
|
||||
oss << "Tried to trace " << var
|
||||
<< " but it is not part of the active trace. Modules that are called during a trace"
|
||||
<< " must be registered as submodules of the thing being traced.";
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Tried to trace ",
|
||||
var,
|
||||
" but it is not part of the active trace. Modules that are called during a trace",
|
||||
" must be registered as submodules of the thing being traced.");
|
||||
}
|
||||
throw std::runtime_error(oss.str());
|
||||
|
||||
} else {
|
||||
// If the values are non-tensors, we try to create constants
|
||||
// and bake those constants into the traced graph
|
||||
@ -225,11 +229,12 @@ Value* TracingState::getValue(const IValue& var) {
|
||||
recordSourceLocation(constant.value()->node());
|
||||
return *constant;
|
||||
}
|
||||
std::ostringstream os;
|
||||
os << "Tracer cannot get value trace for type " << var.tagKind() << ". "
|
||||
<< "The below value could not be materialized as a constant:\n"
|
||||
<< var;
|
||||
throw std::runtime_error(os.str());
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Tracer cannot get value trace for type ",
|
||||
var.tagKind(),
|
||||
". The below value could not be materialized as a constant:\n",
|
||||
var);
|
||||
}
|
||||
}
|
||||
bool TracingState::hasValue(const IValue& var) const {
|
||||
@ -252,15 +257,14 @@ Value* TracingState::getOutput(const IValue& iv, size_t i) {
|
||||
|
||||
auto& value_map = getTracingState()->env_stack.back();
|
||||
auto it = value_map.find(iv);
|
||||
if (it == value_map.end()) {
|
||||
std::ostringstream os;
|
||||
os << "output " << i << " (" << var
|
||||
<< ") of traced region did not have observable "
|
||||
<< "data dependence with trace inputs; this probably indicates your "
|
||||
"program "
|
||||
<< "cannot be understood by the tracer.";
|
||||
throw std::runtime_error(os.str());
|
||||
}
|
||||
TORCH_CHECK(
|
||||
it != value_map.end(),
|
||||
"output ",
|
||||
i,
|
||||
" (",
|
||||
var,
|
||||
") of traced region did not have observable data dependence with trace inputs; ",
|
||||
"this probably indicates your program cannot be understood by the tracer.");
|
||||
return it->second;
|
||||
} else if (iv.isTensorList()) {
|
||||
if (tracing_mode_strict) {
|
||||
@ -281,11 +285,10 @@ Value* TracingState::getOutput(const IValue& iv, size_t i) {
|
||||
graph->insertNode(tuple_node);
|
||||
return tuple_node->output();
|
||||
} else if (iv.isGenericDict()) {
|
||||
if (tracing_mode_strict) {
|
||||
throw std::runtime_error(
|
||||
"Encountering a dict at the output of the tracer" +
|
||||
std::string(STRICT_TRACER_MSG));
|
||||
}
|
||||
TORCH_CHECK(
|
||||
!tracing_mode_strict,
|
||||
"Encountering a dict at the output of the tracer",
|
||||
STRICT_TRACER_MSG);
|
||||
auto dict = iv.toGenericDict();
|
||||
TypePtr key_type = dict.keyType();
|
||||
TypePtr value_type = dict.valueType();
|
||||
@ -304,15 +307,15 @@ Value* TracingState::getOutput(const IValue& iv, size_t i) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!key_type_valid || !value_type_valid) {
|
||||
std::ostringstream os;
|
||||
os << "output " << i << " (" << dict << ") of traced region "
|
||||
<< "cannot be understood by the tracer, only outputs matching"
|
||||
<< "dict[Union[str, Tensor], Union[Tensor, Tuple[Tensor, ...]]] "
|
||||
<< "can be a dictionary output of a traced function";
|
||||
throw std::runtime_error(os.str());
|
||||
}
|
||||
TORCH_CHECK(
|
||||
key_type_valid && value_type_valid,
|
||||
"output ",
|
||||
i,
|
||||
" (",
|
||||
dict,
|
||||
") of traced region cannot be understood by the tracer, only outputs matching ",
|
||||
"dict[Union[str, Tensor], Union[Tensor, Tuple[Tensor, ...]]] ",
|
||||
"can be a dictionary output of a traced function");
|
||||
std::vector<Value*> keys;
|
||||
std::vector<Value*> values;
|
||||
for (const auto& entry : dict) {
|
||||
@ -598,10 +601,11 @@ void TracingState::setValue(const IValue& v, Value* value) {
|
||||
setValue(entry.value(), static_value);
|
||||
}
|
||||
} else {
|
||||
std::ostringstream os;
|
||||
os << "Tracer cannot set value trace for type " << v.tagKind() << ". "
|
||||
<< "Supported types are tensor, tensor list, and tuple of tensors.";
|
||||
throw std::runtime_error(os.str());
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Tracer cannot set value trace for type ",
|
||||
v.tagKind(),
|
||||
". Supported types are tensor, tensor list, and tuple of tensors.");
|
||||
}
|
||||
}
|
||||
|
||||
@ -801,11 +805,10 @@ void addInputs(Node* n, const char* name, at::IntArrayRef value) {
|
||||
recordSourceLocation(info[i]->node());
|
||||
}
|
||||
for (jit::Value* v : info) {
|
||||
if (*v->type() != *jit::IntType::get()) {
|
||||
throw std::runtime_error(
|
||||
"Type mismatch in setposattr for IntArrayRef. Check that your program "
|
||||
"is valid without tracing, and please file a bug report if it is.");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
*v->type() == *jit::IntType::get(),
|
||||
"Type mismatch in setposattr for IntArrayRef. Check that your program "
|
||||
"is valid without tracing, and please file a bug report if it is.");
|
||||
}
|
||||
n->addInput(
|
||||
g->insertNode(g->createList(jit::IntType::get(), info))->output());
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/SmallVector.h>
|
||||
#include <c10/util/intrusive_ptr.h>
|
||||
#include <torch/csrc/jit/frontend/lexer.h>
|
||||
@ -37,10 +38,10 @@ struct Tree : c10::intrusive_ptr_target {
|
||||
return true;
|
||||
}
|
||||
virtual const SourceRange& range() const {
|
||||
throw std::runtime_error("is an Atom");
|
||||
TORCH_CHECK(false, "is an Atom");
|
||||
}
|
||||
virtual const std::string& stringValue() const {
|
||||
throw std::runtime_error("stringValue can only be called on TK_STRING");
|
||||
TORCH_CHECK(false, "stringValue can only be called on TK_STRING");
|
||||
}
|
||||
virtual const TreeList& trees() const {
|
||||
static const TreeList empty_trees = {};
|
||||
@ -79,13 +80,16 @@ struct Tree : c10::intrusive_ptr_target {
|
||||
int lineno,
|
||||
size_t expected_subtrees,
|
||||
bool allow_more) const {
|
||||
if (kind() != k) {
|
||||
std::stringstream ss;
|
||||
ss << filename << ":" << lineno << ": expecting kind '" << kindToString(k)
|
||||
<< "' but found '" << kindToString(kind()) << "'\n";
|
||||
range().highlight(ss);
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
TORCH_CHECK(
|
||||
kind() == k,
|
||||
filename,
|
||||
":",
|
||||
lineno,
|
||||
": expecting kind '",
|
||||
kindToString(k),
|
||||
"' but found '",
|
||||
kindToString(kind()),
|
||||
"'\n");
|
||||
if (trees().size() < expected_subtrees ||
|
||||
(!allow_more && trees().size() != expected_subtrees)) {
|
||||
std::stringstream ss;
|
||||
@ -93,7 +97,7 @@ struct Tree : c10::intrusive_ptr_target {
|
||||
<< expected_subtrees << " subtrees, but found only " << trees().size()
|
||||
<< "\n";
|
||||
range().highlight(ss);
|
||||
throw std::runtime_error(ss.str());
|
||||
TORCH_CHECK(false, ss.str());
|
||||
}
|
||||
}
|
||||
~Tree() override = default;
|
||||
|
||||
@ -367,11 +367,6 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
|
||||
The reconstructed pytree, containing the ``leaves`` placed in the structure described by
|
||||
``treespec``.
|
||||
"""
|
||||
if not _is_pytreespec_instance(treespec):
|
||||
raise TypeError(
|
||||
f"tree_unflatten(leaves, treespec): Expected `treespec` to be instance of "
|
||||
f"PyTreeSpec but got item of type {type(treespec)}."
|
||||
)
|
||||
return optree.tree_unflatten(treespec, leaves) # type: ignore[arg-type]
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user