mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-13 11:15:20 +08:00
Compare commits
14 Commits
csl/manual
...
trunk/f47c
| Author | SHA1 | Date | |
|---|---|---|---|
| f47cadf75d | |||
| 2923b02c6e | |||
| 4b9ba0fb26 | |||
| 106d34c80a | |||
| 0b06109412 | |||
| 2073af5790 | |||
| 9b4ac45d2f | |||
| a45a17f65e | |||
| c5593e75b3 | |||
| c90a976370 | |||
| d144382dc9 | |||
| 78827c5e00 | |||
| ab1e734cd7 | |||
| 888958ad6c |
@ -73,6 +73,19 @@ void box_cox_zero_lambda(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
at::vec::Vectorized<T> box_cox_nonzero_lambda_impl(
|
||||
at::vec::Vectorized<T> data,
|
||||
at::vec::Vectorized<T> lambda1,
|
||||
at::vec::Vectorized<T> lambda2,
|
||||
at::vec::Vectorized<T> k_eps) {
|
||||
auto sum = data + lambda2;
|
||||
auto max = at::vec::max(sum, k_eps);
|
||||
auto lambda_over_1 = at::vec::fast_recieprocal(lambda1);
|
||||
auto pow = max.pow(lambda1);
|
||||
return at::vec::fmsub(pow, lambda_over_1, lambda_over_1);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void box_cox_nonzero_lambda(
|
||||
int64_t D,
|
||||
@ -88,21 +101,18 @@ void box_cox_nonzero_lambda(
|
||||
auto k_eps_vec = Vec(k_eps);
|
||||
for(; j + VLEN < D; j += VLEN) {
|
||||
auto data = Vec::loadu(data_ptr + j);
|
||||
auto lambda2 = Vec::loadu(lambda2_ptr + j);
|
||||
auto sum = data + lambda2;
|
||||
auto max = at::vec::max(sum, k_eps_vec);
|
||||
auto lambda1 = Vec::loadu(lambda1_ptr + j);
|
||||
auto lambda_over_1 = at::vec::fast_recieprocal(lambda1);
|
||||
auto pow = max.pow(lambda1);
|
||||
auto res = at::vec::fmsub(pow, lambda_over_1, lambda_over_1);
|
||||
auto lambda2 = Vec::loadu(lambda2_ptr + j);
|
||||
auto res = box_cox_nonzero_lambda_impl(data, lambda1, lambda2, k_eps_vec);
|
||||
res.store(out + j);
|
||||
}
|
||||
for ( ;j < D; ++j) {
|
||||
auto sum = data_ptr[j] + lambda2_ptr[j];
|
||||
auto max = std::max(sum, k_eps);
|
||||
auto lambda_over_1 = at::vec::fast_recieprocal(lambda1_ptr[j]);
|
||||
auto pow = std::pow(max, lambda1_ptr[j]);
|
||||
out[j] = pow * lambda_over_1 - lambda_over_1;
|
||||
if (j < D) {
|
||||
auto remaining = D - j;
|
||||
auto data = Vec::loadu(data_ptr + j, remaining);
|
||||
auto lambda1 = Vec::loadu(lambda1_ptr + j, remaining);
|
||||
auto lambda2 = Vec::loadu(lambda2_ptr + j, remaining);
|
||||
auto res = box_cox_nonzero_lambda_impl(data, lambda1, lambda2, k_eps_vec);
|
||||
res.store(out + j, remaining);
|
||||
}
|
||||
}
|
||||
#else
|
||||
|
||||
@ -1,11 +1,18 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import itertools
|
||||
from contextlib import nullcontext
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed._local_tensor import (
|
||||
local_tensor_mode,
|
||||
LocalTensor,
|
||||
LocalTensorMode,
|
||||
)
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.tensor import distribute_tensor, DTensor
|
||||
from torch.distributed.tensor import DeviceMesh, distribute_tensor, DTensor
|
||||
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
|
||||
from torch.distributed.tensor._utils import (
|
||||
_compute_local_shape_and_global_offset,
|
||||
@ -14,6 +21,7 @@ from torch.distributed.tensor._utils import (
|
||||
compute_global_tensor_shape,
|
||||
compute_local_shape_and_global_offset,
|
||||
compute_local_tensor_info,
|
||||
ExplicitRedistributionContext,
|
||||
)
|
||||
from torch.distributed.tensor.debug import CommDebugMode
|
||||
from torch.distributed.tensor.placement_types import (
|
||||
@ -851,5 +859,93 @@ class Test2DStridedLocalShard(DTensorTestBase):
|
||||
self.assertEqual(global_tensor, dtensor_2d.full_tensor())
|
||||
|
||||
|
||||
class LocalTensorTestBase(TestCase):
|
||||
def assertEqual(self, lhs, rhs, **kwargs):
|
||||
mode = local_tensor_mode()
|
||||
with nullcontext() if mode is None else mode.disable():
|
||||
if isinstance(lhs, LocalTensor) and isinstance(rhs, LocalTensor):
|
||||
assert isinstance(lhs, LocalTensor) and isinstance(rhs, LocalTensor)
|
||||
super().assertEqual(lhs._ranks, rhs._ranks)
|
||||
for r in lhs._ranks:
|
||||
super().assertEqual(
|
||||
lhs._local_tensors[r],
|
||||
rhs._local_tensors[r],
|
||||
lambda m: f"rank {r}: {m}",
|
||||
)
|
||||
elif isinstance(lhs, LocalTensor) or isinstance(rhs, LocalTensor):
|
||||
lhs, rhs = (lhs, rhs) if isinstance(lhs, LocalTensor) else (rhs, lhs)
|
||||
for r in lhs._ranks:
|
||||
super().assertEqual(
|
||||
lhs._local_tensors[r], rhs, lambda m: f"rank {r}: {m}"
|
||||
)
|
||||
else:
|
||||
return super().assertEqual(lhs, rhs, **kwargs)
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
raise NotImplementedError("override world-size in your subclass")
|
||||
|
||||
def build_device_mesh(self) -> DeviceMesh:
|
||||
return init_device_mesh("cpu", (self.world_size,))
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
torch.distributed.init_process_group(
|
||||
# TODO: test other ranks too
|
||||
"fake",
|
||||
rank=0,
|
||||
world_size=self.world_size,
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
try:
|
||||
dist.destroy_process_group()
|
||||
except AssertionError:
|
||||
pass
|
||||
|
||||
|
||||
class TestExplicitRedistribute(LocalTensorTestBase):
|
||||
@property
|
||||
def world_size(self):
|
||||
return 4
|
||||
|
||||
def test_explicit_matmul(self):
|
||||
with LocalTensorMode(self.world_size):
|
||||
device_mesh = self.build_device_mesh()
|
||||
dim = 128
|
||||
x = torch.randn(8, dim, requires_grad=True)
|
||||
A = torch.randn(dim, dim, requires_grad=True)
|
||||
|
||||
# Prepare DTensors
|
||||
dx = distribute_tensor(x, device_mesh, [Shard(0)])
|
||||
dA = distribute_tensor(A, device_mesh, [Shard(0)])
|
||||
|
||||
# implicit redistribute works as usual by default
|
||||
with CommDebugMode() as comm_mode:
|
||||
torch.matmul(dx, dA)
|
||||
self.assertEqual(comm_mode.get_total_counts(), 1)
|
||||
|
||||
# explicit redistribute works too
|
||||
with ExplicitRedistributionContext():
|
||||
with self.assertRaisesRegex(RuntimeError, "Implicit redistribution"):
|
||||
torch.matmul(dx, dA)
|
||||
|
||||
# explicit redistribute allows manual redistribute
|
||||
with ExplicitRedistributionContext():
|
||||
dA_repl = dA.redistribute(device_mesh, [Replicate()])
|
||||
torch.matmul(dx, dA_repl)
|
||||
|
||||
dx = distribute_tensor(x, device_mesh, [Shard(0)])
|
||||
dA = distribute_tensor(A, device_mesh, [Replicate()])
|
||||
with ExplicitRedistributionContext():
|
||||
dY = torch.matmul(dx, dA_repl)
|
||||
loss = dY.sum()
|
||||
|
||||
# we now see the error during backwards
|
||||
with self.assertRaisesRegex(RuntimeError, "Implicit redistribution"):
|
||||
loss.backward()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
import contextlib
|
||||
import copy
|
||||
import functools
|
||||
import logging
|
||||
import random
|
||||
import unittest
|
||||
from contextlib import contextmanager
|
||||
@ -51,6 +52,9 @@ from torch.testing._internal.inductor_utils import HAS_GPU
|
||||
from torch.testing._internal.triton_utils import requires_cuda_and_triton
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def reset_rng_state():
|
||||
torch.manual_seed(1337)
|
||||
random.seed(1337)
|
||||
@ -1200,6 +1204,116 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
for r in res[1:]:
|
||||
self.assertEqual(res[0], r)
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@patch.object(torch._dynamo.config, "enable_compiler_collectives", True)
|
||||
@patch.object(torch._inductor.config, "max_autotune_gemm", True)
|
||||
@patch.object(torch._inductor.config, "distributed_max_autotune_gemm", True)
|
||||
def test_multiproc_autotune(self):
|
||||
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
||||
torch._dynamo.utils.clear_compilation_metrics()
|
||||
|
||||
@torch.compile()
|
||||
def f(a, b, c):
|
||||
res = (
|
||||
torch.sum((a @ b) + 1.0)
|
||||
+ torch.sum(torch.relu(b @ c))
|
||||
+ torch.sum(c @ a)
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
a = torch.randn(1024, 1024, device=self.rank, dtype=torch.bfloat16)
|
||||
b = torch.randn(1024, 2048, device=self.rank, dtype=torch.bfloat16)
|
||||
c = torch.randn(2048, 1024, device=self.rank, dtype=torch.bfloat16)
|
||||
|
||||
try:
|
||||
f(a, b, c)
|
||||
except Exception:
|
||||
log.exception("Caught exception running f")
|
||||
raise
|
||||
|
||||
metrics = torch._dynamo.utils.get_compilation_metrics()
|
||||
res = [None] * self.world_size
|
||||
torch.distributed.all_gather_object(res, len(metrics))
|
||||
for r in res[1:]:
|
||||
self.assertEqual(res[0], r)
|
||||
|
||||
print(f"Result from {self.rank} is {f(a, b, c)}")
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@patch.object(torch._dynamo.config, "enable_compiler_collectives", True)
|
||||
@patch.object(torch._inductor.config, "max_autotune_gemm", True)
|
||||
@patch.object(torch._inductor.config, "distributed_max_autotune_gemm", True)
|
||||
def test_multiproc_autotune_dynamic_shapes(self):
|
||||
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
||||
torch._dynamo.utils.clear_compilation_metrics()
|
||||
|
||||
@torch.compile()
|
||||
def f(a, b, c):
|
||||
res = (
|
||||
torch.sum((a @ b) + 1.0)
|
||||
+ torch.sum(torch.relu(b @ c))
|
||||
+ torch.sum(c @ a)
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
a = torch.randn(1024, 1024, device=self.rank, dtype=torch.bfloat16)
|
||||
b = torch.randn(1024, 2048, device=self.rank, dtype=torch.bfloat16)
|
||||
c = torch.randn(2048, 1024, device=self.rank, dtype=torch.bfloat16)
|
||||
|
||||
# Mark tensors as dynamic on dimension 0
|
||||
torch._dynamo.mark_dynamic(a, 0)
|
||||
torch._dynamo.mark_dynamic(a, 1)
|
||||
torch._dynamo.mark_dynamic(b, 0)
|
||||
torch._dynamo.mark_dynamic(b, 1)
|
||||
torch._dynamo.mark_dynamic(c, 0)
|
||||
torch._dynamo.mark_dynamic(c, 1)
|
||||
|
||||
try:
|
||||
f(a, b, c)
|
||||
except Exception:
|
||||
log.exception("Caught exception running f")
|
||||
raise
|
||||
|
||||
metrics = torch._dynamo.utils.get_compilation_metrics()
|
||||
res = [None] * self.world_size
|
||||
torch.distributed.all_gather_object(res, len(metrics))
|
||||
for r in res[1:]:
|
||||
self.assertEqual(res[0], r)
|
||||
|
||||
print(f"Result from {self.rank} is {f(a, b, c)}")
|
||||
|
||||
# Store the initial compilation count
|
||||
initial_compile_count = len(metrics)
|
||||
|
||||
# # Test with different sizes to ensure dynamic shapes work without recompilation
|
||||
a2 = torch.randn(512, 512, device=self.rank, dtype=torch.bfloat16)
|
||||
b2 = torch.randn(512, 2048, device=self.rank, dtype=torch.bfloat16)
|
||||
c2 = torch.randn(2048, 512, device=self.rank, dtype=torch.bfloat16)
|
||||
|
||||
try:
|
||||
result2 = f(a2, b2, c2)
|
||||
print(f"Result2 from {self.rank} is {result2}")
|
||||
except Exception:
|
||||
log.exception("Caught exception running f with different sizes")
|
||||
raise
|
||||
|
||||
# Verify no recompilation occurred
|
||||
metrics_after = torch._dynamo.utils.get_compilation_metrics()
|
||||
final_compile_count = len(metrics_after)
|
||||
self.assertEqual(
|
||||
initial_compile_count,
|
||||
final_compile_count,
|
||||
"Expected no recompilation with dynamic shapes",
|
||||
)
|
||||
|
||||
# Verify all ranks have the same compilation count
|
||||
res_after = [None] * self.world_size
|
||||
torch.distributed.all_gather_object(res_after, final_compile_count)
|
||||
for r in res_after[1:]:
|
||||
self.assertEqual(res_after[0], r)
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
def test_get_pg_attr(self):
|
||||
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
||||
|
||||
@ -1428,6 +1428,170 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
|
||||
|
||||
self.assertRaises(torch._dynamo.exc.UserError, lambda: f(torch.tensor([3])))
|
||||
|
||||
def test_check_compiles_when_predicate_true_and_message_has_no_closure(self):
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def f(x):
|
||||
torch._check(x.shape[0] > 3, lambda: "Shape is not greater than 3")
|
||||
return x + 1
|
||||
|
||||
x = torch.randn(4)
|
||||
torch._dynamo.maybe_mark_dynamic(x, 0)
|
||||
|
||||
y = f(x)
|
||||
self.assertEqual(y.shape, x.shape)
|
||||
|
||||
def test_check_compiles_when_predicate_true_constant_and_message_has_no_closure(
|
||||
self,
|
||||
):
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def f(x):
|
||||
torch._check(x.shape[0] > 3, lambda: "Shape is not greater than 3")
|
||||
return x + 1
|
||||
|
||||
x = torch.randn(4)
|
||||
|
||||
y = f(x)
|
||||
self.assertEqual(y.shape, x.shape)
|
||||
|
||||
def test_check_compiles_when_predicate_true_constant_and_message_None(self):
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def f(x):
|
||||
torch._check(x.shape[0] > 3)
|
||||
return x + 1
|
||||
|
||||
x = torch.randn(4)
|
||||
|
||||
y = f(x)
|
||||
self.assertEqual(y.shape, x.shape)
|
||||
|
||||
def test_check_compiles_when_predicate_true_and_message_None(self):
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def f(x):
|
||||
torch._check(x.shape[0] > 3)
|
||||
return x + 1
|
||||
|
||||
x = torch.randn(4)
|
||||
torch._dynamo.maybe_mark_dynamic(x, 0)
|
||||
|
||||
y = f(x)
|
||||
self.assertEqual(y.shape, x.shape)
|
||||
|
||||
def test_check_compiles_when_predicate_true_and_message_has_global(self):
|
||||
global GLOBAL_INT
|
||||
GLOBAL_INT = 1
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def f(x):
|
||||
torch._check(x.shape[0] > 3, lambda: f"{GLOBAL_INT} is not greater than 3")
|
||||
return x + 1
|
||||
|
||||
x = torch.randn(4)
|
||||
torch._dynamo.maybe_mark_dynamic(x, 0)
|
||||
|
||||
y = f(x)
|
||||
self.assertEqual(y.shape, x.shape)
|
||||
|
||||
def test_check_raises_at_runtime_when_predicate_false_and_message_has_global(self):
|
||||
global GLOBAL_INT
|
||||
GLOBAL_INT = 1
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def f(x):
|
||||
torch._check(x.shape[0] > 3, lambda: f"{GLOBAL_INT} is not greater than 3")
|
||||
return x + 1
|
||||
|
||||
x = torch.randn(3)
|
||||
torch._dynamo.maybe_mark_dynamic(x, 0)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, f"{GLOBAL_INT} is not greater than 3"
|
||||
):
|
||||
f(x)
|
||||
|
||||
def test_check_raises_at_runtime_when_predicate_false_and_message_None(self):
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def f(x):
|
||||
torch._check(x.shape[0] > 3)
|
||||
return x + 1
|
||||
|
||||
x = torch.randn(3)
|
||||
torch._dynamo.maybe_mark_dynamic(x, 0)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, None):
|
||||
f(x)
|
||||
|
||||
def test_check_raises_at_runtime_when_predicate_false_constant_and_message_None(
|
||||
self,
|
||||
):
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def f(x):
|
||||
torch._check(x.shape[0] > 3)
|
||||
return x + 1
|
||||
|
||||
x = torch.randn(3)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, None):
|
||||
f(x)
|
||||
|
||||
def test_check_raises_at_runtime_when_predicate_false_and_message_has_no_closure(
|
||||
self,
|
||||
):
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def f(x):
|
||||
torch._check(x.shape[0] > 3, lambda: "Shape is not greater than 3")
|
||||
return x + 1
|
||||
|
||||
x = torch.randn(3)
|
||||
torch._dynamo.maybe_mark_dynamic(x, 0)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Shape is not greater than 3"):
|
||||
f(x)
|
||||
|
||||
def test_check_raises_at_runtime_when_predicate_false_constant_and_message_has_no_closure(
|
||||
self,
|
||||
):
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def f(x):
|
||||
torch._check(x.shape[0] > 3, lambda: "Shape is not greater than 3")
|
||||
return x + 1
|
||||
|
||||
x = torch.randn(3)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Shape is not greater than 3"):
|
||||
f(x)
|
||||
|
||||
def test_check_assert_error_at_runtime_when_predicate_false_and_message_has_closure(
|
||||
self,
|
||||
):
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def f(x):
|
||||
torch._check(x.shape[0] > 3, lambda: f"{x.shape[0]} is not greater than 3")
|
||||
return x + 1
|
||||
|
||||
x = torch.randn(3)
|
||||
torch._dynamo.maybe_mark_dynamic(x, 0)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.Unsupported, "Can't extract message from torch._check()"
|
||||
):
|
||||
f(x)
|
||||
|
||||
def test_check_assert_error_at_runtime_when_predicate_true_and_message_has_closure(
|
||||
self,
|
||||
):
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def f(x):
|
||||
torch._check(x.shape[0] > 3, lambda: f"{x.shape[0]} is not greater than 3")
|
||||
return x + 1
|
||||
|
||||
x = torch.randn(4)
|
||||
torch._dynamo.maybe_mark_dynamic(x, 0)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.Unsupported, "Can't extract message from torch._check()"
|
||||
):
|
||||
f(x)
|
||||
|
||||
def test_assert(self):
|
||||
@torch.compile
|
||||
def fn1(x):
|
||||
|
||||
@ -74,13 +74,13 @@ class TestStreams(torch._dynamo.test_case.TestCase):
|
||||
"""\
|
||||
class <lambda>(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
|
||||
# Annotation: {'stream': None}
|
||||
# Annotation: {'stream': 0}
|
||||
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
|
||||
|
||||
# Annotation: {'stream': None}
|
||||
# Annotation: {'stream': 1}
|
||||
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
|
||||
|
||||
# Annotation: {'stream': None}
|
||||
# Annotation: {'stream': 1}
|
||||
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 2); add_1 = None
|
||||
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_2, add); add_2 = add = None
|
||||
return (add_3,)
|
||||
@ -196,6 +196,7 @@ class <lambda>(torch.nn.Module):
|
||||
s_exp = fn(*inp)
|
||||
self.assertEqual(s_act, s_exp)
|
||||
|
||||
@requires_cuda
|
||||
def test_nested_stream_enter_exit(self):
|
||||
def fn(x, y, s0, s1, s2):
|
||||
with s1:
|
||||
@ -229,13 +230,13 @@ class <lambda>(torch.nn.Module):
|
||||
"""\
|
||||
class <lambda>(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
|
||||
# Annotation: {'stream': None}
|
||||
# Annotation: {'stream': 1}
|
||||
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
|
||||
|
||||
# Annotation: {'stream': None}
|
||||
# Annotation: {'stream': 2}
|
||||
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
|
||||
|
||||
# Annotation: {'stream': None}
|
||||
# Annotation: {'stream': 1}
|
||||
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 2); add = None
|
||||
return (add_1, add_2)
|
||||
""",
|
||||
@ -249,6 +250,7 @@ class <lambda>(torch.nn.Module):
|
||||
def test_nested_stream_enter_exit_graph_break(self):
|
||||
pass
|
||||
|
||||
@requires_cuda
|
||||
def test_local_stream_enter_exit(self):
|
||||
def fn(x, y):
|
||||
s2 = torch.Stream()
|
||||
@ -289,6 +291,7 @@ class <lambda>(torch.nn.Module):
|
||||
""",
|
||||
)
|
||||
|
||||
@requires_cuda
|
||||
def test_local_stream_nested_enter_exit(self):
|
||||
def fn(x, y):
|
||||
s2 = torch.Stream()
|
||||
@ -331,6 +334,7 @@ class <lambda>(torch.nn.Module):
|
||||
""",
|
||||
)
|
||||
|
||||
@requires_cuda
|
||||
def test_stream_with_mutation(self):
|
||||
def fn(x, y):
|
||||
s2 = torch.Stream()
|
||||
@ -380,6 +384,7 @@ class <lambda>(torch.nn.Module):
|
||||
""",
|
||||
)
|
||||
|
||||
@requires_cuda
|
||||
def test_stream_backward(self) -> None:
|
||||
def fn(x, y):
|
||||
s2 = torch.Stream()
|
||||
|
||||
@ -500,13 +500,8 @@ class PaddingTest(TestCaseBase):
|
||||
forward_wrapper = wrapper_codes[0]
|
||||
|
||||
# make sure the load for softmax is aligned
|
||||
if bias:
|
||||
# addmm -> mm + bias and bias is fused with softmax
|
||||
softmax_load_str = "tl.load(in_out_ptr0 + (r0_1 + 30528*x0)"
|
||||
else:
|
||||
softmax_load_str = "tl.load(in_ptr0 + (r0_1 + 30528*x0)"
|
||||
self.assertTrue(
|
||||
softmax_load_str in forward_wrapper,
|
||||
"tl.load(in_ptr0 + (r0_1 + 30528*x0)" in forward_wrapper,
|
||||
f"forward_wrapper: {forward_wrapper}",
|
||||
)
|
||||
|
||||
|
||||
@ -15310,7 +15310,7 @@ if RUN_GPU:
|
||||
),
|
||||
(
|
||||
fn3,
|
||||
"triton_poi_fused_addmm_native_layer_norm",
|
||||
"triton_poi_fused_native_layer_norm_relu",
|
||||
(torch.randn(4, 4, device=GPU_TYPE),),
|
||||
),
|
||||
]
|
||||
@ -15323,7 +15323,7 @@ if RUN_GPU:
|
||||
),
|
||||
(
|
||||
fn3,
|
||||
"triton_poi_fused_LayerNorm_Linear_ReLU",
|
||||
"triton_poi_fused_LayerNorm_ReLU",
|
||||
(torch.randn(4, 4, device=GPU_TYPE),),
|
||||
),
|
||||
]
|
||||
|
||||
@ -7508,6 +7508,8 @@ class TestFXMemoryProfiler(TestCase):
|
||||
device = "cuda"
|
||||
mod = MLPModule(device)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# reset cache to start fresh
|
||||
torch.cuda.memory.empty_cache()
|
||||
torch.cuda.memory._record_memory_history()
|
||||
compiled = torch.compile(mod, backend="aot_eager", fullgraph=True)
|
||||
result = compiled(torch.randn(10, 10, device=device))
|
||||
@ -7518,10 +7520,7 @@ class TestFXMemoryProfiler(TestCase):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
fx_frames = self.collect_frames(augmented_snapshot)
|
||||
if TEST_WITH_ROCM:
|
||||
self.assertGreater(len(fx_frames), 0)
|
||||
else:
|
||||
self.assertEqual(len(fx_frames), 12)
|
||||
self.assertGreater(len(fx_frames), 2)
|
||||
|
||||
for frame in fx_frames:
|
||||
# Every FX frame should have both node_op and node_name
|
||||
|
||||
@ -4251,7 +4251,7 @@ def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
||||
@skipIfRocm
|
||||
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
|
||||
@torch.fx.experimental._config.patch("enrich_profiler_metadata", True)
|
||||
def test_profiler_stack_trace_augmentation(self):
|
||||
"""
|
||||
Test that map_recorded_events_to_aten_ops_with_stack_trace correctly
|
||||
@ -4307,7 +4307,7 @@ event=cudaLaunchKernel node=addmm_1 stack_trace=x = self.linear2(x)"""
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
||||
@skipIfRocm
|
||||
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
|
||||
@torch.fx.experimental._config.patch("enrich_profiler_metadata", True)
|
||||
def test_profiler_multiple_modules(self):
|
||||
"""
|
||||
Test that multiple compiled modules under the same profiler session
|
||||
@ -4351,7 +4351,7 @@ event=cudaLaunchKernel node=sub stack_trace=return x - 1"""
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
||||
@skipIfRocm
|
||||
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
|
||||
@torch.fx.experimental._config.patch("enrich_profiler_metadata", True)
|
||||
def test_profiler_nested_graph_modules(self):
|
||||
"""
|
||||
Test that nested graph modules (e.g., graph modules calling subgraphs)
|
||||
|
||||
@ -739,11 +739,8 @@ enable_aot_compile = False
|
||||
# HACK: this is for testing custom ops profiling only
|
||||
_custom_ops_profile: Optional[Any] = None
|
||||
|
||||
# Experimental: If True, graph module will register fx metadata during recompile()
|
||||
enrich_profiler_metadata: bool = Config( # type: ignore[var-annotated]
|
||||
default=False,
|
||||
env_name_default="TORCH_ENRICH_RPOFILER_STACK_TRACE",
|
||||
)
|
||||
# Deprecated! Please use the config in torch/fx/experimental/_config instead.
|
||||
enrich_profiler_metadata: bool = False
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.utils._config_typing import * # noqa: F401, F403
|
||||
|
||||
@ -47,7 +47,7 @@ from torch.multiprocessing.reductions import StorageWeakRef
|
||||
from torch.utils._content_store import ContentStoreReader, ContentStoreWriter
|
||||
|
||||
from . import config
|
||||
from .utils import clone_inputs, get_debug_dir
|
||||
from .utils import clone_inputs, get_debug_dir, warn_once
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -617,7 +617,7 @@ class InputReader:
|
||||
# way would be very mysterious! Would have been better
|
||||
# not to store device in the serialized format...
|
||||
return storage
|
||||
log.warning("could not load %s, generating random data instead", storage_hash)
|
||||
warn_once(f"could not load {storage_hash}, generating random data instead")
|
||||
shape = (nbytes // dtype_hint.itemsize,)
|
||||
stride = _stride_or_default(None, shape=shape)
|
||||
return rand_strided(shape, stride, dtype_hint, device).untyped_storage()
|
||||
|
||||
@ -2937,5 +2937,18 @@
|
||||
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0288": [
|
||||
{
|
||||
"Gb_type": "Can't extract message from torch._check()",
|
||||
"Context": "str(message_vt)",
|
||||
"Explanation": "The second argument of torch._check() must be a functiondefined within the torch.compile regionthat does not reference a non-local variable.",
|
||||
"Hints": [
|
||||
"Make sure the message function is defined in the torch.compile region.",
|
||||
"Remove any closure variables, e.g. ",
|
||||
"remove references to closure variable `x` in `lambda: f'{x} failed check'`",
|
||||
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@ -1542,7 +1542,7 @@ class OutputGraph(OutputGraphCommon):
|
||||
)
|
||||
)
|
||||
tmp_vars = []
|
||||
for constructor in reversed(index_to_bytecode_constructor.values()):
|
||||
for constructor in index_to_bytecode_constructor.values():
|
||||
constructor(codegen)
|
||||
var_name = (
|
||||
self.new_var()
|
||||
|
||||
@ -3228,7 +3228,7 @@ class InstructionTranslatorBase(
|
||||
|
||||
def BUILD_SLICE(self, inst: Instruction) -> None:
|
||||
items = self.popn(inst.argval)
|
||||
self.push(SliceVariable(items, tx=self))
|
||||
self.push(SliceVariable(items, tx=self)) # type: ignore[arg-type]
|
||||
|
||||
def BUILD_LIST(self, inst: Instruction) -> None:
|
||||
items = self.popn(inst.argval)
|
||||
@ -3607,7 +3607,7 @@ class InstructionTranslatorBase(
|
||||
obj = self.stack[-inst.arg]
|
||||
assert isinstance(obj, ListVariable)
|
||||
assert obj.is_mutable()
|
||||
obj.call_method(self, "extend", [v], {})
|
||||
obj.call_method(self, "extend", [v], {}) # type: ignore[arg-type]
|
||||
|
||||
def LIST_TO_TUPLE(self, inst: Instruction) -> None:
|
||||
self.push(BuiltinVariable(tuple).call_function(self, [self.pop()], {})) # type: ignore[arg-type]
|
||||
@ -3673,7 +3673,7 @@ class InstructionTranslatorBase(
|
||||
def MATCH_KEYS(self, inst: Instruction) -> None:
|
||||
tos = self.stack[-1]
|
||||
assert isinstance(tos, TupleVariable)
|
||||
keys = tos.unpack_var_sequence(self)
|
||||
keys = tos.unpack_var_sequence(self) # type: ignore[arg-type]
|
||||
tos1 = self.stack[-2]
|
||||
assert isinstance(tos1, ConstDictVariable)
|
||||
|
||||
|
||||
@ -180,6 +180,7 @@ manual_torch_name_rule_map: dict[
|
||||
"torch.compiler.is_exporting": TorchInGraphFunctionVariable,
|
||||
"torch._C._to_dlpack": SkipFunctionVariable,
|
||||
"torch.to_dlpack": SkipFunctionVariable,
|
||||
"torch._check": TorchInGraphFunctionVariable,
|
||||
# We graph break on RNG state setters or getters like
|
||||
# `torch.get_rng_state` or `torch.set_rng_state`. These functions
|
||||
# are not aten operations and therefore they are completely ignored
|
||||
@ -2343,7 +2344,6 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys(
|
||||
"torch._check_type",
|
||||
"torch._check_value",
|
||||
"torch._check_with",
|
||||
"torch._check",
|
||||
"torch._compile._disable_dynamo",
|
||||
"torch._functorch.apis.chunk_vmap",
|
||||
"torch._functorch.batch_norm_replacement.batch_norm_without_running_stats",
|
||||
|
||||
@ -1061,9 +1061,7 @@ class VariableBuilder:
|
||||
)
|
||||
set_example_value(stream_proxy.node, value)
|
||||
var = StreamVariable(
|
||||
stream_proxy,
|
||||
value,
|
||||
source=self.source,
|
||||
stream_proxy, value, source=self.source, user_object_index=index
|
||||
)
|
||||
return self.tx.output.side_effects.track_object_existing(value, var)
|
||||
elif isinstance(value, (torch._C._SDPAParams)):
|
||||
@ -3006,14 +3004,16 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
|
||||
return SymNodeVariable(proxy, example_value, **options)
|
||||
elif (
|
||||
isinstance(example_value, torch.Stream)
|
||||
and proxy.node.target
|
||||
in (get_external_object_by_index, torch.accelerator.current_stream)
|
||||
and proxy.node.target == get_external_object_by_index
|
||||
) or proxy.node.target in [
|
||||
device_interface.current_stream
|
||||
for _, device_interface in get_registered_device_interfaces()
|
||||
]:
|
||||
set_example_value(proxy.node, example_value)
|
||||
return StreamVariable(proxy, example_value, **options)
|
||||
index = None
|
||||
if proxy.node.target == get_external_object_by_index:
|
||||
index = proxy.node.args[0]
|
||||
return StreamVariable(proxy, example_value, index, **options)
|
||||
elif (
|
||||
inspect.isclass(proxy.node.target)
|
||||
and issubclass(proxy.node.target, torch.Event)
|
||||
|
||||
@ -1513,7 +1513,7 @@ class WithExitFunctionVariable(VariableTracker):
|
||||
# Note here we reconstruct the context manager rather than the
|
||||
# exit function. The handler generated by BlockStackEntry
|
||||
# will re-enter the context in the resume function.
|
||||
self.ctx.reconstruct_type(codegen) # type: ignore[attr-defined]
|
||||
self.ctx.reconstruct_type(codegen) # type: ignore[union-attr]
|
||||
if codegen.tx.output.partial_convert:
|
||||
if sys.version_info >= (3, 11):
|
||||
codegen.append_output(create_instruction("PUSH_NULL"))
|
||||
@ -1522,10 +1522,10 @@ class WithExitFunctionVariable(VariableTracker):
|
||||
# We rely on classes subtyping `GenericContextWrappingVariable`
|
||||
# to implement these fns and have these attributes
|
||||
codegen.extend_output(
|
||||
[codegen.create_load_const(val) for val in self.ctx.target_values] # type: ignore[arg-type]
|
||||
[codegen.create_load_const(val) for val in self.ctx.target_values] # type: ignore[union-attr]
|
||||
)
|
||||
codegen.extend_output(
|
||||
create_call_function(len(self.ctx.target_values), False) # type: ignore[arg-type]
|
||||
create_call_function(len(self.ctx.target_values), False) # type: ignore[union-attr]
|
||||
)
|
||||
codegen.append_output(create_setup_with(self.target))
|
||||
codegen.append_output(create_instruction("POP_TOP"))
|
||||
|
||||
@ -82,7 +82,8 @@ class ItertoolsVariable(VariableTracker):
|
||||
for item in itertools.product(*seqs, repeat=r)
|
||||
]
|
||||
return variables.ListIteratorVariable(
|
||||
items, mutation_type=ValueMutationNew()
|
||||
items, # type: ignore[arg-type]
|
||||
mutation_type=ValueMutationNew(),
|
||||
)
|
||||
elif (
|
||||
self.value is itertools.combinations
|
||||
@ -98,7 +99,8 @@ class ItertoolsVariable(VariableTracker):
|
||||
for item in itertools.combinations(iterable, r):
|
||||
items.append(variables.TupleVariable(list(item)))
|
||||
return variables.ListIteratorVariable(
|
||||
items, mutation_type=ValueMutationNew()
|
||||
items, # type: ignore[arg-type]
|
||||
mutation_type=ValueMutationNew(),
|
||||
)
|
||||
elif self.value is itertools.groupby:
|
||||
if any(kw != "key" for kw in kwargs.keys()):
|
||||
@ -181,7 +183,8 @@ class ItertoolsVariable(VariableTracker):
|
||||
from_exc=e,
|
||||
)
|
||||
return variables.ListIteratorVariable(
|
||||
result, mutation_type=ValueMutationNew()
|
||||
result, # type: ignore[arg-type]
|
||||
mutation_type=ValueMutationNew(),
|
||||
)
|
||||
elif self.value is itertools.repeat:
|
||||
if len(args) < 2:
|
||||
@ -212,7 +215,8 @@ class ItertoolsVariable(VariableTracker):
|
||||
)
|
||||
]
|
||||
return variables.ListIteratorVariable(
|
||||
items, mutation_type=ValueMutationNew()
|
||||
items, # type: ignore[arg-type]
|
||||
mutation_type=ValueMutationNew(),
|
||||
)
|
||||
else:
|
||||
return super().call_function(tx, args, kwargs)
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
"""
|
||||
Variable tracking implementations for list-like data structures in Dynamo.
|
||||
|
||||
@ -20,7 +18,7 @@ import collections
|
||||
import inspect
|
||||
import operator
|
||||
import sys
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from typing import Any, Optional, Sequence, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
@ -60,11 +58,11 @@ if TYPE_CHECKING:
|
||||
|
||||
class BaseListVariable(VariableTracker):
|
||||
@staticmethod
|
||||
def cls_for_instance(obj):
|
||||
def cls_for_instance(obj: Any) -> type["BaseListVariable"]:
|
||||
return BaseListVariable.cls_for(type(obj))
|
||||
|
||||
@staticmethod
|
||||
def cls_for(obj):
|
||||
def cls_for(obj: Any) -> type:
|
||||
return {
|
||||
iter: ListIteratorVariable,
|
||||
list: ListVariable,
|
||||
@ -80,34 +78,38 @@ class BaseListVariable(VariableTracker):
|
||||
def __init__(
|
||||
self,
|
||||
items: list[VariableTracker],
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
assert isinstance(items, list)
|
||||
assert all(isinstance(x, VariableTracker) for x in items)
|
||||
self.items: list[VariableTracker] = items
|
||||
|
||||
def _as_proxy(self):
|
||||
def _as_proxy(self) -> list[Any]:
|
||||
return [x.as_proxy() for x in self.items]
|
||||
|
||||
def modified(self, items, **kwargs):
|
||||
def modified(
|
||||
self, items: list[VariableTracker], **kwargs: Any
|
||||
) -> "BaseListVariable":
|
||||
return type(self)(items, **kwargs)
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
def value(self) -> Any:
|
||||
return self.as_python_constant()
|
||||
|
||||
def debug_repr_helper(self, prefix, suffix):
|
||||
def debug_repr_helper(self, prefix: str, suffix: str) -> str:
|
||||
return prefix + ", ".join(i.debug_repr() for i in self.items) + suffix
|
||||
|
||||
def as_python_constant(self):
|
||||
def as_python_constant(self) -> Any:
|
||||
return self.python_type()([x.as_python_constant() for x in self.items])
|
||||
|
||||
def as_proxy(self):
|
||||
def as_proxy(self) -> Any:
|
||||
assert self.python_type() is not SizeVariable
|
||||
return self.python_type()(self._as_proxy())
|
||||
|
||||
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
|
||||
def getitem_const(
|
||||
self, tx: "InstructionTranslator", arg: VariableTracker
|
||||
) -> VariableTracker:
|
||||
from .tensor import SymNodeVariable
|
||||
|
||||
if isinstance(arg, SymNodeVariable):
|
||||
@ -134,16 +136,16 @@ class BaseListVariable(VariableTracker):
|
||||
IndexError, tx, args=["list index out of range"]
|
||||
)
|
||||
|
||||
def unpack_var_sequence(self, tx):
|
||||
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
|
||||
return list(self.items)
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
args: list["VariableTracker"],
|
||||
kwargs: dict[str, "VariableTracker"],
|
||||
) -> "VariableTracker":
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
if name == "__getitem__":
|
||||
from .tensor import TensorVariable
|
||||
|
||||
@ -224,15 +226,15 @@ class BaseListVariable(VariableTracker):
|
||||
if type(self) is not type(args[0]):
|
||||
tp_name = self.python_type_name()
|
||||
other = args[0].python_type_name()
|
||||
msg = ConstantVariable.create(
|
||||
msg_vt = ConstantVariable.create(
|
||||
f'can only concatenate {tp_name} (not "{other}") to {tp_name}'
|
||||
)
|
||||
raise_observed_exception(TypeError, tx, args=[msg])
|
||||
raise_observed_exception(TypeError, tx, args=[msg_vt])
|
||||
|
||||
if name == "__add__":
|
||||
return type(self)(self.items + args[0].items, source=self.source)
|
||||
return type(self)(self.items + args[0].items, source=self.source) # type: ignore[attr-defined]
|
||||
else:
|
||||
self.items += args[0].items
|
||||
self.items += args[0].items # type: ignore[attr-defined]
|
||||
return self
|
||||
elif name in ("__mul__", "__imul__"):
|
||||
if kwargs or len(args) != 1:
|
||||
@ -244,10 +246,10 @@ class BaseListVariable(VariableTracker):
|
||||
)
|
||||
|
||||
if not (args[0].is_python_constant() and args[0].python_type() is int):
|
||||
msg = ConstantVariable.create(
|
||||
msg_vt = ConstantVariable.create(
|
||||
f"can't multiply sequence by non-int type of '{args[0].python_type_name()}'"
|
||||
)
|
||||
raise_observed_exception(TypeError, tx, args=[msg])
|
||||
raise_observed_exception(TypeError, tx, args=[msg_vt])
|
||||
|
||||
val = args[0].as_python_constant()
|
||||
|
||||
@ -301,7 +303,7 @@ class BaseListVariable(VariableTracker):
|
||||
|
||||
|
||||
class RangeVariable(BaseListVariable):
|
||||
def __init__(self, items, **kwargs) -> None:
|
||||
def __init__(self, items: Sequence[VariableTracker], **kwargs: Any) -> None:
|
||||
items_to_map = items
|
||||
start = variables.ConstantVariable.create(0)
|
||||
stop = None
|
||||
@ -316,7 +318,7 @@ class RangeVariable(BaseListVariable):
|
||||
else:
|
||||
raise AssertionError
|
||||
|
||||
def maybe_as_int(x):
|
||||
def maybe_as_int(x: VariableTracker) -> VariableTracker:
|
||||
return (
|
||||
ConstantVariable(int(x.value)) if isinstance(x, ConstantVariable) else x
|
||||
)
|
||||
@ -329,22 +331,22 @@ class RangeVariable(BaseListVariable):
|
||||
assert stop is not None
|
||||
super().__init__([start, stop, step], **kwargs)
|
||||
|
||||
def debug_repr(self):
|
||||
def debug_repr(self) -> str:
|
||||
return self.debug_repr_helper("range(", ")")
|
||||
|
||||
def python_type(self):
|
||||
def python_type(self) -> type:
|
||||
return range
|
||||
|
||||
def start(self):
|
||||
def start(self) -> Any:
|
||||
return self.items[0].as_python_constant()
|
||||
|
||||
def stop(self):
|
||||
def stop(self) -> Any:
|
||||
return self.items[1].as_python_constant()
|
||||
|
||||
def step(self):
|
||||
def step(self) -> Any:
|
||||
return self.items[2].as_python_constant()
|
||||
|
||||
def range_length(self):
|
||||
def range_length(self) -> int:
|
||||
lo = self.start()
|
||||
hi = self.stop()
|
||||
step = self.step()
|
||||
@ -357,7 +359,7 @@ class RangeVariable(BaseListVariable):
|
||||
else:
|
||||
return 0
|
||||
|
||||
def _get_slice_indices(self, length, slice):
|
||||
def _get_slice_indices(self, length: int, slice: slice) -> list[int]:
|
||||
step_is_negative = 0
|
||||
|
||||
if slice.step is None:
|
||||
@ -406,7 +408,7 @@ class RangeVariable(BaseListVariable):
|
||||
|
||||
return [start, stop, step]
|
||||
|
||||
def apply_index(self, index):
|
||||
def apply_index(self, index: int) -> VariableTracker:
|
||||
length = self.range_length()
|
||||
if index < 0:
|
||||
index = length + index
|
||||
@ -421,12 +423,12 @@ class RangeVariable(BaseListVariable):
|
||||
|
||||
return variables.ConstantVariable.create(self.start() + (index * self.step()))
|
||||
|
||||
def apply_slice(self, slice):
|
||||
def apply_slice(self, slice: slice) -> "RangeVariable":
|
||||
(slice_start, slice_stop, slice_step) = self._get_slice_indices(
|
||||
self.range_length(), slice
|
||||
)
|
||||
|
||||
def compute_item(index):
|
||||
def compute_item(index: int) -> int:
|
||||
return self.start() + (index * self.step())
|
||||
|
||||
sub_step = self.step() * slice_step
|
||||
@ -442,10 +444,12 @@ class RangeVariable(BaseListVariable):
|
||||
)
|
||||
return result
|
||||
|
||||
def as_python_constant(self):
|
||||
def as_python_constant(self) -> range:
|
||||
return range(*[x.as_python_constant() for x in self.items])
|
||||
|
||||
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
|
||||
def getitem_const(
|
||||
self, tx: "InstructionTranslator", arg: VariableTracker
|
||||
) -> VariableTracker:
|
||||
# implementations mimics https://github.com/python/cpython/blob/main/Objects/rangeobject.c
|
||||
index = arg.as_python_constant()
|
||||
|
||||
@ -457,28 +461,30 @@ class RangeVariable(BaseListVariable):
|
||||
msg = ConstantVariable("range indices must be integers or slices")
|
||||
raise_observed_exception(TypeError, tx, args=[msg])
|
||||
|
||||
def as_proxy(self):
|
||||
def as_proxy(self) -> range:
|
||||
return self.python_type()(*self._as_proxy())
|
||||
|
||||
def unpack_var_sequence(self, tx=None):
|
||||
def unpack_var_sequence(
|
||||
self, tx: Optional["InstructionTranslator"] = None
|
||||
) -> list[VariableTracker]:
|
||||
return [variables.ConstantVariable.create(x) for x in self.as_python_constant()]
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
assert "range" not in codegen.tx.f_globals
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.append_output(codegen.create_load_python_module(range))
|
||||
lambda: codegen.append_output(codegen.create_load_python_module(range)) # type: ignore[arg-type]
|
||||
)
|
||||
codegen.foreach(self.items)
|
||||
codegen.extend_output(create_call_function(3, False))
|
||||
|
||||
def call_obj_hasattr(
|
||||
self, tx: "InstructionTranslator", name: str
|
||||
) -> "VariableTracker":
|
||||
) -> VariableTracker:
|
||||
if self.python_type() is range:
|
||||
return variables.ConstantVariable.create(name in range.__dict__)
|
||||
return super().call_obj_hasattr(tx, name)
|
||||
|
||||
def range_equals(self, other: "RangeVariable"):
|
||||
def range_equals(self, other: "RangeVariable") -> bool:
|
||||
r0, r1 = self, other
|
||||
if (
|
||||
self.range_length() != r1.range_length()
|
||||
@ -487,12 +493,12 @@ class RangeVariable(BaseListVariable):
|
||||
):
|
||||
return False
|
||||
|
||||
if len(r0) == 1:
|
||||
if self.range_length() == 1:
|
||||
return True
|
||||
|
||||
return r0.step() == r1.step()
|
||||
|
||||
def range_count(self, x: VariableTracker):
|
||||
def range_count(self, x: VariableTracker) -> int:
|
||||
# Based on CPython
|
||||
# https://github.com/guilhermeleobas/cpython/blob/baefaa6cba1d69efd2f930cdc56bca682c54b139/Objects/rangeobject.c#L442-L486
|
||||
x = x.as_python_constant()
|
||||
@ -511,7 +517,13 @@ class RangeVariable(BaseListVariable):
|
||||
return int(re)
|
||||
return 0
|
||||
|
||||
def call_method(self, tx, name, args, kwargs):
|
||||
def call_method(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
if name == "__iter__":
|
||||
if not all(var.is_python_constant() for var in self.items):
|
||||
# Can't represent a `range_iterator` without well defined bounds
|
||||
@ -545,7 +557,10 @@ class RangeVariable(BaseListVariable):
|
||||
if pt is not range:
|
||||
return ConstantVariable.create(NotImplemented)
|
||||
|
||||
cmp = self.range_equals(other)
|
||||
if isinstance(other, RangeVariable):
|
||||
cmp = self.range_equals(other)
|
||||
else:
|
||||
cmp = False
|
||||
|
||||
# Two ranges are equal if they produce the same sequence of values
|
||||
if name == "__eq__":
|
||||
@ -554,7 +569,7 @@ class RangeVariable(BaseListVariable):
|
||||
return ConstantVariable(not cmp)
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def var_getattr(self, tx: "InstructionTranslator", name):
|
||||
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
|
||||
fields = ["start", "stop", "step"]
|
||||
if name in fields:
|
||||
return self.items[fields.index(name)]
|
||||
@ -568,11 +583,11 @@ class CommonListMethodsVariable(BaseListVariable):
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
args: list["VariableTracker"],
|
||||
kwargs: dict[str, "VariableTracker"],
|
||||
) -> "VariableTracker":
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
from .tensor import SymNodeVariable
|
||||
|
||||
if name == "append" and self.is_mutable():
|
||||
@ -676,9 +691,9 @@ class CommonListMethodsVariable(BaseListVariable):
|
||||
self.items[key.evaluate_expr()] = value
|
||||
elif isinstance(key, SliceVariable):
|
||||
if key.is_python_constant():
|
||||
self.items[key.as_python_constant()] = list(value.items)
|
||||
self.items[key.as_python_constant()] = list(value.items) # type: ignore[attr-defined]
|
||||
else:
|
||||
items = slice(
|
||||
items_slice = slice(
|
||||
*[
|
||||
(
|
||||
s.evaluate_expr()
|
||||
@ -688,7 +703,7 @@ class CommonListMethodsVariable(BaseListVariable):
|
||||
for s in key.items
|
||||
]
|
||||
)
|
||||
self.items[items] = list(value.items)
|
||||
self.items[items_slice] = list(value.items) # type: ignore[attr-defined]
|
||||
else:
|
||||
self.items[key.as_python_constant()] = value
|
||||
return ConstantVariable.create(None)
|
||||
@ -733,8 +748,8 @@ class CommonListMethodsVariable(BaseListVariable):
|
||||
"0 args and 0 kwargs",
|
||||
f"{len(args)} args and {len(kwargs)} kwargs",
|
||||
)
|
||||
items = list(self.items)
|
||||
return self.modified(items, mutation_type=ValueMutationNew())
|
||||
items_lst: list[VariableTracker] = list(self.items)
|
||||
return self.modified(items_lst, mutation_type=ValueMutationNew())
|
||||
elif name == "reverse" and self.is_mutable():
|
||||
if args or kwargs:
|
||||
raise_args_mismatch(
|
||||
@ -763,13 +778,13 @@ class CommonListMethodsVariable(BaseListVariable):
|
||||
|
||||
|
||||
class ListVariable(CommonListMethodsVariable):
|
||||
def python_type(self):
|
||||
def python_type(self) -> type:
|
||||
return list
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(length={len(self.items)})"
|
||||
|
||||
def debug_repr(self):
|
||||
def debug_repr(self) -> str:
|
||||
return self.debug_repr_helper("[", "]")
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
@ -778,11 +793,11 @@ class ListVariable(CommonListMethodsVariable):
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
args: list["VariableTracker"],
|
||||
kwargs: dict[str, "VariableTracker"],
|
||||
) -> "VariableTracker":
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
from .tensor import SymNodeVariable
|
||||
|
||||
if name == "__setitem__" and self.is_mutable():
|
||||
@ -805,14 +820,14 @@ class ListVariable(CommonListMethodsVariable):
|
||||
msg = ConstantVariable.create("can only assign an iterable")
|
||||
raise_observed_exception(TypeError, tx, args=[msg])
|
||||
|
||||
key = key.as_python_constant()
|
||||
if key.step == 0:
|
||||
key_as_const = key.as_python_constant()
|
||||
if key_as_const.step == 0:
|
||||
msg = ConstantVariable.create("slice step cannot be zero")
|
||||
raise_observed_exception(ValueError, tx, args=[msg])
|
||||
|
||||
value = value.force_unpack_var_sequence(tx)
|
||||
value_unpack = value.force_unpack_var_sequence(tx)
|
||||
try:
|
||||
self.items[key] = value
|
||||
self.items[key_as_const] = value_unpack
|
||||
except Exception as exc:
|
||||
raise_observed_exception(
|
||||
type(exc),
|
||||
@ -859,7 +874,7 @@ class ListVariable(CommonListMethodsVariable):
|
||||
assert first_non_constant_key is not None
|
||||
|
||||
try:
|
||||
python_type = first_non_constant_key.python_type()
|
||||
python_type = str(first_non_constant_key.python_type())
|
||||
except NotImplementedError:
|
||||
python_type = "unknown"
|
||||
|
||||
@ -904,7 +919,7 @@ class ListVariable(CommonListMethodsVariable):
|
||||
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def var_getattr(self, tx, name):
|
||||
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
|
||||
if name == "__class__":
|
||||
source = AttrSource(self.source, name) if self.source else None
|
||||
class_type = self.python_type()
|
||||
@ -916,14 +931,19 @@ class ListVariable(CommonListMethodsVariable):
|
||||
|
||||
def call_obj_hasattr(
|
||||
self, tx: "InstructionTranslator", name: str
|
||||
) -> "VariableTracker":
|
||||
) -> VariableTracker:
|
||||
if self.python_type() is not list:
|
||||
return super().call_obj_hasattr(tx, name)
|
||||
return variables.ConstantVariable.create(hasattr([], name))
|
||||
|
||||
|
||||
class DequeVariable(CommonListMethodsVariable):
|
||||
def __init__(self, items, maxlen=None, **kwargs) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
items: list[VariableTracker],
|
||||
maxlen: Optional[VariableTracker] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if maxlen is None:
|
||||
maxlen = ConstantVariable.create(None)
|
||||
assert maxlen.is_python_constant(), (
|
||||
@ -935,17 +955,17 @@ class DequeVariable(CommonListMethodsVariable):
|
||||
items = items[-maxlen.as_python_constant() :]
|
||||
super().__init__(items, **kwargs)
|
||||
|
||||
def python_type(self):
|
||||
def python_type(self) -> type:
|
||||
return collections.deque
|
||||
|
||||
def debug_repr(self):
|
||||
def debug_repr(self) -> str:
|
||||
if self.maxlen.as_python_constant() is None:
|
||||
return self.debug_repr_helper(
|
||||
"deque([", "], maxlen=" + self.maxlen.debug_repr() + ")"
|
||||
)
|
||||
return self.debug_repr_helper("deque([", "])")
|
||||
|
||||
def as_python_constant(self):
|
||||
def as_python_constant(self) -> collections.deque[Any]:
|
||||
return self.python_type()(
|
||||
[x.as_python_constant() for x in self.items],
|
||||
maxlen=self.maxlen.as_python_constant(),
|
||||
@ -954,7 +974,7 @@ class DequeVariable(CommonListMethodsVariable):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.append_output(
|
||||
codegen.create_load_python_module(collections.deque)
|
||||
codegen.create_load_python_module(collections.deque) # type: ignore[arg-type]
|
||||
)
|
||||
)
|
||||
codegen.foreach(self.items)
|
||||
@ -962,18 +982,18 @@ class DequeVariable(CommonListMethodsVariable):
|
||||
codegen(self.maxlen)
|
||||
codegen.extend_output(codegen.create_call_function_kw(2, ("maxlen",), False))
|
||||
|
||||
def var_getattr(self, tx: "InstructionTranslator", name):
|
||||
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
|
||||
if name == "maxlen":
|
||||
return self.maxlen
|
||||
return super().var_getattr(tx, name)
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
args: list["VariableTracker"],
|
||||
kwargs: dict[str, "VariableTracker"],
|
||||
) -> "VariableTracker":
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
if (
|
||||
name == "__setitem__"
|
||||
and self.is_mutable()
|
||||
@ -1068,20 +1088,20 @@ class DequeVariable(CommonListMethodsVariable):
|
||||
|
||||
def call_obj_hasattr(
|
||||
self, tx: "InstructionTranslator", name: str
|
||||
) -> "VariableTracker":
|
||||
) -> VariableTracker:
|
||||
if self.python_type() is collections.deque:
|
||||
return variables.ConstantVariable.create(name in collections.deque.__dict__)
|
||||
return super().call_obj_hasattr(tx, name)
|
||||
|
||||
|
||||
class TupleVariable(BaseListVariable):
|
||||
def python_type(self):
|
||||
def python_type(self) -> type[tuple]: # type: ignore[type-arg]
|
||||
return tuple
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(length={len(self.items)})"
|
||||
|
||||
def debug_repr(self):
|
||||
def debug_repr(self) -> str:
|
||||
return self.debug_repr_helper("(", ")")
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
@ -1090,14 +1110,14 @@ class TupleVariable(BaseListVariable):
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
args: list["VariableTracker"],
|
||||
kwargs: dict[str, "VariableTracker"],
|
||||
) -> "VariableTracker":
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def var_getattr(self, tx, name):
|
||||
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
|
||||
if name == "__class__":
|
||||
source = AttrSource(self.source, name) if self.source else None
|
||||
class_type = self.python_type()
|
||||
@ -1109,7 +1129,7 @@ class TupleVariable(BaseListVariable):
|
||||
|
||||
def call_obj_hasattr(
|
||||
self, tx: "InstructionTranslator", name: str
|
||||
) -> "VariableTracker":
|
||||
) -> VariableTracker:
|
||||
if self.python_type() is not tuple:
|
||||
return super().call_obj_hasattr(tx, name)
|
||||
return variables.ConstantVariable.create(hasattr((), name))
|
||||
@ -1127,18 +1147,18 @@ class SizeVariable(TupleVariable):
|
||||
self,
|
||||
items: list[VariableTracker],
|
||||
proxy: Optional[torch.fx.Proxy] = None,
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.proxy = proxy
|
||||
super().__init__(items, **kwargs)
|
||||
|
||||
def debug_repr(self):
|
||||
def debug_repr(self) -> str:
|
||||
return self.debug_repr_helper("torch.Size([", "])")
|
||||
|
||||
def python_type(self):
|
||||
def python_type(self) -> type:
|
||||
return torch.Size
|
||||
|
||||
def as_proxy(self):
|
||||
def as_proxy(self) -> Any:
|
||||
if self.proxy is not None:
|
||||
return self.proxy
|
||||
|
||||
@ -1193,10 +1213,10 @@ class SizeVariable(TupleVariable):
|
||||
] + create_call_function(1, False)
|
||||
codegen.extend_output(build_torch_size)
|
||||
|
||||
def unpack_var_sequence(self, tx):
|
||||
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
|
||||
return list(self.items)
|
||||
|
||||
def numel(self, tx):
|
||||
def numel(self, tx: "InstructionTranslator") -> VariableTracker:
|
||||
from .builtin import BuiltinVariable
|
||||
from .tensor import SymNodeVariable
|
||||
|
||||
@ -1226,11 +1246,11 @@ class SizeVariable(TupleVariable):
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
args: list["VariableTracker"],
|
||||
kwargs: dict[str, "VariableTracker"],
|
||||
) -> "VariableTracker":
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
if name == "__getitem__":
|
||||
if kwargs or len(args) != 1:
|
||||
raise_args_mismatch(
|
||||
@ -1253,7 +1273,9 @@ class SizeVariable(TupleVariable):
|
||||
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def get_item_dyn(self, tx: "InstructionTranslator", arg: VariableTracker):
|
||||
def get_item_dyn(
|
||||
self, tx: "InstructionTranslator", arg: VariableTracker
|
||||
) -> VariableTracker:
|
||||
from .tensor import SymNodeVariable
|
||||
|
||||
if isinstance(arg, SymNodeVariable):
|
||||
@ -1269,7 +1291,7 @@ class SizeVariable(TupleVariable):
|
||||
|
||||
def call_obj_hasattr(
|
||||
self, tx: "InstructionTranslator", name: str
|
||||
) -> "VariableTracker":
|
||||
) -> VariableTracker:
|
||||
return variables.ConstantVariable.create(hasattr(torch.Size, name))
|
||||
|
||||
|
||||
@ -1280,33 +1302,39 @@ class NamedTupleVariable(TupleVariable):
|
||||
*TupleVariable._nonvar_fields,
|
||||
}
|
||||
|
||||
def __init__(self, items, tuple_cls, dynamic_attributes=None, **kwargs) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
items: list[VariableTracker],
|
||||
tuple_cls: type,
|
||||
dynamic_attributes: Optional[dict[str, VariableTracker]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(items, **kwargs)
|
||||
self.tuple_cls = tuple_cls
|
||||
self.dynamic_attributes = dynamic_attributes if dynamic_attributes else {}
|
||||
|
||||
def is_namedtuple(self):
|
||||
def is_namedtuple(self) -> bool:
|
||||
return isinstance(getattr(self.tuple_cls, "_fields", None), tuple) and callable(
|
||||
getattr(self.tuple_cls, "_make", None)
|
||||
)
|
||||
|
||||
def is_structseq(self):
|
||||
def is_structseq(self) -> bool:
|
||||
return not self.is_namedtuple()
|
||||
|
||||
def fields(self):
|
||||
def fields(self) -> tuple[str, ...]:
|
||||
return namedtuple_fields(self.tuple_cls)
|
||||
|
||||
def debug_repr(self):
|
||||
def debug_repr(self) -> str:
|
||||
if self.is_structseq():
|
||||
# StructSequenceType(iterable)
|
||||
return repr(self.tuple_cls([Lit(x.debug_repr()) for x in self.items]))
|
||||
# NamedTupleType(*iterable)
|
||||
return repr(self.tuple_cls(*(Lit(x.debug_repr()) for x in self.items)))
|
||||
|
||||
def python_type(self):
|
||||
def python_type(self) -> type:
|
||||
return self.tuple_cls
|
||||
|
||||
def as_python_constant(self):
|
||||
def as_python_constant(self) -> Any:
|
||||
if self.is_structseq():
|
||||
# StructSequenceType(iterable)
|
||||
result = self.python_type()([x.as_python_constant() for x in self.items])
|
||||
@ -1328,7 +1356,7 @@ class NamedTupleVariable(TupleVariable):
|
||||
|
||||
return result
|
||||
|
||||
def as_proxy(self):
|
||||
def as_proxy(self) -> Any:
|
||||
assert self.python_type() is not SizeVariable
|
||||
if self.is_structseq():
|
||||
# StructSequenceType(iterable)
|
||||
@ -1342,7 +1370,10 @@ class NamedTupleVariable(TupleVariable):
|
||||
# StructSequenceType(iterable)
|
||||
# NamedTupleType(*iterable)
|
||||
# NamedTupleType._make(iterable)
|
||||
create_fn = self.tuple_cls if self.is_structseq() else self.tuple_cls._make
|
||||
if self.is_structseq():
|
||||
create_fn = self.tuple_cls
|
||||
else:
|
||||
create_fn = self.tuple_cls._make # type: ignore[attr-defined]
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.append_output(
|
||||
codegen.create_load_const_unchecked(create_fn)
|
||||
@ -1384,8 +1415,8 @@ class NamedTupleVariable(TupleVariable):
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
@ -1446,7 +1477,9 @@ class NamedTupleVariable(TupleVariable):
|
||||
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
|
||||
def getitem_const(
|
||||
self, tx: "InstructionTranslator", arg: VariableTracker
|
||||
) -> VariableTracker:
|
||||
if isinstance(arg, SliceVariable):
|
||||
# slicing a namedtuple produces a tuple
|
||||
return TupleVariable(
|
||||
@ -1455,8 +1488,8 @@ class NamedTupleVariable(TupleVariable):
|
||||
)
|
||||
return super().getitem_const(tx, arg)
|
||||
|
||||
def var_getattr(self, tx: "InstructionTranslator", name):
|
||||
def check_and_create_method():
|
||||
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
|
||||
def check_and_create_method() -> Optional[VariableTracker]:
|
||||
method = inspect.getattr_static(self.tuple_cls, name, None)
|
||||
if isinstance(method, classmethod):
|
||||
# We need the unbounded cls method to avoid the inline __self__
|
||||
@ -1489,8 +1522,8 @@ class NamedTupleVariable(TupleVariable):
|
||||
return super().var_getattr(tx, name)
|
||||
|
||||
if name == "_fields":
|
||||
source = NamedTupleFieldsSource(self.source) if self.source else None
|
||||
return VariableTracker.build(tx, self.fields(), source=source)
|
||||
result_source = NamedTupleFieldsSource(self.source) if self.source else None
|
||||
return VariableTracker.build(tx, self.fields(), source=result_source)
|
||||
|
||||
if name in self.dynamic_attributes:
|
||||
return self.dynamic_attributes[name]
|
||||
@ -1505,14 +1538,19 @@ class NamedTupleVariable(TupleVariable):
|
||||
|
||||
def call_obj_hasattr(
|
||||
self, tx: "InstructionTranslator", name: str
|
||||
) -> "VariableTracker":
|
||||
) -> VariableTracker:
|
||||
return variables.ConstantVariable.create(
|
||||
name in self.dynamic_attributes or hasattr(self.tuple_cls, name)
|
||||
)
|
||||
|
||||
|
||||
class SliceVariable(VariableTracker):
|
||||
def __init__(self, items, tx=None, **kwargs) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
items: Sequence[VariableTracker],
|
||||
tx: Optional["InstructionTranslator"] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
items_to_map = items
|
||||
start, stop, step = [variables.ConstantVariable.create(None)] * 3
|
||||
|
||||
@ -1547,23 +1585,23 @@ class SliceVariable(VariableTracker):
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def debug_repr(self):
|
||||
return self.debug_repr_helper("slice(", ")")
|
||||
def debug_repr(self) -> str:
|
||||
return "slice(" + ", ".join(i.debug_repr() for i in self.items) + ")"
|
||||
|
||||
def as_proxy(self):
|
||||
def as_proxy(self) -> slice:
|
||||
return slice(*[x.as_proxy() for x in self.items])
|
||||
|
||||
def python_type(self):
|
||||
def python_type(self) -> type:
|
||||
return slice
|
||||
|
||||
def as_python_constant(self):
|
||||
def as_python_constant(self) -> slice:
|
||||
return slice(*[guard_if_dyn(x) for x in self.items])
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
codegen.foreach(self.items)
|
||||
codegen.append_output(create_instruction("BUILD_SLICE", arg=len(self.items)))
|
||||
|
||||
def var_getattr(self, tx: "InstructionTranslator", name):
|
||||
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
|
||||
if name in cmp_name_to_op_mapping:
|
||||
return variables.GetAttrVariable(self, name)
|
||||
fields = ["start", "stop", "step"]
|
||||
@ -1584,7 +1622,9 @@ class ListIteratorVariable(IteratorVariable):
|
||||
*IteratorVariable._nonvar_fields,
|
||||
}
|
||||
|
||||
def __init__(self, items, index: int = 0, **kwargs) -> None:
|
||||
def __init__(
|
||||
self, items: list[VariableTracker], index: int = 0, **kwargs: Any
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
assert isinstance(items, list)
|
||||
# Removing this check as it slows things down too much
|
||||
@ -1598,7 +1638,7 @@ class ListIteratorVariable(IteratorVariable):
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(length={len(self.items)}, index={repr(self.index)})"
|
||||
|
||||
def next_variable(self, tx):
|
||||
def next_variable(self, tx: "InstructionTranslator") -> VariableTracker:
|
||||
assert self.is_mutable()
|
||||
old_index = self.index
|
||||
if old_index >= len(self.items) or self.is_exhausted:
|
||||
@ -1609,27 +1649,31 @@ class ListIteratorVariable(IteratorVariable):
|
||||
self.index += 1
|
||||
return self.items[old_index]
|
||||
|
||||
def call_obj_hasattr(self, tx, name):
|
||||
def call_obj_hasattr(
|
||||
self, tx: "InstructionTranslator", name: str
|
||||
) -> VariableTracker:
|
||||
return variables.ConstantVariable.create(hasattr(iter([]), name))
|
||||
|
||||
def python_type(self):
|
||||
def python_type(self) -> type:
|
||||
return type(iter([]))
|
||||
|
||||
def as_python_constant(self):
|
||||
def as_python_constant(self) -> Any:
|
||||
if self.index > 0:
|
||||
raise NotImplementedError
|
||||
return iter([x.as_python_constant() for x in self.items])
|
||||
|
||||
def has_unpack_var_sequence(self, tx):
|
||||
def has_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool:
|
||||
return True
|
||||
|
||||
def unpack_var_sequence(self, tx):
|
||||
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
|
||||
if self.is_exhausted:
|
||||
return []
|
||||
self.is_exhausted = True
|
||||
return list(self.items[self.index :])
|
||||
|
||||
def force_unpack_var_sequence(self, tx) -> list[VariableTracker]:
|
||||
def force_unpack_var_sequence(
|
||||
self, tx: "InstructionTranslator"
|
||||
) -> list[VariableTracker]:
|
||||
return self.unpack_var_sequence(tx)
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
@ -1656,27 +1700,37 @@ class RangeIteratorVariable(IteratorVariable):
|
||||
"iter_obj",
|
||||
}
|
||||
|
||||
def __init__(self, start: int, stop: int, step: int, len_: int, **kwargs):
|
||||
def __init__(
|
||||
self, start: int, stop: int, step: int, len_: int, **kwargs: Any
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.start = start
|
||||
self.stop = stop
|
||||
self.step = step
|
||||
self.len = len_
|
||||
|
||||
def call_method(self, tx, name, args, kwargs):
|
||||
def call_method(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
if name == "__next__":
|
||||
return self.next_variable(tx)
|
||||
elif name == "__iter__":
|
||||
return self
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def call_obj_hasattr(self, tx, name):
|
||||
def call_obj_hasattr(
|
||||
self, tx: "InstructionTranslator", name: str
|
||||
) -> VariableTracker:
|
||||
if self.python_type() is range_iterator:
|
||||
ri = iter(range(0))
|
||||
return ConstantVariable(hasattr(ri, name))
|
||||
return super().call_obj_hasattr(tx, name)
|
||||
|
||||
def next_variable(self, tx):
|
||||
def next_variable(self, tx: "InstructionTranslator") -> VariableTracker:
|
||||
if self.len <= 0:
|
||||
raise_observed_exception(StopIteration, tx)
|
||||
|
||||
@ -1685,12 +1739,12 @@ class RangeIteratorVariable(IteratorVariable):
|
||||
self.start += self.step
|
||||
return ConstantVariable.create(current)
|
||||
|
||||
def python_type(self):
|
||||
def python_type(self) -> type:
|
||||
return range_iterator
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.append_output(codegen.create_load_python_module(range))
|
||||
lambda: codegen.append_output(codegen.create_load_python_module(range)) # type: ignore[arg-type]
|
||||
)
|
||||
codegen.append_output(codegen.create_load_const(self.start))
|
||||
codegen.append_output(codegen.create_load_const(self.stop))
|
||||
|
||||
@ -204,11 +204,11 @@ class StreamVariable(StreamContextVariable):
|
||||
self,
|
||||
proxy: Proxy,
|
||||
value: torch.Stream,
|
||||
user_object_index: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# Index into the user object table
|
||||
# used to pass arbitrary objects to the graph
|
||||
user_object_index = kwargs.pop("user_obj_index", None)
|
||||
if proxy is not None and "example_value" in proxy.node.meta:
|
||||
assert proxy.node.meta["example_value"] == value
|
||||
|
||||
@ -300,7 +300,7 @@ class StreamVariable(StreamContextVariable):
|
||||
codegen.append_output(codegen.create_load_const(self.user_object_index))
|
||||
codegen.extend_output(create_call_function(1, False))
|
||||
else:
|
||||
# TODO mlazos: evaluate if we still need this
|
||||
# This will support the legacy behavior
|
||||
prefix = f"_stream_{self.device}"
|
||||
name = codegen.tx.output.install_global_by_id(prefix, self.value)
|
||||
codegen.append_output(codegen.create_load_global(name, add=True))
|
||||
|
||||
@ -78,7 +78,7 @@ from .ctx_manager import (
|
||||
)
|
||||
from .dicts import ConstDictVariable
|
||||
from .distributed import DistributedVariable, ProcessGroupVariable
|
||||
from .functions import bind_args_cached
|
||||
from .functions import bind_args_cached, NestedUserFunctionVariable
|
||||
from .lists import ListVariable, TupleVariable
|
||||
from .torch_function import (
|
||||
can_dispatch_torch_function,
|
||||
@ -1318,6 +1318,86 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||
|
||||
return ConstantVariable.create(None)
|
||||
|
||||
@register(torch._check)
|
||||
def handle_check(self, tx: "InstructionTranslator", *args, **kwargs):
|
||||
predicate_vt = None
|
||||
message_vt = None
|
||||
|
||||
if args:
|
||||
predicate_vt = args[0]
|
||||
rest_args = args[1:]
|
||||
else:
|
||||
rest_args = ()
|
||||
|
||||
if predicate_vt is None and "cond" in kwargs:
|
||||
predicate_vt = kwargs.pop("cond")
|
||||
|
||||
if rest_args:
|
||||
message_vt = rest_args[0]
|
||||
elif "message" in kwargs:
|
||||
message_vt = kwargs.pop("message")
|
||||
|
||||
if predicate_vt is None:
|
||||
return wrap_fx_proxy(
|
||||
tx=tx,
|
||||
proxy=tx.output.create_proxy(
|
||||
"call_function",
|
||||
self.value,
|
||||
(),
|
||||
{},
|
||||
),
|
||||
)
|
||||
|
||||
message_eager = None
|
||||
message_graph_proxy = None
|
||||
if message_vt is not None:
|
||||
if (
|
||||
not isinstance(message_vt, NestedUserFunctionVariable)
|
||||
or message_vt.has_closure()
|
||||
):
|
||||
unimplemented_v2(
|
||||
gb_type="Can't extract message from torch._check()",
|
||||
context=str(message_vt),
|
||||
explanation=(
|
||||
"The second argument of torch._check() must be a function"
|
||||
"defined within the torch.compile region"
|
||||
"that does not reference a non-local variable."
|
||||
),
|
||||
hints=[
|
||||
"Make sure the message function is defined in the torch.compile region.",
|
||||
"Remove any closure variables, e.g. "
|
||||
"remove references to closure variable `x` in `lambda: f'{x} failed check'`",
|
||||
*graph_break_hints.SUPPORTABLE,
|
||||
],
|
||||
)
|
||||
message_eager = message_vt.get_function()
|
||||
|
||||
message_graph_proxy = tx.output.register_static_attr_and_return_proxy(
|
||||
"_check_message", message_eager
|
||||
)
|
||||
|
||||
if predicate_vt.is_python_constant():
|
||||
self.value(predicate_vt.as_python_constant(), message_eager)
|
||||
return ConstantVariable.create(None)
|
||||
|
||||
predicate_proxy = predicate_vt.as_proxy()
|
||||
|
||||
proxy_args: tuple[Any, ...]
|
||||
if message_graph_proxy is None:
|
||||
proxy_args = (predicate_proxy,)
|
||||
else:
|
||||
proxy_args = (predicate_proxy, message_graph_proxy)
|
||||
|
||||
return wrap_fx_proxy(
|
||||
tx=tx,
|
||||
proxy=tx.output.create_proxy(
|
||||
"call_function",
|
||||
self.value,
|
||||
proxy_args,
|
||||
{},
|
||||
),
|
||||
)
|
||||
|
||||
return handlers
|
||||
|
||||
def call_function(
|
||||
|
||||
@ -838,7 +838,6 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
||||
proxy=tx.output.create_proxy(
|
||||
"call_function", get_external_object_by_index, (ind,), {}
|
||||
),
|
||||
user_obj_index=ind,
|
||||
)
|
||||
else:
|
||||
tensor_variable = wrap_fx_proxy(
|
||||
|
||||
@ -104,7 +104,7 @@ from .._dynamo.exc import ShortenTraceback, SkipFrame
|
||||
from ..fx._lazy_graph_module import _use_lazy_graph_module
|
||||
from ..fx.graph import _PyTreeCodeGen
|
||||
from ..utils._triton import has_triton
|
||||
from . import config, metrics
|
||||
from . import config, distributed_autotune, metrics
|
||||
from .codegen.common import get_wrapper_codegen_for_device, init_backend_registration
|
||||
from .debug import DebugContext
|
||||
from .decomposition import select_decomp_table
|
||||
@ -1431,7 +1431,11 @@ class _InProcessFxCompile(FxCompile):
|
||||
# We are going to start code generating runtime asserts, so make sure
|
||||
# you don't start adding new ones in the lowering process
|
||||
graph.freeze_runtime_asserts()
|
||||
with V.set_graph_handler(graph), V.set_extern_kernel_nodes([]):
|
||||
with (
|
||||
V.set_graph_handler(graph),
|
||||
V.set_extern_kernel_nodes([]),
|
||||
distributed_autotune.graph_context(),
|
||||
):
|
||||
graph.run(*example_inputs)
|
||||
output_strides: list[Optional[tuple[_StrideExprStr, ...]]] = []
|
||||
if graph.graph_outputs is not None:
|
||||
|
||||
@ -447,6 +447,14 @@ use_experimental_benchmarker: bool = Config(
|
||||
justknob="pytorch/inductor:use_experimental_benchmarker",
|
||||
)
|
||||
|
||||
# Enable distributed autotuning. When this is enabled we will distribute the
|
||||
# autotuning across distributed ranks in the same program group - so instead of
|
||||
# each rank autotuning every kernel they only autotune 1/world size kernels and
|
||||
# then share the results.
|
||||
distributed_max_autotune_gemm = (
|
||||
os.environ.get("TORCHINDUCTOR_DISTRIBUTED_MAX_AUTOTUNE_GEMM") == "1"
|
||||
)
|
||||
|
||||
# enable slow autotuning passes to select algorithms
|
||||
max_autotune = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE") == "1"
|
||||
|
||||
|
||||
386
torch/_inductor/distributed_autotune.py
Normal file
386
torch/_inductor/distributed_autotune.py
Normal file
@ -0,0 +1,386 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import dataclasses
|
||||
from typing import Any, TYPE_CHECKING, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import sympy
|
||||
|
||||
import torch._logging
|
||||
import torch.distributed as dist
|
||||
import torch.fx
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
from . import config, select_algorithm
|
||||
from .ir import (
|
||||
Buffer,
|
||||
ChoiceCaller,
|
||||
Layout,
|
||||
MultiTemplateBuffer,
|
||||
OperationBuffer,
|
||||
ShapeAsConstantBuffer,
|
||||
StorageBox,
|
||||
TensorBox,
|
||||
)
|
||||
from .kernel_inputs import KernelInputs, MMKernelInputs
|
||||
from .scheduler import SchedulerNode
|
||||
from .virtualized import NullHandler, V
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator, Sequence
|
||||
|
||||
|
||||
_DISTRIBUTED_AUTOTUNE_KEY = "distributed_autotune"
|
||||
|
||||
_AUTOTUNE_PG: dist.ProcessGroup | None = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _DistributedAutotuneState:
|
||||
"""
|
||||
State used to track autotuning during a graph_context()
|
||||
"""
|
||||
|
||||
# This is the next operator index. Used to figure out which rank should do
|
||||
# the autotuning.
|
||||
autotuned_index: int = 0
|
||||
|
||||
# For debugging - used to make sure that we autotune the same number of
|
||||
# local operators that we expected to.
|
||||
autotuned_local_count: int = 0
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _DistributedAutotuneInfo:
|
||||
index: int
|
||||
local: bool
|
||||
|
||||
|
||||
def get_autotune_pg() -> dist.ProcessGroup | None:
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
global _AUTOTUNE_PG
|
||||
if _AUTOTUNE_PG is None:
|
||||
_AUTOTUNE_PG = dist.distributed_c10d._new_group_with_tag(
|
||||
pg_tag="pt2_distributed_autotune_pg"
|
||||
)
|
||||
return _AUTOTUNE_PG
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def schedule(scheduler: torch._inductor.scheduler.Scheduler) -> None:
|
||||
"""
|
||||
Finish the distributed autotuning by propagating the autotuning results
|
||||
between the ranks and then replacing the placeholder with the real Buffer.
|
||||
"""
|
||||
assert config.distributed_max_autotune_gemm
|
||||
autotune_results = _autotune_local_nodes(scheduler)
|
||||
choices_by_index = _sync(autotune_results)
|
||||
_autotune_remote_nodes(scheduler, choices_by_index)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def graph_context() -> Generator[None, None, None]:
|
||||
"""
|
||||
Wrapped around processing a graph, sets up figuring out which ranks tune
|
||||
which shapes.
|
||||
"""
|
||||
assert not isinstance(
|
||||
V.get_distributed_autotune_state(check_poisoned=False), # type: ignore[call-arg]
|
||||
_DistributedAutotuneState,
|
||||
)
|
||||
V.set_distributed_autotune_state(_DistributedAutotuneState())
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
V.set_distributed_autotune_state(NullHandler())
|
||||
|
||||
|
||||
def maybe_autotune_remote(
|
||||
name: str, choices: list[ChoiceCaller], inputs: list[Buffer], layout: Layout
|
||||
) -> TensorBox | ShapeAsConstantBuffer | None:
|
||||
"""
|
||||
Used by an op (like `mm`) to determine if the op should be autotuned
|
||||
locally (returns None) or remotely (returns a placeholder Buffer).
|
||||
"""
|
||||
if not config.distributed_max_autotune_gemm:
|
||||
return None
|
||||
|
||||
if not (autotune_pg := get_autotune_pg()):
|
||||
return None
|
||||
|
||||
if len(choices) <= 1:
|
||||
return None
|
||||
|
||||
state = V.distributed_autotune_state
|
||||
index = state.autotuned_index
|
||||
state.autotuned_index += 1
|
||||
local = index % autotune_pg.size() == autotune_pg.rank()
|
||||
|
||||
V.current_node.meta[_DISTRIBUTED_AUTOTUNE_KEY] = _DistributedAutotuneInfo(
|
||||
index, local
|
||||
)
|
||||
if local:
|
||||
state.autotuned_local_count += 1
|
||||
return None
|
||||
|
||||
return torch._inductor.ir.TensorBox.create(
|
||||
_DistributedAutotuneBuffer(name, inputs, layout)
|
||||
)
|
||||
|
||||
|
||||
class _DistributedAutotuneBuffer(MultiTemplateBuffer):
|
||||
"""
|
||||
A MultiTemplateBuffer which represents a kernel being autotuned on a
|
||||
different rank. When `schedule` is called this will be replaced by the
|
||||
"real" buffer.
|
||||
"""
|
||||
|
||||
# Name of the kernel being autotuned.
|
||||
_kernel_name: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel_name: str,
|
||||
inputs: list[Buffer],
|
||||
layout: Layout,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
layout,
|
||||
inputs,
|
||||
choice_timings_fn=self._dummy_choice_timings,
|
||||
unfiltered_choices=[],
|
||||
allowed_prologue_inps=OrderedSet({}),
|
||||
)
|
||||
|
||||
self._kernel_name = kernel_name
|
||||
|
||||
def _dummy_choice_timings(
|
||||
self, _hint_override: int | None
|
||||
) -> dict[ChoiceCaller, float]:
|
||||
# This should never get called. It means that a remote autotune was
|
||||
# scheduled but never filled in.
|
||||
raise NotImplementedError
|
||||
|
||||
def autotune(self, ser_choice: _SerializedChoice) -> TensorBox:
|
||||
"""
|
||||
Given a _SerializedChoice (autotune results from another rank)
|
||||
compute the final TensorBox.
|
||||
"""
|
||||
|
||||
from .select_algorithm import autotune_select_algorithm
|
||||
|
||||
with patch.object(V.graph, "scheduler", None):
|
||||
kernel_inputs = MMKernelInputs([*self.original_inputs])
|
||||
assert isinstance(self.layout, Layout)
|
||||
choice = ser_choice.get_choice(self.layout, kernel_inputs)
|
||||
buffer = autotune_select_algorithm(
|
||||
self._kernel_name,
|
||||
[choice],
|
||||
kernel_inputs.nodes(),
|
||||
self.layout,
|
||||
)
|
||||
assert isinstance(buffer, TensorBox)
|
||||
return buffer
|
||||
|
||||
|
||||
# Can we make this async?
|
||||
def _sync(autotune_results: list[_SerializedChoice]) -> Sequence[_SerializedChoice]:
|
||||
"""
|
||||
Perform the all_gather to collect the autotune results from all the ranks.
|
||||
"""
|
||||
|
||||
autotune_pg = get_autotune_pg()
|
||||
assert autotune_pg
|
||||
|
||||
# Perform allgather
|
||||
all_states: list[list[_SerializedChoice]] = [None] * autotune_pg.size() # type: ignore[list-item]
|
||||
torch.distributed.all_gather_object(all_states, autotune_results, group=autotune_pg)
|
||||
|
||||
node_count = sum(len(x) for x in all_states)
|
||||
# It's faster to briefly lie about the type than to unzip the results and append.
|
||||
choices_by_index: list[_SerializedChoice] = [None] * node_count # type: ignore[list-item]
|
||||
|
||||
check_count = 0
|
||||
for i, other_results in enumerate(all_states):
|
||||
for choice in other_results:
|
||||
assert isinstance(choice, _SerializedChoice)
|
||||
assert choices_by_index[choice.index] is None
|
||||
choices_by_index[choice.index] = choice
|
||||
check_count += 1
|
||||
|
||||
assert node_count == check_count, f"count mismatch: {node_count} != {check_count}"
|
||||
return choices_by_index
|
||||
|
||||
|
||||
class _SerializedChoice:
|
||||
"""
|
||||
This is a serializer for the autotune choice. KernelTemplateChoice can't
|
||||
be serialized directly (the template and inputs prevent this) so we need to
|
||||
serialize it by parts and reconstruct later on.
|
||||
"""
|
||||
|
||||
def __init__(self, index: int, choice: ChoiceCaller) -> None:
|
||||
self.index = index
|
||||
self.template_uid = _SerializedChoice._template_uid_from_choice(choice)
|
||||
self.kwargs = self._compute_kwargs(choice.description)
|
||||
|
||||
def get_choice(self, layout: Layout, inputs: KernelInputs) -> ChoiceCaller | None:
|
||||
"""
|
||||
Deserialize the ChoiceCaller and return it.
|
||||
"""
|
||||
|
||||
template = self._template_from_uid()
|
||||
|
||||
kwargs = {**self.kwargs}
|
||||
if "BLOCK_K" in kwargs:
|
||||
# TODO: Do we really need to externally compute this value? If it's
|
||||
# needed I'm surprised it's not just part of the original template
|
||||
# description.
|
||||
# This needs the actual 'k' to figure out the value.
|
||||
k = inputs.nodes()[0].get_size()[1]
|
||||
kwargs["EVEN_K"] = sympy.gcd(k, kwargs["BLOCK_K"]) == kwargs["BLOCK_K"]
|
||||
|
||||
extra_kwargs: dict[str, Any] = {}
|
||||
from .kernel_template_choice import (
|
||||
DictKernelTemplateParams,
|
||||
KernelTemplateChoice,
|
||||
)
|
||||
|
||||
params = DictKernelTemplateParams(kwargs)
|
||||
ktc = KernelTemplateChoice(template, params, extra_kwargs, layout, inputs)
|
||||
return ktc.choice
|
||||
|
||||
@staticmethod
|
||||
def _compute_kwargs(description: str) -> dict[str, Union[int, str, bool]]:
|
||||
"""
|
||||
Given a template description turn it into input kwargs.
|
||||
"""
|
||||
if not description:
|
||||
return {}
|
||||
|
||||
# TODO: It seems like it would be better if the template could provide
|
||||
# this directly instead of having to parse a string.
|
||||
kwargs: dict[str, Union[int, str, bool]] = {}
|
||||
for cfg in description.split(","):
|
||||
key, val = cfg.split("=", 1)
|
||||
key, val = key.strip(), val.strip()
|
||||
if val == "True":
|
||||
kwargs[key] = True
|
||||
elif val == "False":
|
||||
kwargs[key] = False
|
||||
elif val.isdigit():
|
||||
kwargs[key] = int(val)
|
||||
else:
|
||||
assert val.startswith("'") and val.endswith("'")
|
||||
kwargs[key] = val[1:-1]
|
||||
return kwargs
|
||||
|
||||
@staticmethod
|
||||
def _template_uid_from_choice(choice: ChoiceCaller) -> str:
|
||||
"""
|
||||
Given a ChoiceCaller figure out which template represents it. This
|
||||
is reversed by _template_from_uid().
|
||||
"""
|
||||
|
||||
# We need a better way to do this - right now we need to add each
|
||||
# supported template directly.
|
||||
if isinstance(choice, select_algorithm.ExternKernelCaller):
|
||||
if choice.choice.name == "mm":
|
||||
return "torch._inductor.kernel.mm.aten_mm"
|
||||
else:
|
||||
raise RuntimeError(f"TODO: kernel {choice.choice.name!r}")
|
||||
elif isinstance(choice, select_algorithm.TritonTemplateCaller):
|
||||
return "torch._inductor.kernel.mm.mm_template"
|
||||
else:
|
||||
raise RuntimeError(f"TODO: {type(choice)}")
|
||||
|
||||
def _template_from_uid(self) -> Any:
|
||||
"""
|
||||
See _template_uid_from_choice().
|
||||
"""
|
||||
parts = self.template_uid.split(".")
|
||||
obj = globals()[parts[0]]
|
||||
for k in parts[1:]:
|
||||
obj = getattr(obj, k)
|
||||
return obj
|
||||
|
||||
|
||||
def _autotune_local_nodes(
|
||||
scheduler: torch._inductor.scheduler.Scheduler,
|
||||
) -> list[_SerializedChoice]:
|
||||
"""
|
||||
Go through the nodes in the scheduler and autotune the kernels which
|
||||
should be autotuned by this rank.
|
||||
"""
|
||||
|
||||
autotune_results: list[_SerializedChoice] = []
|
||||
|
||||
for node in scheduler.nodes:
|
||||
if not isinstance(node, SchedulerNode):
|
||||
continue
|
||||
|
||||
if (inner_node := node.node) is None:
|
||||
continue
|
||||
|
||||
if isinstance(inner_node, _DistributedAutotuneBuffer):
|
||||
# This is marked for remote autotuning.
|
||||
continue
|
||||
|
||||
if not isinstance(inner_node, MultiTemplateBuffer):
|
||||
continue
|
||||
|
||||
if (origin_node := inner_node.origin_node) is None:
|
||||
continue
|
||||
|
||||
if (meta := origin_node.meta) is None:
|
||||
continue
|
||||
|
||||
info = meta.get(_DISTRIBUTED_AUTOTUNE_KEY)
|
||||
if info is None:
|
||||
continue
|
||||
|
||||
assert info.local
|
||||
|
||||
# We force autotuning here
|
||||
# Still takes advantage of async precompile
|
||||
# We need all the configs before fusion
|
||||
min_choice, _ = inner_node.get_min_choice()
|
||||
|
||||
choice = _SerializedChoice(info.index, min_choice)
|
||||
autotune_results.append(choice)
|
||||
|
||||
state = V.distributed_autotune_state
|
||||
assert len(autotune_results) == state.autotuned_local_count, (
|
||||
f"incorrect local autotuned nodes found ({len(autotune_results)} != {state.autotuned_local_count})"
|
||||
)
|
||||
return autotune_results
|
||||
|
||||
|
||||
def _autotune_remote_nodes(
|
||||
scheduler: torch._inductor.scheduler.Scheduler,
|
||||
choices_by_index: Sequence[_SerializedChoice],
|
||||
) -> None:
|
||||
"""
|
||||
Go through the nodes in the scheduler and autotune the nodes that were
|
||||
autotuned on remote ranks.
|
||||
"""
|
||||
|
||||
for i, node in enumerate(scheduler.nodes):
|
||||
if isinstance(node, SchedulerNode) and isinstance(
|
||||
(dist_node := node.node), _DistributedAutotuneBuffer
|
||||
):
|
||||
assert dist_node.origin_node is not None
|
||||
info = dist_node.origin_node.meta[_DISTRIBUTED_AUTOTUNE_KEY]
|
||||
out_tensorbox = dist_node.autotune(choices_by_index[info.index])
|
||||
|
||||
out_storage = out_tensorbox.data
|
||||
assert isinstance(out_storage, StorageBox)
|
||||
out_buffer = out_storage.data
|
||||
assert isinstance(out_buffer, OperationBuffer)
|
||||
|
||||
assert out_buffer.layout == dist_node.layout
|
||||
|
||||
scheduler._replace_node(out_buffer, dist_node, i, node)
|
||||
@ -52,8 +52,8 @@ from ..utils import (
|
||||
decode_device,
|
||||
get_all_devices,
|
||||
get_gpu_type,
|
||||
has_uses_tagged_as,
|
||||
is_gpu,
|
||||
is_pointwise_use,
|
||||
OPTIMUS_EXCLUDE_POST_GRAD,
|
||||
)
|
||||
from ..virtualized import V
|
||||
@ -1511,10 +1511,8 @@ def should_prefer_unfused_addmm(match):
|
||||
if not is_gpu(inp.meta["val"].device.type):
|
||||
return False
|
||||
|
||||
return has_uses_tagged_as(
|
||||
match.output_node(),
|
||||
(torch.Tag.pointwise, torch.Tag.reduction),
|
||||
)
|
||||
output = match.output_node()
|
||||
return all(is_pointwise_use(use) for use in output.users)
|
||||
|
||||
|
||||
@register_graph_pattern(
|
||||
|
||||
@ -19,7 +19,7 @@ from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.nn.functional import ScalingType # type: ignore[attr-defined]
|
||||
from torch.torch_version import TorchVersion
|
||||
|
||||
from .. import config as inductor_config
|
||||
from .. import config as inductor_config, distributed_autotune
|
||||
from ..codegen.cuda.gemm_template import CUTLASS2xGemmTemplate, CUTLASS3xGemmTemplate
|
||||
from ..codegen.rocm.ck_tile_universal_gemm_template import CKTileGemmTemplate
|
||||
from ..codegen.rocm.ck_universal_gemm_template import CKGemmTemplate
|
||||
@ -1315,6 +1315,11 @@ def tuned_mm(mat1, mat2, out_dtype=None, *, layout=None):
|
||||
# The future will be awaited at scheduling time in select_algorithm.py
|
||||
best_config_future = gen_best_config(mat1, mat2)
|
||||
|
||||
if box := distributed_autotune.maybe_autotune_remote(
|
||||
name, choices, kernel_inputs.nodes(), layout
|
||||
):
|
||||
return box
|
||||
|
||||
return autotune_select_algorithm(
|
||||
name,
|
||||
choices,
|
||||
|
||||
@ -449,7 +449,6 @@ class SchedulerDonatedBuffer(SchedulerBuffer):
|
||||
|
||||
class BaseSchedulerNode:
|
||||
ancestors: OrderedSet[str]
|
||||
debug_device_str: Callable[[BaseSchedulerNode], list[str]]
|
||||
group: tuple[torch.device, tuple[tuple[sympy.Expr, ...], ...]]
|
||||
last_usage: OrderedSet[str]
|
||||
# .min_order and .max_order are only relevant for "grouped" nodes such as FusedSchedulerNode.
|
||||
@ -461,21 +460,26 @@ class BaseSchedulerNode:
|
||||
max_order: int
|
||||
mpi_node: MemoryPlanningInfoForNode
|
||||
mutation_renames: dict[str, str]
|
||||
node: Optional[ir.Operation]
|
||||
node: Optional[ir.Operation] = None
|
||||
outputs: list[SchedulerBuffer]
|
||||
outputs_by_name: dict[str, SchedulerBuffer]
|
||||
override_estimated_runtime: Optional[float] = None
|
||||
read_writes: dependencies.ReadWrites
|
||||
unmet_dependencies: OrderedSet[Dep]
|
||||
written: bool = False
|
||||
|
||||
def __init__(self, scheduler: Scheduler) -> None:
|
||||
self.scheduler = scheduler
|
||||
self.debug_device_str = lambda *args, **kwargs: []
|
||||
self.scheduler: Scheduler = scheduler
|
||||
self.debug_device_str: Callable[[BaseSchedulerNode], list[str]] = (
|
||||
lambda *args, **kwargs: []
|
||||
)
|
||||
|
||||
def _init_from_node(self, node: ir.Operation) -> None:
|
||||
self.node = node
|
||||
self.ancestors = OrderedSet()
|
||||
self.last_usage = OrderedSet() # buffers that won't be used after this kernel
|
||||
self.last_usage = OrderedSet[
|
||||
str
|
||||
]() # buffers that won't be used after this kernel
|
||||
self.written = False
|
||||
self.outputs = [
|
||||
SchedulerBuffer(
|
||||
@ -2643,6 +2647,12 @@ class Scheduler:
|
||||
if config._pre_fusion_custom_pass is not None:
|
||||
self.nodes = config._pre_fusion_custom_pass(self.nodes)
|
||||
|
||||
if config.distributed_max_autotune_gemm:
|
||||
from . import distributed_autotune
|
||||
|
||||
distributed_autotune.schedule(self)
|
||||
self.compute_ancestors()
|
||||
|
||||
self.nodes = self.fuse_nodes(self.nodes)
|
||||
if config._post_fusion_custom_pass is not None:
|
||||
self.nodes = config._post_fusion_custom_pass(self.nodes)
|
||||
@ -3515,6 +3525,7 @@ class Scheduler:
|
||||
|
||||
new_scheduler_node.min_order = node.min_order
|
||||
new_scheduler_node.max_order = node.max_order
|
||||
new_scheduler_node.ancestors = node.ancestors
|
||||
new_scheduler_node.last_usage = node.last_usage
|
||||
|
||||
def _any_atomic_add(self, node_list: Sequence[BaseSchedulerNode]) -> bool:
|
||||
|
||||
@ -549,70 +549,6 @@ def is_pointwise_use(
|
||||
return torch.Tag.pointwise in target.tags or is_pointwise_fn(target)
|
||||
|
||||
|
||||
class LogicalConnective(enum.Enum):
|
||||
OR = enum.auto()
|
||||
AND = enum.auto()
|
||||
|
||||
|
||||
def has_uses(
|
||||
target: Node,
|
||||
use_selector_fn: Callable[[torch._ops.OpOverload], bool] = lambda _: False,
|
||||
use_aggregate_type: LogicalConnective = LogicalConnective.OR,
|
||||
) -> bool:
|
||||
"""
|
||||
Given a target, explore the uses of `target` by applying `use_selector_fn`
|
||||
on them, and then aggregate these booleans with the `use_aggregate_type`
|
||||
logical connective.
|
||||
|
||||
Uses in view ops will follow the views uses.
|
||||
"""
|
||||
|
||||
def get_use_aggregate_fn(
|
||||
use_aggregate_type: LogicalConnective,
|
||||
) -> Callable[[Iterator[Any]], bool]:
|
||||
match use_aggregate_type:
|
||||
case LogicalConnective.AND:
|
||||
return all
|
||||
case LogicalConnective.OR:
|
||||
return any
|
||||
case _:
|
||||
return any
|
||||
|
||||
use_aggregate_fn = get_use_aggregate_fn(use_aggregate_type)
|
||||
|
||||
def has_uses_impl(use: Node) -> bool:
|
||||
if use.op != "call_function":
|
||||
return False
|
||||
if not (
|
||||
isinstance(use.target, torch._ops.OpOverload)
|
||||
or use.target is operator.getitem
|
||||
):
|
||||
return False
|
||||
|
||||
target = cast(torch._ops.OpOverload, use.target)
|
||||
# Process getitem and view
|
||||
if target is operator.getitem or is_view(target):
|
||||
return use_aggregate_fn(has_uses_impl(user) for user in use.users)
|
||||
|
||||
return use_selector_fn(target)
|
||||
|
||||
return use_aggregate_fn(has_uses_impl(user) for user in target.users)
|
||||
|
||||
|
||||
def has_uses_tagged_as(
|
||||
target: Node,
|
||||
use_tags: Collection[torch.Tag],
|
||||
use_aggregate_type: LogicalConnective = LogicalConnective.OR,
|
||||
) -> bool:
|
||||
"""
|
||||
Is there a use with given tags?
|
||||
"""
|
||||
|
||||
return has_uses(
|
||||
target, lambda use: any(tag in use_tags for tag in use.tags), use_aggregate_type
|
||||
)
|
||||
|
||||
|
||||
def gen_gm_and_inputs(
|
||||
target: Any, args: list[Any], kwargs: dict[str, Any]
|
||||
) -> tuple[GraphModule, list[torch.Tensor]]:
|
||||
|
||||
@ -86,6 +86,8 @@ if TYPE_CHECKING:
|
||||
from torch._inductor.loop_body import InterpreterShim
|
||||
from torch._subclasses import FakeTensorMode
|
||||
|
||||
from .distributed_autotune import _DistributedAutotuneState
|
||||
|
||||
threadlocal = local()
|
||||
|
||||
T = TypeVar("T")
|
||||
@ -201,6 +203,9 @@ _current_node: Virtualized[torch.fx.Node] = Virtualized("current_node", NullHand
|
||||
_local_buffer_context: Virtualized[LocalBufferContext] = Virtualized(
|
||||
"local_buffer_context", NullHandler
|
||||
)
|
||||
_distributed_autotune_state: Virtualized[_DistributedAutotuneState] = Virtualized(
|
||||
"distributed_autotune_state", NullHandler
|
||||
)
|
||||
|
||||
|
||||
def _choices_default():
|
||||
@ -370,6 +375,12 @@ class _V:
|
||||
set_local_buffer_context: Callable[[Any], Any] = _local_buffer_context._set_handler
|
||||
get_local_buffer_context: Callable[[], Any] = _local_buffer_context._get_handler
|
||||
set_choices_handler: Callable[[Any], Any] = _choices._set_handler
|
||||
set_distributed_autotune_state: Callable[[Any], Any] = (
|
||||
_distributed_autotune_state._set_handler
|
||||
)
|
||||
get_distributed_autotune_state: Callable[[], Any] = (
|
||||
_distributed_autotune_state._get_handler
|
||||
)
|
||||
|
||||
@property
|
||||
def ops(self) -> OpsHandler[Any]:
|
||||
@ -429,5 +440,9 @@ class _V:
|
||||
def choices(self) -> InductorChoices:
|
||||
return _choices._get_handler()
|
||||
|
||||
@property
|
||||
def distributed_autotune_state(self):
|
||||
return _distributed_autotune_state._get_handler()
|
||||
|
||||
|
||||
V = _V()
|
||||
|
||||
@ -33,14 +33,6 @@ static inline int PyCode_GetNFreevars(PyCodeObject* code) {
|
||||
#endif
|
||||
}
|
||||
|
||||
// Provided by CPython but getting the header for them is very hard
|
||||
#if IS_PYTHON_3_11_PLUS
|
||||
// NOLINTNEXTLINE(readability-redundant-declaration)
|
||||
PyAPI_FUNC(void) _PyWeakref_ClearRef(PyWeakReference* self);
|
||||
#else
|
||||
extern void _PyWeakref_ClearRef(PyWeakReference* self);
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -20,7 +20,10 @@ from torch.distributed.tensor._tp_conv import (
|
||||
convolution_backward_handler,
|
||||
convolution_handler,
|
||||
)
|
||||
from torch.distributed.tensor._utils import try_find_mesh_from_args
|
||||
from torch.distributed.tensor._utils import (
|
||||
ExplicitRedistributionContext,
|
||||
try_find_mesh_from_args,
|
||||
)
|
||||
from torch.distributed.tensor.placement_types import Partial, Placement, Replicate
|
||||
from torch.utils._debug_mode import get_active_debug_mode
|
||||
from torch.utils._python_dispatch import return_and_correct_aliasing
|
||||
@ -199,6 +202,10 @@ class OpDispatcher:
|
||||
if participating:
|
||||
# computation that happens in the current rank of the mesh, normal case
|
||||
if output_sharding.needs_redistribute:
|
||||
if ExplicitRedistributionContext.is_active():
|
||||
raise RuntimeError(
|
||||
f"Implicit redistribution occurred while ExplicitRedistributionContext was active for {op_info.schema}"
|
||||
)
|
||||
# If sharding propagation decision needs redistribute, perform redistribute
|
||||
# on args first, which could potentially modify args (i.e. allgather certain arg)
|
||||
assert output_sharding.redistribute_schema is not None
|
||||
|
||||
@ -18,6 +18,33 @@ from torch.distributed.tensor.placement_types import (
|
||||
from torch.utils._typing_utils import not_none
|
||||
|
||||
|
||||
class ExplicitRedistributionContext:
|
||||
"""
|
||||
Within this context manager, DTensor will refuse to perform implicit redistribution,
|
||||
instead raising an error. Manual calls to ``redistribute()`` are required wherever a redistribution
|
||||
must occur to avoid erroring. This can be used to ensure that the user is aware of all redistribution.
|
||||
|
||||
Note: it is easier to use this mode on just the forward pass of a typical DTensor program, as the backwards pass
|
||||
may contain implicit redistribution calls that are not visible to the user and difficult to replace with manual
|
||||
calls. Redistribution during backward can be made explicit by writing `autograd.Function`s that are no-op
|
||||
during forward and perform a manual redistribution during backwards.
|
||||
"""
|
||||
|
||||
_explicit_redistribute_mode = False
|
||||
|
||||
@classmethod
|
||||
def is_active(cls) -> bool:
|
||||
return cls._explicit_redistribute_mode
|
||||
|
||||
def __enter__(self):
|
||||
self.prev = ExplicitRedistributionContext._explicit_redistribute_mode
|
||||
ExplicitRedistributionContext._explicit_redistribute_mode = True
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
ExplicitRedistributionContext._explicit_redistribute_mode = self.prev
|
||||
|
||||
|
||||
def _explicit_order_placements(
|
||||
mesh_shape: ShapeType, placements: Sequence[Placement]
|
||||
) -> Sequence[tuple[int, Placement]]:
|
||||
|
||||
@ -2,6 +2,8 @@ import os
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
from torch.utils._config_module import Config, install_config_module
|
||||
|
||||
|
||||
# [@compile_ignored: debug] Fails hard instead of graph breaking on guard on data dependent errors.
|
||||
no_data_dependent_graph_break = (
|
||||
@ -100,7 +102,11 @@ backed_size_oblivious = False
|
||||
# Skip dtype check in meta registrations. Only used for systems that does its own dtype checking.
|
||||
skip_dtype_check_in_meta_registrations = False
|
||||
|
||||
from torch.utils._config_module import install_config_module
|
||||
# Experimental: If True, graph module will register fx metadata during recompile()
|
||||
enrich_profiler_metadata: bool = Config( # type: ignore[var-annotated]
|
||||
default=False,
|
||||
env_name_default="TORCH_ENRICH_RPOFILER_STACK_TRACE",
|
||||
)
|
||||
|
||||
|
||||
install_config_module(sys.modules[__name__])
|
||||
|
||||
@ -20,6 +20,7 @@ from torch.nn.modules.module import _addindent
|
||||
from torch.package import Importer, PackageExporter, PackageImporter, sys_importer
|
||||
|
||||
from ._compatibility import compatibility
|
||||
from .experimental import _config as fx_experimental_config
|
||||
from .graph import (
|
||||
_BoxedCodeGen,
|
||||
_custom_builtins,
|
||||
@ -858,14 +859,15 @@ class {module_name}(torch.nn.Module):
|
||||
called after editing the contained ``graph``, otherwise the generated
|
||||
code of this ``GraphModule`` will be out of date.
|
||||
"""
|
||||
# Do not import anything inside recompile, it might slow down the
|
||||
# function and cause perf regression. Import outside of the method instead.
|
||||
if isinstance(self._graph._codegen, _PyTreeCodeGen):
|
||||
self._in_spec = self._graph._codegen.pytree_info.in_spec
|
||||
self._out_spec = self._graph._codegen.pytree_info.out_spec
|
||||
|
||||
from torch._dynamo import config as dynamo_config
|
||||
|
||||
python_code = self._graph.python_code(
|
||||
root_module="self", record_func=dynamo_config.enrich_profiler_metadata
|
||||
root_module="self",
|
||||
record_func=fx_experimental_config.enrich_profiler_metadata,
|
||||
)
|
||||
self._code = python_code.src
|
||||
self._lineno_map = python_code._lineno_map
|
||||
@ -874,7 +876,7 @@ class {module_name}(torch.nn.Module):
|
||||
cls = type(self)
|
||||
co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {}
|
||||
|
||||
if dynamo_config.enrich_profiler_metadata:
|
||||
if fx_experimental_config.enrich_profiler_metadata:
|
||||
# Generate metadata and register for profiler augmentation
|
||||
node_metadata: dict[int, dict[str, Any]] = {}
|
||||
for i, node in enumerate(self._graph.nodes):
|
||||
|
||||
Reference in New Issue
Block a user