mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Graph Partition] Pass all OSS unit tests (#154667)
Graph partition leads to 6.2% speedup on vision_maskrcnn, 5.8% speedup on yolov3. [P1819700563](https://www.internalfb.com/phabricator/paste/view/P1819700563), 39.5% speedup on speech_transformer inference [P1830602200](https://www.internalfb.com/phabricator/paste/view/P1830602200), 85% speedup on speech_transformer training [P1831115315](https://www.internalfb.com/phabricator/paste/view/P1831115315). Run the same diff on two days and both show speedup on average. [first TorchInductor Benchmark ci run](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Mon%2C%2021%20Jul%202025%2016%3A37%3A55%20GMT&stopTime=Mon%2C%2028%20Jul%202025%2016%3A37%3A55%20GMT&granularity=hour&mode=inference&dtype=bfloat16&deviceName=cuda%20(h100)&lBranch=bf/partition-turn-on&lCommit=75ef90fe89b82c967362a2d40fdf1af047202bc2&rBranch=main&rCommit=abcb24f4de11f8fedf2c2c9ff53b6092ef42306d) <img width="1885" height="752" alt="image" src="https://github.com/user-attachments/assets/13bba9fc-5dbf-42ad-8558-d54f7e367b41" /> [second TorchInductorBenchmark ci run](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Wed%2C%2023%20Jul%202025%2016%3A38%3A27%20GMT&stopTime=Wed%2C%2030%20Jul%202025%2016%3A38%3A27%20GMT&granularity=hour&mode=inference&dtype=bfloat16&deviceName=cuda%20(h100)&lBranch=bf/partition-turn-on&lCommit=66de27e29338c26b1be94733049868cb0309ea52&rBranch=main&rCommit=70d2e9ba455c3c910f6f95b24171c8eee7bc00bf) <img width="2513" height="1030" alt="image" src="https://github.com/user-attachments/assets/3a413dcb-2314-4292-919a-7ca181f9eeac" /> Pull Request resolved: https://github.com/pytorch/pytorch/pull/154667 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
edaa151d0d
commit
5f1010fbb3
@ -3085,7 +3085,16 @@ main()
|
|||||||
self.assertEqual(counters["compiled_autograd"]["captures"], 1)
|
self.assertEqual(counters["compiled_autograd"]["captures"], 1)
|
||||||
# Compiled autograd lifts custom autograd.Function bwd instead of tracing it.
|
# Compiled autograd lifts custom autograd.Function bwd instead of tracing it.
|
||||||
# Must skip since we do not know if the cpu scalar will be used only in ATen/prim ops.
|
# Must skip since we do not know if the cpu scalar will be used only in ATen/prim ops.
|
||||||
self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
|
if inductor_config.graph_partition:
|
||||||
|
# instead of skipping cudagraph, graph partition splits off cpu inputs/outputs and ops
|
||||||
|
# and cudagraphify the remaining computation. So there is no cudagraph skip.
|
||||||
|
expected_cudagraph_skips = 0
|
||||||
|
else:
|
||||||
|
expected_cudagraph_skips = 1
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
counters["inductor"]["cudagraph_skips"], expected_cudagraph_skips
|
||||||
|
)
|
||||||
|
|
||||||
@scoped_load_inline
|
@scoped_load_inline
|
||||||
@requires_cuda_and_triton
|
@requires_cuda_and_triton
|
||||||
@ -3150,9 +3159,18 @@ TORCH_LIBRARY(test_cudagraphs_cpu_scalar_used_in_cpp_custom_op, m) {
|
|||||||
# into it. We must skip since we do not know if the cpu scalar will be used only in ATen/prim ops.
|
# into it. We must skip since we do not know if the cpu scalar will be used only in ATen/prim ops.
|
||||||
# In the future, we can consider having a cpu scalar movement pass sometime after we trace
|
# In the future, we can consider having a cpu scalar movement pass sometime after we trace
|
||||||
# into the custom C++ autograd::Function (like in AOTDispatcher)
|
# into the custom C++ autograd::Function (like in AOTDispatcher)
|
||||||
|
if inductor_config.graph_partition:
|
||||||
|
# instead of skipping cudagraph, graph partition splits off cpu inputs/outputs and ops
|
||||||
|
# and cudagraphify the remaining computation. So there is no cudagraph skip.
|
||||||
|
expected_cudagraph_skips = 0
|
||||||
|
elif inductor_config.cpp_wrapper:
|
||||||
|
expected_cudagraph_skips = 2
|
||||||
|
else:
|
||||||
|
expected_cudagraph_skips = 1
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
counters["inductor"]["cudagraph_skips"],
|
counters["inductor"]["cudagraph_skips"],
|
||||||
2 if inductor_config.cpp_wrapper else 1,
|
expected_cudagraph_skips,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_logs(self):
|
def test_logs(self):
|
||||||
|
@ -472,6 +472,9 @@ class CondTests(TestCase):
|
|||||||
@requires_gpu
|
@requires_gpu
|
||||||
@parametrize("device", ["cpu", GPU_TYPE])
|
@parametrize("device", ["cpu", GPU_TYPE])
|
||||||
@torch._inductor.config.patch(size_asserts=False)
|
@torch._inductor.config.patch(size_asserts=False)
|
||||||
|
# TODO: graph partition does not support creating tensor
|
||||||
|
# with dynamic shape in conditional subgraph yet
|
||||||
|
@torch._inductor.config.patch(graph_partition=False)
|
||||||
def test_cond_unbacked_symint_inner(self, device):
|
def test_cond_unbacked_symint_inner(self, device):
|
||||||
class Model(torch.nn.Module):
|
class Model(torch.nn.Module):
|
||||||
def forward(self, p, a):
|
def forward(self, p, a):
|
||||||
|
@ -189,9 +189,9 @@ class CudaReproTests(TestCase):
|
|||||||
# padded bias should have an expanded dim
|
# padded bias should have an expanded dim
|
||||||
FileCheck().check("buf0 =").check_same(", 0, ").run(code[0])
|
FileCheck().check("buf0 =").check_same(", 0, ").run(code[0])
|
||||||
# single fused padded kernel
|
# single fused padded kernel
|
||||||
FileCheck().check("def call").check_count(
|
FileCheck().check_count("empty_strided_cuda(", 1, exactly=True).check(
|
||||||
"empty_strided_cuda", 1, exactly=True
|
"return"
|
||||||
).check("return").run(code[0])
|
).run(code[0])
|
||||||
|
|
||||||
self.assertEqual(out, f(*inputs))
|
self.assertEqual(out, f(*inputs))
|
||||||
|
|
||||||
|
@ -279,10 +279,14 @@ if HAS_CUDA_AND_TRITON:
|
|||||||
with capture_stderr() as captured_output:
|
with capture_stderr() as captured_output:
|
||||||
foo(torch.ones([10], device="cuda"), torch.ones([20]))
|
foo(torch.ones([10], device="cuda"), torch.ones([20]))
|
||||||
|
|
||||||
FileCheck().check(
|
if torch._inductor.config.graph_partition:
|
||||||
"skipping cudagraphs due to cpu device (arg1_1). Found from"
|
# graph partition splits on cpu ops
|
||||||
).check("y + 2").run(captured_output[0])
|
self.assertEqual(counters["inductor"]["cudagraph_skips"], 0)
|
||||||
self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
|
else:
|
||||||
|
FileCheck().check(
|
||||||
|
"skipping cudagraphs due to cpu device (arg1_1). Found from"
|
||||||
|
).check("y + 2").run(captured_output[0])
|
||||||
|
self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
|
||||||
|
|
||||||
with capture_stderr() as captured_output:
|
with capture_stderr() as captured_output:
|
||||||
foo(
|
foo(
|
||||||
@ -292,7 +296,10 @@ if HAS_CUDA_AND_TRITON:
|
|||||||
FileCheck().check("skipping cudagraphs due to multiple devices").run(
|
FileCheck().check("skipping cudagraphs due to multiple devices").run(
|
||||||
captured_output[0]
|
captured_output[0]
|
||||||
)
|
)
|
||||||
self.assertEqual(counters["inductor"]["cudagraph_skips"], 2)
|
self.assertEqual(
|
||||||
|
counters["inductor"]["cudagraph_skips"],
|
||||||
|
1 if torch._inductor.config.graph_partition else 2,
|
||||||
|
)
|
||||||
|
|
||||||
@torch._inductor.config.patch("triton.cudagraph_skip_dynamic_graphs", True)
|
@torch._inductor.config.patch("triton.cudagraph_skip_dynamic_graphs", True)
|
||||||
def test_skip_symbolic(self):
|
def test_skip_symbolic(self):
|
||||||
@ -807,10 +814,16 @@ if HAS_CUDA_AND_TRITON:
|
|||||||
# the three saved tensors should die in the backward
|
# the three saved tensors should die in the backward
|
||||||
# we kept alive the output
|
# we kept alive the output
|
||||||
self.assertEqual(self.curr_node().expected_dead_indices_before_graph, [])
|
self.assertEqual(self.curr_node().expected_dead_indices_before_graph, [])
|
||||||
self.assertEqual(
|
if torch._inductor.config.graph_partition:
|
||||||
self.curr_node().expected_dead_indices_after_graph,
|
self.assertEqual(
|
||||||
[(0, 1), (0, 2)],
|
self.curr_node().expected_dead_indices_after_graph,
|
||||||
)
|
[(0, 0), (0, 2)],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.assertEqual(
|
||||||
|
self.curr_node().expected_dead_indices_after_graph,
|
||||||
|
[(0, 1), (0, 2)],
|
||||||
|
)
|
||||||
self.assertFalse(self.get_manager().new_graph_id().id == 0)
|
self.assertFalse(self.get_manager().new_graph_id().id == 0)
|
||||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
|
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
|
||||||
|
|
||||||
@ -1127,8 +1140,13 @@ if HAS_CUDA_AND_TRITON:
|
|||||||
|
|
||||||
node = self.curr_node()
|
node = self.curr_node()
|
||||||
first_node = next(node._path_from_root)
|
first_node = next(node._path_from_root)
|
||||||
self.assertFalse(first_node.unaliased_in_all_paths[0])
|
if torch._inductor.config.graph_partition:
|
||||||
self.assertTrue(first_node.cached_tensor_outputs[0] is None)
|
# graph partition may changed the order of outputs
|
||||||
|
self.assertFalse(first_node.unaliased_in_all_paths[1])
|
||||||
|
self.assertTrue(first_node.cached_tensor_outputs[1] is None)
|
||||||
|
else:
|
||||||
|
self.assertFalse(first_node.unaliased_in_all_paths[0])
|
||||||
|
self.assertTrue(first_node.cached_tensor_outputs[0] is None)
|
||||||
|
|
||||||
@torch._inductor.config.patch("implicit_fallbacks", True)
|
@torch._inductor.config.patch("implicit_fallbacks", True)
|
||||||
def test_multinomial(self):
|
def test_multinomial(self):
|
||||||
@ -1631,10 +1649,16 @@ if HAS_CUDA_AND_TRITON:
|
|||||||
# the three saved tensors should die in the backward
|
# the three saved tensors should die in the backward
|
||||||
# we kept alive the output
|
# we kept alive the output
|
||||||
self.assertEqual(self.curr_node().expected_dead_indices_before_graph, [])
|
self.assertEqual(self.curr_node().expected_dead_indices_before_graph, [])
|
||||||
self.assertEqual(
|
if torch._inductor.config.graph_partition:
|
||||||
self.curr_node().expected_dead_indices_after_graph,
|
self.assertEqual(
|
||||||
[(0, 1), (0, 2)],
|
self.curr_node().expected_dead_indices_after_graph,
|
||||||
)
|
[(0, 0), (0, 2)],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.assertEqual(
|
||||||
|
self.curr_node().expected_dead_indices_after_graph,
|
||||||
|
[(0, 1), (0, 2)],
|
||||||
|
)
|
||||||
self.assertFalse(self.get_manager().new_graph_id().id == 0)
|
self.assertFalse(self.get_manager().new_graph_id().id == 0)
|
||||||
|
|
||||||
def test_separate_recordings(self):
|
def test_separate_recordings(self):
|
||||||
@ -2137,8 +2161,8 @@ if HAS_CUDA_AND_TRITON:
|
|||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
Exception,
|
Exception,
|
||||||
r"(?s)static input data pointer changed.\n"
|
r"(?s)static input data pointer changed.\n"
|
||||||
r"input name: primals_2. data pointer changed from .* to .*. input stack trace:.*"
|
r"input name: primals_.*. data pointer changed from .* to .*. input stack trace:.*"
|
||||||
r"input name: primals_3. data pointer changed from .* to .*. input stack trace:.*,"
|
r"input name: primals_.*. data pointer changed from .* to .*. input stack trace:.*,"
|
||||||
r" in forward\n.* self.static_tensor.add\_\(torch.ones\(\(2, 2\), device=\"cuda\"\)\).*\n",
|
r" in forward\n.* self.static_tensor.add\_\(torch.ones\(\(2, 2\), device=\"cuda\"\)\).*\n",
|
||||||
):
|
):
|
||||||
self.curr_node().run(
|
self.curr_node().run(
|
||||||
@ -3551,6 +3575,278 @@ if HAS_CUDA_AND_TRITON:
|
|||||||
|
|
||||||
self.assertEqual(self.get_manager().new_graph_id().id, 2)
|
self.assertEqual(self.get_manager().new_graph_id().id, 2)
|
||||||
|
|
||||||
|
@torch._inductor.config.patch("graph_partition", True)
|
||||||
|
def test_graph_partition_simple(self):
|
||||||
|
def f(x, y):
|
||||||
|
x1 = x + 1
|
||||||
|
y1 = y + 1
|
||||||
|
y_cpu = y1.cpu() + 1
|
||||||
|
z = x @ y
|
||||||
|
return x1 + y1 + z + y_cpu.to("cuda")
|
||||||
|
|
||||||
|
x, y = [torch.ones(2, 2, device="cuda") for _ in range(2)]
|
||||||
|
x_cloned, y_cloned = [tmp.clone() for tmp in [x, y]]
|
||||||
|
eager_out = f(x, y)
|
||||||
|
|
||||||
|
f_compiled = torch.compile(f)
|
||||||
|
compiled_out = f_compiled(x_cloned, y_cloned)
|
||||||
|
self.assertEqual(eager_out, compiled_out)
|
||||||
|
|
||||||
|
_, code = run_and_get_code(f_compiled, x_cloned, y_cloned)
|
||||||
|
|
||||||
|
if not config.cpp_wrapper:
|
||||||
|
FileCheck().check("def partition_0(args):").check(
|
||||||
|
"recursively_apply_fns = runner.recursively_apply_fns"
|
||||||
|
).run(code[0])
|
||||||
|
|
||||||
|
@torch._inductor.config.patch("graph_partition", True)
|
||||||
|
def test_graph_partition_foreach_op(self):
|
||||||
|
def fn(a0, a1):
|
||||||
|
c = torch._foreach_abs([a0, a1])
|
||||||
|
return torch.mul(c[0], a0)
|
||||||
|
|
||||||
|
compiled_fn = torch.compile(fn)
|
||||||
|
|
||||||
|
a0 = torch.randn(2, 3, device="cuda")
|
||||||
|
a1 = torch.randn(2, 3, device="cuda")
|
||||||
|
eager_out = fn(a0, a1)
|
||||||
|
compiled_out = compiled_fn(a0, a1)
|
||||||
|
self.assertEqual(eager_out, compiled_out)
|
||||||
|
|
||||||
|
@torch._inductor.config.patch("graph_partition", True)
|
||||||
|
def test_graph_partition_condition_op(self):
|
||||||
|
def f(p, b):
|
||||||
|
def true_fn(x):
|
||||||
|
return torch.cos(x)
|
||||||
|
|
||||||
|
def false_fn(x):
|
||||||
|
return torch.sin(x)
|
||||||
|
|
||||||
|
return torch.cond(p, true_fn, false_fn, [b])
|
||||||
|
|
||||||
|
compiled_f = torch.compile(f)
|
||||||
|
|
||||||
|
# static shape
|
||||||
|
p = torch.tensor([True], device="cuda")
|
||||||
|
a = torch.ones([2, 3], device="cuda")
|
||||||
|
eager_out = f(p, a)
|
||||||
|
compiled_out = compiled_f(p, a)
|
||||||
|
self.assertEqual(eager_out, compiled_out)
|
||||||
|
|
||||||
|
# dynamic shape with backed symint
|
||||||
|
p = torch.tensor([True], device="cuda")
|
||||||
|
a = torch.ones([4, 5], device="cuda")
|
||||||
|
eager_out = f(p, a)
|
||||||
|
compiled_out = compiled_f(p, a)
|
||||||
|
self.assertEqual(eager_out, compiled_out)
|
||||||
|
|
||||||
|
@torch._inductor.config.patch("graph_partition", True)
|
||||||
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||||
|
def test_graph_partition_unbacked_symint_multi_output_layout(self):
|
||||||
|
def f(p, size_tensor):
|
||||||
|
size_val = size_tensor.item()
|
||||||
|
b = torch.ones([size_val, 3], device="cuda")
|
||||||
|
|
||||||
|
def true_fn(x):
|
||||||
|
return torch.cos(x), torch.cos(x) + 1
|
||||||
|
|
||||||
|
def false_fn(x):
|
||||||
|
return torch.sin(x), torch.sin(x) + 1
|
||||||
|
|
||||||
|
cond_out = torch.cond(p, true_fn, false_fn, [b])
|
||||||
|
return cond_out[0] + cond_out[1]
|
||||||
|
|
||||||
|
compiled_f = torch.compile(f)
|
||||||
|
p = torch.tensor([True], device="cuda")
|
||||||
|
size_tensor = torch.tensor(2, device="cuda")
|
||||||
|
eager_out = f(p, size_tensor)
|
||||||
|
compiled_out = compiled_f(p, size_tensor)
|
||||||
|
self.assertEqual(eager_out, compiled_out)
|
||||||
|
|
||||||
|
@torch._inductor.config.patch("graph_partition", True)
|
||||||
|
def test_graph_partition_symint(self):
|
||||||
|
def f(x, y):
|
||||||
|
x1 = x + 1
|
||||||
|
y1 = y + 1
|
||||||
|
y_cpu = y1.cpu() + 1
|
||||||
|
z = x @ y
|
||||||
|
return x1 + y1 + z + y_cpu.to("cuda")
|
||||||
|
|
||||||
|
f_compiled = torch.compile(f)
|
||||||
|
x, y = (
|
||||||
|
torch.ones(3, 3, device="cuda"),
|
||||||
|
torch.randn(3, 3, device="cuda"),
|
||||||
|
)
|
||||||
|
compiled_out = f_compiled(x, y)
|
||||||
|
self.assertEqual(compiled_out, f(x, y))
|
||||||
|
|
||||||
|
x, y = (
|
||||||
|
torch.ones(4, 4, device="cuda"),
|
||||||
|
torch.randn(4, 4, device="cuda"),
|
||||||
|
)
|
||||||
|
compiled_out = f_compiled(x, y)
|
||||||
|
self.assertEqual(compiled_out, f(x, y))
|
||||||
|
|
||||||
|
@torch._inductor.config.patch("graph_partition", True)
|
||||||
|
def test_graph_partition_symint_cat_backward(self):
|
||||||
|
def f(x, w):
|
||||||
|
y = torch.cat((x, x), dim=0)
|
||||||
|
z = y @ w
|
||||||
|
return z @ z.T
|
||||||
|
|
||||||
|
compiled_f = torch.compile(f)
|
||||||
|
|
||||||
|
for shape in (2, 3):
|
||||||
|
torch.manual_seed(42)
|
||||||
|
eager_x = torch.randn(shape, 2, device="cuda")
|
||||||
|
eager_w = torch.randn(2, 2, device="cuda", requires_grad=True)
|
||||||
|
torch.manual_seed(42)
|
||||||
|
compiled_x = torch.randn(shape, 2, device="cuda")
|
||||||
|
compiled_w = torch.randn(2, 2, device="cuda", requires_grad=True)
|
||||||
|
|
||||||
|
f(eager_x, eager_w).sum().backward()
|
||||||
|
compiled_f(compiled_x, compiled_w).sum().backward()
|
||||||
|
self.assertEqual(eager_w.grad, compiled_w.grad)
|
||||||
|
|
||||||
|
@dynamo_config.patch("capture_dynamic_output_shape_ops", True)
|
||||||
|
@config.patch(implicit_fallbacks=True)
|
||||||
|
@torch._inductor.config.patch("graph_partition", True)
|
||||||
|
def test_graph_partition_symint_from_nested_indirect_indexing(self):
|
||||||
|
def nested(x, repeats):
|
||||||
|
rank = torch.arange(repeats.numel(), device=x.device)
|
||||||
|
index = rank.repeat_interleave(repeats, dim=0)
|
||||||
|
return torch.index_select(x, index=index, dim=0)
|
||||||
|
|
||||||
|
example_inputs = (
|
||||||
|
torch.randn((32, 64), device="cuda"),
|
||||||
|
repeats := torch.tensor([5, 10, 15], device="cuda"),
|
||||||
|
)
|
||||||
|
torch._dynamo.mark_dynamic(repeats, 0) # create backed symint
|
||||||
|
|
||||||
|
nested_opt = torch.compile(nested, backend="inductor")
|
||||||
|
|
||||||
|
expect = nested(*example_inputs)
|
||||||
|
actual = nested_opt(*example_inputs)
|
||||||
|
self.assertEqual(expect, actual)
|
||||||
|
|
||||||
|
@torch._inductor.config.patch("graph_partition", True)
|
||||||
|
def test_graph_partition_symint_from_mutation_index(self):
|
||||||
|
x = torch.zeros(7, device="cuda")
|
||||||
|
|
||||||
|
def fn(n, a):
|
||||||
|
a[n] = -1
|
||||||
|
return a
|
||||||
|
|
||||||
|
opt_fn = torch.compile(fn, fullgraph=True)
|
||||||
|
|
||||||
|
for n in range(2, x.shape[0]):
|
||||||
|
opt_fn(n, x)
|
||||||
|
self.assertEqual(x[n], -1)
|
||||||
|
|
||||||
|
# Negative index triggers new compilation.
|
||||||
|
opt_fn(-x.shape[0], x)
|
||||||
|
|
||||||
|
self.assertEqual(x[0], -1)
|
||||||
|
|
||||||
|
@torch._inductor.config.patch("graph_partition", True)
|
||||||
|
def test_graph_partition_unbacked_symint(self):
|
||||||
|
def f(x, y):
|
||||||
|
x1 = x + 1
|
||||||
|
y1 = y + 1
|
||||||
|
y_cpu = y1.cpu() + 1
|
||||||
|
z = x @ y
|
||||||
|
return x1 + y1 + z + y_cpu.to("cuda")
|
||||||
|
|
||||||
|
f_compiled = torch.compile(f)
|
||||||
|
x, y = (
|
||||||
|
torch.ones(3, 3, device="cuda"),
|
||||||
|
torch.randn(3, 3, device="cuda"),
|
||||||
|
)
|
||||||
|
|
||||||
|
torch._dynamo.decorators.mark_unbacked(x, 0)
|
||||||
|
torch._dynamo.decorators.mark_unbacked(y, 1)
|
||||||
|
|
||||||
|
compiled_out = f_compiled(x, y)
|
||||||
|
eager_out = f(x, y)
|
||||||
|
self.assertEqual(compiled_out, eager_out)
|
||||||
|
|
||||||
|
@torch._inductor.config.patch("graph_partition", True)
|
||||||
|
def test_graph_partition_dynamic_scalar_inputs(self):
|
||||||
|
def f(x, y, integer):
|
||||||
|
x1 = x + 1
|
||||||
|
y1 = y + 1
|
||||||
|
y_cpu = y1.cpu() + 1
|
||||||
|
z = x @ y
|
||||||
|
z += integer
|
||||||
|
return x1 + y1 + z + y_cpu.to("cuda")
|
||||||
|
|
||||||
|
f_compiled = torch.compile(f)
|
||||||
|
x, y = (
|
||||||
|
torch.ones(3, 3, device="cuda"),
|
||||||
|
torch.randn(3, 3, device="cuda"),
|
||||||
|
)
|
||||||
|
|
||||||
|
torch._dynamo.decorators.mark_unbacked(x, 0)
|
||||||
|
torch._dynamo.decorators.mark_unbacked(y, 1)
|
||||||
|
|
||||||
|
compiled_out = f_compiled(x, y, 5)
|
||||||
|
self.assertEqual(compiled_out, f(x, y, 5))
|
||||||
|
|
||||||
|
compiled_out = f_compiled(x, y, 6)
|
||||||
|
self.assertEqual(compiled_out, f(x, y, 6))
|
||||||
|
|
||||||
|
@torch._inductor.config.patch("graph_partition", True)
|
||||||
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||||
|
def test_graph_partition_item(self):
|
||||||
|
def f(x):
|
||||||
|
y = x + 1
|
||||||
|
scalar = y.item()
|
||||||
|
return x + y + scalar
|
||||||
|
|
||||||
|
compiled_f = torch.compile(f)
|
||||||
|
compiled_out = compiled_f(torch.tensor(1, device="cuda"))
|
||||||
|
self.assertEqual(compiled_out, f(torch.tensor(1, device="cuda")))
|
||||||
|
|
||||||
|
@torch._inductor.config.patch("graph_partition", True)
|
||||||
|
def test_graph_partition_buffer_reuse(self):
|
||||||
|
def f(x, y):
|
||||||
|
x1 = x + 1
|
||||||
|
y1 = y + 1
|
||||||
|
y_cpu = y1.cpu() + 1
|
||||||
|
z = x1 + y1 + x @ y
|
||||||
|
u = (y_cpu.to("cuda") + 2) @ y + 3
|
||||||
|
u_cpu = u.cpu() + 2
|
||||||
|
return z + u_cpu.to("cuda")
|
||||||
|
|
||||||
|
x, y = [torch.ones(2, 2, device="cuda") for _ in range(2)]
|
||||||
|
x_cloned, y_cloned = [tmp.clone() for tmp in [x, y]]
|
||||||
|
eager_out = f(x, y)
|
||||||
|
|
||||||
|
f_compiled = torch.compile(f)
|
||||||
|
compiled_out = f_compiled(x_cloned, y_cloned)
|
||||||
|
|
||||||
|
self.assertEqual(eager_out, compiled_out)
|
||||||
|
|
||||||
|
@torch._inductor.config.patch("graph_partition", True)
|
||||||
|
def test_graph_partition_fused_scheduler_node(self):
|
||||||
|
def foo(x):
|
||||||
|
x = x * 20
|
||||||
|
x_alias = x[0]
|
||||||
|
y = x * 10
|
||||||
|
y_alias = y[0]
|
||||||
|
torch._dynamo.graph_break()
|
||||||
|
ind = torch.tensor(4, device="cuda")
|
||||||
|
x_alias2 = x[ind:]
|
||||||
|
y_alias2 = y[ind:]
|
||||||
|
return x, x_alias, x_alias2, y_alias, y_alias2
|
||||||
|
|
||||||
|
compiled_foo = torch.compile(foo)
|
||||||
|
x = torch.rand([20, 20], device="cuda")
|
||||||
|
|
||||||
|
eager_out = foo(x)
|
||||||
|
compiled_out = compiled_foo(x)
|
||||||
|
self.assertEqual(eager_out, compiled_out)
|
||||||
|
|
||||||
def test_meta_tensor(self):
|
def test_meta_tensor(self):
|
||||||
def foobar(x, y):
|
def foobar(x, y):
|
||||||
return x * 2, y * 3
|
return x * 2, y * 3
|
||||||
|
@ -31,10 +31,11 @@ class InductorAnnotationTestCase(TestCase):
|
|||||||
code = self.get_code()
|
code = self.get_code()
|
||||||
|
|
||||||
self.assertTrue("from torch.cuda import nvtx" in code)
|
self.assertTrue("from torch.cuda import nvtx" in code)
|
||||||
self.assertEqual(
|
self.assertTrue(
|
||||||
code.count("training_annotation = nvtx._device_range_start('inference')"), 1
|
code.count("training_annotation = nvtx._device_range_start('inference')")
|
||||||
|
>= 1
|
||||||
)
|
)
|
||||||
self.assertEqual(code.count("nvtx._device_range_end(training_annotation)"), 1)
|
self.assertTrue(code.count("nvtx._device_range_end(training_annotation)") >= 1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -68,9 +68,16 @@ class TestOperatorReorderForPeakMemory(TestCase):
|
|||||||
outp_corr = self.model(self.inputs)
|
outp_corr = self.model(self.inputs)
|
||||||
compiled_model = torch.compile(self.model)
|
compiled_model = torch.compile(self.model)
|
||||||
code = run_and_get_triton_code(compiled_model, self.inputs)
|
code = run_and_get_triton_code(compiled_model, self.inputs)
|
||||||
|
|
||||||
|
call_str = (
|
||||||
|
"def call(self, args):"
|
||||||
|
if torch._inductor.config.graph_partition
|
||||||
|
else "def call(args):"
|
||||||
|
)
|
||||||
|
|
||||||
(
|
(
|
||||||
FileCheck()
|
FileCheck()
|
||||||
.check("def call(args):")
|
.check(call_str)
|
||||||
.check("buf1 = ")
|
.check("buf1 = ")
|
||||||
.check("buf0 = ")
|
.check("buf0 = ")
|
||||||
.check("buf2 = ")
|
.check("buf2 = ")
|
||||||
@ -105,6 +112,12 @@ class TestOperatorReorderForPeakMemory(TestCase):
|
|||||||
methods=[memory.topological_sort_lpmf],
|
methods=[memory.topological_sort_lpmf],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
call_str = (
|
||||||
|
"def call(self, args):"
|
||||||
|
if torch._inductor.config.graph_partition
|
||||||
|
else "def call(args):"
|
||||||
|
)
|
||||||
|
|
||||||
with mock.patch.object(
|
with mock.patch.object(
|
||||||
memory, "reorder_for_peak_memory", reorder_with_only_lpmf
|
memory, "reorder_for_peak_memory", reorder_with_only_lpmf
|
||||||
):
|
):
|
||||||
@ -113,7 +126,7 @@ class TestOperatorReorderForPeakMemory(TestCase):
|
|||||||
code = run_and_get_triton_code(compiled_model, self.inputs)
|
code = run_and_get_triton_code(compiled_model, self.inputs)
|
||||||
(
|
(
|
||||||
FileCheck()
|
FileCheck()
|
||||||
.check("def call(args):")
|
.check(call_str)
|
||||||
.check("buf1 = ")
|
.check("buf1 = ")
|
||||||
.check("buf0 = ")
|
.check("buf0 = ")
|
||||||
.check("buf2 = ")
|
.check("buf2 = ")
|
||||||
@ -148,15 +161,22 @@ class TestOperatorReorderForPeakMemory(TestCase):
|
|||||||
methods=[memory.topological_sort_bfs],
|
methods=[memory.topological_sort_bfs],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
call_str = (
|
||||||
|
"def call(self, args):"
|
||||||
|
if torch._inductor.config.graph_partition
|
||||||
|
else "def call(args):"
|
||||||
|
)
|
||||||
|
|
||||||
with mock.patch.object(
|
with mock.patch.object(
|
||||||
memory, "reorder_for_peak_memory", reorder_with_only_bfs
|
memory, "reorder_for_peak_memory", reorder_with_only_bfs
|
||||||
):
|
):
|
||||||
compiled_model = torch.compile(self.model)
|
compiled_model = torch.compile(self.model)
|
||||||
|
|
||||||
code = run_and_get_triton_code(compiled_model, self.inputs)
|
code = run_and_get_triton_code(compiled_model, self.inputs)
|
||||||
|
|
||||||
(
|
(
|
||||||
FileCheck()
|
FileCheck()
|
||||||
.check("def call(args):")
|
.check(call_str)
|
||||||
.check("buf0 = ")
|
.check("buf0 = ")
|
||||||
.check("buf1 = ")
|
.check("buf1 = ")
|
||||||
.check("buf2 = ")
|
.check("buf2 = ")
|
||||||
@ -191,6 +211,12 @@ class TestOperatorReorderForPeakMemory(TestCase):
|
|||||||
methods=[memory.topological_sort_dfs],
|
methods=[memory.topological_sort_dfs],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
call_str = (
|
||||||
|
"def call(self, args):"
|
||||||
|
if torch._inductor.config.graph_partition
|
||||||
|
else "def call(args):"
|
||||||
|
)
|
||||||
|
|
||||||
with mock.patch.object(
|
with mock.patch.object(
|
||||||
memory, "reorder_for_peak_memory", reorder_with_only_dfs
|
memory, "reorder_for_peak_memory", reorder_with_only_dfs
|
||||||
):
|
):
|
||||||
@ -199,7 +225,7 @@ class TestOperatorReorderForPeakMemory(TestCase):
|
|||||||
code = run_and_get_triton_code(compiled_model, self.inputs)
|
code = run_and_get_triton_code(compiled_model, self.inputs)
|
||||||
(
|
(
|
||||||
FileCheck()
|
FileCheck()
|
||||||
.check("def call(args):")
|
.check(call_str)
|
||||||
.check("buf0 = ")
|
.check("buf0 = ")
|
||||||
.check("buf2 = ")
|
.check("buf2 = ")
|
||||||
.check("buf4 = ")
|
.check("buf4 = ")
|
||||||
|
@ -15044,302 +15044,6 @@ if RUN_GPU:
|
|||||||
"'XBLOCK': 'constexpr'"
|
"'XBLOCK': 'constexpr'"
|
||||||
).run(code[0])
|
).run(code[0])
|
||||||
|
|
||||||
@torch._inductor.config.patch("graph_partition", True)
|
|
||||||
def test_graph_partition(self):
|
|
||||||
def f(x, y):
|
|
||||||
x1 = x + 1
|
|
||||||
y1 = y + 1
|
|
||||||
y_cpu = y1.cpu() + 1
|
|
||||||
z = x @ y
|
|
||||||
return x1 + y1 + z + y_cpu.to(GPU_TYPE)
|
|
||||||
|
|
||||||
x, y = [torch.ones(2, 2, device=self.device) for _ in range(2)]
|
|
||||||
x_cloned, y_cloned = [tmp.clone() for tmp in [x, y]]
|
|
||||||
eager_out = f(x, y)
|
|
||||||
|
|
||||||
f_compiled = torch.compile(f)
|
|
||||||
compiled_out = f_compiled(x_cloned, y_cloned)
|
|
||||||
self.assertEqual(eager_out, compiled_out)
|
|
||||||
|
|
||||||
_, code = run_and_get_code(f_compiled, x_cloned, y_cloned)
|
|
||||||
|
|
||||||
if not config.cpp_wrapper:
|
|
||||||
FileCheck().check("def partition_0(args):").check(
|
|
||||||
"(buf0, buf1, arg0_1, arg1_1) = self.partitions[0](partition0_args)"
|
|
||||||
).check("recursively_apply_fns = runner.recursively_apply_fns").run(
|
|
||||||
code[0]
|
|
||||||
)
|
|
||||||
|
|
||||||
@torch._inductor.config.patch("graph_partition", True)
|
|
||||||
def test_graph_partition_foreach_op(self):
|
|
||||||
def fn(a0, a1):
|
|
||||||
c = torch._foreach_abs([a0, a1])
|
|
||||||
return torch.mul(c[0], a0)
|
|
||||||
|
|
||||||
compiled_fn = torch.compile(fn)
|
|
||||||
|
|
||||||
a0 = torch.randn(2, 3, device=self.device)
|
|
||||||
a1 = torch.randn(2, 3, device=self.device)
|
|
||||||
eager_out = fn(a0, a1)
|
|
||||||
compiled_out = compiled_fn(a0, a1)
|
|
||||||
self.assertEqual(eager_out, compiled_out)
|
|
||||||
|
|
||||||
@torch._inductor.config.patch("graph_partition", True)
|
|
||||||
def test_graph_partition_multiple_functions(self):
|
|
||||||
def f(x, y):
|
|
||||||
x1 = x + 1
|
|
||||||
y1 = y + 1
|
|
||||||
y_cpu = y1.cpu() + 1
|
|
||||||
z = x @ y
|
|
||||||
return x1 + y1 + z + y_cpu.to(GPU_TYPE)
|
|
||||||
|
|
||||||
def g(x):
|
|
||||||
return x + 1
|
|
||||||
|
|
||||||
x, y = [torch.ones(2, 2, device=self.device) for _ in range(2)]
|
|
||||||
x_cloned, y_cloned = [tmp.clone() for tmp in [x, y]]
|
|
||||||
eager_out = g(f(x, y))
|
|
||||||
|
|
||||||
f_compiled = torch.compile(f)
|
|
||||||
g_compiled = torch.compile(g)
|
|
||||||
compiled_out = g_compiled(f_compiled(x_cloned, y_cloned))
|
|
||||||
|
|
||||||
self.assertEqual(eager_out, compiled_out)
|
|
||||||
|
|
||||||
@torch._inductor.config.patch("graph_partition", True)
|
|
||||||
def test_graph_partition_condition_op(self):
|
|
||||||
def f(p, b):
|
|
||||||
def true_fn(x):
|
|
||||||
return torch.cos(x)
|
|
||||||
|
|
||||||
def false_fn(x):
|
|
||||||
return torch.sin(x)
|
|
||||||
|
|
||||||
return torch.cond(p, true_fn, false_fn, [b])
|
|
||||||
|
|
||||||
compiled_f = torch.compile(f)
|
|
||||||
|
|
||||||
# static shape
|
|
||||||
p = torch.tensor([True], device=self.device)
|
|
||||||
a = torch.ones([2, 3], device=self.device)
|
|
||||||
eager_out = f(p, a)
|
|
||||||
compiled_out = compiled_f(p, a)
|
|
||||||
self.assertEqual(eager_out, compiled_out)
|
|
||||||
|
|
||||||
# dynamic shape with backed symint
|
|
||||||
p = torch.tensor([True], device=self.device)
|
|
||||||
a = torch.ones([4, 5], device=self.device)
|
|
||||||
eager_out = f(p, a)
|
|
||||||
compiled_out = compiled_f(p, a)
|
|
||||||
self.assertEqual(eager_out, compiled_out)
|
|
||||||
|
|
||||||
@torch._inductor.config.patch("graph_partition", True)
|
|
||||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
||||||
def test_graph_partition_unbacked_symint_multi_output_layout(self):
|
|
||||||
def f(p, size_tensor):
|
|
||||||
size_val = size_tensor.item()
|
|
||||||
b = torch.ones([size_val, 3], device=GPU_TYPE)
|
|
||||||
|
|
||||||
def true_fn(x):
|
|
||||||
return torch.cos(x), torch.cos(x) + 1
|
|
||||||
|
|
||||||
def false_fn(x):
|
|
||||||
return torch.sin(x), torch.sin(x) + 1
|
|
||||||
|
|
||||||
cond_out = torch.cond(p, true_fn, false_fn, [b])
|
|
||||||
return cond_out[0] + cond_out[1]
|
|
||||||
|
|
||||||
compiled_f = torch.compile(f)
|
|
||||||
p = torch.tensor([True], device=GPU_TYPE)
|
|
||||||
size_tensor = torch.tensor(2, device=GPU_TYPE)
|
|
||||||
eager_out = f(p, size_tensor)
|
|
||||||
compiled_out = compiled_f(p, size_tensor)
|
|
||||||
self.assertEqual(eager_out, compiled_out)
|
|
||||||
|
|
||||||
@torch._inductor.config.patch("graph_partition", True)
|
|
||||||
def test_graph_partition_symint(self):
|
|
||||||
def f(x, y):
|
|
||||||
x1 = x + 1
|
|
||||||
y1 = y + 1
|
|
||||||
y_cpu = y1.cpu() + 1
|
|
||||||
z = x @ y
|
|
||||||
return x1 + y1 + z + y_cpu.to(GPU_TYPE)
|
|
||||||
|
|
||||||
f_compiled = torch.compile(f)
|
|
||||||
x, y = (
|
|
||||||
torch.ones(3, 3, device=self.device),
|
|
||||||
torch.randn(3, 3, device=self.device),
|
|
||||||
)
|
|
||||||
compiled_out = f_compiled(x, y)
|
|
||||||
self.assertEqual(compiled_out, f(x, y))
|
|
||||||
|
|
||||||
x, y = (
|
|
||||||
torch.ones(4, 4, device=self.device),
|
|
||||||
torch.randn(4, 4, device=self.device),
|
|
||||||
)
|
|
||||||
compiled_out = f_compiled(x, y)
|
|
||||||
self.assertEqual(compiled_out, f(x, y))
|
|
||||||
|
|
||||||
@torch._inductor.config.patch("graph_partition", True)
|
|
||||||
def test_graph_partition_symint_cat_backward(self):
|
|
||||||
def f(x, w):
|
|
||||||
y = torch.cat((x, x), dim=0)
|
|
||||||
z = y @ w
|
|
||||||
return z @ z.T
|
|
||||||
|
|
||||||
compiled_f = torch.compile(f)
|
|
||||||
|
|
||||||
for shape in (2, 3):
|
|
||||||
torch.manual_seed(42)
|
|
||||||
eager_x = torch.randn(shape, 2, device=self.device)
|
|
||||||
eager_w = torch.randn(2, 2, device=self.device, requires_grad=True)
|
|
||||||
torch.manual_seed(42)
|
|
||||||
compiled_x = torch.randn(shape, 2, device=self.device)
|
|
||||||
compiled_w = torch.randn(2, 2, device=self.device, requires_grad=True)
|
|
||||||
|
|
||||||
f(eager_x, eager_w).sum().backward()
|
|
||||||
compiled_f(compiled_x, compiled_w).sum().backward()
|
|
||||||
self.assertEqual(eager_w.grad, compiled_w.grad)
|
|
||||||
|
|
||||||
@dynamo_config.patch("capture_dynamic_output_shape_ops", True)
|
|
||||||
@config.patch(implicit_fallbacks=True)
|
|
||||||
@torch._inductor.config.patch("graph_partition", True)
|
|
||||||
def test_graph_partition_symint_from_nested_indirect_indexing(self):
|
|
||||||
def nested(x, repeats):
|
|
||||||
rank = torch.arange(repeats.numel(), device=x.device)
|
|
||||||
index = rank.repeat_interleave(repeats, dim=0)
|
|
||||||
return torch.index_select(x, index=index, dim=0)
|
|
||||||
|
|
||||||
example_inputs = (
|
|
||||||
torch.randn((32, 64), device=self.device),
|
|
||||||
repeats := torch.tensor([5, 10, 15], device=self.device),
|
|
||||||
)
|
|
||||||
torch._dynamo.mark_dynamic(repeats, 0) # create backed symint
|
|
||||||
|
|
||||||
nested_opt = torch.compile(nested, backend="inductor")
|
|
||||||
|
|
||||||
expect = nested(*example_inputs)
|
|
||||||
actual = nested_opt(*example_inputs)
|
|
||||||
self.assertEqual(expect, actual)
|
|
||||||
|
|
||||||
@torch._inductor.config.patch("graph_partition", True)
|
|
||||||
def test_graph_partition_symint_from_mutation_index(self):
|
|
||||||
x = torch.zeros(7, device=GPU_TYPE)
|
|
||||||
|
|
||||||
def fn(n, a):
|
|
||||||
a[n] = -1
|
|
||||||
return a
|
|
||||||
|
|
||||||
opt_fn = torch.compile(fn, fullgraph=True)
|
|
||||||
|
|
||||||
for n in range(2, x.shape[0]):
|
|
||||||
opt_fn(n, x)
|
|
||||||
self.assertEqual(x[n], -1)
|
|
||||||
|
|
||||||
# Negative index triggers new compilation.
|
|
||||||
opt_fn(-x.shape[0], x)
|
|
||||||
|
|
||||||
self.assertEqual(x[0], -1)
|
|
||||||
|
|
||||||
@torch._inductor.config.patch("graph_partition", True)
|
|
||||||
def test_graph_partition_unbacked_symint(self):
|
|
||||||
def f(x, y):
|
|
||||||
x1 = x + 1
|
|
||||||
y1 = y + 1
|
|
||||||
y_cpu = y1.cpu() + 1
|
|
||||||
z = x @ y
|
|
||||||
return x1 + y1 + z + y_cpu.to(GPU_TYPE)
|
|
||||||
|
|
||||||
f_compiled = torch.compile(f)
|
|
||||||
x, y = (
|
|
||||||
torch.ones(3, 3, device=self.device),
|
|
||||||
torch.randn(3, 3, device=self.device),
|
|
||||||
)
|
|
||||||
|
|
||||||
torch._dynamo.decorators.mark_unbacked(x, 0)
|
|
||||||
torch._dynamo.decorators.mark_unbacked(y, 1)
|
|
||||||
|
|
||||||
compiled_out = f_compiled(x, y)
|
|
||||||
eager_out = f(x, y)
|
|
||||||
self.assertEqual(compiled_out, eager_out)
|
|
||||||
|
|
||||||
@torch._inductor.config.patch("graph_partition", True)
|
|
||||||
def test_graph_partition_dynamic_scalar_inputs(self):
|
|
||||||
def f(x, y, integer):
|
|
||||||
x1 = x + 1
|
|
||||||
y1 = y + 1
|
|
||||||
y_cpu = y1.cpu() + 1
|
|
||||||
z = x @ y
|
|
||||||
z += integer
|
|
||||||
return x1 + y1 + z + y_cpu.to(GPU_TYPE)
|
|
||||||
|
|
||||||
f_compiled = torch.compile(f)
|
|
||||||
x, y = (
|
|
||||||
torch.ones(3, 3, device=self.device),
|
|
||||||
torch.randn(3, 3, device=self.device),
|
|
||||||
)
|
|
||||||
|
|
||||||
torch._dynamo.decorators.mark_unbacked(x, 0)
|
|
||||||
torch._dynamo.decorators.mark_unbacked(y, 1)
|
|
||||||
|
|
||||||
compiled_out = f_compiled(x, y, 5)
|
|
||||||
self.assertEqual(compiled_out, f(x, y, 5))
|
|
||||||
|
|
||||||
compiled_out = f_compiled(x, y, 6)
|
|
||||||
self.assertEqual(compiled_out, f(x, y, 6))
|
|
||||||
|
|
||||||
@torch._inductor.config.patch("graph_partition", True)
|
|
||||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
||||||
def test_graph_partition_item(self):
|
|
||||||
def f(x):
|
|
||||||
y = x + 1
|
|
||||||
scalar = y.item()
|
|
||||||
return x + y + scalar
|
|
||||||
|
|
||||||
compiled_f = torch.compile(f)
|
|
||||||
compiled_out = f(torch.tensor(1, device=GPU_TYPE))
|
|
||||||
self.assertEqual(compiled_out, f(torch.tensor(1, device=GPU_TYPE)))
|
|
||||||
|
|
||||||
@torch._inductor.config.patch("graph_partition", True)
|
|
||||||
def test_graph_partition_buffer_reuse(self):
|
|
||||||
def f(x, y):
|
|
||||||
x1 = x + 1
|
|
||||||
y1 = y + 1
|
|
||||||
y_cpu = y1.cpu() + 1
|
|
||||||
z = x1 + y1 + x @ y
|
|
||||||
u = (y_cpu.to(GPU_TYPE) + 2) @ y + 3
|
|
||||||
u_cpu = u.cpu() + 2
|
|
||||||
return z + u_cpu.to(GPU_TYPE)
|
|
||||||
|
|
||||||
x, y = [torch.ones(2, 2, device=GPU_TYPE) for _ in range(2)]
|
|
||||||
x_cloned, y_cloned = [tmp.clone() for tmp in [x, y]]
|
|
||||||
eager_out = f(x, y)
|
|
||||||
|
|
||||||
f_compiled = torch.compile(f)
|
|
||||||
compiled_out = f_compiled(x_cloned, y_cloned)
|
|
||||||
|
|
||||||
self.assertEqual(eager_out, compiled_out)
|
|
||||||
|
|
||||||
@torch._inductor.config.patch("graph_partition", True)
|
|
||||||
def test_graph_partition_fused_scheduler_node(self):
|
|
||||||
def foo(x):
|
|
||||||
x = x * 20
|
|
||||||
x_alias = x[0]
|
|
||||||
y = x * 10
|
|
||||||
y_alias = y[0]
|
|
||||||
torch._dynamo.graph_break()
|
|
||||||
ind = torch.tensor(4, device=GPU_TYPE)
|
|
||||||
x_alias2 = x[ind:]
|
|
||||||
y_alias2 = y[ind:]
|
|
||||||
return x, x_alias, x_alias2, y_alias, y_alias2
|
|
||||||
|
|
||||||
foo = torch.compile(foo)
|
|
||||||
x = torch.rand([20, 20], device=GPU_TYPE)
|
|
||||||
_, code = run_and_get_code(foo, x)
|
|
||||||
|
|
||||||
if not config.cpp_wrapper:
|
|
||||||
FileCheck().check("def partition_0(args):").run(code[0])
|
|
||||||
|
|
||||||
@unittest.skipIf(TEST_WITH_ROCM or not IS_SM90, "no scaled_grouped_mm support")
|
@unittest.skipIf(TEST_WITH_ROCM or not IS_SM90, "no scaled_grouped_mm support")
|
||||||
def test_respect_scaled_grouped_mm_layout_tag(self):
|
def test_respect_scaled_grouped_mm_layout_tag(self):
|
||||||
# scaled_grouped_mm needs `mat2` to be column-major
|
# scaled_grouped_mm needs `mat2` to be column-major
|
||||||
|
@ -50,6 +50,7 @@ from ..utils import (
|
|||||||
get_benchmark_name,
|
get_benchmark_name,
|
||||||
IndentedBuffer,
|
IndentedBuffer,
|
||||||
is_codegen_graph_partition_subgraph,
|
is_codegen_graph_partition_subgraph,
|
||||||
|
is_using_cudagraph_partition,
|
||||||
LineContext,
|
LineContext,
|
||||||
sympy_product,
|
sympy_product,
|
||||||
sympy_str,
|
sympy_str,
|
||||||
@ -1197,7 +1198,14 @@ class PythonWrapperCodegen(CodeGen):
|
|||||||
self.write_args(graph_input_names)
|
self.write_args(graph_input_names)
|
||||||
|
|
||||||
self.codegen_inputs()
|
self.codegen_inputs()
|
||||||
self.codegen_input_size_and_nan_asserts()
|
|
||||||
|
# avoid duplicating asserts for both partition functions and
|
||||||
|
# the call function when using cudagraph partition
|
||||||
|
if not (
|
||||||
|
is_using_cudagraph_partition()
|
||||||
|
and (not is_codegen_graph_partition_subgraph(self))
|
||||||
|
):
|
||||||
|
self.codegen_input_size_and_nan_asserts()
|
||||||
|
|
||||||
def codegen_input_size_and_nan_asserts(self) -> None:
|
def codegen_input_size_and_nan_asserts(self) -> None:
|
||||||
if config.size_asserts:
|
if config.size_asserts:
|
||||||
|
@ -437,7 +437,11 @@ max_autotune_report_choices_stats = (
|
|||||||
)
|
)
|
||||||
|
|
||||||
# enable inductor graph partition to allow multiple inductor graphs for the same dynamo graph
|
# enable inductor graph partition to allow multiple inductor graphs for the same dynamo graph
|
||||||
graph_partition = False
|
graph_partition: bool = (
|
||||||
|
os.environ.get("TORCHINDUCTOR_GRAPH_PARTITION", "1" if not is_fbcode() else "0")
|
||||||
|
== "1"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# force cublas and triton to use the same precision; cublas supports TF32 for matmul operations
|
# force cublas and triton to use the same precision; cublas supports TF32 for matmul operations
|
||||||
# when m, n, k are multiples of 16, 16, 8, whereas triton supports TF32 for matmul operations
|
# when m, n, k are multiples of 16, 16, 8, whereas triton supports TF32 for matmul operations
|
||||||
|
@ -10,6 +10,8 @@ from torch._dynamo.utils import counters, get_metrics_context
|
|||||||
from torch._inductor.utils import GraphPartitionMap, InputType
|
from torch._inductor.utils import GraphPartitionMap, InputType
|
||||||
from torch.utils._ordered_set import OrderedSet
|
from torch.utils._ordered_set import OrderedSet
|
||||||
|
|
||||||
|
from .utils import is_using_cudagraph_partition
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
@ -170,7 +172,8 @@ def check_multiple_devices_or_any_cpu_nodes(
|
|||||||
# meta tensors are supported since there is no compute
|
# meta tensors are supported since there is no compute
|
||||||
device_node_mapping.pop(torch.device("meta"), None)
|
device_node_mapping.pop(torch.device("meta"), None)
|
||||||
|
|
||||||
if torch._inductor.config.graph_partition:
|
# dynamo cudagraph does not support graph partition
|
||||||
|
if is_using_cudagraph_partition():
|
||||||
# graph partition supports splitting on cpu op. So we can ignore cpu nodes.
|
# graph partition supports splitting on cpu op. So we can ignore cpu nodes.
|
||||||
device_node_mapping.pop(torch.device("cpu"), None)
|
device_node_mapping.pop(torch.device("cpu"), None)
|
||||||
|
|
||||||
|
@ -2179,7 +2179,10 @@ class Scheduler:
|
|||||||
self.nodes = comms.reorder_compute_and_comm_for_overlap(self.nodes)
|
self.nodes = comms.reorder_compute_and_comm_for_overlap(self.nodes)
|
||||||
self.process_grouped_nodes()
|
self.process_grouped_nodes()
|
||||||
|
|
||||||
if torch._inductor.config.graph_partition:
|
if (
|
||||||
|
torch._inductor.config.graph_partition
|
||||||
|
and torch._inductor.config.triton.cudagraphs
|
||||||
|
):
|
||||||
self.nodes = self.maybe_reorder_for_minimizing_partition(self.nodes)
|
self.nodes = self.maybe_reorder_for_minimizing_partition(self.nodes)
|
||||||
self.nodes = self.reorder_for_partition_with_simple_dependency(self.nodes)
|
self.nodes = self.reorder_for_partition_with_simple_dependency(self.nodes)
|
||||||
|
|
||||||
@ -4312,6 +4315,12 @@ class Scheduler:
|
|||||||
) -> bool:
|
) -> bool:
|
||||||
"""Return True if we should partition the inductor graph on this node"""
|
"""Return True if we should partition the inductor graph on this node"""
|
||||||
|
|
||||||
|
# When not using cudagraphs, keep all kernels in the `call` function
|
||||||
|
# instead of graph partition functions, since graph partition only brings
|
||||||
|
# benefit to cudagraph
|
||||||
|
if not torch._inductor.config.triton.cudagraphs:
|
||||||
|
return True
|
||||||
|
|
||||||
# avoid duplicating logs when should_partition is called multiple times
|
# avoid duplicating logs when should_partition is called multiple times
|
||||||
# on the same node
|
# on the same node
|
||||||
def noop_log(msg: str, node: Optional[BaseSchedulerNode]) -> None:
|
def noop_log(msg: str, node: Optional[BaseSchedulerNode]) -> None:
|
||||||
|
@ -3329,6 +3329,13 @@ def is_codegen_graph_partition_subgraph(wrapper: PythonWrapperCodegen) -> bool:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_using_cudagraph_partition() -> bool:
|
||||||
|
return (
|
||||||
|
torch._inductor.config.triton.cudagraphs
|
||||||
|
and torch._inductor.config.graph_partition
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def dtype_from_size(size: int) -> torch.dtype:
|
def dtype_from_size(size: int) -> torch.dtype:
|
||||||
from .virtualized import V
|
from .virtualized import V
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user