Compare commits

...

4 Commits

4 changed files with 357 additions and 24 deletions

View File

@ -0,0 +1,110 @@
# Owner(s): ["module: inductor"]
"""
Test collective op autotuning - Phase 1: Basic functionality with 2 ranks.
This test validates that:
1. Collective ops are detected correctly
2. CollectiveBenchmarker is used for collective ops
3. 2 ranks can sync and benchmark successfully
4. Results are correct
"""
import torch
import torch.distributed as dist
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import run_tests
class TestCollectiveAutotuning(MultiProcessTestCase):
@property
def world_size(self):
return 2
def setUp(self):
super().setUp()
self._spawn_processes()
@skip_if_lt_x_gpu(2)
def test_single_allreduce_2ranks(self):
"""Test single all_reduce with 2 ranks"""
# Initialize distributed
dist.init_process_group(
backend="nccl",
init_method=f"file:///tmp/test_collective_autotune_{self.id()}",
world_size=self.world_size,
rank=self.rank,
)
rank = dist.get_rank()
device = f"cuda:{rank}"
# Register the default process group
from torch._C._distributed_c10d import _register_process_group
_register_process_group("default", dist.group.WORLD)
# Define custom collective op
@torch.library.custom_op("test::my_allreduce", mutates_args=())
def my_allreduce(x: torch.Tensor) -> torch.Tensor:
result = x.clone()
return torch.ops._c10d_functional.all_reduce_(result, "sum", "default")
# Fake implementation for abstract
@my_allreduce.register_fake
def _(x):
return torch.empty_like(x)
# Implementation 1: Direct NCCL
def allreduce_nccl(x):
result = x.clone()
return torch.ops._c10d_functional.all_reduce_(result, "sum", "default")
# Implementation 2: Simulate chunked (for testing multiple choices)
def allreduce_chunked(x, chunk_size=1024):
# For now, just call the regular allreduce
result = x.clone()
return torch.ops._c10d_functional.all_reduce_(result, "sum", "default")
# Register autotuning
from torch._inductor.kernel.custom_op import (
CustomOpConfig,
register_custom_op_autotuning,
)
register_custom_op_autotuning(
my_allreduce,
configs=[
CustomOpConfig(allreduce_nccl),
CustomOpConfig(allreduce_chunked, chunk_size=1024),
],
)
# Test model
class SimpleModel(torch.nn.Module):
def forward(self, x):
return my_allreduce(x)
model = torch.compile(SimpleModel()).to(device)
# Run
x = torch.randn(128, 128, device=device)
x_copy = x.clone()
y = model(x)
# Verify: sum across 2 ranks
expected = x_copy * 2
torch.testing.assert_close(y, expected, rtol=1e-3, atol=1e-3)
if rank == 0:
print("Single allreduce test passed!")
dist.destroy_process_group()
if __name__ == "__main__":
run_tests()

View File

@ -81,6 +81,10 @@ class SubgraphChoiceCaller(ir.ChoiceCaller):
import torch._inductor.config as inductor_config
from torch._inductor.graph import GraphLowering
# Sanitize name to be a valid Python identifier
# Replace :: and other invalid characters with _
safe_name = self.name.replace("::", "_").replace(".", "_")
bm_graph_lowering = GraphLowering(
gm=self.gm,
example_inputs=self.example_inputs,
@ -90,7 +94,7 @@ class SubgraphChoiceCaller(ir.ChoiceCaller):
extern_node_serializer=V.graph.extern_node_serializer,
is_inference=V.graph.is_inference,
is_backward=V.graph.is_backward,
name=f"benchmark_{self.name}",
name=f"benchmark_{safe_name}",
)
for sym_inp in self.sym_inputs:

View File

