Compare commits

...

14 Commits

Author SHA1 Message Date
f47cadf75d [BE][Typing][Dynamo] Type torch/_dynamo/variables/lists.py (#167156)
Provides type coverage to torch/_dynamo/variables/dicts.py

Coverage report:
`mypy torch/_dynamo/variables/lists.py --linecount-report /tmp/coverage_log`

Compare before to after - we go from 0 lines and 0 funcs covered to 1759 lines and 102 funcs covered

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167156
Approved by: https://github.com/Skylion007, https://github.com/rtimpe
2025-11-07 00:15:40 +00:00
2923b02c6e [DTensor] add explicit mode (ExplicitRedistributionContext) (#166593)
usage:

```
dx = distribute_tensor(x, device_mesh, [Shard(0)])
dA = distribute_tensor(A, device_mesh, [Shard(0)])
with ExplicitRedistributionContext():
    with self.assertRaisesRegex(RuntimeError, "Implicit redistribution"):
        # Shard(0) @ Shard(0) requires a redistribution
        torch.matmul(dx, dA)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166593
Approved by: https://github.com/ezyang
2025-11-07 00:04:19 +00:00
4b9ba0fb26 [user-streams] Add requires cuda to all test cases (#167195)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167195
Approved by: https://github.com/Lucaskabela
ghstack dependencies: #167175, #167176, #167180
2025-11-06 23:13:47 +00:00
106d34c80a [user-streams] add requires cuda decorator (#167180)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167180
Approved by: https://github.com/donigian, https://github.com/Lucaskabela, https://github.com/Skylion007
ghstack dependencies: #167175, #167176
2025-11-06 23:13:47 +00:00
0b06109412 [user-streams] Fix bug in object bytecode construction (#167176)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167176
Approved by: https://github.com/Lucaskabela
ghstack dependencies: #167175
2025-11-06 23:13:47 +00:00
2073af5790 [user-streams] Refactor user object index in streams (#167175)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167175
Approved by: https://github.com/Lucaskabela
2025-11-06 23:13:47 +00:00
9b4ac45d2f Revert "[Inductor] addmm with bias -> unfuse bias if there is a pointwise/reduction consumer (#166165)"
This reverts commit eefa16342c9f322b56c7c0cd6d309c3ed8f0b882.

Reverted https://github.com/pytorch/pytorch/pull/166165 on behalf of https://github.com/jeanschmidt due to Breaking internal tests D86216934 ([comment](https://github.com/pytorch/pytorch/pull/166165#issuecomment-3499645688))
2025-11-06 22:34:48 +00:00
a45a17f65e Fix boxcox to return same result for same input in one batch (#166986)
Summary:
The SIMD path is using SLEEF version of pow which is slightly different from std::pow. The fix is to use the same vectorized code (with partial load and store) for the trailing data as well to ensure consistency between results.

Deploy:
Need to make a hotfix in waas to monitor release signals, since this diff can cause testing failures in veloski and waas release correctness tests.

Test Plan: Sandcastle.

Differential Revision: D86218207

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166986
Approved by: https://github.com/swolchok
2025-11-06 22:33:26 +00:00
c5593e75b3 Fix flaky memory profiler test (#167168)
Fixes #167037

Do not check the exact number of frames.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167168
Approved by: https://github.com/angelayi
2025-11-06 21:39:44 +00:00
c90a976370 Update pythoncapi_compat.h (#167138)
Update to commit 44c8e14bbbb5d5135ae90957036a61397e4df577.

Should slightly simplify https://github.com/pytorch/pytorch/pull/166342
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167138
Approved by: https://github.com/albanD
2025-11-06 21:31:58 +00:00
d144382dc9 Move enrich_profiler_metadata config import out of gm.recompile() (#167114)
Fixes T243967987

Move `enrich_profiler_metadata` from `torch._dynamo.config` to `torch.fx.experimental._config`.

We cannot import anything inside recompile(), it made some perf regress internally. We move the config so we can import it at the top of `graph_module.py` without causing any circular import.

We also cannot delete the old config right now because some internal tests rely on copies of the old `graph_module.py` cpp file in unit tests. But I think we should be able to delete the old config soon after this PR lands.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167114
Approved by: https://github.com/angelayi
2025-11-06 21:21:40 +00:00
78827c5e00 Distributed Autotuning (#163369)
This is the initial prototype of distributed autotuning. It's intended to be a basis for iteration rather than the final end product.

Currently when we run a SPMD program we compile the ranks independently. As a result the autotuning is repeated on every rank. So for a 8-GPU program with 8 matmul operators we'll autotune 64 (8*8) times.

Distributed autotuning uses collectives to distribute the autotuning across the ranks so each rank autotunes 1/worldsize the total operators. So in our 8-GPU example we would only perform 8 autotunes total (one on each rank) rather than 64.

There are several advantages:
1. Faster autotuning times - each CPU/GPU does less work total
2. Better determinism - currently it's possible for two ranks to choose different algorithms for the same operator. With distributed autotuning we choose the algorithm once for the entire program.

Results:

In testing using llama3 8B on torchtitan max-autotune time was reduced from 52s -> 26s and exhaustive-autotuning was reduced from 2009s -> 613s.

Usage:

The feature is controlled by the environment variable TORCHINDUCTOR_DISTRIBUTED_AUTOTUNE.

Co-authored-by: @PaulZhang12

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163369
Approved by: https://github.com/PaulZhang12
2025-11-06 21:10:21 +00:00
ab1e734cd7 [ez] avoid log spam when random data is generated (#166919)
It's annoying to see full screen of this warning when running fx_graph_runnable files saved in tlparse.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166919
Approved by: https://github.com/eellison
2025-11-06 21:05:20 +00:00
888958ad6c Prevent torch._check causing graph breaks (#164676)
Handle `torch._check` in `TorchInGraphFunctionVariable.call_function`. Basically, it has two arguments - a predicate (bool) and a message (callable). If predicate is a constant, evaluate `torch._check`. If predicate is true, it just will compile and nothing happens. If predicate is false, `torch._check` will raise an exception.

If predicate is not constant, we manually emit a proxy. I tried to build as_proxy() inside NestedUserFunctionVariable, but failed to, that's why I create it here. I try to extract message. If it's a function, I retrieve it. If not, set it to None. Maybe we could extract it if message is a closure, but not sure how

Fixes #163668

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164676
Approved by: https://github.com/williamwen42, https://github.com/mlazos

Co-authored-by: William Wen <william.wen42@gmail.com>
2025-11-06 21:00:48 +00:00
36 changed files with 2635 additions and 312 deletions

View File

@ -73,6 +73,19 @@ void box_cox_zero_lambda(
}
}
template <typename T>
at::vec::Vectorized<T> box_cox_nonzero_lambda_impl(
at::vec::Vectorized<T> data,
at::vec::Vectorized<T> lambda1,
at::vec::Vectorized<T> lambda2,
at::vec::Vectorized<T> k_eps) {
auto sum = data + lambda2;
auto max = at::vec::max(sum, k_eps);
auto lambda_over_1 = at::vec::fast_recieprocal(lambda1);
auto pow = max.pow(lambda1);
return at::vec::fmsub(pow, lambda_over_1, lambda_over_1);
}
template <typename T>
void box_cox_nonzero_lambda(
int64_t D,
@ -88,21 +101,18 @@ void box_cox_nonzero_lambda(
auto k_eps_vec = Vec(k_eps);
for(; j + VLEN < D; j += VLEN) {
auto data = Vec::loadu(data_ptr + j);
auto lambda2 = Vec::loadu(lambda2_ptr + j);
auto sum = data + lambda2;
auto max = at::vec::max(sum, k_eps_vec);
auto lambda1 = Vec::loadu(lambda1_ptr + j);
auto lambda_over_1 = at::vec::fast_recieprocal(lambda1);
auto pow = max.pow(lambda1);
auto res = at::vec::fmsub(pow, lambda_over_1, lambda_over_1);
auto lambda2 = Vec::loadu(lambda2_ptr + j);
auto res = box_cox_nonzero_lambda_impl(data, lambda1, lambda2, k_eps_vec);
res.store(out + j);
}
for ( ;j < D; ++j) {
auto sum = data_ptr[j] + lambda2_ptr[j];
auto max = std::max(sum, k_eps);
auto lambda_over_1 = at::vec::fast_recieprocal(lambda1_ptr[j]);
auto pow = std::pow(max, lambda1_ptr[j]);
out[j] = pow * lambda_over_1 - lambda_over_1;
if (j < D) {
auto remaining = D - j;
auto data = Vec::loadu(data_ptr + j, remaining);
auto lambda1 = Vec::loadu(lambda1_ptr + j, remaining);
auto lambda2 = Vec::loadu(lambda2_ptr + j, remaining);
auto res = box_cox_nonzero_lambda_impl(data, lambda1, lambda2, k_eps_vec);
res.store(out + j, remaining);
}
}
#else

View File

@ -1,11 +1,18 @@
# Owner(s): ["oncall: distributed"]
import itertools
from contextlib import nullcontext
from typing import Any
import torch
import torch.distributed as dist
from torch.distributed._local_tensor import (
local_tensor_mode,
LocalTensor,
LocalTensorMode,
)
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import distribute_tensor, DTensor
from torch.distributed.tensor import DeviceMesh, distribute_tensor, DTensor
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
from torch.distributed.tensor._utils import (
_compute_local_shape_and_global_offset,
@ -14,6 +21,7 @@ from torch.distributed.tensor._utils import (
compute_global_tensor_shape,
compute_local_shape_and_global_offset,
compute_local_tensor_info,
ExplicitRedistributionContext,
)
from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.placement_types import (
@ -851,5 +859,93 @@ class Test2DStridedLocalShard(DTensorTestBase):
self.assertEqual(global_tensor, dtensor_2d.full_tensor())
class LocalTensorTestBase(TestCase):
def assertEqual(self, lhs, rhs, **kwargs):
mode = local_tensor_mode()
with nullcontext() if mode is None else mode.disable():
if isinstance(lhs, LocalTensor) and isinstance(rhs, LocalTensor):
assert isinstance(lhs, LocalTensor) and isinstance(rhs, LocalTensor)
super().assertEqual(lhs._ranks, rhs._ranks)
for r in lhs._ranks:
super().assertEqual(
lhs._local_tensors[r],
rhs._local_tensors[r],
lambda m: f"rank {r}: {m}",
)
elif isinstance(lhs, LocalTensor) or isinstance(rhs, LocalTensor):
lhs, rhs = (lhs, rhs) if isinstance(lhs, LocalTensor) else (rhs, lhs)
for r in lhs._ranks:
super().assertEqual(
lhs._local_tensors[r], rhs, lambda m: f"rank {r}: {m}"
)
else:
return super().assertEqual(lhs, rhs, **kwargs)
@property
def world_size(self):
raise NotImplementedError("override world-size in your subclass")
def build_device_mesh(self) -> DeviceMesh:
return init_device_mesh("cpu", (self.world_size,))
def setUp(self):
super().setUp()
torch.distributed.init_process_group(
# TODO: test other ranks too
"fake",
rank=0,
world_size=self.world_size,
)
def tearDown(self):
super().tearDown()
try:
dist.destroy_process_group()
except AssertionError:
pass
class TestExplicitRedistribute(LocalTensorTestBase):
@property
def world_size(self):
return 4
def test_explicit_matmul(self):
with LocalTensorMode(self.world_size):
device_mesh = self.build_device_mesh()
dim = 128
x = torch.randn(8, dim, requires_grad=True)
A = torch.randn(dim, dim, requires_grad=True)
# Prepare DTensors
dx = distribute_tensor(x, device_mesh, [Shard(0)])
dA = distribute_tensor(A, device_mesh, [Shard(0)])
# implicit redistribute works as usual by default
with CommDebugMode() as comm_mode:
torch.matmul(dx, dA)
self.assertEqual(comm_mode.get_total_counts(), 1)
# explicit redistribute works too
with ExplicitRedistributionContext():
with self.assertRaisesRegex(RuntimeError, "Implicit redistribution"):
torch.matmul(dx, dA)
# explicit redistribute allows manual redistribute
with ExplicitRedistributionContext():
dA_repl = dA.redistribute(device_mesh, [Replicate()])
torch.matmul(dx, dA_repl)
dx = distribute_tensor(x, device_mesh, [Shard(0)])
dA = distribute_tensor(A, device_mesh, [Replicate()])
with ExplicitRedistributionContext():
dY = torch.matmul(dx, dA_repl)
loss = dY.sum()
# we now see the error during backwards
with self.assertRaisesRegex(RuntimeError, "Implicit redistribution"):
loss.backward()
if __name__ == "__main__":
run_tests()

View File

@ -2,6 +2,7 @@
import contextlib
import copy
import functools
import logging
import random
import unittest
from contextlib import contextmanager
@ -51,6 +52,9 @@ from torch.testing._internal.inductor_utils import HAS_GPU
from torch.testing._internal.triton_utils import requires_cuda_and_triton
log = logging.getLogger(__name__)
def reset_rng_state():
torch.manual_seed(1337)
random.seed(1337)
@ -1200,6 +1204,116 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
for r in res[1:]:
self.assertEqual(res[0], r)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@patch.object(torch._dynamo.config, "enable_compiler_collectives", True)
@patch.object(torch._inductor.config, "max_autotune_gemm", True)
@patch.object(torch._inductor.config, "distributed_max_autotune_gemm", True)
def test_multiproc_autotune(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
torch._dynamo.utils.clear_compilation_metrics()
@torch.compile()
def f(a, b, c):
res = (
torch.sum((a @ b) + 1.0)
+ torch.sum(torch.relu(b @ c))
+ torch.sum(c @ a)
)
return res
a = torch.randn(1024, 1024, device=self.rank, dtype=torch.bfloat16)
b = torch.randn(1024, 2048, device=self.rank, dtype=torch.bfloat16)
c = torch.randn(2048, 1024, device=self.rank, dtype=torch.bfloat16)
try:
f(a, b, c)
except Exception:
log.exception("Caught exception running f")
raise
metrics = torch._dynamo.utils.get_compilation_metrics()
res = [None] * self.world_size
torch.distributed.all_gather_object(res, len(metrics))
for r in res[1:]:
self.assertEqual(res[0], r)
print(f"Result from {self.rank} is {f(a, b, c)}")
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@patch.object(torch._dynamo.config, "enable_compiler_collectives", True)
@patch.object(torch._inductor.config, "max_autotune_gemm", True)
@patch.object(torch._inductor.config, "distributed_max_autotune_gemm", True)
def test_multiproc_autotune_dynamic_shapes(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
torch._dynamo.utils.clear_compilation_metrics()
@torch.compile()
def f(a, b, c):
res = (
torch.sum((a @ b) + 1.0)
+ torch.sum(torch.relu(b @ c))
+ torch.sum(c @ a)
)
return res
a = torch.randn(1024, 1024, device=self.rank, dtype=torch.bfloat16)
b = torch.randn(1024, 2048, device=self.rank, dtype=torch.bfloat16)
c = torch.randn(2048, 1024, device=self.rank, dtype=torch.bfloat16)
# Mark tensors as dynamic on dimension 0
torch._dynamo.mark_dynamic(a, 0)
torch._dynamo.mark_dynamic(a, 1)
torch._dynamo.mark_dynamic(b, 0)
torch._dynamo.mark_dynamic(b, 1)
torch._dynamo.mark_dynamic(c, 0)
torch._dynamo.mark_dynamic(c, 1)
try:
f(a, b, c)
except Exception:
log.exception("Caught exception running f")
raise
metrics = torch._dynamo.utils.get_compilation_metrics()
res = [None] * self.world_size
torch.distributed.all_gather_object(res, len(metrics))
for r in res[1:]:
self.assertEqual(res[0], r)
print(f"Result from {self.rank} is {f(a, b, c)}")
# Store the initial compilation count
initial_compile_count = len(metrics)
# # Test with different sizes to ensure dynamic shapes work without recompilation
a2 = torch.randn(512, 512, device=self.rank, dtype=torch.bfloat16)
b2 = torch.randn(512, 2048, device=self.rank, dtype=torch.bfloat16)
c2 = torch.randn(2048, 512, device=self.rank, dtype=torch.bfloat16)
try:
result2 = f(a2, b2, c2)
print(f"Result2 from {self.rank} is {result2}")
except Exception:
log.exception("Caught exception running f with different sizes")
raise
# Verify no recompilation occurred
metrics_after = torch._dynamo.utils.get_compilation_metrics()
final_compile_count = len(metrics_after)
self.assertEqual(
initial_compile_count,
final_compile_count,
"Expected no recompilation with dynamic shapes",
)
# Verify all ranks have the same compilation count
res_after = [None] * self.world_size
torch.distributed.all_gather_object(res_after, final_compile_count)
for r in res_after[1:]:
self.assertEqual(res_after[0], r)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_get_pg_attr(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):

View File

@ -1428,6 +1428,170 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
self.assertRaises(torch._dynamo.exc.UserError, lambda: f(torch.tensor([3])))
def test_check_compiles_when_predicate_true_and_message_has_no_closure(self):
@torch.compile(backend="eager", fullgraph=True)
def f(x):
torch._check(x.shape[0] > 3, lambda: "Shape is not greater than 3")
return x + 1
x = torch.randn(4)
torch._dynamo.maybe_mark_dynamic(x, 0)
y = f(x)
self.assertEqual(y.shape, x.shape)
def test_check_compiles_when_predicate_true_constant_and_message_has_no_closure(
self,
):
@torch.compile(backend="eager", fullgraph=True)
def f(x):
torch._check(x.shape[0] > 3, lambda: "Shape is not greater than 3")
return x + 1
x = torch.randn(4)
y = f(x)
self.assertEqual(y.shape, x.shape)
def test_check_compiles_when_predicate_true_constant_and_message_None(self):
@torch.compile(backend="eager", fullgraph=True)
def f(x):
torch._check(x.shape[0] > 3)
return x + 1
x = torch.randn(4)
y = f(x)
self.assertEqual(y.shape, x.shape)
def test_check_compiles_when_predicate_true_and_message_None(self):
@torch.compile(backend="eager", fullgraph=True)
def f(x):
torch._check(x.shape[0] > 3)
return x + 1
x = torch.randn(4)
torch._dynamo.maybe_mark_dynamic(x, 0)
y = f(x)
self.assertEqual(y.shape, x.shape)
def test_check_compiles_when_predicate_true_and_message_has_global(self):
global GLOBAL_INT
GLOBAL_INT = 1
@torch.compile(backend="eager", fullgraph=True)
def f(x):
torch._check(x.shape[0] > 3, lambda: f"{GLOBAL_INT} is not greater than 3")
return x + 1
x = torch.randn(4)
torch._dynamo.maybe_mark_dynamic(x, 0)
y = f(x)
self.assertEqual(y.shape, x.shape)
def test_check_raises_at_runtime_when_predicate_false_and_message_has_global(self):
global GLOBAL_INT
GLOBAL_INT = 1
@torch.compile(backend="eager", fullgraph=True)
def f(x):
torch._check(x.shape[0] > 3, lambda: f"{GLOBAL_INT} is not greater than 3")
return x + 1
x = torch.randn(3)
torch._dynamo.maybe_mark_dynamic(x, 0)
with self.assertRaisesRegex(
RuntimeError, f"{GLOBAL_INT} is not greater than 3"
):
f(x)
def test_check_raises_at_runtime_when_predicate_false_and_message_None(self):
@torch.compile(backend="eager", fullgraph=True)
def f(x):
torch._check(x.shape[0] > 3)
return x + 1
x = torch.randn(3)
torch._dynamo.maybe_mark_dynamic(x, 0)
with self.assertRaisesRegex(RuntimeError, None):
f(x)
def test_check_raises_at_runtime_when_predicate_false_constant_and_message_None(
self,
):
@torch.compile(backend="eager", fullgraph=True)
def f(x):
torch._check(x.shape[0] > 3)
return x + 1
x = torch.randn(3)
with self.assertRaisesRegex(RuntimeError, None):
f(x)
def test_check_raises_at_runtime_when_predicate_false_and_message_has_no_closure(
self,
):
@torch.compile(backend="eager", fullgraph=True)
def f(x):
torch._check(x.shape[0] > 3, lambda: "Shape is not greater than 3")
return x + 1
x = torch.randn(3)
torch._dynamo.maybe_mark_dynamic(x, 0)
with self.assertRaisesRegex(RuntimeError, "Shape is not greater than 3"):
f(x)
def test_check_raises_at_runtime_when_predicate_false_constant_and_message_has_no_closure(
self,
):
@torch.compile(backend="eager", fullgraph=True)
def f(x):
torch._check(x.shape[0] > 3, lambda: "Shape is not greater than 3")
return x + 1
x = torch.randn(3)
with self.assertRaisesRegex(RuntimeError, "Shape is not greater than 3"):
f(x)
def test_check_assert_error_at_runtime_when_predicate_false_and_message_has_closure(
self,
):
@torch.compile(backend="eager", fullgraph=True)
def f(x):
torch._check(x.shape[0] > 3, lambda: f"{x.shape[0]} is not greater than 3")
return x + 1
x = torch.randn(3)
torch._dynamo.maybe_mark_dynamic(x, 0)
with self.assertRaisesRegex(
torch._dynamo.exc.Unsupported, "Can't extract message from torch._check()"
):
f(x)
def test_check_assert_error_at_runtime_when_predicate_true_and_message_has_closure(
self,
):
@torch.compile(backend="eager", fullgraph=True)
def f(x):
torch._check(x.shape[0] > 3, lambda: f"{x.shape[0]} is not greater than 3")
return x + 1
x = torch.randn(4)
torch._dynamo.maybe_mark_dynamic(x, 0)
with self.assertRaisesRegex(
torch._dynamo.exc.Unsupported, "Can't extract message from torch._check()"
):
f(x)
def test_assert(self):
@torch.compile
def fn1(x):

View File

@ -74,13 +74,13 @@ class TestStreams(torch._dynamo.test_case.TestCase):
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
# Annotation: {'stream': None}
# Annotation: {'stream': 0}
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
# Annotation: {'stream': None}
# Annotation: {'stream': 1}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
# Annotation: {'stream': None}
# Annotation: {'stream': 1}
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 2); add_1 = None
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_2, add); add_2 = add = None
return (add_3,)
@ -196,6 +196,7 @@ class <lambda>(torch.nn.Module):
s_exp = fn(*inp)
self.assertEqual(s_act, s_exp)
@requires_cuda
def test_nested_stream_enter_exit(self):
def fn(x, y, s0, s1, s2):
with s1:
@ -229,13 +230,13 @@ class <lambda>(torch.nn.Module):
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
# Annotation: {'stream': None}
# Annotation: {'stream': 1}
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
# Annotation: {'stream': None}
# Annotation: {'stream': 2}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
# Annotation: {'stream': None}
# Annotation: {'stream': 1}
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 2); add = None
return (add_1, add_2)
""",
@ -249,6 +250,7 @@ class <lambda>(torch.nn.Module):
def test_nested_stream_enter_exit_graph_break(self):
pass
@requires_cuda
def test_local_stream_enter_exit(self):
def fn(x, y):
s2 = torch.Stream()
@ -289,6 +291,7 @@ class <lambda>(torch.nn.Module):
""",
)
@requires_cuda
def test_local_stream_nested_enter_exit(self):
def fn(x, y):
s2 = torch.Stream()
@ -331,6 +334,7 @@ class <lambda>(torch.nn.Module):
""",
)
@requires_cuda
def test_stream_with_mutation(self):
def fn(x, y):
s2 = torch.Stream()
@ -380,6 +384,7 @@ class <lambda>(torch.nn.Module):
""",
)
@requires_cuda
def test_stream_backward(self) -> None:
def fn(x, y):
s2 = torch.Stream()

View File

@ -500,13 +500,8 @@ class PaddingTest(TestCaseBase):
forward_wrapper = wrapper_codes[0]
# make sure the load for softmax is aligned
if bias:
# addmm -> mm + bias and bias is fused with softmax
softmax_load_str = "tl.load(in_out_ptr0 + (r0_1 + 30528*x0)"
else:
softmax_load_str = "tl.load(in_ptr0 + (r0_1 + 30528*x0)"
self.assertTrue(
softmax_load_str in forward_wrapper,
"tl.load(in_ptr0 + (r0_1 + 30528*x0)" in forward_wrapper,
f"forward_wrapper: {forward_wrapper}",
)

View File

@ -15310,7 +15310,7 @@ if RUN_GPU:
),
(
fn3,
"triton_poi_fused_addmm_native_layer_norm",
"triton_poi_fused_native_layer_norm_relu",
(torch.randn(4, 4, device=GPU_TYPE),),
),
]
@ -15323,7 +15323,7 @@ if RUN_GPU:
),
(
fn3,
"triton_poi_fused_LayerNorm_Linear_ReLU",
"triton_poi_fused_LayerNorm_ReLU",
(torch.randn(4, 4, device=GPU_TYPE),),
),
]

View File

@ -7508,6 +7508,8 @@ class TestFXMemoryProfiler(TestCase):
device = "cuda"
mod = MLPModule(device)
with tempfile.TemporaryDirectory() as tmpdir:
# reset cache to start fresh
torch.cuda.memory.empty_cache()
torch.cuda.memory._record_memory_history()
compiled = torch.compile(mod, backend="aot_eager", fullgraph=True)
result = compiled(torch.randn(10, 10, device=device))
@ -7518,10 +7520,7 @@ class TestFXMemoryProfiler(TestCase):
torch.cuda.empty_cache()
fx_frames = self.collect_frames(augmented_snapshot)
if TEST_WITH_ROCM:
self.assertGreater(len(fx_frames), 0)
else:
self.assertEqual(len(fx_frames), 12)
self.assertGreater(len(fx_frames), 2)
for frame in fx_frames:
# Every FX frame should have both node_op and node_name

View File

@ -4251,7 +4251,7 @@ def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@skipIfRocm
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
@torch.fx.experimental._config.patch("enrich_profiler_metadata", True)
def test_profiler_stack_trace_augmentation(self):
"""
Test that map_recorded_events_to_aten_ops_with_stack_trace correctly
@ -4307,7 +4307,7 @@ event=cudaLaunchKernel node=addmm_1 stack_trace=x = self.linear2(x)"""
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@skipIfRocm
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
@torch.fx.experimental._config.patch("enrich_profiler_metadata", True)
def test_profiler_multiple_modules(self):
"""
Test that multiple compiled modules under the same profiler session
@ -4351,7 +4351,7 @@ event=cudaLaunchKernel node=sub stack_trace=return x - 1"""
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@skipIfRocm
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
@torch.fx.experimental._config.patch("enrich_profiler_metadata", True)
def test_profiler_nested_graph_modules(self):
"""
Test that nested graph modules (e.g., graph modules calling subgraphs)

View File

@ -739,11 +739,8 @@ enable_aot_compile = False
# HACK: this is for testing custom ops profiling only
_custom_ops_profile: Optional[Any] = None
# Experimental: If True, graph module will register fx metadata during recompile()
enrich_profiler_metadata: bool = Config( # type: ignore[var-annotated]
default=False,
env_name_default="TORCH_ENRICH_RPOFILER_STACK_TRACE",
)
# Deprecated! Please use the config in torch/fx/experimental/_config instead.
enrich_profiler_metadata: bool = False
if TYPE_CHECKING:
from torch.utils._config_typing import * # noqa: F401, F403

View File

@ -47,7 +47,7 @@ from torch.multiprocessing.reductions import StorageWeakRef
from torch.utils._content_store import ContentStoreReader, ContentStoreWriter
from . import config
from .utils import clone_inputs, get_debug_dir
from .utils import clone_inputs, get_debug_dir, warn_once
if TYPE_CHECKING:
@ -617,7 +617,7 @@ class InputReader:
# way would be very mysterious! Would have been better
# not to store device in the serialized format...
return storage
log.warning("could not load %s, generating random data instead", storage_hash)
warn_once(f"could not load {storage_hash}, generating random data instead")
shape = (nbytes // dtype_hint.itemsize,)
stride = _stride_or_default(None, shape=shape)
return rand_strided(shape, stride, dtype_hint, device).untyped_storage()

View File

@ -2937,5 +2937,18 @@
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
]
}
],
"GB0288": [
{
"Gb_type": "Can't extract message from torch._check()",
"Context": "str(message_vt)",
"Explanation": "The second argument of torch._check() must be a functiondefined within the torch.compile regionthat does not reference a non-local variable.",
"Hints": [
"Make sure the message function is defined in the torch.compile region.",
"Remove any closure variables, e.g. ",
"remove references to closure variable `x` in `lambda: f'{x} failed check'`",
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
]
}
]
}

View File

@ -1542,7 +1542,7 @@ class OutputGraph(OutputGraphCommon):
)
)
tmp_vars = []
for constructor in reversed(index_to_bytecode_constructor.values()):
for constructor in index_to_bytecode_constructor.values():
constructor(codegen)
var_name = (
self.new_var()

View File

@ -3228,7 +3228,7 @@ class InstructionTranslatorBase(
def BUILD_SLICE(self, inst: Instruction) -> None:
items = self.popn(inst.argval)
self.push(SliceVariable(items, tx=self))
self.push(SliceVariable(items, tx=self)) # type: ignore[arg-type]
def BUILD_LIST(self, inst: Instruction) -> None:
items = self.popn(inst.argval)
@ -3607,7 +3607,7 @@ class InstructionTranslatorBase(
obj = self.stack[-inst.arg]
assert isinstance(obj, ListVariable)
assert obj.is_mutable()
obj.call_method(self, "extend", [v], {})
obj.call_method(self, "extend", [v], {}) # type: ignore[arg-type]
def LIST_TO_TUPLE(self, inst: Instruction) -> None:
self.push(BuiltinVariable(tuple).call_function(self, [self.pop()], {})) # type: ignore[arg-type]
@ -3673,7 +3673,7 @@ class InstructionTranslatorBase(
def MATCH_KEYS(self, inst: Instruction) -> None:
tos = self.stack[-1]
assert isinstance(tos, TupleVariable)
keys = tos.unpack_var_sequence(self)
keys = tos.unpack_var_sequence(self) # type: ignore[arg-type]
tos1 = self.stack[-2]
assert isinstance(tos1, ConstDictVariable)

View File

@ -180,6 +180,7 @@ manual_torch_name_rule_map: dict[
"torch.compiler.is_exporting": TorchInGraphFunctionVariable,
"torch._C._to_dlpack": SkipFunctionVariable,
"torch.to_dlpack": SkipFunctionVariable,
"torch._check": TorchInGraphFunctionVariable,
# We graph break on RNG state setters or getters like
# `torch.get_rng_state` or `torch.set_rng_state`. These functions
# are not aten operations and therefore they are completely ignored
@ -2343,7 +2344,6 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys(
"torch._check_type",
"torch._check_value",
"torch._check_with",
"torch._check",
"torch._compile._disable_dynamo",
"torch._functorch.apis.chunk_vmap",
"torch._functorch.batch_norm_replacement.batch_norm_without_running_stats",

View File

@ -1061,9 +1061,7 @@ class VariableBuilder:
)
set_example_value(stream_proxy.node, value)
var = StreamVariable(
stream_proxy,
value,
source=self.source,
stream_proxy, value, source=self.source, user_object_index=index
)
return self.tx.output.side_effects.track_object_existing(value, var)
elif isinstance(value, (torch._C._SDPAParams)):
@ -3006,14 +3004,16 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
return SymNodeVariable(proxy, example_value, **options)
elif (
isinstance(example_value, torch.Stream)
and proxy.node.target
in (get_external_object_by_index, torch.accelerator.current_stream)
and proxy.node.target == get_external_object_by_index
) or proxy.node.target in [
device_interface.current_stream
for _, device_interface in get_registered_device_interfaces()
]:
set_example_value(proxy.node, example_value)
return StreamVariable(proxy, example_value, **options)
index = None
if proxy.node.target == get_external_object_by_index:
index = proxy.node.args[0]
return StreamVariable(proxy, example_value, index, **options)
elif (
inspect.isclass(proxy.node.target)
and issubclass(proxy.node.target, torch.Event)

View File

@ -1513,7 +1513,7 @@ class WithExitFunctionVariable(VariableTracker):
# Note here we reconstruct the context manager rather than the
# exit function. The handler generated by BlockStackEntry
# will re-enter the context in the resume function.
self.ctx.reconstruct_type(codegen) # type: ignore[attr-defined]
self.ctx.reconstruct_type(codegen) # type: ignore[union-attr]
if codegen.tx.output.partial_convert:
if sys.version_info >= (3, 11):
codegen.append_output(create_instruction("PUSH_NULL"))
@ -1522,10 +1522,10 @@ class WithExitFunctionVariable(VariableTracker):
# We rely on classes subtyping `GenericContextWrappingVariable`
# to implement these fns and have these attributes
codegen.extend_output(
[codegen.create_load_const(val) for val in self.ctx.target_values] # type: ignore[arg-type]
[codegen.create_load_const(val) for val in self.ctx.target_values] # type: ignore[union-attr]
)
codegen.extend_output(
create_call_function(len(self.ctx.target_values), False) # type: ignore[arg-type]
create_call_function(len(self.ctx.target_values), False) # type: ignore[union-attr]
)
codegen.append_output(create_setup_with(self.target))
codegen.append_output(create_instruction("POP_TOP"))

View File

@ -82,7 +82,8 @@ class ItertoolsVariable(VariableTracker):
for item in itertools.product(*seqs, repeat=r)
]
return variables.ListIteratorVariable(
items, mutation_type=ValueMutationNew()
items, # type: ignore[arg-type]
mutation_type=ValueMutationNew(),
)
elif (
self.value is itertools.combinations
@ -98,7 +99,8 @@ class ItertoolsVariable(VariableTracker):
for item in itertools.combinations(iterable, r):
items.append(variables.TupleVariable(list(item)))
return variables.ListIteratorVariable(
items, mutation_type=ValueMutationNew()
items, # type: ignore[arg-type]
mutation_type=ValueMutationNew(),
)
elif self.value is itertools.groupby:
if any(kw != "key" for kw in kwargs.keys()):
@ -181,7 +183,8 @@ class ItertoolsVariable(VariableTracker):
from_exc=e,
)
return variables.ListIteratorVariable(
result, mutation_type=ValueMutationNew()
result, # type: ignore[arg-type]
mutation_type=ValueMutationNew(),
)
elif self.value is itertools.repeat:
if len(args) < 2:
@ -212,7 +215,8 @@ class ItertoolsVariable(VariableTracker):
)
]
return variables.ListIteratorVariable(
items, mutation_type=ValueMutationNew()
items, # type: ignore[arg-type]
mutation_type=ValueMutationNew(),
)
else:
return super().call_function(tx, args, kwargs)

View File

@ -1,5 +1,3 @@
# mypy: ignore-errors
"""
Variable tracking implementations for list-like data structures in Dynamo.
@ -20,7 +18,7 @@ import collections
import inspect
import operator
import sys
from typing import Optional, TYPE_CHECKING
from typing import Any, Optional, Sequence, TYPE_CHECKING
import torch
import torch.fx
@ -60,11 +58,11 @@ if TYPE_CHECKING:
class BaseListVariable(VariableTracker):
@staticmethod
def cls_for_instance(obj):
def cls_for_instance(obj: Any) -> type["BaseListVariable"]:
return BaseListVariable.cls_for(type(obj))
@staticmethod
def cls_for(obj):
def cls_for(obj: Any) -> type:
return {
iter: ListIteratorVariable,
list: ListVariable,
@ -80,34 +78,38 @@ class BaseListVariable(VariableTracker):
def __init__(
self,
items: list[VariableTracker],
**kwargs,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
assert isinstance(items, list)
assert all(isinstance(x, VariableTracker) for x in items)
self.items: list[VariableTracker] = items
def _as_proxy(self):
def _as_proxy(self) -> list[Any]:
return [x.as_proxy() for x in self.items]
def modified(self, items, **kwargs):
def modified(
self, items: list[VariableTracker], **kwargs: Any
) -> "BaseListVariable":
return type(self)(items, **kwargs)
@property
def value(self):
def value(self) -> Any:
return self.as_python_constant()
def debug_repr_helper(self, prefix, suffix):
def debug_repr_helper(self, prefix: str, suffix: str) -> str:
return prefix + ", ".join(i.debug_repr() for i in self.items) + suffix
def as_python_constant(self):
def as_python_constant(self) -> Any:
return self.python_type()([x.as_python_constant() for x in self.items])
def as_proxy(self):
def as_proxy(self) -> Any:
assert self.python_type() is not SizeVariable
return self.python_type()(self._as_proxy())
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
def getitem_const(
self, tx: "InstructionTranslator", arg: VariableTracker
) -> VariableTracker:
from .tensor import SymNodeVariable
if isinstance(arg, SymNodeVariable):
@ -134,16 +136,16 @@ class BaseListVariable(VariableTracker):
IndexError, tx, args=["list index out of range"]
)
def unpack_var_sequence(self, tx):
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
return list(self.items)
def call_method(
self,
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
if name == "__getitem__":
from .tensor import TensorVariable
@ -224,15 +226,15 @@ class BaseListVariable(VariableTracker):
if type(self) is not type(args[0]):
tp_name = self.python_type_name()
other = args[0].python_type_name()
msg = ConstantVariable.create(
msg_vt = ConstantVariable.create(
f'can only concatenate {tp_name} (not "{other}") to {tp_name}'
)
raise_observed_exception(TypeError, tx, args=[msg])
raise_observed_exception(TypeError, tx, args=[msg_vt])
if name == "__add__":
return type(self)(self.items + args[0].items, source=self.source)
return type(self)(self.items + args[0].items, source=self.source) # type: ignore[attr-defined]
else:
self.items += args[0].items
self.items += args[0].items # type: ignore[attr-defined]
return self
elif name in ("__mul__", "__imul__"):
if kwargs or len(args) != 1:
@ -244,10 +246,10 @@ class BaseListVariable(VariableTracker):
)
if not (args[0].is_python_constant() and args[0].python_type() is int):
msg = ConstantVariable.create(
msg_vt = ConstantVariable.create(
f"can't multiply sequence by non-int type of '{args[0].python_type_name()}'"
)
raise_observed_exception(TypeError, tx, args=[msg])
raise_observed_exception(TypeError, tx, args=[msg_vt])
val = args[0].as_python_constant()
@ -301,7 +303,7 @@ class BaseListVariable(VariableTracker):
class RangeVariable(BaseListVariable):
def __init__(self, items, **kwargs) -> None:
def __init__(self, items: Sequence[VariableTracker], **kwargs: Any) -> None:
items_to_map = items
start = variables.ConstantVariable.create(0)
stop = None
@ -316,7 +318,7 @@ class RangeVariable(BaseListVariable):
else:
raise AssertionError
def maybe_as_int(x):
def maybe_as_int(x: VariableTracker) -> VariableTracker:
return (
ConstantVariable(int(x.value)) if isinstance(x, ConstantVariable) else x
)
@ -329,22 +331,22 @@ class RangeVariable(BaseListVariable):
assert stop is not None
super().__init__([start, stop, step], **kwargs)
def debug_repr(self):
def debug_repr(self) -> str:
return self.debug_repr_helper("range(", ")")
def python_type(self):
def python_type(self) -> type:
return range
def start(self):
def start(self) -> Any:
return self.items[0].as_python_constant()
def stop(self):
def stop(self) -> Any:
return self.items[1].as_python_constant()
def step(self):
def step(self) -> Any:
return self.items[2].as_python_constant()
def range_length(self):
def range_length(self) -> int:
lo = self.start()
hi = self.stop()
step = self.step()
@ -357,7 +359,7 @@ class RangeVariable(BaseListVariable):
else:
return 0
def _get_slice_indices(self, length, slice):
def _get_slice_indices(self, length: int, slice: slice) -> list[int]:
step_is_negative = 0
if slice.step is None:
@ -406,7 +408,7 @@ class RangeVariable(BaseListVariable):
return [start, stop, step]
def apply_index(self, index):
def apply_index(self, index: int) -> VariableTracker:
length = self.range_length()
if index < 0:
index = length + index
@ -421,12 +423,12 @@ class RangeVariable(BaseListVariable):
return variables.ConstantVariable.create(self.start() + (index * self.step()))
def apply_slice(self, slice):
def apply_slice(self, slice: slice) -> "RangeVariable":
(slice_start, slice_stop, slice_step) = self._get_slice_indices(
self.range_length(), slice
)
def compute_item(index):
def compute_item(index: int) -> int:
return self.start() + (index * self.step())
sub_step = self.step() * slice_step
@ -442,10 +444,12 @@ class RangeVariable(BaseListVariable):
)
return result
def as_python_constant(self):
def as_python_constant(self) -> range:
return range(*[x.as_python_constant() for x in self.items])
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
def getitem_const(
self, tx: "InstructionTranslator", arg: VariableTracker
) -> VariableTracker:
# implementations mimics https://github.com/python/cpython/blob/main/Objects/rangeobject.c
index = arg.as_python_constant()
@ -457,28 +461,30 @@ class RangeVariable(BaseListVariable):
msg = ConstantVariable("range indices must be integers or slices")
raise_observed_exception(TypeError, tx, args=[msg])
def as_proxy(self):
def as_proxy(self) -> range:
return self.python_type()(*self._as_proxy())
def unpack_var_sequence(self, tx=None):
def unpack_var_sequence(
self, tx: Optional["InstructionTranslator"] = None
) -> list[VariableTracker]:
return [variables.ConstantVariable.create(x) for x in self.as_python_constant()]
def reconstruct(self, codegen: "PyCodegen") -> None:
assert "range" not in codegen.tx.f_globals
codegen.add_push_null(
lambda: codegen.append_output(codegen.create_load_python_module(range))
lambda: codegen.append_output(codegen.create_load_python_module(range)) # type: ignore[arg-type]
)
codegen.foreach(self.items)
codegen.extend_output(create_call_function(3, False))
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> "VariableTracker":
) -> VariableTracker:
if self.python_type() is range:
return variables.ConstantVariable.create(name in range.__dict__)
return super().call_obj_hasattr(tx, name)
def range_equals(self, other: "RangeVariable"):
def range_equals(self, other: "RangeVariable") -> bool:
r0, r1 = self, other
if (
self.range_length() != r1.range_length()
@ -487,12 +493,12 @@ class RangeVariable(BaseListVariable):
):
return False
if len(r0) == 1:
if self.range_length() == 1:
return True
return r0.step() == r1.step()
def range_count(self, x: VariableTracker):
def range_count(self, x: VariableTracker) -> int:
# Based on CPython
# https://github.com/guilhermeleobas/cpython/blob/baefaa6cba1d69efd2f930cdc56bca682c54b139/Objects/rangeobject.c#L442-L486
x = x.as_python_constant()
@ -511,7 +517,13 @@ class RangeVariable(BaseListVariable):
return int(re)
return 0
def call_method(self, tx, name, args, kwargs):
def call_method(
self,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
if name == "__iter__":
if not all(var.is_python_constant() for var in self.items):
# Can't represent a `range_iterator` without well defined bounds
@ -545,7 +557,10 @@ class RangeVariable(BaseListVariable):
if pt is not range:
return ConstantVariable.create(NotImplemented)
cmp = self.range_equals(other)
if isinstance(other, RangeVariable):
cmp = self.range_equals(other)
else:
cmp = False
# Two ranges are equal if they produce the same sequence of values
if name == "__eq__":
@ -554,7 +569,7 @@ class RangeVariable(BaseListVariable):
return ConstantVariable(not cmp)
return super().call_method(tx, name, args, kwargs)
def var_getattr(self, tx: "InstructionTranslator", name):
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
fields = ["start", "stop", "step"]
if name in fields:
return self.items[fields.index(name)]
@ -568,11 +583,11 @@ class CommonListMethodsVariable(BaseListVariable):
def call_method(
self,
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
from .tensor import SymNodeVariable
if name == "append" and self.is_mutable():
@ -676,9 +691,9 @@ class CommonListMethodsVariable(BaseListVariable):
self.items[key.evaluate_expr()] = value
elif isinstance(key, SliceVariable):
if key.is_python_constant():
self.items[key.as_python_constant()] = list(value.items)
self.items[key.as_python_constant()] = list(value.items) # type: ignore[attr-defined]
else:
items = slice(
items_slice = slice(
*[
(
s.evaluate_expr()
@ -688,7 +703,7 @@ class CommonListMethodsVariable(BaseListVariable):
for s in key.items
]
)
self.items[items] = list(value.items)
self.items[items_slice] = list(value.items) # type: ignore[attr-defined]
else:
self.items[key.as_python_constant()] = value
return ConstantVariable.create(None)
@ -733,8 +748,8 @@ class CommonListMethodsVariable(BaseListVariable):
"0 args and 0 kwargs",
f"{len(args)} args and {len(kwargs)} kwargs",
)
items = list(self.items)
return self.modified(items, mutation_type=ValueMutationNew())
items_lst: list[VariableTracker] = list(self.items)
return self.modified(items_lst, mutation_type=ValueMutationNew())
elif name == "reverse" and self.is_mutable():
if args or kwargs:
raise_args_mismatch(
@ -763,13 +778,13 @@ class CommonListMethodsVariable(BaseListVariable):
class ListVariable(CommonListMethodsVariable):
def python_type(self):
def python_type(self) -> type:
return list
def __repr__(self) -> str:
return f"{self.__class__.__name__}(length={len(self.items)})"
def debug_repr(self):
def debug_repr(self) -> str:
return self.debug_repr_helper("[", "]")
def reconstruct(self, codegen: "PyCodegen") -> None:
@ -778,11 +793,11 @@ class ListVariable(CommonListMethodsVariable):
def call_method(
self,
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
from .tensor import SymNodeVariable
if name == "__setitem__" and self.is_mutable():
@ -805,14 +820,14 @@ class ListVariable(CommonListMethodsVariable):
msg = ConstantVariable.create("can only assign an iterable")
raise_observed_exception(TypeError, tx, args=[msg])
key = key.as_python_constant()
if key.step == 0:
key_as_const = key.as_python_constant()
if key_as_const.step == 0:
msg = ConstantVariable.create("slice step cannot be zero")
raise_observed_exception(ValueError, tx, args=[msg])
value = value.force_unpack_var_sequence(tx)
value_unpack = value.force_unpack_var_sequence(tx)
try:
self.items[key] = value
self.items[key_as_const] = value_unpack
except Exception as exc:
raise_observed_exception(
type(exc),
@ -859,7 +874,7 @@ class ListVariable(CommonListMethodsVariable):
assert first_non_constant_key is not None
try:
python_type = first_non_constant_key.python_type()
python_type = str(first_non_constant_key.python_type())
except NotImplementedError:
python_type = "unknown"
@ -904,7 +919,7 @@ class ListVariable(CommonListMethodsVariable):
return super().call_method(tx, name, args, kwargs)
def var_getattr(self, tx, name):
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
if name == "__class__":
source = AttrSource(self.source, name) if self.source else None
class_type = self.python_type()
@ -916,14 +931,19 @@ class ListVariable(CommonListMethodsVariable):
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> "VariableTracker":
) -> VariableTracker:
if self.python_type() is not list:
return super().call_obj_hasattr(tx, name)
return variables.ConstantVariable.create(hasattr([], name))
class DequeVariable(CommonListMethodsVariable):
def __init__(self, items, maxlen=None, **kwargs) -> None:
def __init__(
self,
items: list[VariableTracker],
maxlen: Optional[VariableTracker] = None,
**kwargs: Any,
) -> None:
if maxlen is None:
maxlen = ConstantVariable.create(None)
assert maxlen.is_python_constant(), (
@ -935,17 +955,17 @@ class DequeVariable(CommonListMethodsVariable):
items = items[-maxlen.as_python_constant() :]
super().__init__(items, **kwargs)
def python_type(self):
def python_type(self) -> type:
return collections.deque
def debug_repr(self):
def debug_repr(self) -> str:
if self.maxlen.as_python_constant() is None:
return self.debug_repr_helper(
"deque([", "], maxlen=" + self.maxlen.debug_repr() + ")"
)
return self.debug_repr_helper("deque([", "])")
def as_python_constant(self):
def as_python_constant(self) -> collections.deque[Any]:
return self.python_type()(
[x.as_python_constant() for x in self.items],
maxlen=self.maxlen.as_python_constant(),
@ -954,7 +974,7 @@ class DequeVariable(CommonListMethodsVariable):
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.add_push_null(
lambda: codegen.append_output(
codegen.create_load_python_module(collections.deque)
codegen.create_load_python_module(collections.deque) # type: ignore[arg-type]
)
)
codegen.foreach(self.items)
@ -962,18 +982,18 @@ class DequeVariable(CommonListMethodsVariable):
codegen(self.maxlen)
codegen.extend_output(codegen.create_call_function_kw(2, ("maxlen",), False))
def var_getattr(self, tx: "InstructionTranslator", name):
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
if name == "maxlen":
return self.maxlen
return super().var_getattr(tx, name)
def call_method(
self,
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
if (
name == "__setitem__"
and self.is_mutable()
@ -1068,20 +1088,20 @@ class DequeVariable(CommonListMethodsVariable):
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> "VariableTracker":
) -> VariableTracker:
if self.python_type() is collections.deque:
return variables.ConstantVariable.create(name in collections.deque.__dict__)
return super().call_obj_hasattr(tx, name)
class TupleVariable(BaseListVariable):
def python_type(self):
def python_type(self) -> type[tuple]: # type: ignore[type-arg]
return tuple
def __repr__(self) -> str:
return f"{self.__class__.__name__}(length={len(self.items)})"
def debug_repr(self):
def debug_repr(self) -> str:
return self.debug_repr_helper("(", ")")
def reconstruct(self, codegen: "PyCodegen") -> None:
@ -1090,14 +1110,14 @@ class TupleVariable(BaseListVariable):
def call_method(
self,
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
return super().call_method(tx, name, args, kwargs)
def var_getattr(self, tx, name):
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
if name == "__class__":
source = AttrSource(self.source, name) if self.source else None
class_type = self.python_type()
@ -1109,7 +1129,7 @@ class TupleVariable(BaseListVariable):
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> "VariableTracker":
) -> VariableTracker:
if self.python_type() is not tuple:
return super().call_obj_hasattr(tx, name)
return variables.ConstantVariable.create(hasattr((), name))
@ -1127,18 +1147,18 @@ class SizeVariable(TupleVariable):
self,
items: list[VariableTracker],
proxy: Optional[torch.fx.Proxy] = None,
**kwargs,
**kwargs: Any,
) -> None:
self.proxy = proxy
super().__init__(items, **kwargs)
def debug_repr(self):
def debug_repr(self) -> str:
return self.debug_repr_helper("torch.Size([", "])")
def python_type(self):
def python_type(self) -> type:
return torch.Size
def as_proxy(self):
def as_proxy(self) -> Any:
if self.proxy is not None:
return self.proxy
@ -1193,10 +1213,10 @@ class SizeVariable(TupleVariable):
] + create_call_function(1, False)
codegen.extend_output(build_torch_size)
def unpack_var_sequence(self, tx):
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
return list(self.items)
def numel(self, tx):
def numel(self, tx: "InstructionTranslator") -> VariableTracker:
from .builtin import BuiltinVariable
from .tensor import SymNodeVariable
@ -1226,11 +1246,11 @@ class SizeVariable(TupleVariable):
def call_method(
self,
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
if name == "__getitem__":
if kwargs or len(args) != 1:
raise_args_mismatch(
@ -1253,7 +1273,9 @@ class SizeVariable(TupleVariable):
return super().call_method(tx, name, args, kwargs)
def get_item_dyn(self, tx: "InstructionTranslator", arg: VariableTracker):
def get_item_dyn(
self, tx: "InstructionTranslator", arg: VariableTracker
) -> VariableTracker:
from .tensor import SymNodeVariable
if isinstance(arg, SymNodeVariable):
@ -1269,7 +1291,7 @@ class SizeVariable(TupleVariable):
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> "VariableTracker":
) -> VariableTracker:
return variables.ConstantVariable.create(hasattr(torch.Size, name))
@ -1280,33 +1302,39 @@ class NamedTupleVariable(TupleVariable):
*TupleVariable._nonvar_fields,
}
def __init__(self, items, tuple_cls, dynamic_attributes=None, **kwargs) -> None:
def __init__(
self,
items: list[VariableTracker],
tuple_cls: type,
dynamic_attributes: Optional[dict[str, VariableTracker]] = None,
**kwargs: Any,
) -> None:
super().__init__(items, **kwargs)
self.tuple_cls = tuple_cls
self.dynamic_attributes = dynamic_attributes if dynamic_attributes else {}
def is_namedtuple(self):
def is_namedtuple(self) -> bool:
return isinstance(getattr(self.tuple_cls, "_fields", None), tuple) and callable(
getattr(self.tuple_cls, "_make", None)
)
def is_structseq(self):
def is_structseq(self) -> bool:
return not self.is_namedtuple()
def fields(self):
def fields(self) -> tuple[str, ...]:
return namedtuple_fields(self.tuple_cls)
def debug_repr(self):
def debug_repr(self) -> str:
if self.is_structseq():
# StructSequenceType(iterable)
return repr(self.tuple_cls([Lit(x.debug_repr()) for x in self.items]))
# NamedTupleType(*iterable)
return repr(self.tuple_cls(*(Lit(x.debug_repr()) for x in self.items)))
def python_type(self):
def python_type(self) -> type:
return self.tuple_cls
def as_python_constant(self):
def as_python_constant(self) -> Any:
if self.is_structseq():
# StructSequenceType(iterable)
result = self.python_type()([x.as_python_constant() for x in self.items])
@ -1328,7 +1356,7 @@ class NamedTupleVariable(TupleVariable):
return result
def as_proxy(self):
def as_proxy(self) -> Any:
assert self.python_type() is not SizeVariable
if self.is_structseq():
# StructSequenceType(iterable)
@ -1342,7 +1370,10 @@ class NamedTupleVariable(TupleVariable):
# StructSequenceType(iterable)
# NamedTupleType(*iterable)
# NamedTupleType._make(iterable)
create_fn = self.tuple_cls if self.is_structseq() else self.tuple_cls._make
if self.is_structseq():
create_fn = self.tuple_cls
else:
create_fn = self.tuple_cls._make # type: ignore[attr-defined]
codegen.add_push_null(
lambda: codegen.append_output(
codegen.create_load_const_unchecked(create_fn)
@ -1384,8 +1415,8 @@ class NamedTupleVariable(TupleVariable):
def call_method(
self,
tx,
name,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
@ -1446,7 +1477,9 @@ class NamedTupleVariable(TupleVariable):
return super().call_method(tx, name, args, kwargs)
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
def getitem_const(
self, tx: "InstructionTranslator", arg: VariableTracker
) -> VariableTracker:
if isinstance(arg, SliceVariable):
# slicing a namedtuple produces a tuple
return TupleVariable(
@ -1455,8 +1488,8 @@ class NamedTupleVariable(TupleVariable):
)
return super().getitem_const(tx, arg)
def var_getattr(self, tx: "InstructionTranslator", name):
def check_and_create_method():
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
def check_and_create_method() -> Optional[VariableTracker]:
method = inspect.getattr_static(self.tuple_cls, name, None)
if isinstance(method, classmethod):
# We need the unbounded cls method to avoid the inline __self__
@ -1489,8 +1522,8 @@ class NamedTupleVariable(TupleVariable):
return super().var_getattr(tx, name)
if name == "_fields":
source = NamedTupleFieldsSource(self.source) if self.source else None
return VariableTracker.build(tx, self.fields(), source=source)
result_source = NamedTupleFieldsSource(self.source) if self.source else None
return VariableTracker.build(tx, self.fields(), source=result_source)
if name in self.dynamic_attributes:
return self.dynamic_attributes[name]
@ -1505,14 +1538,19 @@ class NamedTupleVariable(TupleVariable):
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> "VariableTracker":
) -> VariableTracker:
return variables.ConstantVariable.create(
name in self.dynamic_attributes or hasattr(self.tuple_cls, name)
)
class SliceVariable(VariableTracker):
def __init__(self, items, tx=None, **kwargs) -> None:
def __init__(
self,
items: Sequence[VariableTracker],
tx: Optional["InstructionTranslator"] = None,
**kwargs: Any,
) -> None:
items_to_map = items
start, stop, step = [variables.ConstantVariable.create(None)] * 3
@ -1547,23 +1585,23 @@ class SliceVariable(VariableTracker):
super().__init__(**kwargs)
def debug_repr(self):
return self.debug_repr_helper("slice(", ")")
def debug_repr(self) -> str:
return "slice(" + ", ".join(i.debug_repr() for i in self.items) + ")"
def as_proxy(self):
def as_proxy(self) -> slice:
return slice(*[x.as_proxy() for x in self.items])
def python_type(self):
def python_type(self) -> type:
return slice
def as_python_constant(self):
def as_python_constant(self) -> slice:
return slice(*[guard_if_dyn(x) for x in self.items])
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.foreach(self.items)
codegen.append_output(create_instruction("BUILD_SLICE", arg=len(self.items)))
def var_getattr(self, tx: "InstructionTranslator", name):
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
if name in cmp_name_to_op_mapping:
return variables.GetAttrVariable(self, name)
fields = ["start", "stop", "step"]
@ -1584,7 +1622,9 @@ class ListIteratorVariable(IteratorVariable):
*IteratorVariable._nonvar_fields,
}
def __init__(self, items, index: int = 0, **kwargs) -> None:
def __init__(
self, items: list[VariableTracker], index: int = 0, **kwargs: Any
) -> None:
super().__init__(**kwargs)
assert isinstance(items, list)
# Removing this check as it slows things down too much
@ -1598,7 +1638,7 @@ class ListIteratorVariable(IteratorVariable):
def __repr__(self) -> str:
return f"{self.__class__.__name__}(length={len(self.items)}, index={repr(self.index)})"
def next_variable(self, tx):
def next_variable(self, tx: "InstructionTranslator") -> VariableTracker:
assert self.is_mutable()
old_index = self.index
if old_index >= len(self.items) or self.is_exhausted:
@ -1609,27 +1649,31 @@ class ListIteratorVariable(IteratorVariable):
self.index += 1
return self.items[old_index]
def call_obj_hasattr(self, tx, name):
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> VariableTracker:
return variables.ConstantVariable.create(hasattr(iter([]), name))
def python_type(self):
def python_type(self) -> type:
return type(iter([]))
def as_python_constant(self):
def as_python_constant(self) -> Any:
if self.index > 0:
raise NotImplementedError
return iter([x.as_python_constant() for x in self.items])
def has_unpack_var_sequence(self, tx):
def has_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool:
return True
def unpack_var_sequence(self, tx):
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
if self.is_exhausted:
return []
self.is_exhausted = True
return list(self.items[self.index :])
def force_unpack_var_sequence(self, tx) -> list[VariableTracker]:
def force_unpack_var_sequence(
self, tx: "InstructionTranslator"
) -> list[VariableTracker]:
return self.unpack_var_sequence(tx)
def reconstruct(self, codegen: "PyCodegen") -> None:
@ -1656,27 +1700,37 @@ class RangeIteratorVariable(IteratorVariable):
"iter_obj",
}
def __init__(self, start: int, stop: int, step: int, len_: int, **kwargs):
def __init__(
self, start: int, stop: int, step: int, len_: int, **kwargs: Any
) -> None:
super().__init__(**kwargs)
self.start = start
self.stop = stop
self.step = step
self.len = len_
def call_method(self, tx, name, args, kwargs):
def call_method(
self,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
if name == "__next__":
return self.next_variable(tx)
elif name == "__iter__":
return self
return super().call_method(tx, name, args, kwargs)
def call_obj_hasattr(self, tx, name):
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> VariableTracker:
if self.python_type() is range_iterator:
ri = iter(range(0))
return ConstantVariable(hasattr(ri, name))
return super().call_obj_hasattr(tx, name)
def next_variable(self, tx):
def next_variable(self, tx: "InstructionTranslator") -> VariableTracker:
if self.len <= 0:
raise_observed_exception(StopIteration, tx)
@ -1685,12 +1739,12 @@ class RangeIteratorVariable(IteratorVariable):
self.start += self.step
return ConstantVariable.create(current)
def python_type(self):
def python_type(self) -> type:
return range_iterator
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.add_push_null(
lambda: codegen.append_output(codegen.create_load_python_module(range))
lambda: codegen.append_output(codegen.create_load_python_module(range)) # type: ignore[arg-type]
)
codegen.append_output(codegen.create_load_const(self.start))
codegen.append_output(codegen.create_load_const(self.stop))

View File

@ -204,11 +204,11 @@ class StreamVariable(StreamContextVariable):
self,
proxy: Proxy,
value: torch.Stream,
user_object_index: Optional[int] = None,
**kwargs: Any,
) -> None:
# Index into the user object table
# used to pass arbitrary objects to the graph
user_object_index = kwargs.pop("user_obj_index", None)
if proxy is not None and "example_value" in proxy.node.meta:
assert proxy.node.meta["example_value"] == value
@ -300,7 +300,7 @@ class StreamVariable(StreamContextVariable):
codegen.append_output(codegen.create_load_const(self.user_object_index))
codegen.extend_output(create_call_function(1, False))
else:
# TODO mlazos: evaluate if we still need this
# This will support the legacy behavior
prefix = f"_stream_{self.device}"
name = codegen.tx.output.install_global_by_id(prefix, self.value)
codegen.append_output(codegen.create_load_global(name, add=True))

View File

@ -78,7 +78,7 @@ from .ctx_manager import (
)
from .dicts import ConstDictVariable
from .distributed import DistributedVariable, ProcessGroupVariable
from .functions import bind_args_cached
from .functions import bind_args_cached, NestedUserFunctionVariable
from .lists import ListVariable, TupleVariable
from .torch_function import (
can_dispatch_torch_function,
@ -1318,6 +1318,86 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
return ConstantVariable.create(None)
@register(torch._check)
def handle_check(self, tx: "InstructionTranslator", *args, **kwargs):
predicate_vt = None
message_vt = None
if args:
predicate_vt = args[0]
rest_args = args[1:]
else:
rest_args = ()
if predicate_vt is None and "cond" in kwargs:
predicate_vt = kwargs.pop("cond")
if rest_args:
message_vt = rest_args[0]
elif "message" in kwargs:
message_vt = kwargs.pop("message")
if predicate_vt is None:
return wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
self.value,
(),
{},
),
)
message_eager = None
message_graph_proxy = None
if message_vt is not None:
if (
not isinstance(message_vt, NestedUserFunctionVariable)
or message_vt.has_closure()
):
unimplemented_v2(
gb_type="Can't extract message from torch._check()",
context=str(message_vt),
explanation=(
"The second argument of torch._check() must be a function"
"defined within the torch.compile region"
"that does not reference a non-local variable."
),
hints=[
"Make sure the message function is defined in the torch.compile region.",
"Remove any closure variables, e.g. "
"remove references to closure variable `x` in `lambda: f'{x} failed check'`",
*graph_break_hints.SUPPORTABLE,
],
)
message_eager = message_vt.get_function()
message_graph_proxy = tx.output.register_static_attr_and_return_proxy(
"_check_message", message_eager
)
if predicate_vt.is_python_constant():
self.value(predicate_vt.as_python_constant(), message_eager)
return ConstantVariable.create(None)
predicate_proxy = predicate_vt.as_proxy()
proxy_args: tuple[Any, ...]
if message_graph_proxy is None:
proxy_args = (predicate_proxy,)
else:
proxy_args = (predicate_proxy, message_graph_proxy)
return wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
self.value,
proxy_args,
{},
),
)
return handlers
def call_function(

View File

@ -838,7 +838,6 @@ class UserDefinedClassVariable(UserDefinedVariable):
proxy=tx.output.create_proxy(
"call_function", get_external_object_by_index, (ind,), {}
),
user_obj_index=ind,
)
else:
tensor_variable = wrap_fx_proxy(

View File

@ -104,7 +104,7 @@ from .._dynamo.exc import ShortenTraceback, SkipFrame
from ..fx._lazy_graph_module import _use_lazy_graph_module
from ..fx.graph import _PyTreeCodeGen
from ..utils._triton import has_triton
from . import config, metrics
from . import config, distributed_autotune, metrics
from .codegen.common import get_wrapper_codegen_for_device, init_backend_registration
from .debug import DebugContext
from .decomposition import select_decomp_table
@ -1431,7 +1431,11 @@ class _InProcessFxCompile(FxCompile):
# We are going to start code generating runtime asserts, so make sure
# you don't start adding new ones in the lowering process
graph.freeze_runtime_asserts()
with V.set_graph_handler(graph), V.set_extern_kernel_nodes([]):
with (
V.set_graph_handler(graph),
V.set_extern_kernel_nodes([]),
distributed_autotune.graph_context(),
):
graph.run(*example_inputs)
output_strides: list[Optional[tuple[_StrideExprStr, ...]]] = []
if graph.graph_outputs is not None:

View File

@ -447,6 +447,14 @@ use_experimental_benchmarker: bool = Config(
justknob="pytorch/inductor:use_experimental_benchmarker",
)
# Enable distributed autotuning. When this is enabled we will distribute the
# autotuning across distributed ranks in the same program group - so instead of
# each rank autotuning every kernel they only autotune 1/world size kernels and
# then share the results.
distributed_max_autotune_gemm = (
os.environ.get("TORCHINDUCTOR_DISTRIBUTED_MAX_AUTOTUNE_GEMM") == "1"
)
# enable slow autotuning passes to select algorithms
max_autotune = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE") == "1"

View File

@ -0,0 +1,386 @@
from __future__ import annotations
import contextlib
import dataclasses
from typing import Any, TYPE_CHECKING, Union
from unittest.mock import patch
import sympy
import torch._logging
import torch.distributed as dist
import torch.fx
from torch.utils._ordered_set import OrderedSet
from . import config, select_algorithm
from .ir import (
Buffer,
ChoiceCaller,
Layout,
MultiTemplateBuffer,
OperationBuffer,
ShapeAsConstantBuffer,
StorageBox,
TensorBox,
)
from .kernel_inputs import KernelInputs, MMKernelInputs
from .scheduler import SchedulerNode
from .virtualized import NullHandler, V
if TYPE_CHECKING:
from collections.abc import Generator, Sequence
_DISTRIBUTED_AUTOTUNE_KEY = "distributed_autotune"
_AUTOTUNE_PG: dist.ProcessGroup | None = None
@dataclasses.dataclass
class _DistributedAutotuneState:
"""
State used to track autotuning during a graph_context()
"""
# This is the next operator index. Used to figure out which rank should do
# the autotuning.
autotuned_index: int = 0
# For debugging - used to make sure that we autotune the same number of
# local operators that we expected to.
autotuned_local_count: int = 0
@dataclasses.dataclass
class _DistributedAutotuneInfo:
index: int
local: bool
def get_autotune_pg() -> dist.ProcessGroup | None:
if dist.is_available() and dist.is_initialized():
global _AUTOTUNE_PG
if _AUTOTUNE_PG is None:
_AUTOTUNE_PG = dist.distributed_c10d._new_group_with_tag(
pg_tag="pt2_distributed_autotune_pg"
)
return _AUTOTUNE_PG
return None
def schedule(scheduler: torch._inductor.scheduler.Scheduler) -> None:
"""
Finish the distributed autotuning by propagating the autotuning results
between the ranks and then replacing the placeholder with the real Buffer.
"""
assert config.distributed_max_autotune_gemm
autotune_results = _autotune_local_nodes(scheduler)
choices_by_index = _sync(autotune_results)
_autotune_remote_nodes(scheduler, choices_by_index)
@contextlib.contextmanager
def graph_context() -> Generator[None, None, None]:
"""
Wrapped around processing a graph, sets up figuring out which ranks tune
which shapes.
"""
assert not isinstance(
V.get_distributed_autotune_state(check_poisoned=False), # type: ignore[call-arg]
_DistributedAutotuneState,
)
V.set_distributed_autotune_state(_DistributedAutotuneState())
try:
yield
finally:
V.set_distributed_autotune_state(NullHandler())
def maybe_autotune_remote(
name: str, choices: list[ChoiceCaller], inputs: list[Buffer], layout: Layout
) -> TensorBox | ShapeAsConstantBuffer | None:
"""
Used by an op (like `mm`) to determine if the op should be autotuned
locally (returns None) or remotely (returns a placeholder Buffer).
"""
if not config.distributed_max_autotune_gemm:
return None
if not (autotune_pg := get_autotune_pg()):
return None
if len(choices) <= 1:
return None
state = V.distributed_autotune_state
index = state.autotuned_index
state.autotuned_index += 1
local = index % autotune_pg.size() == autotune_pg.rank()
V.current_node.meta[_DISTRIBUTED_AUTOTUNE_KEY] = _DistributedAutotuneInfo(
index, local
)
if local:
state.autotuned_local_count += 1
return None
return torch._inductor.ir.TensorBox.create(
_DistributedAutotuneBuffer(name, inputs, layout)
)
class _DistributedAutotuneBuffer(MultiTemplateBuffer):
"""
A MultiTemplateBuffer which represents a kernel being autotuned on a
different rank. When `schedule` is called this will be replaced by the
"real" buffer.
"""
# Name of the kernel being autotuned.
_kernel_name: str
def __init__(
self,
kernel_name: str,
inputs: list[Buffer],
layout: Layout,
) -> None:
super().__init__(
layout,
inputs,
choice_timings_fn=self._dummy_choice_timings,
unfiltered_choices=[],
allowed_prologue_inps=OrderedSet({}),
)
self._kernel_name = kernel_name
def _dummy_choice_timings(
self, _hint_override: int | None
) -> dict[ChoiceCaller, float]:
# This should never get called. It means that a remote autotune was
# scheduled but never filled in.
raise NotImplementedError
def autotune(self, ser_choice: _SerializedChoice) -> TensorBox:
"""
Given a _SerializedChoice (autotune results from another rank)
compute the final TensorBox.
"""
from .select_algorithm import autotune_select_algorithm
with patch.object(V.graph, "scheduler", None):
kernel_inputs = MMKernelInputs([*self.original_inputs])
assert isinstance(self.layout, Layout)
choice = ser_choice.get_choice(self.layout, kernel_inputs)
buffer = autotune_select_algorithm(
self._kernel_name,
[choice],
kernel_inputs.nodes(),
self.layout,
)
assert isinstance(buffer, TensorBox)
return buffer
# Can we make this async?
def _sync(autotune_results: list[_SerializedChoice]) -> Sequence[_SerializedChoice]:
"""
Perform the all_gather to collect the autotune results from all the ranks.
"""
autotune_pg = get_autotune_pg()
assert autotune_pg
# Perform allgather
all_states: list[list[_SerializedChoice]] = [None] * autotune_pg.size() # type: ignore[list-item]
torch.distributed.all_gather_object(all_states, autotune_results, group=autotune_pg)
node_count = sum(len(x) for x in all_states)
# It's faster to briefly lie about the type than to unzip the results and append.
choices_by_index: list[_SerializedChoice] = [None] * node_count # type: ignore[list-item]
check_count = 0
for i, other_results in enumerate(all_states):
for choice in other_results:
assert isinstance(choice, _SerializedChoice)
assert choices_by_index[choice.index] is None
choices_by_index[choice.index] = choice
check_count += 1
assert node_count == check_count, f"count mismatch: {node_count} != {check_count}"
return choices_by_index
class _SerializedChoice:
"""
This is a serializer for the autotune choice. KernelTemplateChoice can't
be serialized directly (the template and inputs prevent this) so we need to
serialize it by parts and reconstruct later on.
"""
def __init__(self, index: int, choice: ChoiceCaller) -> None:
self.index = index
self.template_uid = _SerializedChoice._template_uid_from_choice(choice)
self.kwargs = self._compute_kwargs(choice.description)
def get_choice(self, layout: Layout, inputs: KernelInputs) -> ChoiceCaller | None:
"""
Deserialize the ChoiceCaller and return it.
"""
template = self._template_from_uid()
kwargs = {**self.kwargs}
if "BLOCK_K" in kwargs:
# TODO: Do we really need to externally compute this value? If it's
# needed I'm surprised it's not just part of the original template
# description.
# This needs the actual 'k' to figure out the value.
k = inputs.nodes()[0].get_size()[1]
kwargs["EVEN_K"] = sympy.gcd(k, kwargs["BLOCK_K"]) == kwargs["BLOCK_K"]
extra_kwargs: dict[str, Any] = {}
from .kernel_template_choice import (
DictKernelTemplateParams,
KernelTemplateChoice,
)
params = DictKernelTemplateParams(kwargs)
ktc = KernelTemplateChoice(template, params, extra_kwargs, layout, inputs)
return ktc.choice
@staticmethod
def _compute_kwargs(description: str) -> dict[str, Union[int, str, bool]]:
"""
Given a template description turn it into input kwargs.
"""
if not description:
return {}
# TODO: It seems like it would be better if the template could provide
# this directly instead of having to parse a string.
kwargs: dict[str, Union[int, str, bool]] = {}
for cfg in description.split(","):
key, val = cfg.split("=", 1)
key, val = key.strip(), val.strip()
if val == "True":
kwargs[key] = True
elif val == "False":
kwargs[key] = False
elif val.isdigit():
kwargs[key] = int(val)
else:
assert val.startswith("'") and val.endswith("'")
kwargs[key] = val[1:-1]
return kwargs
@staticmethod
def _template_uid_from_choice(choice: ChoiceCaller) -> str:
"""
Given a ChoiceCaller figure out which template represents it. This
is reversed by _template_from_uid().
"""
# We need a better way to do this - right now we need to add each
# supported template directly.
if isinstance(choice, select_algorithm.ExternKernelCaller):
if choice.choice.name == "mm":
return "torch._inductor.kernel.mm.aten_mm"
else:
raise RuntimeError(f"TODO: kernel {choice.choice.name!r}")
elif isinstance(choice, select_algorithm.TritonTemplateCaller):
return "torch._inductor.kernel.mm.mm_template"
else:
raise RuntimeError(f"TODO: {type(choice)}")
def _template_from_uid(self) -> Any:
"""
See _template_uid_from_choice().
"""
parts = self.template_uid.split(".")
obj = globals()[parts[0]]
for k in parts[1:]:
obj = getattr(obj, k)
return obj
def _autotune_local_nodes(
scheduler: torch._inductor.scheduler.Scheduler,
) -> list[_SerializedChoice]:
"""
Go through the nodes in the scheduler and autotune the kernels which
should be autotuned by this rank.
"""
autotune_results: list[_SerializedChoice] = []
for node in scheduler.nodes:
if not isinstance(node, SchedulerNode):
continue
if (inner_node := node.node) is None:
continue
if isinstance(inner_node, _DistributedAutotuneBuffer):
# This is marked for remote autotuning.
continue
if not isinstance(inner_node, MultiTemplateBuffer):
continue
if (origin_node := inner_node.origin_node) is None:
continue
if (meta := origin_node.meta) is None:
continue
info = meta.get(_DISTRIBUTED_AUTOTUNE_KEY)
if info is None:
continue
assert info.local
# We force autotuning here
# Still takes advantage of async precompile
# We need all the configs before fusion
min_choice, _ = inner_node.get_min_choice()
choice = _SerializedChoice(info.index, min_choice)
autotune_results.append(choice)
state = V.distributed_autotune_state
assert len(autotune_results) == state.autotuned_local_count, (
f"incorrect local autotuned nodes found ({len(autotune_results)} != {state.autotuned_local_count})"
)
return autotune_results
def _autotune_remote_nodes(
scheduler: torch._inductor.scheduler.Scheduler,
choices_by_index: Sequence[_SerializedChoice],
) -> None:
"""
Go through the nodes in the scheduler and autotune the nodes that were
autotuned on remote ranks.
"""
for i, node in enumerate(scheduler.nodes):
if isinstance(node, SchedulerNode) and isinstance(
(dist_node := node.node), _DistributedAutotuneBuffer
):
assert dist_node.origin_node is not None
info = dist_node.origin_node.meta[_DISTRIBUTED_AUTOTUNE_KEY]
out_tensorbox = dist_node.autotune(choices_by_index[info.index])
out_storage = out_tensorbox.data
assert isinstance(out_storage, StorageBox)
out_buffer = out_storage.data
assert isinstance(out_buffer, OperationBuffer)
assert out_buffer.layout == dist_node.layout
scheduler._replace_node(out_buffer, dist_node, i, node)

View File

@ -52,8 +52,8 @@ from ..utils import (
decode_device,
get_all_devices,
get_gpu_type,
has_uses_tagged_as,
is_gpu,
is_pointwise_use,
OPTIMUS_EXCLUDE_POST_GRAD,
)
from ..virtualized import V
@ -1511,10 +1511,8 @@ def should_prefer_unfused_addmm(match):
if not is_gpu(inp.meta["val"].device.type):
return False
return has_uses_tagged_as(
match.output_node(),
(torch.Tag.pointwise, torch.Tag.reduction),
)
output = match.output_node()
return all(is_pointwise_use(use) for use in output.users)
@register_graph_pattern(

View File

@ -19,7 +19,7 @@ from torch.fx.experimental.proxy_tensor import make_fx
from torch.nn.functional import ScalingType # type: ignore[attr-defined]
from torch.torch_version import TorchVersion
from .. import config as inductor_config
from .. import config as inductor_config, distributed_autotune
from ..codegen.cuda.gemm_template import CUTLASS2xGemmTemplate, CUTLASS3xGemmTemplate
from ..codegen.rocm.ck_tile_universal_gemm_template import CKTileGemmTemplate
from ..codegen.rocm.ck_universal_gemm_template import CKGemmTemplate
@ -1315,6 +1315,11 @@ def tuned_mm(mat1, mat2, out_dtype=None, *, layout=None):
# The future will be awaited at scheduling time in select_algorithm.py
best_config_future = gen_best_config(mat1, mat2)
if box := distributed_autotune.maybe_autotune_remote(
name, choices, kernel_inputs.nodes(), layout
):
return box
return autotune_select_algorithm(
name,
choices,

View File

@ -449,7 +449,6 @@ class SchedulerDonatedBuffer(SchedulerBuffer):
class BaseSchedulerNode:
ancestors: OrderedSet[str]
debug_device_str: Callable[[BaseSchedulerNode], list[str]]
group: tuple[torch.device, tuple[tuple[sympy.Expr, ...], ...]]
last_usage: OrderedSet[str]
# .min_order and .max_order are only relevant for "grouped" nodes such as FusedSchedulerNode.
@ -461,21 +460,26 @@ class BaseSchedulerNode:
max_order: int
mpi_node: MemoryPlanningInfoForNode
mutation_renames: dict[str, str]
node: Optional[ir.Operation]
node: Optional[ir.Operation] = None
outputs: list[SchedulerBuffer]
outputs_by_name: dict[str, SchedulerBuffer]
override_estimated_runtime: Optional[float] = None
read_writes: dependencies.ReadWrites
unmet_dependencies: OrderedSet[Dep]
written: bool = False
def __init__(self, scheduler: Scheduler) -> None:
self.scheduler = scheduler
self.debug_device_str = lambda *args, **kwargs: []
self.scheduler: Scheduler = scheduler
self.debug_device_str: Callable[[BaseSchedulerNode], list[str]] = (
lambda *args, **kwargs: []
)
def _init_from_node(self, node: ir.Operation) -> None:
self.node = node
self.ancestors = OrderedSet()
self.last_usage = OrderedSet() # buffers that won't be used after this kernel
self.last_usage = OrderedSet[
str
]() # buffers that won't be used after this kernel
self.written = False
self.outputs = [
SchedulerBuffer(
@ -2643,6 +2647,12 @@ class Scheduler:
if config._pre_fusion_custom_pass is not None:
self.nodes = config._pre_fusion_custom_pass(self.nodes)
if config.distributed_max_autotune_gemm:
from . import distributed_autotune
distributed_autotune.schedule(self)
self.compute_ancestors()
self.nodes = self.fuse_nodes(self.nodes)
if config._post_fusion_custom_pass is not None:
self.nodes = config._post_fusion_custom_pass(self.nodes)
@ -3515,6 +3525,7 @@ class Scheduler:
new_scheduler_node.min_order = node.min_order
new_scheduler_node.max_order = node.max_order
new_scheduler_node.ancestors = node.ancestors
new_scheduler_node.last_usage = node.last_usage
def _any_atomic_add(self, node_list: Sequence[BaseSchedulerNode]) -> bool:

View File

@ -549,70 +549,6 @@ def is_pointwise_use(
return torch.Tag.pointwise in target.tags or is_pointwise_fn(target)
class LogicalConnective(enum.Enum):
OR = enum.auto()
AND = enum.auto()
def has_uses(
target: Node,
use_selector_fn: Callable[[torch._ops.OpOverload], bool] = lambda _: False,
use_aggregate_type: LogicalConnective = LogicalConnective.OR,
) -> bool:
"""
Given a target, explore the uses of `target` by applying `use_selector_fn`
on them, and then aggregate these booleans with the `use_aggregate_type`
logical connective.
Uses in view ops will follow the views uses.
"""
def get_use_aggregate_fn(
use_aggregate_type: LogicalConnective,
) -> Callable[[Iterator[Any]], bool]:
match use_aggregate_type:
case LogicalConnective.AND:
return all
case LogicalConnective.OR:
return any
case _:
return any
use_aggregate_fn = get_use_aggregate_fn(use_aggregate_type)
def has_uses_impl(use: Node) -> bool:
if use.op != "call_function":
return False
if not (
isinstance(use.target, torch._ops.OpOverload)
or use.target is operator.getitem
):
return False
target = cast(torch._ops.OpOverload, use.target)
# Process getitem and view
if target is operator.getitem or is_view(target):
return use_aggregate_fn(has_uses_impl(user) for user in use.users)
return use_selector_fn(target)
return use_aggregate_fn(has_uses_impl(user) for user in target.users)
def has_uses_tagged_as(
target: Node,
use_tags: Collection[torch.Tag],
use_aggregate_type: LogicalConnective = LogicalConnective.OR,
) -> bool:
"""
Is there a use with given tags?
"""
return has_uses(
target, lambda use: any(tag in use_tags for tag in use.tags), use_aggregate_type
)
def gen_gm_and_inputs(
target: Any, args: list[Any], kwargs: dict[str, Any]
) -> tuple[GraphModule, list[torch.Tensor]]:

View File

@ -86,6 +86,8 @@ if TYPE_CHECKING:
from torch._inductor.loop_body import InterpreterShim
from torch._subclasses import FakeTensorMode
from .distributed_autotune import _DistributedAutotuneState
threadlocal = local()
T = TypeVar("T")
@ -201,6 +203,9 @@ _current_node: Virtualized[torch.fx.Node] = Virtualized("current_node", NullHand
_local_buffer_context: Virtualized[LocalBufferContext] = Virtualized(
"local_buffer_context", NullHandler
)
_distributed_autotune_state: Virtualized[_DistributedAutotuneState] = Virtualized(
"distributed_autotune_state", NullHandler
)
def _choices_default():
@ -370,6 +375,12 @@ class _V:
set_local_buffer_context: Callable[[Any], Any] = _local_buffer_context._set_handler
get_local_buffer_context: Callable[[], Any] = _local_buffer_context._get_handler
set_choices_handler: Callable[[Any], Any] = _choices._set_handler
set_distributed_autotune_state: Callable[[Any], Any] = (
_distributed_autotune_state._set_handler
)
get_distributed_autotune_state: Callable[[], Any] = (
_distributed_autotune_state._get_handler
)
@property
def ops(self) -> OpsHandler[Any]:
@ -429,5 +440,9 @@ class _V:
def choices(self) -> InductorChoices:
return _choices._get_handler()
@property
def distributed_autotune_state(self):
return _distributed_autotune_state._get_handler()
V = _V()

View File

@ -33,14 +33,6 @@ static inline int PyCode_GetNFreevars(PyCodeObject* code) {
#endif
}
// Provided by CPython but getting the header for them is very hard
#if IS_PYTHON_3_11_PLUS
// NOLINTNEXTLINE(readability-redundant-declaration)
PyAPI_FUNC(void) _PyWeakref_ClearRef(PyWeakReference* self);
#else
extern void _PyWeakref_ClearRef(PyWeakReference* self);
#endif
#ifdef __cplusplus
}
#endif

File diff suppressed because it is too large Load Diff

View File

@ -20,7 +20,10 @@ from torch.distributed.tensor._tp_conv import (
convolution_backward_handler,
convolution_handler,
)
from torch.distributed.tensor._utils import try_find_mesh_from_args
from torch.distributed.tensor._utils import (
ExplicitRedistributionContext,
try_find_mesh_from_args,
)
from torch.distributed.tensor.placement_types import Partial, Placement, Replicate
from torch.utils._debug_mode import get_active_debug_mode
from torch.utils._python_dispatch import return_and_correct_aliasing
@ -199,6 +202,10 @@ class OpDispatcher:
if participating:
# computation that happens in the current rank of the mesh, normal case
if output_sharding.needs_redistribute:
if ExplicitRedistributionContext.is_active():
raise RuntimeError(
f"Implicit redistribution occurred while ExplicitRedistributionContext was active for {op_info.schema}"
)
# If sharding propagation decision needs redistribute, perform redistribute
# on args first, which could potentially modify args (i.e. allgather certain arg)
assert output_sharding.redistribute_schema is not None

View File

@ -18,6 +18,33 @@ from torch.distributed.tensor.placement_types import (
from torch.utils._typing_utils import not_none
class ExplicitRedistributionContext:
"""
Within this context manager, DTensor will refuse to perform implicit redistribution,
instead raising an error. Manual calls to ``redistribute()`` are required wherever a redistribution
must occur to avoid erroring. This can be used to ensure that the user is aware of all redistribution.
Note: it is easier to use this mode on just the forward pass of a typical DTensor program, as the backwards pass
may contain implicit redistribution calls that are not visible to the user and difficult to replace with manual
calls. Redistribution during backward can be made explicit by writing `autograd.Function`s that are no-op
during forward and perform a manual redistribution during backwards.
"""
_explicit_redistribute_mode = False
@classmethod
def is_active(cls) -> bool:
return cls._explicit_redistribute_mode
def __enter__(self):
self.prev = ExplicitRedistributionContext._explicit_redistribute_mode
ExplicitRedistributionContext._explicit_redistribute_mode = True
return self
def __exit__(self, exc_type, exc_val, exc_tb):
ExplicitRedistributionContext._explicit_redistribute_mode = self.prev
def _explicit_order_placements(
mesh_shape: ShapeType, placements: Sequence[Placement]
) -> Sequence[tuple[int, Placement]]:

View File

@ -2,6 +2,8 @@ import os
import sys
from typing import Optional
from torch.utils._config_module import Config, install_config_module
# [@compile_ignored: debug] Fails hard instead of graph breaking on guard on data dependent errors.
no_data_dependent_graph_break = (
@ -100,7 +102,11 @@ backed_size_oblivious = False
# Skip dtype check in meta registrations. Only used for systems that does its own dtype checking.
skip_dtype_check_in_meta_registrations = False
from torch.utils._config_module import install_config_module
# Experimental: If True, graph module will register fx metadata during recompile()
enrich_profiler_metadata: bool = Config( # type: ignore[var-annotated]
default=False,
env_name_default="TORCH_ENRICH_RPOFILER_STACK_TRACE",
)
install_config_module(sys.modules[__name__])

View File

@ -20,6 +20,7 @@ from torch.nn.modules.module import _addindent
from torch.package import Importer, PackageExporter, PackageImporter, sys_importer
from ._compatibility import compatibility
from .experimental import _config as fx_experimental_config
from .graph import (
_BoxedCodeGen,
_custom_builtins,
@ -858,14 +859,15 @@ class {module_name}(torch.nn.Module):
called after editing the contained ``graph``, otherwise the generated
code of this ``GraphModule`` will be out of date.
"""
# Do not import anything inside recompile, it might slow down the
# function and cause perf regression. Import outside of the method instead.
if isinstance(self._graph._codegen, _PyTreeCodeGen):
self._in_spec = self._graph._codegen.pytree_info.in_spec
self._out_spec = self._graph._codegen.pytree_info.out_spec
from torch._dynamo import config as dynamo_config
python_code = self._graph.python_code(
root_module="self", record_func=dynamo_config.enrich_profiler_metadata
root_module="self",
record_func=fx_experimental_config.enrich_profiler_metadata,
)
self._code = python_code.src
self._lineno_map = python_code._lineno_map
@ -874,7 +876,7 @@ class {module_name}(torch.nn.Module):
cls = type(self)
co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {}
if dynamo_config.enrich_profiler_metadata:
if fx_experimental_config.enrich_profiler_metadata:
# Generate metadata and register for profiler augmentation
node_metadata: dict[int, dict[str, Any]] = {}
for i, node in enumerate(self._graph.nodes):