Compare commits

...

16 Commits

Author SHA1 Message Date
35df547a0d Update on "Distributed Autotuning"
This is the initial prototype of distributed autotuning. It's intended to be a basis for iteration rather than the final end product.

Currently when we run a SPMD program we compile the ranks independently. As a result the autotuning is repeated on every rank. So for a 8-GPU program with 8 matmul operators we'll autotune 64 (8*8) times.

Distributed autotuning uses collectives to distribute the autotuning across the ranks so each rank autotunes 1/worldsize the total operators. So in our 8-GPU example we would only perform 8 autotunes total (one on each rank) rather than 64.

There are several advantages:
1. Faster autotuning times - each CPU/GPU does less work total
2. Better determinism - currently it's possible for two ranks to choose different algorithms for the same operator. With distributed autotuning we choose the algorithm once for the entire program.

Results:

In testing using llama3 8B on torchtitan max-autotune time was reduced from 52s -> 26s and exhaustive-autotuning was reduced from 2009s -> 613s.

Usage:

The feature is controlled by the environment variable TORCHINDUCTOR_DISTRIBUTED_AUTOTUNE.

Co-authored-by: PaulZhang12 




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-10-27 14:24:40 -07:00
37fb351148 Update base for Update on "Distributed Autotuning"
This is the initial prototype of distributed autotuning. It's intended to be a basis for iteration rather than the final end product.

Currently when we run a SPMD program we compile the ranks independently. As a result the autotuning is repeated on every rank. So for a 8-GPU program with 8 matmul operators we'll autotune 64 (8*8) times.

Distributed autotuning uses collectives to distribute the autotuning across the ranks so each rank autotunes 1/worldsize the total operators. So in our 8-GPU example we would only perform 8 autotunes total (one on each rank) rather than 64.

There are several advantages:
1. Faster autotuning times - each CPU/GPU does less work total
2. Better determinism - currently it's possible for two ranks to choose different algorithms for the same operator. With distributed autotuning we choose the algorithm once for the entire program.

Results:

In testing using llama3 8B on torchtitan max-autotune time was reduced from 52s -> 26s and exhaustive-autotuning was reduced from 2009s -> 613s.

Usage:

The feature is controlled by the environment variable TORCHINDUCTOR_DISTRIBUTED_AUTOTUNE.

Co-authored-by: PaulZhang12 




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-10-27 14:24:40 -07:00
7c963a2ec7 Update on "Distributed Autotuning"
This is the initial prototype of distributed autotuning. It's intended to be a basis for iteration rather than the final end product.

Currently when we run a SPMD program we compile the ranks independently. As a result the autotuning is repeated on every rank. So for a 8-GPU program with 8 matmul operators we'll autotune 64 (8*8) times.

Distributed autotuning uses collectives to distribute the autotuning across the ranks so each rank autotunes 1/worldsize the total operators. So in our 8-GPU example we would only perform 8 autotunes total (one on each rank) rather than 64.

There are several advantages:
1. Faster autotuning times - each CPU/GPU does less work total
2. Better determinism - currently it's possible for two ranks to choose different algorithms for the same operator. With distributed autotuning we choose the algorithm once for the entire program.

Results:

In testing using llama3 8B on torchtitan max-autotune time was reduced from 52s -> 26s and exhaustive-autotuning was reduced from 2009s -> 613s.

Usage:

The feature is controlled by the environment variable TORCHINDUCTOR_DISTRIBUTED_AUTOTUNE.

Co-authored-by: Paul Zhang <paulzhanumich.edu>




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-10-27 10:24:33 -07:00
54b84f19b1 Update base for Update on "Distributed Autotuning"
This is the initial prototype of distributed autotuning. It's intended to be a basis for iteration rather than the final end product.

Currently when we run a SPMD program we compile the ranks independently. As a result the autotuning is repeated on every rank. So for a 8-GPU program with 8 matmul operators we'll autotune 64 (8*8) times.

Distributed autotuning uses collectives to distribute the autotuning across the ranks so each rank autotunes 1/worldsize the total operators. So in our 8-GPU example we would only perform 8 autotunes total (one on each rank) rather than 64.

There are several advantages:
1. Faster autotuning times - each CPU/GPU does less work total
2. Better determinism - currently it's possible for two ranks to choose different algorithms for the same operator. With distributed autotuning we choose the algorithm once for the entire program.

Results:

In testing using llama3 8B on torchtitan max-autotune time was reduced from 52s -> 26s and exhaustive-autotuning was reduced from 2009s -> 613s.

Usage:

The feature is controlled by the environment variable TORCHINDUCTOR_DISTRIBUTED_AUTOTUNE.

Co-authored-by: Paul Zhang <paulzhanumich.edu>




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-10-27 10:24:32 -07:00
b7fc8dff8f Update on "Distributed Autotuning"
This is the initial prototype of distributed autotuning. It's intended to be a basis for iteration rather than the final end product.

Currently when we run a SPMD program we compile the ranks independently. As a result the autotuning is repeated on every rank. So for a 8-GPU program with 8 matmul operators we'll autotune 64 (8*8) times.

Distributed autotuning uses collectives to distribute the autotuning across the ranks so each rank autotunes 1/worldsize the total operators. So in our 8-GPU example we would only perform 8 autotunes total (one on each rank) rather than 64.