@ -321,6 +321,43 @@ def autotune_custom_op(
)
input_gen_fns = _adapt_user_input_gen_fns(inputs, arg_names, user_input_gen_fns)
# Detect collective operations for specialized benchmarking
is_collective = False
process_group = None
if op_overload:
from torch._inductor.select_algorithm import is_collective_op
op_name = str(op_overload)
is_collective = is_collective_op(op_name)
if is_collective:
# Extract process_group from non_tensor_args
for kwargs_dict in non_tensor_args:
if "group" in kwargs_dict:
process_group = kwargs_dict["group"]
break
elif "process_group" in kwargs_dict:
process_group = kwargs_dict["process_group"]
break
# Log collective op detection for debugging
import torch.distributed as dist
if dist.is_initialized():
rank = dist.get_rank(process_group)
log.debug(
"Detected collective op on rank %d: %s (process_group=%s)",
rank,
op_name,
"default" if process_group is None else "custom",
)
else:
log.debug(
"Detected collective op: %s (distributed not initialized)",
op_name,
)
# Run autotuning and get both result and winning choice
selected_result, winning_choice = autotune_select_algorithm(
name=name,
@ -329,6 +366,8 @@ def autotune_custom_op(
layout=choices[0].layout,
input_gen_fns=input_gen_fns,
return_choice=True,
is_collective=is_collective,
process_group=process_group,
)
# Apply inlining for fusion if winning_choice has graph; otherwise return result as-is(default fallback impl)

View File

@ -27,7 +27,6 @@ import sympy
import torch
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
from torch._dynamo.device_interface import get_interface_for_device
from torch._dynamo.testing import rand_strided
from torch._dynamo.utils import (
counters,
@ -82,7 +81,6 @@ from .utils import (
do_bench_using_profiling,
FakeIndentedBuffer,
get_dtype_size,
is_gpu,
Placeholder,
restore_stdout_stderr,
sympy_dot,
@ -102,6 +100,33 @@ VERIFY: dict[str, Any] = {}
PRINT_AUTOTUNE = True
DEBUG = False
# Collective operation names for specialized benchmarking
COLLECTIVE_OPS = OrderedSet(
[
"torch.ops._c10d_functional.all_reduce.default",
"torch.ops._c10d_functional.all_reduce_.default",
"torch.ops._c10d_functional.all_gather_into_tensor.default",
"torch.ops._c10d_functional.reduce_scatter_tensor.default",
"torch.ops._c10d_functional.all_to_all_single.default",
"torch.ops._c10d_functional_autograd.all_reduce.default",
"torch.ops._c10d_functional_autograd.all_gather_into_tensor.default",
"torch.ops._c10d_functional_autograd.reduce_scatter_tensor.default",
"torch.ops._c10d_functional_autograd.all_to_all_single.default",
]
)
def is_collective_op(op_name: str) -> bool:
"""Check if an operation is a collective operation.
Args:
op_name: Name of the operation to check
Returns:
True if the operation is a collective op, False otherwise
"""
return op_name in COLLECTIVE_OPS
if TYPE_CHECKING:
import concurrent
@ -2704,6 +2729,8 @@ class AlgorithmSelectorCache(PersistentCache):
return_multi_template=False,
best_config_future=None,
return_choice=False, # TODO: return_choice is temporary and will be refactored soon
is_collective=False, # Flag for collective operations
process_group=None, # Process group for collective ops
):
from .codegen.cuda.cuda_kernel import CUDATemplateCaller
@ -2763,7 +2790,13 @@ class AlgorithmSelectorCache(PersistentCache):
# TODO(nmacchioni): remove this layer of abstraction
# construct `benchmark_fn` which should pick between in-process and sub-process autotuning
benchmark_fn = self.make_benchmark_fn(
choices, input_nodes, layout, input_gen_fns, hint_override=hint_override
choices,
input_nodes,
layout,
input_gen_fns,
hint_override=hint_override,
is_collective=is_collective,
process_group=process_group,
)
# `benchmark_fn(choices)` will execute each choice, and return a dict[choice, timing] which
# maps each choice to its runtime, calculated by the specified benchmarker, in milliseconds
@ -3297,29 +3330,148 @@ class AlgorithmSelectorCache(PersistentCache):
benchmark_tensors = autotune_args.get_benchmark_tensors(cls._is_extern(choice))
inputs, output = benchmark_tensors.unpack()
output.zero_()
result = choice.benchmark(*inputs, out=output)
device_type = next(
(tensor.device.type for tensor in inputs if is_gpu(tensor.device.type)),
"cuda",
)
device_interface = get_interface_for_device(device_type)
if device_interface.is_available():
device_interface.synchronize() # shake out any CUDA errors
if VERIFY and autotune_args.expected is not None:
autotune_args.verify(**VERIFY)
return result
timing = choice.benchmark(*inputs, out=output)
return timing
@classmethod
def benchmark_collective_choice(
cls,
choice: ChoiceCaller,
autotune_args: AutotuneArgs,
process_group,
timeout_seconds: float = 30.0,
) -> float:
"""
Benchmark a choice for collective operations with cross-rank synchronization.
This method ensures all ranks synchronize before and during benchmarking
to get accurate timing measurements for distributed collective operations.
Uses barrier synchronization and collects max time across all ranks.
Args:
choice: The choice to benchmark
autotune_args: Autotuning arguments containing input/output tensors
process_group: Process group for collective synchronization
timeout_seconds: Timeout for benchmarking (unused in current impl)
Returns:
Benchmark time in microseconds (averaged and max-reduced across ranks)
"""
import torch.distributed as dist
rank = dist.get_rank(process_group)
# Get benchmark tensors
benchmark_tensors = autotune_args.get_benchmark_tensors(cls._is_extern(choice))
inputs, output = benchmark_tensors.unpack()
output.zero_()
# For SubgraphChoiceCaller, use barrier-synchronized benchmarking
if hasattr(choice, "gm") and choice.gm is not None:
# Warmup with sync
dist.barrier(group=process_group)
torch.cuda.synchronize()
# Benchmark with multiple runs (using a reasonable default)
nruns = 10
total_time = 0.0
for _ in range(nruns):
# Critical: barrier ensures all ranks start simultaneously
dist.barrier(group=process_group)
torch.cuda.synchronize()
start_evt = torch.cuda.Event(enable_timing=True)
end_evt = torch.cuda.Event(enable_timing=True)
start_evt.record()
choice.benchmark(*inputs, out=output)
end_evt.record()
end_evt.synchronize()
total_time += start_evt.elapsed_time(end_evt)
# Average time in microseconds (ms * 1000)
avg_time = (total_time / nruns) * 1000.0
# All-reduce to get max time across ranks (conservative estimate)
time_tensor = torch.tensor(
[avg_time], dtype=torch.float32, device=f"cuda:{rank}"
)
dist.all_reduce(time_tensor, op=dist.ReduceOp.MAX, group=process_group)
timing = time_tensor.item()
log.debug(
"Collective benchmark for %s on rank %d: %.2f us",
choice.name,
rank,
timing,
)
return timing
else:
# Fallback to regular benchmark for non-subgraph choices
log.debug(
"Choice %s on rank %d does not have gm attribute, using regular benchmark",
choice.name,
rank,
)
return cls.benchmark_choice(choice, autotune_args)
@classmethod
def benchmark_choices(
cls,
choices: Sequence[ChoiceCaller],
autotune_args: AutotuneArgs,
is_collective: bool = False,
process_group=None,
) -> dict[ChoiceCaller, float]:
"""
Benchmark a list of choices and return timing dict.
For collective operations, uses specialized benchmarking with
cross-rank synchronization to ensure accurate timing.
Args:
choices: List of choices to benchmark
autotune_args: Autotuning arguments
is_collective: Whether this is a collective operation
process_group: Process group for collective synchronization
Returns:
Dictionary mapping choices to their benchmark times
"""
# Check if this is a collective operation requiring special handling
if is_collective:
import torch.distributed as dist
if not dist.is_initialized():
log.warning(
"Collective op detected but distributed not initialized. "
"Falling back to regular benchmarking."
)
is_collective = False
else:
rank = dist.get_rank(process_group)
log.debug(
"Using collective benchmarking for %d choices on rank %d",
len(choices),
rank,
)
timings = {}
for choice in choices:
try:
timing = cls.benchmark_choice(choice, autotune_args)
if is_collective:
# Use collective benchmarking with timeout protection
timing = cls.benchmark_collective_choice(
choice, autotune_args, process_group
)
else:
# Regular benchmarking
timing = cls.benchmark_choice(choice, autotune_args)
except CUDACompileError:
from torch._inductor.codegen.cuda.cuda_kernel import CUDATemplateCaller
@ -3383,11 +3535,18 @@ class AlgorithmSelectorCache(PersistentCache):
layout: ir.Layout,
input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]],
hint_override: Optional[int] = None,
is_collective=False,
process_group=None,
) -> dict[ChoiceCaller, float]:
inputs = cls.get_inputs(
choices, input_nodes, layout, input_gen_fns, hint_override=hint_override
)
return cls.benchmark_choices(choices, inputs)
return cls.benchmark_choices(
choices,
inputs,
is_collective=is_collective,
process_group=process_group,
)
@classmethod
def benchmark_in_sub_process(
@ -3419,18 +3578,37 @@ class AlgorithmSelectorCache(PersistentCache):
layout: ir.Layout,
input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]],
hint_override: Optional[int] = None,
is_collective=False,
process_group=None,
):
if DEBUG:
print(f"{len(choices)} tuning requests:")
if config.autotune_in_subproc:
return functools.partial(
cls.benchmark_in_sub_process,
input_nodes=input_nodes,
layout=layout,
input_gen_fns=input_gen_fns,
hint_override=hint_override,
)
# Collective ops in subprocess require special handling.
# For now, fallback to current process if collective.
if is_collective:
log.debug(
"Collective op autotuning in subprocess not yet supported. "
"Falling back to current process."
)
return functools.partial(
cls.benchmark_in_current_process,
input_nodes=input_nodes,
layout=layout,
input_gen_fns=input_gen_fns,
hint_override=hint_override,
is_collective=is_collective,
process_group=process_group,
)
else:
return functools.partial(
cls.benchmark_in_sub_process,
input_nodes=input_nodes,
layout=layout,
input_gen_fns=input_gen_fns,
hint_override=hint_override,
)
else:
return functools.partial(
cls.benchmark_in_current_process,
@ -3438,6 +3616,8 @@ class AlgorithmSelectorCache(PersistentCache):
layout=layout,
input_gen_fns=input_gen_fns,
hint_override=hint_override,
is_collective=is_collective,
process_group=process_group,
)
@staticmethod