There are several advantages:
1. Faster autotuning times - each CPU/GPU does less work total
2. Better determinism - currently it's possible for two ranks to choose different algorithms for the same operator. With distributed autotuning we choose the algorithm once for the entire program.

Results:

In testing using llama3 8B on torchtitan max-autotune time was reduced from 52s -> 26s and exhaustive-autotuning was reduced from 2009s -> 613s.

Usage:

The feature is controlled by the environment variable TORCHINDUCTOR_DISTRIBUTED_AUTOTUNE.

Co-authored-by: Paul Zhang <paulzhanumich.edu>




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-10-17 16:22:14 -07:00
f7a124f7df Update base for Update on "Distributed Autotuning"
This is the initial prototype of distributed autotuning. It's intended to be a basis for iteration rather than the final end product.

Currently when we run a SPMD program we compile the ranks independently. As a result the autotuning is repeated on every rank. So for a 8-GPU program with 8 matmul operators we'll autotune 64 (8*8) times.

Distributed autotuning uses collectives to distribute the autotuning across the ranks so each rank autotunes 1/worldsize the total operators. So in our 8-GPU example we would only perform 8 autotunes total (one on each rank) rather than 64.

There are several advantages:
1. Faster autotuning times - each CPU/GPU does less work total
2. Better determinism - currently it's possible for two ranks to choose different algorithms for the same operator. With distributed autotuning we choose the algorithm once for the entire program.

Results:

In testing using llama3 8B on torchtitan max-autotune time was reduced from 52s -> 26s and exhaustive-autotuning was reduced from 2009s -> 613s.

Usage:

The feature is controlled by the environment variable TORCHINDUCTOR_DISTRIBUTED_AUTOTUNE.

Co-authored-by: Paul Zhang <paulzhanumich.edu>




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-10-17 16:22:13 -07:00
0b2257018b Update on "Distributed Autotuning"
This is the initial prototype of distributed autotuning. It's intended to be a basis for iteration rather than the final end product.

Currently when we run a SPMD program we compile the ranks independently. As a result the autotuning is repeated on every rank. So for a 8-GPU program with 8 matmul operators we'll autotune 64 (8*8) times.

Distributed autotuning uses collectives to distribute the autotuning across the ranks so each rank autotunes 1/worldsize the total operators. So in our 8-GPU example we would only perform 8 autotunes total (one on each rank) rather than 64.

There are several advantages:
1. Faster autotuning times - each CPU/GPU does less work total
2. Better determinism - currently it's possible for two ranks to choose different algorithms for the same operator. With distributed autotuning we choose the algorithm once for the entire program.

Usage:

The feature is controlled by the environment variable TORCHINDUCTOR_DISTRIBUTED_AUTOTUNE.

Co-authored-by: Paul Zhang <paulzhanumich.edu>




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-10-16 10:51:55 -07:00
64b0463029 Update base for Update on "Distributed Autotuning"
This is the initial prototype of distributed autotuning. It's intended to be a basis for iteration rather than the final end product.

Currently when we run a SPMD program we compile the ranks independently. As a result the autotuning is repeated on every rank. So for a 8-GPU program with 8 matmul operators we'll autotune 64 (8*8) times.

Distributed autotuning uses collectives to distribute the autotuning across the ranks so each rank autotunes 1/worldsize the total operators. So in our 8-GPU example we would only perform 8 autotunes total (one on each rank) rather than 64.

There are several advantages:
1. Faster autotuning times - each CPU/GPU does less work total
2. Better determinism - currently it's possible for two ranks to choose different algorithms for the same operator. With distributed autotuning we choose the algorithm once for the entire program.

Usage:

The feature is controlled by the environment variable TORCHINDUCTOR_DISTRIBUTED_AUTOTUNE.

Co-authored-by: Paul Zhang <paulzhanumich.edu>




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-10-16 10:51:54 -07:00
0c942a6209 Update on "Distributed Autotuning"
This is the initial prototype of distributed autotuning. It's intended to be a basis for iteration rather than the final end product.

Currently when we run a SPMD program we compile the ranks independently. As a result the autotuning is repeated on every rank. So for a 8-GPU program with 8 matmul operators we'll autotune 64 (8*8) times.

Distributed autotuning uses collectives to distribute the autotuning across the ranks so each rank autotunes 1/worldsize the total operators. So in our 8-GPU example we would only perform 8 autotunes total (one on each rank) rather than 64.

There are several advantages:
1. Faster autotuning times - each CPU/GPU does less work total
2. Better determinism - currently it's possible for two ranks to choose different algorithms for the same operator. With distributed autotuning we choose the algorithm once for the entire program.

Usage:

The feature is controlled by the environment variable TORCHINDUCTOR_DISTRIBUTED_AUTOTUNE.

Co-authored-by: Paul Zhang <paulzhanumich.edu>




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-10-15 16:31:25 -07:00
92c525c408 Update base for Update on "Distributed Autotuning"
This is the initial prototype of distributed autotuning. It's intended to be a basis for iteration rather than the final end product.

Currently when we run a SPMD program we compile the ranks independently. As a result the autotuning is repeated on every rank. So for a 8-GPU program with 8 matmul operators we'll autotune 64 (8*8) times.

Distributed autotuning uses collectives to distribute the autotuning across the ranks so each rank autotunes 1/worldsize the total operators. So in our 8-GPU example we would only perform 8 autotunes total (one on each rank) rather than 64.

There are several advantages:
1. Faster autotuning times - each CPU/GPU does less work total
2. Better determinism - currently it's possible for two ranks to choose different algorithms for the same operator. With distributed autotuning we choose the algorithm once for the entire program.

Usage:

The feature is controlled by the environment variable TORCHINDUCTOR_DISTRIBUTED_AUTOTUNE.

Co-authored-by: Paul Zhang <paulzhanumich.edu>




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-10-15 16:31:25 -07:00
642991a107 Update on "distributed autotuning"
**Posted for internal discussion - not ready for general review yet.**

Co-authored-by: Paul Zhang <paulzhanumich.edu>




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-10-15 07:53:20 -07:00
d065d05826 Update base for Update on "distributed autotuning"
**Posted for internal discussion - not ready for general review yet.**

Co-authored-by: Paul Zhang <paulzhanumich.edu>




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-10-15 07:53:19 -07:00
01087d7641 Update on "distributed autotuning"
**Posted for internal discussion - not ready for general review yet.**




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-09-21 21:12:36 -07:00
5cbb487556 Update base for Update on "distributed autotuning"
**Posted for internal discussion - not ready for general review yet.**




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-09-21 21:12:35 -07:00
ba7a3df09c distributed autotuning
[ghstack-poisoned]
2025-09-19 13:34:29 -07:00
8a12a07f3d refactor: move replace_operation_buffer to global scope
[ghstack-poisoned]
2025-09-19 13:34:22 -07:00
7 changed files with 633 additions and 68 deletions

View File

@ -2,6 +2,7 @@
import contextlib
import copy
import functools
import logging
import random
import unittest
from contextlib import contextmanager
@ -46,11 +47,18 @@ from torch.testing._internal.common_distributed import (
requires_accelerator_dist_backend,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import skipIfXpu
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
skipIfXpu,
)
from torch.testing._internal.inductor_utils import HAS_GPU
from torch.testing._internal.triton_utils import requires_cuda_and_triton
log = logging.getLogger(__name__)
def reset_rng_state():
torch.manual_seed(1337)
random.seed(1337)
@ -565,6 +573,7 @@ class TestFakeDistributedSingleProc(torch._dynamo.test_case.TestCase):
# single process version; if it's just a problem in the Dynamo distributed
# # optimizer, you should be able to repro it single process!
@requires_accelerator_dist_backend(["nccl", "xccl"])
@instantiate_parametrized_tests
class TestMultiProc(DynamoDistributedMultiProcTestCase):
"""
Note: MultiProcTestCase spawns processes per test and is slow.
@ -1200,6 +1209,124 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
for r in res[1:]:
self.assertEqual(res[0], r)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@patch.object(torch._dynamo.config, "enable_compiler_collectives", True)
@patch.object(torch._inductor.config, "max_autotune_gemm", True)
@patch.object(torch._inductor.config, "distributed_max_autotune_gemm", True)
@parametrize("backend", ("TRITON", "ATEN"))
def test_multiproc_autotune(self, backend):
with patch.object(
torch._inductor.config, "max_autotune_gemm_backends", backend
):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
torch._dynamo.utils.clear_compilation_metrics()
@torch.compile()
def f(a, b, c):
res = (
torch.sum((a @ b) + 1.0)
+ torch.sum(torch.relu(b @ c))
+ torch.sum(c @ a)
)
return res
a = torch.randn(1024, 1024, device=self.rank, dtype=torch.bfloat16)
b = torch.randn(1024, 2048, device=self.rank, dtype=torch.bfloat16)
c = torch.randn(2048, 1024, device=self.rank, dtype=torch.bfloat16)
try:
f(a, b, c)
except Exception:
log.exception("Caught exception running f")
raise
metrics = torch._dynamo.utils.get_compilation_metrics()
res = [None] * self.world_size
torch.distributed.all_gather_object(res, len(metrics))
for r in res[1:]:
self.assertEqual(res[0], r)
print(f"Result from {self.rank} is {f(a, b, c)}")
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@patch.object(torch._dynamo.config, "enable_compiler_collectives", True)
@patch.object(torch._inductor.config, "max_autotune_gemm", True)
@patch.object(torch._inductor.config, "distributed_max_autotune_gemm", True)
@parametrize("backend", ("TRITON", "ATEN"))
def test_multiproc_autotune_dynamic_shapes(self, backend):
with patch.object(
torch._inductor.config, "max_autotune_gemm_backends", backend
):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
torch._dynamo.utils.clear_compilation_metrics()
@torch.compile()
def f(a, b, c):
res = (
torch.sum((a @ b) + 1.0)
+ torch.sum(torch.relu(b @ c))
+ torch.sum(c @ a)
)
return res
a = torch.randn(1024, 1024, device=self.rank, dtype=torch.bfloat16)
b = torch.randn(1024, 2048, device=self.rank, dtype=torch.bfloat16)
c = torch.randn(2048, 1024, device=self.rank, dtype=torch.bfloat16)
# Mark tensors as dynamic on dimension 0
torch._dynamo.mark_dynamic(a, 0)
torch._dynamo.mark_dynamic(a, 1)
torch._dynamo.mark_dynamic(b, 0)
torch._dynamo.mark_dynamic(b, 1)
torch._dynamo.mark_dynamic(c, 0)
torch._dynamo.mark_dynamic(c, 1)
try:
f(a, b, c)
except Exception:
log.exception("Caught exception running f")
raise
metrics = torch._dynamo.utils.get_compilation_metrics()
res = [None] * self.world_size
torch.distributed.all_gather_object(res, len(metrics))
for r in res[1:]:
self.assertEqual(res[0], r)
print(f"Result from {self.rank} is {f(a, b, c)}")
# Store the initial compilation count
initial_compile_count = len(metrics)
# # Test with different sizes to ensure dynamic shapes work without recompilation
a2 = torch.randn(512, 512, device=self.rank, dtype=torch.bfloat16)
b2 = torch.randn(512, 2048, device=self.rank, dtype=torch.bfloat16)
c2 = torch.randn(2048, 512, device=self.rank, dtype=torch.bfloat16)
try:
result2 = f(a2, b2, c2)
print(f"Result2 from {self.rank} is {result2}")
except Exception:
log.exception("Caught exception running f with different sizes")
raise
# Verify no recompilation occurred
metrics_after = torch._dynamo.utils.get_compilation_metrics()
final_compile_count = len(metrics_after)
self.assertEqual(
initial_compile_count,
final_compile_count,
"Expected no recompilation with dynamic shapes",
)
# Verify all ranks have the same compilation count
res_after = [None] * self.world_size
torch.distributed.all_gather_object(res_after, final_compile_count)
for r in res_after[1:]:
self.assertEqual(res_after[0], r)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_get_pg_attr(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):

View File

@ -104,7 +104,7 @@ from .._dynamo.exc import ShortenTraceback, SkipFrame
from ..fx._lazy_graph_module import _use_lazy_graph_module
from ..fx.graph import _PyTreeCodeGen
from ..utils._triton import has_triton
from . import config, metrics
from . import config, distributed_autotune, metrics
from .codegen.common import get_wrapper_codegen_for_device, init_backend_registration
from .debug import DebugContext
from .decomposition import select_decomp_table
@ -1431,7 +1431,11 @@ class _InProcessFxCompile(FxCompile):
# We are going to start code generating runtime asserts, so make sure
# you don't start adding new ones in the lowering process
graph.freeze_runtime_asserts()
with V.set_graph_handler(graph), V.set_extern_kernel_nodes([]):
with (
V.set_graph_handler(graph),
V.set_extern_kernel_nodes([]),
distributed_autotune.graph_context(),
):
graph.run(*example_inputs)
output_strides: list[Optional[tuple[_StrideExprStr, ...]]] = []
if graph.graph_outputs is not None:

View File

@ -445,6 +445,14 @@ use_experimental_benchmarker: bool = Config(
justknob="pytorch/inductor:use_experimental_benchmarker",
)
# Enable distributed autotuning. When this is enabled we will distribute the
# autotuning across distributed ranks in the same program group - so instead of
# each rank autotuning every kernel they only autotune 1/world size kernels and
# then share the results.
distributed_max_autotune_gemm = (
os.environ.get("TORCHINDUCTOR_DISTRIBUTED_MAX_AUTOTUNE_GEMM") == "1"
)
# enable slow autotuning passes to select algorithms
max_autotune = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE") == "1"

View File

@ -0,0 +1,386 @@
from __future__ import annotations
import contextlib
import dataclasses
from typing import Any, TYPE_CHECKING, Union
from unittest.mock import patch
import sympy
import torch._logging
import torch.distributed as dist
import torch.fx
from torch.utils._ordered_set import OrderedSet
from . import config, select_algorithm
from .ir import (
Buffer,
ChoiceCaller,
Layout,
MultiTemplateBuffer,
OperationBuffer,
ShapeAsConstantBuffer,
StorageBox,
TensorBox,
)
from .kernel_inputs import KernelInputs, MMKernelInputs
from .scheduler import SchedulerNode
from .virtualized import NullHandler, V
if TYPE_CHECKING:
from collections.abc import Generator, Sequence
_DISTRIBUTED_AUTOTUNE_KEY = "distributed_autotune"
_AUTOTUNE_PG: dist.ProcessGroup | None = None
@dataclasses.dataclass
class _DistributedAutotuneState:
"""
State used to track autotuning during a graph_context()
"""
# This is the next operator index. Used to figure out which rank should do
# the autotuning.
autotuned_index: int = 0
# For debugging - used to make sure that we autotune the same number of
# local operators that we expected to.
autotuned_local_count: int = 0
@dataclasses.dataclass
class _DistributedAutotuneInfo:
index: int
local: bool
def get_autotune_pg() -> dist.ProcessGroup | None:
if dist.is_available() and dist.is_initialized():
global _AUTOTUNE_PG
if _AUTOTUNE_PG is None:
_AUTOTUNE_PG = dist.distributed_c10d._new_group_with_tag(
pg_tag="pt2_distributed_autotune_pg"
)
return _AUTOTUNE_PG
return None
def schedule(scheduler: torch._inductor.scheduler.Scheduler) -> None:
"""
Finish the distributed autotuning by propagating the autotuning results
between the ranks and then replacing the placeholder with the real Buffer.
"""
assert config.distributed_max_autotune_gemm
autotune_results = _autotune_local_nodes(scheduler)
choices_by_index = _sync(autotune_results)
_autotune_remote_nodes(scheduler, choices_by_index)
@contextlib.contextmanager
def graph_context() -> Generator[None, None, None]:
"""
Wrapped around processing a graph, sets up figuring out which ranks tune
which shapes.
"""
assert not isinstance(
V.get_distributed_autotune_state(check_poisoned=False), # type: ignore[call-arg]
_DistributedAutotuneState,
)
V.set_distributed_autotune_state(_DistributedAutotuneState())
try:
yield
finally:
V.set_distributed_autotune_state(NullHandler())
def maybe_autotune_remote(
name: str, choices: list[ChoiceCaller], inputs: list[Buffer], layout: Layout
) -> TensorBox | ShapeAsConstantBuffer | None:
"""
Used by an op (like `mm`) to determine if the op should be autotuned
locally (returns None) or remotely (returns a placeholder Buffer).
"""
if not config.distributed_max_autotune_gemm:
return None
if not (autotune_pg := get_autotune_pg()):
return None
if len(choices) <= 1:
return None
state = V.distributed_autotune_state
index = state.autotuned_index
state.autotuned_index += 1
local = index % autotune_pg.size() == autotune_pg.rank()
V.current_node.meta[_DISTRIBUTED_AUTOTUNE_KEY] = _DistributedAutotuneInfo(
index, local
)
if local:
state.autotuned_local_count += 1
return None
return torch._inductor.ir.TensorBox.create(
_DistributedAutotuneBuffer(name, inputs, layout)
)
class _DistributedAutotuneBuffer(MultiTemplateBuffer):
"""
A MultiTemplateBuffer which represents a kernel being autotuned on a
different rank. When `schedule` is called this will be replaced by the
"real" buffer.
"""
# Name of the kernel being autotuned.
_kernel_name: str
def __init__(
self,
kernel_name: str,
inputs: list[Buffer],
layout: Layout,
) -> None:
super().__init__(
layout,
inputs,
choice_timings_fn=self._dummy_choice_timings,
unfiltered_choices=[],
allowed_prologue_inps=OrderedSet({}),
)
self._kernel_name = kernel_name
def _dummy_choice_timings(
self, _hint_override: int | None
) -> dict[ChoiceCaller, float]:
# This should never get called. It means that a remote autotune was
# scheduled but never filled in.
raise NotImplementedError
def autotune(self, ser_choice: _SerializedChoice) -> TensorBox:
"""
Given a _SerializedChoice (autotune results from another rank)
compute the final TensorBox.
"""
from .select_algorithm import autotune_select_algorithm
with patch.object(V.graph, "scheduler", None):
kernel_inputs = MMKernelInputs([*self.original_inputs])
assert isinstance(self.layout, Layout)
choice = ser_choice.get_choice(self.layout, kernel_inputs)
buffer = autotune_select_algorithm(
self._kernel_name,
[choice],
kernel_inputs.nodes(),
self.layout,
)
assert isinstance(buffer, TensorBox)
return buffer
# Can we make this async?
def _sync(autotune_results: list[_SerializedChoice]) -> Sequence[_SerializedChoice]:
"""
Perform the all_gather to collect the autotune results from all the ranks.
"""
autotune_pg = get_autotune_pg()
assert autotune_pg
# Perform allgather
all_states: list[list[_SerializedChoice]] = [None] * autotune_pg.size() # type: ignore[list-item]
torch.distributed.all_gather_object(all_states, autotune_results, group=autotune_pg)
node_count = sum(len(x) for x in all_states)
# It's faster to briefly lie about the type than to unzip the results and append.
choices_by_index: list[_SerializedChoice] = [None] * node_count # type: ignore[list-item]
check_count = 0
for i, other_results in enumerate(all_states):
for choice in other_results:
assert isinstance(choice, _SerializedChoice)
assert choices_by_index[choice.index] is None
choices_by_index[choice.index] = choice
check_count += 1
assert node_count == check_count, f"count mismatch: {node_count} != {check_count}"
return choices_by_index
class _SerializedChoice:
"""
This is a serializer for the autotune choice. KernelTemplateChoice can't
be serialized directly (the template and inputs prevent this) so we need to
serialize it by parts and reconstruct later on.
"""
def __init__(self, index: int, choice: ChoiceCaller) -> None:
self.index = index
self.template_uid = _SerializedChoice._template_uid_from_choice(choice)
self.kwargs = self._compute_kwargs(choice.description)
def get_choice(self, layout: Layout, inputs: KernelInputs) -> ChoiceCaller | None:
"""
Deserialize the ChoiceCaller and return it.
"""
template = self._template_from_uid()
kwargs = {**self.kwargs}
if "BLOCK_K" in kwargs:
# TODO: Do we really need to externally compute this value? If it's
# needed I'm surprised it's not just part of the original template
# description.
# This needs the actual 'k' to figure out the value.
k = inputs.nodes()[0].get_size()[1]
kwargs["EVEN_K"] = sympy.gcd(k, kwargs["BLOCK_K"]) == kwargs["BLOCK_K"]
extra_kwargs: dict[str, Any] = {}
from .kernel_template_choice import (
DictKernelTemplateParams,
KernelTemplateChoice,
)
params = DictKernelTemplateParams(kwargs)
ktc = KernelTemplateChoice(template, params, extra_kwargs, layout, inputs)
return ktc.choice
@staticmethod
def _compute_kwargs(description: str) -> dict[str, Union[int, str, bool]]:
"""
Given a template description turn it into input kwargs.
"""
if not description:
return {}
# TODO: It seems like it would be better if the template could provide
# this directly instead of having to parse a string.
kwargs: dict[str, Union[int, str, bool]] = {}
for cfg in description.split(","):
key, val = cfg.split("=", 1)
key, val = key.strip(), val.strip()
if val == "True":
kwargs[key] = True
elif val == "False":
kwargs[key] = False
elif val.isdigit():
kwargs[key] = int(val)
else:
assert val.startswith("'") and val.endswith("'")
kwargs[key] = val[1:-1]
return kwargs
@staticmethod
def _template_uid_from_choice(choice: ChoiceCaller) -> str:
"""
Given a ChoiceCaller figure out which template represents it. This
is reversed by _template_from_uid().
"""
# We need a better way to do this - right now we need to add each
# supported template directly.
if isinstance(choice, select_algorithm.ExternKernelCaller):
if choice.choice.name == "mm":
return "torch._inductor.kernel.mm.aten_mm"
else:
raise RuntimeError(f"TODO: kernel {choice.choice.name!r}")
elif isinstance(choice, select_algorithm.TritonTemplateCaller):
return "torch._inductor.kernel.mm.mm_template"
else:
raise RuntimeError(f"TODO: {type(choice)}")
def _template_from_uid(self) -> Any:
"""
See _template_uid_from_choice().
"""
parts = self.template_uid.split(".")
obj = globals()[parts[0]]
for k in parts[1:]:
obj = getattr(obj, k)
return obj
def _autotune_local_nodes(
scheduler: torch._inductor.scheduler.Scheduler,
) -> list[_SerializedChoice]:
"""
Go through the nodes in the scheduler and autotune the kernels which
should be autotuned by this rank.
"""
autotune_results: list[_SerializedChoice] = []
for node in scheduler.nodes:
if not isinstance(node, SchedulerNode):
continue
if (inner_node := node.node) is None:
continue
if isinstance(inner_node, _DistributedAutotuneBuffer):
# This is marked for remote autotuning.
continue
if not isinstance(inner_node, MultiTemplateBuffer):
continue
if (origin_node := inner_node.origin_node) is None:
continue
if (meta := origin_node.meta) is None:
continue
info = meta.get(_DISTRIBUTED_AUTOTUNE_KEY)
if info is None:
continue
assert info.local
# We force autotuning here
# Still takes advantage of async precompile
# We need all the configs before fusion
min_choice, _ = inner_node.get_min_choice()
choice = _SerializedChoice(info.index, min_choice)
autotune_results.append(choice)
state = V.distributed_autotune_state
assert len(autotune_results) == state.autotuned_local_count, (
f"incorrect local autotuned nodes found ({len(autotune_results)} != {state.autotuned_local_count})"
)
return autotune_results
def _autotune_remote_nodes(
scheduler: torch._inductor.scheduler.Scheduler,
choices_by_index: Sequence[_SerializedChoice],
) -> None:
"""
Go through the nodes in the scheduler and autotune the nodes that were
autotuned on remote ranks.
"""
for i, node in enumerate(scheduler.nodes):
if isinstance(node, SchedulerNode) and isinstance(
(dist_node := node.node), _DistributedAutotuneBuffer
):
assert dist_node.origin_node is not None
info = dist_node.origin_node.meta[_DISTRIBUTED_AUTOTUNE_KEY]
out_tensorbox = dist_node.autotune(choices_by_index[info.index])
out_storage = out_tensorbox.data
assert isinstance(out_storage, StorageBox)
out_buffer = out_storage.data
assert isinstance(out_buffer, OperationBuffer)
assert out_buffer.layout == dist_node.layout
scheduler._replace_node(out_buffer, dist_node, i, node)

View File

@ -19,7 +19,7 @@ from torch.fx.experimental.proxy_tensor import make_fx
from torch.nn.functional import ScalingType # type: ignore[attr-defined]
from torch.torch_version import TorchVersion
from .. import config as inductor_config
from .. import config as inductor_config, distributed_autotune
from ..codegen.cuda.gemm_template import CUTLASS2xGemmTemplate, CUTLASS3xGemmTemplate
from ..codegen.rocm.ck_tile_universal_gemm_template import CKTileGemmTemplate
from ..codegen.rocm.ck_universal_gemm_template import CKGemmTemplate
@ -1096,6 +1096,11 @@ def tuned_mm(mat1, mat2, out_dtype=None, *, layout=None):
# The future will be awaited at scheduling time in select_algorithm.py
best_config_future = gen_best_config(mat1, mat2)
if box := distributed_autotune.maybe_autotune_remote(
name, choices, kernel_inputs.nodes(), layout
):
return box
return autotune_select_algorithm(
name,
choices,

View File

@ -230,9 +230,9 @@ class SchedulerDonatedBuffer(SchedulerBuffer):
class BaseSchedulerNode:
ancestors: OrderedSet[str]
group: tuple[torch.device, tuple[tuple[sympy.Expr, ...], ...]]
read_writes: dependencies.ReadWrites
unmet_dependencies: OrderedSet[Dep]
last_usage: OrderedSet[str]
# .min_order and .max_order are only relevant for "grouped" nodes such as FusedSchedulerNode.
# e.g. if the FusedSchedulerNode includes nodes (op_1, op_2, op_3), and op_X is X-th node
# in `self.scheduler.nodes`, then for this FusedSchedulerNode, .min_order is 1 and .max_order is 3.
@ -241,7 +241,14 @@ class BaseSchedulerNode:
min_order: int
max_order: int
mpi_node: MemoryPlanningInfoForNode
mutation_renames: dict[str, str]
node: Optional[ir.Operation] = None
outputs: list[SchedulerBuffer]
outputs_by_name: dict[str, SchedulerBuffer]
override_estimated_runtime: Optional[float] = None
read_writes: dependencies.ReadWrites
unmet_dependencies: OrderedSet[Dep]
written: bool = False
def __init__(self, scheduler: Scheduler) -> None:
self.scheduler: Scheduler = scheduler
@ -250,13 +257,13 @@ class BaseSchedulerNode:
)
def _init_from_node(self, node: ir.Operation) -> None:
self.node: Optional[ir.Operation] = node
self.ancestors: OrderedSet[str] = OrderedSet()
self.node = node
self.ancestors = OrderedSet()
self.last_usage = OrderedSet[
str
]() # buffers that won't be used after this kernel
self.written = False
self.outputs: list[SchedulerBuffer] = [
self.outputs = [
SchedulerBuffer(
scheduler=self.scheduler,
node=output,
@ -264,16 +271,14 @@ class BaseSchedulerNode:
)
for output in node.get_outputs()
]
self.outputs_by_name: dict[str, SchedulerBuffer] = {
buf.get_name(): buf for buf in self.outputs
}
self.outputs_by_name = {buf.get_name(): buf for buf in self.outputs}
# mutation_renames for the current node. Due to potential
# more mutations happening later, this can be different
# to Scheduler.mutation_renames. Also this dict should be small
# since only mutation information relevant to the deps for this
# node is stored here.
self.mutation_renames: dict[str, str] = {}
self.mutation_renames = {}
def __repr__(self) -> str:
return f"{type(self).__name__}(name={self.get_name()!r})"
@ -2218,6 +2223,34 @@ def pick_loop_order(
return order
def _replace_operation_buffer(
orig_node: ir.MultiTemplateBuffer, new_node: ir.OperationBuffer
) -> None:
replaced_buf_name = new_node.get_name()
orig_buf_name = orig_node.get_name()
assert isinstance(orig_buf_name, str) and isinstance(replaced_buf_name, str)
replaced_op_name = new_node.get_operation_name()
orig_op_name = orig_node.get_operation_name()
assert isinstance(orig_op_name, str) and isinstance(replaced_op_name, str)
del V.graph.name_to_buffer[replaced_buf_name]
new_node.name = orig_buf_name
del V.graph.name_to_op[replaced_op_name]
new_node.operation_name = orig_op_name
orig = V.graph.buffers.index(orig_node)
V.graph.buffers.remove(new_node)
V.graph.buffers[orig] = new_node
V.graph.name_to_buffer[orig_buf_name] = new_node
orig = V.graph.operations.index(orig_node)
V.graph.operations.remove(new_node)
V.graph.operations[orig] = new_node
V.graph.name_to_op[orig_op_name] = new_node
@dataclasses.dataclass
class NodeUser:
node: Union[BaseSchedulerNode, OutputNode]
@ -2347,6 +2380,12 @@ class Scheduler:
if config._pre_fusion_custom_pass is not None:
self.nodes = config._pre_fusion_custom_pass(self.nodes)
if config.distributed_max_autotune_gemm:
from . import distributed_autotune
distributed_autotune.schedule(self)
self.compute_ancestors()
self.nodes = self.fuse_nodes(self.nodes)
if config._post_fusion_custom_pass is not None:
self.nodes = config._post_fusion_custom_pass(self.nodes)
@ -3119,33 +3158,6 @@ class Scheduler:
will force completion of compilation and benchmarking.
"""
def replace_operation_buffer(
orig_node: ir.MultiTemplateBuffer, new_node: ir.OperationBuffer
) -> None:
replaced_buf_name = new_node.get_name()
orig_buf_name = orig_node.get_name()
assert isinstance(orig_buf_name, str) and isinstance(replaced_buf_name, str)
replaced_op_name = new_node.get_operation_name()
orig_op_name = orig_node.get_operation_name()
assert isinstance(orig_op_name, str) and isinstance(replaced_op_name, str)
del V.graph.name_to_buffer[replaced_buf_name]
new_node.name = orig_buf_name
del V.graph.name_to_op[replaced_op_name]
new_node.operation_name = orig_op_name
orig = V.graph.buffers.index(orig_node)
V.graph.buffers.remove(new_node)
V.graph.buffers[orig] = new_node
V.graph.name_to_buffer[orig_buf_name] = new_node
orig = V.graph.operations.index(orig_node)
V.graph.operations.remove(new_node)
V.graph.operations[orig] = new_node
V.graph.name_to_op[orig_op_name] = new_node
for i, node in enumerate(self.nodes):
if isinstance(node, SchedulerNode) and isinstance(
node.node, ir.MultiTemplateBuffer
@ -3195,40 +3207,48 @@ class Scheduler:
assert isinstance(out_buffer, ir.OperationBuffer)
out_buffer.layout = multi_node.layout
replace_operation_buffer(multi_node, out_buffer)
new_scheduler_node = self.create_scheduler_node(out_buffer)
self._replace_node(out_buffer, multi_node, i, node)
self.nodes[i] = new_scheduler_node
self.name_to_node[node.get_name()] = new_scheduler_node
self.name_to_fused_node[node.get_name()] = new_scheduler_node
def _replace_node(
self,
out_buffer: ir.OperationBuffer,
multi_node: ir.MultiTemplateBuffer,
i: int,
node: SchedulerNode,
) -> None:
_replace_operation_buffer(multi_node, out_buffer)
new_scheduler_node = self.create_scheduler_node(out_buffer)
# We need to reflect the mutation renames that were recorded in the original node
mutation_renames = {}
for dep in itertools.chain(
node.read_writes.reads, node.unmet_dependencies
):
if real_name := self.mutation_real_name.get(dep.name, None):
mutation_renames[real_name] = dep.name
self.nodes[i] = new_scheduler_node
self.name_to_node[node.get_name()] = new_scheduler_node
self.name_to_fused_node[node.get_name()] = new_scheduler_node
def rename_deps(deps: OrderedSet[Dep]) -> OrderedSet[Dep]:
return OrderedSet(dep.rename(mutation_renames) for dep in deps)
# We need to reflect the mutation renames that were recorded in the original node
mutation_renames = {}
for dep in itertools.chain(node.read_writes.reads, node.unmet_dependencies):
if real_name := self.mutation_real_name.get(dep.name, None):
mutation_renames[real_name] = dep.name
new_scheduler_node.unmet_dependencies = rename_deps(
new_scheduler_node.unmet_dependencies
)
new_scheduler_node.read_writes.reads = rename_deps(
new_scheduler_node.read_writes.reads
)
def rename_deps(deps: OrderedSet[Dep]) -> OrderedSet[Dep]:
return OrderedSet(dep.rename(mutation_renames) for dep in deps)
for new_out, old_out in zip(
new_scheduler_node.get_outputs(), node.get_outputs()
):
self.name_to_buf[old_out.get_name()] = new_out
new_out.users = old_out.users
new_scheduler_node.unmet_dependencies = rename_deps(
new_scheduler_node.unmet_dependencies
)
new_scheduler_node.read_writes.reads = rename_deps(
new_scheduler_node.read_writes.reads
)
new_scheduler_node.min_order = node.min_order
new_scheduler_node.max_order = node.max_order
new_scheduler_node.last_usage = node.last_usage
for new_out, old_out in zip(
new_scheduler_node.get_outputs(), node.get_outputs()
):
self.name_to_buf[old_out.get_name()] = new_out
new_out.users = old_out.users
new_scheduler_node.min_order = node.min_order
new_scheduler_node.max_order = node.max_order
new_scheduler_node.ancestors = node.ancestors
new_scheduler_node.last_usage = node.last_usage
def _any_atomic_add(self, node_list: Sequence[BaseSchedulerNode]) -> bool:
return any(

View File

@ -84,6 +84,8 @@ if TYPE_CHECKING:
from torch._inductor.loop_body import InterpreterShim
from torch._subclasses import FakeTensorMode
from .distributed_autotune import _DistributedAutotuneState
threadlocal = local()
T = TypeVar("T")
@ -199,6 +201,9 @@ _current_node: Virtualized[torch.fx.Node] = Virtualized("current_node", NullHand
_local_buffer_context: Virtualized[LocalBufferContext] = Virtualized(
"local_buffer_context", NullHandler
)
_distributed_autotune_state: Virtualized[_DistributedAutotuneState] = Virtualized(
"distributed_autotune_state", NullHandler
)
def _choices_default():
@ -364,6 +369,12 @@ class _V:
set_local_buffer_context: Callable[[Any], Any] = _local_buffer_context._set_handler
get_local_buffer_context: Callable[[], Any] = _local_buffer_context._get_handler
set_choices_handler: Callable[[Any], Any] = _choices._set_handler
set_distributed_autotune_state: Callable[[Any], Any] = (
_distributed_autotune_state._set_handler
)
get_distributed_autotune_state: Callable[[], Any] = (
_distributed_autotune_state._get_handler
)
@property
def ops(self) -> OpsHandler[Any]:
@ -423,5 +434,9 @@ class _V:
def choices(self) -> InductorChoices:
return _choices._get_handler()
@property
def distributed_autotune_state(self):
return _distributed_autotune_state._get_handler()
V = _V()