Compare commits

..

14 Commits

Author SHA1 Message Date
8c6c066024 Update on "[CP] Refactor CP sharding rules into separate module and only register when CP is enabled"
Previously, CP-specific sharding strategies (which shard on the sequence dimension) were directly
included in the base sharding strategies for scaled_dot_product_attention operators in
`_matrix_ops.py`. This meant these strategies were always available, even when CP was not enabled,
which could lead to incorrect sharding behavior as these sharding rules are not mathmetically correct without CP.

1. **Created new module**:
`torch/distributed/tensor/experimental/_context_parallel/_sharding_rules.py`
   - Implements `op_strategy_context()` - a context manager for temporarily
registering/unregistering strategies
   - Defines CP-enhanced strategy functions for all 6 scaled_dot_product_attention ops (forward and
 backward for flash, efficient, and cudnn variants)
   - Provides `register_cp_sharding_rules()` and `unregister_cp_sharding_rules()` APIs

2. **Updated `_matrix_ops.py`**
   - Removed all CP-specific sharding rules (sequence dimension sharding strategies)
   - Base strategies now only contain replicate, tensor parallelism, and batch sharding
strategies

3. **Updated `_attention.py`**
   - `_enable_cp_dtensor_dispatcher()` now calls `register_cp_sharding_rules()` to dynamically add
CP strategies
   - ~`_disable_cp_dtensor_dispatcher()` now calls `unregister_cp_sharding_rules()` to restore
original strategies~ This will invalidate all the sharding prop caches. Disable it for now.

cc H-Huang awgu wanchaol fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-11-11 08:53:59 -08:00
87ae43a0d6 Update base for Update on "[CP] Refactor CP sharding rules into separate module and only register when CP is enabled"
Previously, CP-specific sharding strategies (which shard on the sequence dimension) were directly
included in the base sharding strategies for scaled_dot_product_attention operators in
`_matrix_ops.py`. This meant these strategies were always available, even when CP was not enabled,
which could lead to incorrect sharding behavior as these sharding rules are not mathmetically correct without CP.

1. **Created new module**:
`torch/distributed/tensor/experimental/_context_parallel/_sharding_rules.py`
   - Implements `op_strategy_context()` - a context manager for temporarily
registering/unregistering strategies
   - Defines CP-enhanced strategy functions for all 6 scaled_dot_product_attention ops (forward and
 backward for flash, efficient, and cudnn variants)
   - Provides `register_cp_sharding_rules()` and `unregister_cp_sharding_rules()` APIs

2. **Updated `_matrix_ops.py`**
   - Removed all CP-specific sharding rules (sequence dimension sharding strategies)
   - Base strategies now only contain replicate, tensor parallelism, and batch sharding
strategies

3. **Updated `_attention.py`**
   - `_enable_cp_dtensor_dispatcher()` now calls `register_cp_sharding_rules()` to dynamically add
CP strategies
   - ~`_disable_cp_dtensor_dispatcher()` now calls `unregister_cp_sharding_rules()` to restore
original strategies~ This will invalidate all the sharding prop caches. Disable it for now.

cc H-Huang awgu wanchaol fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-11-11 08:53:59 -08:00
2cff58cc2d Update
[ghstack-poisoned]
2025-11-11 00:04:25 -08:00
46bd412746 Update
[ghstack-poisoned]
2025-11-10 23:22:33 -08:00
5d07795b28 Update
[ghstack-poisoned]
2025-11-10 17:39:24 -08:00
306914e071 Update (base update)
[ghstack-poisoned]
2025-11-10 09:50:54 -08:00
fb6807bf86 Update
[ghstack-poisoned]
2025-11-10 09:50:54 -08:00
f6a79b2a4a [inductor] Wrap pallas_call in jax.jit (#167441)
My understanding is this is needed for performance.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167441
Approved by: https://github.com/oulgen
2025-11-10 17:29:56 +00:00
2fcf41dd8e Add the ruff rule and skip everything for now (#167360)
Part of https://github.com/pytorch/pytorch/issues/164878
We can start narrowing the skips and remove them as PRs keep landing.

This PR is just to setup the scaffolding, fix will be in follow up
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167360
Approved by: https://github.com/janeyx99
2025-11-10 17:10:15 +00:00
31ccd8f13e [AOTI] Fix a mixed-device bug for scatter_add (#167341)
Summary: Fix https://github.com/pytorch/pytorch/issues/166841. AOTI incorrectly generates a call to aoti_torch_cuda_scatter_reduce_two_out while the op should actually run on CPU. Fix by using the correct device when calling _generate_scatter_fallback in the wrapper codegen.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167341
Approved by: https://github.com/yushangdi
2025-11-10 16:59:44 +00:00
1209123500 Update
[ghstack-poisoned]
2025-11-07 16:21:29 -08:00
168fc7cfc5 Update
[ghstack-poisoned]
2025-11-07 15:37:40 -08:00
639276822f Update (base update)
[ghstack-poisoned]
2025-11-07 15:19:16 -08:00
8cbe13ad31 Update
[ghstack-poisoned]
2025-11-07 15:19:16 -08:00
15 changed files with 1089 additions and 280 deletions

View File

@ -1308,8 +1308,319 @@ coverage_ignore_functions = [
# torch.onnx.symbolic_opset7
"max",
"min",
# torch.onnx.symbolic_opset8
"addmm",
"bmm",
"empty",
"empty_like",
"flatten",
"full",
"full_like",
"gt",
"lt",
"matmul",
"mm",
"ones",
"ones_like",
"prelu",
"repeat",
"zeros",
"zeros_like",
# torch.onnx.symbolic_opset9
"abs",
"acos",
"adaptive_avg_pool1d",
"adaptive_avg_pool2d",
"adaptive_avg_pool3d",
"adaptive_max_pool1d",
"adaptive_max_pool2d",
"adaptive_max_pool3d",
"add",
"addcmul",
"addmm",
"alias",
"amax",
"amin",
"aminmax",
"arange",
"argmax",
"argmin",
"as_strided",
"as_tensor",
"asin",
"atan",
"atan2",
"avg_pool1d",
"avg_pool2d",
"avg_pool3d",
"baddbmm",
"batch_norm",
"bernoulli",
"bitwise_not",
"bitwise_or",
"bmm",
"broadcast_tensors",
"broadcast_to",
"bucketize",
"cat",
"cdist",
"ceil",
"clamp",
"clamp_max",
"clamp_min",
"clone",
"constant_pad_nd",
"contiguous",
"conv1d",
"conv2d",
"conv3d",
"conv_tbc",
"conv_transpose1d",
"conv_transpose2d",
"conv_transpose3d",
"convert_element_type",
"convolution",
"cos",
"cosine_similarity",
"cross",
"cumsum",
"detach",
"dim",
"div",
"dot",
"dropout",
"elu",
"embedding",
"embedding_bag",
"empty",
"empty_like",
"eq",
"erf",
"exp",
"expand",
"expand_as",
"eye",
"fill",
"flatten",
"floor",
"floor_divide",
"floordiv",
"frobenius_norm",
"full",
"full_like",
"gather",
"ge",
"gelu",
"get_pool_ceil_padding",
"glu",
"group_norm",
"gru",
"gt",
"hann_window",
"hardshrink",
"hardsigmoid",
"hardswish",
"hardtanh",
"index",
"index_add",
"index_copy",
"index_fill",
"index_put",
"index_select",
"instance_norm",
"is_floating_point",
"is_pinned",
"isnan",
"item",
"kl_div",
"layer_norm",
"le",
"leaky_relu",
"lerp",
"lift",
"linalg_cross",
"linalg_matrix_norm",
"linalg_norm",
"linalg_vector_norm",
"linear",
"linspace",
"log",
"log10",
"log1p",
"log2",
"log_sigmoid",
"log_softmax",
"logical_and",
"logical_not",
"logical_or",
"logical_xor",
"logit",
"logsumexp",
"lstm",
"lstm_cell",
"lt",
"masked_fill",
"masked_fill_",
"matmul",
"max",
"max_pool1d",
"max_pool1d_with_indices",
"max_pool2d",
"max_pool2d_with_indices",
"max_pool3d",
"max_pool3d_with_indices",
"maximum",
"meshgrid",
"min",
"minimum",
"mish",
"mm",
"movedim",
"mse_loss",
"mul",
"multinomial",
"mv",
"narrow",
"native_layer_norm",
"ne",
"neg",
"new_empty",
"new_full",
"new_ones",
"new_zeros",
"nonzero",
"nonzero_numpy",
"noop_complex_operators",
"norm",
"numel",
"numpy_T",
"one_hot",
"ones",
"ones_like",
"onnx_placeholder",
"overload_by_arg_count",
"pad",
"pairwise_distance",
"permute",
"pixel_shuffle",
"pixel_unshuffle",
"pow",
"prelu",
"prim_constant",
"prim_constant_chunk",
"prim_constant_split",
"prim_data",
"prim_device",
"prim_dtype",
"prim_if",
"prim_layout",
"prim_list_construct",
"prim_list_unpack",
"prim_loop",
"prim_max",
"prim_min",
"prim_shape",
"prim_tolist",
"prim_tuple_construct",
"prim_type",
"prim_unchecked_cast",
"prim_uninitialized",
"rand",
"rand_like",
"randint",
"randint_like",
"randn",
"randn_like",
"reciprocal",
"reflection_pad",
"relu",
"relu6",
"remainder",
"repeat",
"repeat_interleave",
"replication_pad",
"reshape",
"reshape_as",
"rnn_relu",
"rnn_tanh",
"roll",
"rrelu",
"rsqrt",
"rsub",
"scalar_tensor",
"scatter",
"scatter_add",
"select",
"selu",
"sigmoid",
"sign",
"silu",
"sin",
"size",
"slice",
"softmax",
"softplus",
"softshrink",
"sort",
"split",
"split_with_sizes",
"sqrt",
"square",
"squeeze",
"stack",
"std",
"std_mean",
"sub",
"t",
"take",
"tan",
"tanh",
"tanhshrink",
"tensor",
"threshold",
"to",
"topk",
"transpose",
"true_divide",
"type_as",
"unbind",
"unfold",
"unsafe_chunk",
"unsafe_split",
"unsafe_split_with_sizes",
"unsqueeze",
"unsupported_complex_operators",
"unused",
"upsample_bilinear2d",
"upsample_linear1d",
"upsample_nearest1d",
"upsample_nearest2d",
"upsample_nearest3d",
"upsample_trilinear3d",
"var",
"var_mean",
"view",
"view_as",
"where",
"wrap_logical_op_with_cast_to",
"wrap_logical_op_with_negation",
"zero",
"zeros",
"zeros_like",
# torch.onnx.utils
"disable_apex_o2_state_dict_hook",
"export",
"export_to_pretty_string",
"exporter_context",
"is_in_onnx_export",
"model_signature",
"register_custom_op_symbolic",
"select_model_mode_for_export",
"setup_onnx_logging",
"unconvertible_ops",
"unpack_quantized_tensor",
"warn_on_static_input_change",
# torch.onnx.verification
"check_export_model_diff",
"verify",
"verify_aten_graph",
@ -1400,6 +1711,32 @@ coverage_ignore_functions = [
"noop_context_fn",
"set_checkpoint_early_stop",
"set_device_states",
# torch.utils.collect_env
"check_release_file",
"get_cachingallocator_config",
"get_clang_version",
"get_cmake_version",
"get_conda_packages",
"get_cpu_info",
"get_cuda_module_loading_config",
"get_cudnn_version",
"get_env_info",
"get_gcc_version",
"get_gpu_info",
"get_libc_version",
"get_lsb_version",
"get_mac_version",
"get_nvidia_driver_version",
"get_nvidia_smi",
"get_os",
"get_pip_packages",
"get_platform",
"get_pretty_env_info",
"get_python_platform",
"get_running_cuda_version",
"get_windows_version",
"is_xnnpack_available",
"pretty_str",
# torch.utils.cpp_backtrace
"get_cpp_backtrace",
# torch.utils.cpp_extension
@ -1463,6 +1800,52 @@ coverage_ignore_functions = [
"apply_shuffle_seed",
"apply_shuffle_settings",
"get_all_graph_pipes",
# torch.utils.flop_counter
"addmm_flop",
"baddbmm_flop",
"bmm_flop",
"conv_backward_flop",
"conv_flop",
"conv_flop_count",
"convert_num_with_suffix",
"get_shape",
"get_suffix_str",
"mm_flop",
"normalize_tuple",
"register_flop_formula",
"sdpa_backward_flop",
"sdpa_backward_flop_count",
"sdpa_flop",
"sdpa_flop_count",
"shape_wrapper",
"transpose_shape",
# torch.utils.hipify.hipify_python
"add_dim3",
"compute_stats",
"extract_arguments",
"file_add_header",
"file_specific_replacement",
"find_bracket_group",
"find_closure_group",
"find_parentheses_group",
"fix_static_global_kernels",
"get_hip_file_path",
"hip_header_magic",
"hipify",
"is_caffe2_gpu_file",
"is_cusparse_file",
"is_out_of_place",
"is_pytorch_file",
"is_special_file",
"match_extensions",
"matched_files_iter",
"openf",
"preprocess_file_and_save_result",
"preprocessor",
"processKernelLaunches",
"replace_extern_shared",
"replace_math_functions",
"str2bool",
# torch.utils.hooks
"unserializable_hook",
"warn_if_has_hooks",

View File

@ -19,91 +19,6 @@
swap_tensors
```
# torch.utils.collect_env
```{eval-rst}
.. automodule:: torch.utils.collect_env
```
```{eval-rst}
.. currentmodule:: torch.utils.collect_env
```
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
check_release_file
is_xnnpack_available
pretty_str
```
# torch.utils.flop_counter
```{eval-rst}
.. automodule:: torch.utils.flop_counter
```
```{eval-rst}
.. currentmodule:: torch.utils.flop_counter
```
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
baddbmm_flop
bmm_flop
conv_backward_flop
conv_flop
conv_flop_count
register_flop_formula
sdpa_backward_flop
sdpa_backward_flop_count
sdpa_flop
sdpa_flop_count
shape_wrapper
```
# torch.utils.hipify.hipify_python
```{eval-rst}
.. automodule:: torch.utils.hipify.hipify_python
```
```{eval-rst}
.. currentmodule:: torch.utils.hipify.hipify_python
```
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
compute_stats
extract_arguments
file_add_header
file_specific_replacement
find_bracket_group
find_closure_group
find_parentheses_group
fix_static_global_kernels
hip_header_magic
hipify
is_caffe2_gpu_file
is_cusparse_file
is_out_of_place
is_pytorch_file
is_special_file
openf
preprocess_file_and_save_result
preprocessor
processKernelLaunches
replace_extern_shared
replace_math_functions
str2bool
```
<!-- This module needs to be documented. Adding here in the meantime
for tracking purposes -->
```{eval-rst}
@ -128,6 +43,7 @@ for tracking purposes -->
.. py:module:: torch.utils.benchmark.utils.valgrind_wrapper.timer_interface
.. py:module:: torch.utils.bundled_inputs
.. py:module:: torch.utils.checkpoint
.. py:module:: torch.utils.collect_env
.. py:module:: torch.utils.cpp_backtrace
.. py:module:: torch.utils.cpp_extension
.. py:module:: torch.utils.data.backward_compatibility
@ -164,8 +80,10 @@ for tracking purposes -->
.. py:module:: torch.utils.data.sampler
.. py:module:: torch.utils.dlpack
.. py:module:: torch.utils.file_baton
.. py:module:: torch.utils.flop_counter
.. py:module:: torch.utils.hipify.constants
.. py:module:: torch.utils.hipify.cuda_to_hip_mappings
.. py:module:: torch.utils.hipify.hipify_python
.. py:module:: torch.utils.hipify.version
.. py:module:: torch.utils.hooks
.. py:module:: torch.utils.jit.log_extract

View File

@ -260,6 +260,7 @@ select = [
"TRY401", # verbose-log-message
"UP",
"YTT",
"S101",
]
[tool.ruff.lint.pyupgrade]
@ -339,6 +340,39 @@ keep-runtime-typing = true
"tools/linter/**" = [
"LOG015" # please fix
]
"benchmarks/**" = [
"S101"
]
"test/**" = [
"S101"
]
"torchgen/**" = [
"S101"
]
"torch/**" = [
"S101"
]
"tools/**" = [
"S101"
]
"setup.py" = [
"S101"
]
"functorch/**" = [
"S101"
]
"docs/**" = [
"S101"
]
"android/**" = [
"S101"
]
".github/**" = [
"S101"
]
".ci/**" = [
"S101"
]
[tool.codespell]
ignore-words = "tools/linter/dictionary.txt"

View File

@ -813,6 +813,50 @@ class TestSharding(DTensorTestBase):
),
)
@skip_if_lt_x_gpu(2)
@with_comms
@unittest.skipIf(
not PLATFORM_SUPPORTS_FUSED_ATTENTION,
"Does not support flash nor efficient attention",
)
def test_attention_shard_without_cp(self) -> None:
"""Test that sharding on sequence dimension without CP enabled is not supported."""
from torch.distributed.tensor import distribute_tensor, Replicate, Shard
B = 2
nheads = 4
seq_len = 256
dim = 32
device_mesh = init_device_mesh(
mesh_shape=(2,), mesh_dim_names=("cp",), device_type=self.device_type
)
# Create q, k, v tensors with shape (B, nheads, seq_len, dim)
q = torch.randn(
B, nheads, seq_len, dim, device=self.device_type, dtype=torch.bfloat16
)
k = torch.randn(
B, nheads, seq_len, dim, device=self.device_type, dtype=torch.bfloat16
)
v = torch.randn(
B, nheads, seq_len, dim, device=self.device_type, dtype=torch.bfloat16
)
q_dt = distribute_tensor(q, device_mesh, [Shard(2)])
k_dt = distribute_tensor(k, device_mesh, [Shard(2)])
v_dt = distribute_tensor(v, device_mesh, [Shard(2)])
# Run SDPA with sequence-sharded tensors WITHOUT enabling CP
# Without CP enabled, DTensor should select a different strategy
# (not sequence-sharded) because Shard(2) strategy is only available with CP
out = F.scaled_dot_product_attention(q_dt, k_dt, v_dt)
# Verify the output is NOT sharded on sequence dimension (dim 2)
# This proves that CP sharding rules were not used
self.assertNotEqual(out.placements[0], Shard(2))
# The output should be replicated or sharded on batch head dimensions.
self.assertIn(out.placements[0], [Replicate(), Shard(0), Shard(1)])
RingAttentionTestWithLocalTensor = create_local_tensor_test_class(
RingAttentionTest,

View File

@ -7522,6 +7522,38 @@ class AOTInductorTestsTemplate:
eager_outputs = model(*example_inputs)
torch.testing.assert_close(eager_outputs, compiled_outputs)
@requires_gpu
def test_mixed_device_1(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("Mixed-device test requires GPU")
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
# Buffers are on CPU
self.register_buffer(
"index", torch.tensor([1, 4, 1, 7], device="cpu", dtype=torch.int64)
)
self.register_buffer(
"src", torch.ones(4, device="cpu", dtype=torch.int64)
)
def forward(self, matrix, vector):
# Inputs are on CUDA
# 1. Operation on CPU tensors
z = torch.zeros((vector.shape[0],), device="cpu", dtype=torch.int64)
scatter_result = z.scatter_add(0, self.index, self.src)
# 2. Move result to CUDA and continue on CUDA
v = vector + scatter_result.to(vector.dtype).to(GPU_TYPE)
return torch.matmul(matrix, v)
example_inputs = (
torch.randn(10, 10, device=self.device),
torch.randn(10, device=self.device),
)
self.check_model(Model(), example_inputs, move_model_to_device=False)
class AOTInductorLoggingTest(LoggingTestCase):
@make_logging_test(dynamic=logging.DEBUG)

View File

@ -218,6 +218,7 @@ def check_model(
dynamic_shapes=None,
atol=None,
rtol=None,
move_model_to_device=True,
):
with (
torch.no_grad(),
@ -229,7 +230,7 @@ def check_model(
),
):
torch.manual_seed(0)
if not isinstance(model, types.FunctionType):
if not isinstance(model, types.FunctionType) and move_model_to_device:
model = model.to(self.device)
# For non mixed device inputs with default "cpu",set the device manually.

View File

@ -1,5 +1,6 @@
# Owner(s): ["oncall: pt2"]
import functools
import re
import sys
import unittest
@ -230,6 +231,33 @@ class PallasTestsMixin:
self.assertIn("import jax.numpy as jnp", code)
self.assertIn("from jax.experimental import pallas as pl", code)
def test_jax_jit_wrapper_is_emitted(self):
"""Ensure generated Pallas code wraps pl.pallas_call in jax.jit."""
key = "cuda_backend" if self.DEVICE == "cuda" else "cpu_backend"
@torch.compile(backend="inductor", options={key: "pallas"})
def pallas_fn(a, b):
return a + b
_, (code,) = run_and_get_code(
pallas_fn,
torch.randn(32, device=self.DEVICE),
torch.randn(32, device=self.DEVICE),
)
kernel_match = re.search(r"def (pallas_[A-Za-z0-9_]+)_kernel", code)
self.assertIsNotNone(kernel_match)
kernel_name = kernel_match.group(1)
wrapper_name = f"{kernel_name}_jit_wrapper"
self.assertIn(wrapper_name, code)
start = code.index(f"def {wrapper_name}")
end = code.index(f"def {kernel_name}_main", start)
wrapper_block = code[start:end]
self.assertIn("jax.jit", code)
self.assertNotIn("torch.", wrapper_block)
def test_2d_tensor(self):
"""Test with 2D tensors (though current implementation flattens)."""

View File

@ -221,7 +221,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
"""
)
self.add_device_include(self.device)
for device in V.graph.device_types:
if device != "meta":
self.add_device_include(device)
if V.graph.aot_mode:
if config.aot_inductor.dynamic_linkage:
@ -1423,11 +1425,13 @@ class CppWrapperCpu(PythonWrapperCodegen):
src_is_tensor,
reduce,
kwargs,
device,
):
reduce = self._get_scatter_reduce_enum(reduce)
# call the ABI shim function instead of the ATen one
cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name, self.device)
self.add_device_include(device)
cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name, device)
# TODO: consider remove "_out" and add missing inplace variants to fallback_ops.py
cpp_kernel_name = cpp_kernel_name.replace("__", "_") + "_out"
inputs_wrapped = [str(x) for x in inputs]

View File

@ -708,11 +708,14 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
src_is_tensor,
reduce,
kwargs,
device,
):
reduce = self._get_scatter_reduce_enum(reduce)
# call the ABI shim function instead of the ATen one
cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name, self.device)
self.add_device_include(device)
cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name, device)
# TODO: consider remove "_out" and add missing inplace variants to fallback_ops.py
cpp_kernel_name = cpp_kernel_name.replace("__", "_") + "_out"
self._assert_safe_to_use_borrow_arrayref_tensor_as_tensor()

View File

@ -287,6 +287,7 @@ class PallasKernel(SIMDKernel):
code = IndentedBuffer()
code.splice(
"""
import functools
import torch
import jax
import jax.numpy as jnp
@ -301,6 +302,9 @@ class PallasKernel(SIMDKernel):
kernel_params = [a.name for a in arg_defs]
kernel_name = name or "<KERNEL_NAME>"
interpret_literal = (
"True" if V.graph.get_current_device_or_throw().type == "cpu" else "False"
)
code.writeline(f"def {kernel_name}_kernel({', '.join(kernel_params)}):")
with code.indent():
# Emit compute (CSE) and store lines; they reference *_ptr[...] directly
@ -309,16 +313,22 @@ class PallasKernel(SIMDKernel):
for line in self.stores._lines:
code.writeline(str(line))
jit_wrapper_name = f"{kernel_name}_jit_wrapper"
code.writeline("@functools.partial(jax.jit, static_argnums=(0, 1))")
code.writeline(f"def {jit_wrapper_name}(out_shape, out_dtype, *kernel_refs):")
with code.indent():
code.writeline("out_spec = jax.ShapeDtypeStruct(out_shape, out_dtype)")
code.writeline("return pl.pallas_call(")
code.writeline(f" {kernel_name}_kernel,")
code.writeline(" out_shape=out_spec,")
code.writeline(f" interpret={interpret_literal},")
code.writeline(" grid=(1,),")
code.writeline(")(*kernel_refs)")
# Host entry: convert torch tensors <-> jax, call pallas_call and copy back
main_name = f"{kernel_name}_main"
code.writeline(f"def {main_name}({', '.join(kernel_params)}, stream=None):")
with code.indent():
# Determine interpret statically based on codegen device
interpret_literal = (
"True"
if V.graph.get_current_device_or_throw().type == "cpu"
else "False"
)
# Identify inputs (in_ptr*) and output (out_ptr*)
input_params = [
p for p in kernel_params if p.startswith(("in_ptr", "in_out_ptr"))
@ -337,9 +347,9 @@ class PallasKernel(SIMDKernel):
for inp in input_params:
code.writeline(f"{inp}_jax = jax.dlpack.from_dlpack({inp})")
# Get output spec from PyTorch tensor
code.writeline("# Prepare output spec from PyTorch tensor")
code.writeline("# Map PyTorch dtype to JAX dtype string")
# Get output metadata from PyTorch tensor
code.writeline("# Prepare output metadata from PyTorch tensor")
code.writeline("# Map PyTorch dtype to JAX dtype")
code.writeline("_torch_dtype_to_jax = {")
code.writeline(
" torch.float32: jnp.float32, torch.float64: jnp.float64, torch.float16: jnp.float16,"
@ -349,21 +359,14 @@ class PallasKernel(SIMDKernel):
)
code.writeline(" torch.uint8: jnp.uint8, torch.bool: jnp.bool_,")
code.writeline("}")
code.writeline(
f"out_spec = jax.ShapeDtypeStruct({output_param}.shape, _torch_dtype_to_jax[{output_param}.dtype])"
)
code.writeline(f"out_shape = tuple({output_param}.shape)")
code.writeline(f"out_dtype = _torch_dtype_to_jax[{output_param}.dtype]")
# Call pallas
# Pass interpret=True on CPU, False otherwise (single call, no duplication)
code.writeline("compiled = pl.pallas_call(")
code.writeline(f" lambda *refs: {kernel_name}_kernel(*refs),")
code.writeline(" out_shape=out_spec,")
code.writeline(f" interpret={interpret_literal},")
code.writeline(" grid=(1,),")
code.writeline(")")
jax_input_args = ", ".join([f"{inp}_jax" for inp in input_params])
code.writeline(f"res = compiled({jax_input_args})")
call_args = ["out_shape", "out_dtype"] + [
f"{inp}_jax" for inp in input_params
]
call_arg_str = ", ".join(call_args)
code.writeline(f"res = {jit_wrapper_name}({call_arg_str})")
# Copy result back
code.writeline("# Copy result back into the provided torch output tensor")

View File

@ -971,6 +971,7 @@ class ScatterFallbackLine(WrapperLine):
else:
(x, index) = (t.codegen_reference() for t in node.inputs)
src = node.constant_args[1]
device = d.type if (d := node.get_device()) else V.graph.device_type
self.wrapper._generate_scatter_fallback(
x,
[x, node.constant_args[0], index, src],
@ -979,6 +980,7 @@ class ScatterFallbackLine(WrapperLine):
node.src_is_tensor,
node.kwargs["reduce"],
node.codegen_kwargs(),
device,
)
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
@ -1632,6 +1634,7 @@ class PythonWrapperCodegen(CodeGen):
src_is_tensor,
reduce,
kwargs,
device,
):
line = f"{python_kernel_name}({','.join(map(str, inputs))}"
if python_kernel_name.startswith("aten.scatter_reduce"):

View File

@ -267,16 +267,10 @@ def scaled_mm_strategy(op_schema: OpSchema) -> OpStrategy:
return _scaled_mm_like_strategy("mk,kn->mn", mesh, op_schema)
@register_op_strategy(
aten._scaled_dot_product_flash_attention.default, schema_info=RuntimeSchemaInfo(5)
)
def scaled_dot_product_flash_attention_strategy(op_schema: OpSchema) -> OpStrategy:
# NOTE: currently we only support some simple strategies to support tensor parallelism
# TODO: sdpa might be a good candidate for us to explore decomposed sharding propagation
# as it involves: matmul, pointwise, reduction ops together.
mesh = op_schema.get_mesh_from_args()
def _scaled_dot_product_flash_attention_base_strategies(
op_schema: OpSchema,
) -> list[PlacementList]:
"""Helper that returns list of base placement strategies (without CP)."""
return_debug_mask = len(op_schema.args_schema) >= 6 and op_schema.args_schema[5]
q_input_strategy = op_schema.args_schema[0]
if not isinstance(q_input_strategy, OpStrategy):
@ -349,37 +343,30 @@ def scaled_dot_product_flash_attention_strategy(op_schema: OpSchema) -> OpStrate
Shard(0), # v
]
)
return single_mesh_dim_strategies
# Context Parallelism: shards on the sequence dim
debug_attn_mask_sharding = Shard(2) if return_debug_mask else Replicate()
single_mesh_dim_strategies.append(
[
Shard(2), # output
Shard(2), # logsumexp
None, # cum_seq_q
None, # cum_seq_k
None, # max_q
None, # max_k
Replicate(), # rng_state
None, # unused
debug_attn_mask_sharding, # debugattn
Shard(2), # q
Shard(2), # k
Shard(2), # v
]
@register_op_strategy(
aten._scaled_dot_product_flash_attention.default, schema_info=RuntimeSchemaInfo(5)
)
def scaled_dot_product_flash_attention_strategy(op_schema: OpSchema) -> OpStrategy:
# NOTE: currently we only support some simple strategies to support tensor parallelism
# TODO: sdpa might be a good candidate for us to explore decomposed sharding propagation
# as it involves: matmul, pointwise, reduction ops together.
mesh = op_schema.get_mesh_from_args()
single_mesh_dim_strategies = _scaled_dot_product_flash_attention_base_strategies(
op_schema
)
return expand_to_full_mesh_op_strategy(
mesh, op_schema, single_mesh_dim_strategies, input_index=9
)
@register_op_strategy(aten._scaled_dot_product_flash_attention_backward.default)
def scaled_dot_product_flash_attention_backward_strategy(
def _scaled_dot_product_flash_attention_backward_base_strategies(
op_schema: OpSchema,
) -> OpStrategy:
# backward op does not need to validate the mesh since forward op has already done it
mesh = op_schema.get_mesh_from_args(validate=False)
) -> list[PlacementList]:
"""Helper that returns list of base placement strategies (without CP)."""
q_input_strategy = op_schema.args_schema[1]
if not isinstance(q_input_strategy, OpStrategy):
raise AssertionError(f"Expected OpStrategy, got {type(q_input_strategy)}")
@ -444,24 +431,18 @@ def scaled_dot_product_flash_attention_backward_strategy(
batch_dim_sharding.extend([Replicate()] * (num_tensor_inputs - 6))
single_mesh_dim_strategies.append(batch_dim_sharding)
# Context Parallelism: shards on the sequence dim
seq_dim_sharding: PlacementList = [
Shard(2), # grad_q
Shard(2), # grad_k
Shard(2), # grad_v
Shard(2), # grad_output
Shard(2), # q
Shard(2), # k
Shard(2), # v
Shard(2), # output
Shard(2), # logsumexp
]
# accept replicate on the rest tensor inputs, potentially
# cum_seq_q, cum_seq_k, philox_seed, philox_offset
# at indices 6, 7, 12, 13, respectively
seq_dim_sharding.extend([Replicate()] * (num_tensor_inputs - 6))
single_mesh_dim_strategies.append(seq_dim_sharding)
return single_mesh_dim_strategies
@register_op_strategy(aten._scaled_dot_product_flash_attention_backward.default)
def scaled_dot_product_flash_attention_backward_strategy(
op_schema: OpSchema,
) -> OpStrategy:
# backward op does not need to validate the mesh since forward op has already done it
mesh = op_schema.get_mesh_from_args(validate=False)
single_mesh_dim_strategies = (
_scaled_dot_product_flash_attention_backward_base_strategies(op_schema)
)
return expand_to_full_mesh_op_strategy(
mesh, op_schema, single_mesh_dim_strategies, input_index=3
)
@ -486,13 +467,10 @@ def constant_pad_nd_strategy(op_schema: OpSchema) -> OpStrategy:
)
@register_op_strategy(
aten._scaled_dot_product_efficient_attention.default,
schema_info=RuntimeSchemaInfo(4),
)
def scaled_dot_product_efficient_attention_strategy(op_schema: OpSchema) -> OpStrategy:
# NOTE: currently we only support some simple strategies to support tensor parallelism
mesh = op_schema.get_mesh_from_args()
def _scaled_dot_product_efficient_attention_base_strategies(
op_schema: OpSchema,
) -> list[PlacementList]:
"""Helper that returns list of base placement strategies (without CP)."""
q_input_strategy = op_schema.args_schema[0]
if not isinstance(q_input_strategy, OpStrategy):
raise AssertionError(f"Expected OpStrategy, got {type(q_input_strategy)}")
@ -518,19 +496,6 @@ def scaled_dot_product_efficient_attention_strategy(op_schema: OpSchema) -> OpSt
if has_attn_bias:
all_replicate.append(Replicate()) # attn bias
# Context Parallelism: shards on the sequence dim
single_mesh_dim_strategies.append(
[
Shard(2), # output
Shard(2), # logsumexp
None, # philox_seed
None, # philox_offset
Shard(2), # q
Shard(2), # k
Shard(2), # v
]
)
single_mesh_dim_strategies.append(all_replicate)
# second we can accept the sharding pattern of tensor parallelism, which
@ -576,6 +541,19 @@ def scaled_dot_product_efficient_attention_strategy(op_schema: OpSchema) -> OpSt
single_mesh_dim_strategies.append(batch_sharding)
return single_mesh_dim_strategies
@register_op_strategy(
aten._scaled_dot_product_efficient_attention.default,
schema_info=RuntimeSchemaInfo(4),
)
def scaled_dot_product_efficient_attention_strategy(op_schema: OpSchema) -> OpStrategy:
# NOTE: currently we only support some simple strategies to support tensor parallelism
mesh = op_schema.get_mesh_from_args()
single_mesh_dim_strategies = (
_scaled_dot_product_efficient_attention_base_strategies(op_schema)
)
return expand_to_full_mesh_op_strategy(
mesh,
op_schema,
@ -584,13 +562,10 @@ def scaled_dot_product_efficient_attention_strategy(op_schema: OpSchema) -> OpSt
)
@register_op_strategy(aten._scaled_dot_product_efficient_attention_backward.default)
def scaled_dot_product_efficient_attention_backward_strategy(
def _scaled_dot_product_efficient_attention_backward_base_strategies(
op_schema: OpSchema,
) -> OpStrategy:
# backward op does not need to validate the mesh since forward op has already done it
mesh = op_schema.get_mesh_from_args(validate=False)
) -> list[PlacementList]:
"""Helper that returns list of base placement strategies (without CP)."""
q_input_strategy = op_schema.args_schema[1]
if not isinstance(q_input_strategy, OpStrategy):
raise AssertionError(f"Expected OpStrategy, got {type(q_input_strategy)}")
@ -662,27 +637,18 @@ def scaled_dot_product_efficient_attention_backward_strategy(
batch_dim_sharding.extend([Replicate(), Replicate()])
single_mesh_dim_strategies.append(batch_dim_sharding)
# Context Parallelism: shards on the sequence dim
seq_dim_sharding: PlacementList = [
Shard(2), # grad_q
Shard(2), # grad_k
Shard(2), # grad_v
Shard(1) if has_attn_bias else None, # grad_bias
Shard(2), # grad_output
Shard(2), # q
Shard(2), # k
Shard(2), # v
Shard(2), # output
Shard(2), # logsumexp
]
# accept replicate on the rest tensor inputs, potentially
# cum_seq_q, cum_seq_k, philox_seed, philox_offset
# at indices 6, 7, 12, 13, respectively
if has_attn_bias:
num_heads_dim_sharding.insert(8, Shard(1))
seq_dim_sharding.extend([Replicate(), Replicate()])
single_mesh_dim_strategies.append(seq_dim_sharding)
return single_mesh_dim_strategies
@register_op_strategy(aten._scaled_dot_product_efficient_attention_backward.default)
def scaled_dot_product_efficient_attention_backward_strategy(
op_schema: OpSchema,
) -> OpStrategy:
# backward op does not need to validate the mesh since forward op has already done it
mesh = op_schema.get_mesh_from_args(validate=False)
single_mesh_dim_strategies = (
_scaled_dot_product_efficient_attention_backward_base_strategies(op_schema)
)
return expand_to_full_mesh_op_strategy(
mesh,
op_schema,
@ -691,13 +657,10 @@ def scaled_dot_product_efficient_attention_backward_strategy(
)
@register_op_strategy(
aten._scaled_dot_product_cudnn_attention.default,
schema_info=RuntimeSchemaInfo(4),
)
def scaled_dot_product_cudnn_attention_strategy(op_schema: OpSchema) -> OpStrategy:
mesh = op_schema.get_mesh_from_args()
def _scaled_dot_product_cudnn_attention_base_strategies(
op_schema: OpSchema,
) -> list[PlacementList]:
"""Helper that returns list of base placement strategies (without CP)."""
(
query_strategy, # query
_, # key
@ -785,39 +748,27 @@ def scaled_dot_product_cudnn_attention_strategy(op_schema: OpSchema) -> OpStrate
]
single_mesh_dim_strategies.append(batch_dim_sharding)
# Context Parallelism: shards on the sequence dim
cp_sharding = Shard(2) # seq dim
logsumexp_sharding = cp_sharding if compute_log_sumexp else Replicate()
debug_attn_mask_sharding = cp_sharding if return_debug_mask else None
return single_mesh_dim_strategies
single_mesh_dim_strategies.append(
[
cp_sharding, # output
logsumexp_sharding, # logsumexp
None, # cum_seq_q
None, # cum_seq_k
None, # max_q
None, # max_k
None, # philox_seed
None, # philox_offset
debug_attn_mask_sharding, # debug_attn_mask
cp_sharding, # q
cp_sharding, # k
cp_sharding, # v
]
@register_op_strategy(
aten._scaled_dot_product_cudnn_attention.default,
schema_info=RuntimeSchemaInfo(4),
)
def scaled_dot_product_cudnn_attention_strategy(op_schema: OpSchema) -> OpStrategy:
mesh = op_schema.get_mesh_from_args()
single_mesh_dim_strategies = _scaled_dot_product_cudnn_attention_base_strategies(
op_schema
)
return expand_to_full_mesh_op_strategy(
mesh, op_schema, single_mesh_dim_strategies, input_index=9
)
@register_op_strategy(aten._scaled_dot_product_cudnn_attention_backward.default)
def scaled_scaled_dot_product_cudnn_attention_backward_strategy(
def _scaled_dot_product_cudnn_attention_backward_base_strategies(
op_schema: OpSchema,
) -> OpStrategy:
# backward op does not need to validate the mesh since forward op has already done it
mesh = op_schema.get_mesh_from_args(validate=False)
) -> list[PlacementList]:
"""Helper that returns list of base placement strategies (without CP)."""
if len(op_schema.args_schema) < 15:
raise AssertionError(
f"Expected at least 15 args_schema, got {len(op_schema.args_schema)}"
@ -892,23 +843,7 @@ def scaled_scaled_dot_product_cudnn_attention_backward_strategy(
num_heads_dim_sharding = num_heads_dim_sharding_out + num_heads_dim_sharding_inp
single_mesh_dim_strategies.append(num_heads_dim_sharding)
# case 3: Context Parallelism which shards on the sequence dim
context_parallel_sharding_out: PlacementList = [Shard(2)] * 3
context_parallel_sharding_inp: PlacementList = [Shard(2)] * 6
context_parallel_sharding_inp += [
Replicate()
] * 2 # philox_seed, philox_offset is casted to Replicate() in DTensor
context_parallel_sharding_inp += [Shard(2) if has_attn_bias else None]
context_parallel_sharding_inp += [None] * 6
if has_scale:
context_parallel_sharding_inp.append(None)
context_parallel_sharding = (
context_parallel_sharding_out + context_parallel_sharding_inp
)
single_mesh_dim_strategies.append(context_parallel_sharding)
# case 4: we can accept the sharding pattern of batch parallelism, which
# case 3: we can accept the sharding pattern of batch parallelism, which
# shards on the batch dimension
qkv_sharding = Shard(0)
output_sharding = Shard(0)
@ -929,6 +864,18 @@ def scaled_scaled_dot_product_cudnn_attention_backward_strategy(
batch_dim_sharding = batch_dim_sharding_out + batch_dim_sharding_inp
single_mesh_dim_strategies.append(batch_dim_sharding)
return single_mesh_dim_strategies
@register_op_strategy(aten._scaled_dot_product_cudnn_attention_backward.default)
def scaled_scaled_dot_product_cudnn_attention_backward_strategy(
op_schema: OpSchema,
) -> OpStrategy:
# backward op does not need to validate the mesh since forward op has already done it
mesh = op_schema.get_mesh_from_args(validate=False)
single_mesh_dim_strategies = (
_scaled_dot_product_cudnn_attention_backward_base_strategies(op_schema)
)
return expand_to_full_mesh_op_strategy(
mesh, op_schema, single_mesh_dim_strategies, input_index=3
)

View File

@ -989,16 +989,31 @@ def _restore_function(fn: Callable, fn_module: types.ModuleType) -> None:
def _enable_cp_dtensor_dispatcher() -> None:
"""Enables DTensor dispatcher to dispatch SDPA to CP."""
# Enable custom op handlers for CP
DTensor._op_dispatcher._custom_op_handlers = {
**exitsing_custom_ops,
**custom_ops,
}
# Register CP-specific sharding rules
from ._sharding_rules import register_cp_sharding_rules
register_cp_sharding_rules()
def _disable_cp_dtensor_dispatcher() -> None:
"""Disables DTensor dispatcher to dispatch SDPA to CP."""
# Restore original custom op handlers
DTensor._op_dispatcher._custom_op_handlers = exitsing_custom_ops
# TODO: unregister_cp_sharding_rules() will cause all DTensor sharding
# propagation cache being invalidated. It is not easy to achieve
# selectively invalidating lru cache without rewriting the sharding
# propagation wrapper. Disable unregister_cp_sharding_rules() call
# for now.
# from ._sharding_rules import unregister_cp_sharding_rules
# unregister_cp_sharding_rules()
def _enable_context_parallel_dispatcher_impl(seq_dim: int, mesh: DeviceMesh) -> None:
sdpa_cp = _ContextParallel(
@ -1032,9 +1047,7 @@ def _disable_context_parallel_dispatcher_impl() -> None:
_disable_cp_dtensor_dispatcher()
_compiled_create_block_mask = torch.compile(
create_block_mask, dynamic=False, fullgraph=True
)
_compiled_create_block_mask = None
def _context_parallel_buffers(
@ -1187,9 +1200,12 @@ def _create_cp_block_mask(
f"BLOCK_SIZE {_DEFAULT_SPARSE_BLOCK_SIZE}. This is not supported yet. "
)
compiled_create_block_mask = torch.compile(
create_block_mask, dynamic=False, fullgraph=True
)
global _compiled_create_block_mask
if _compiled_create_block_mask is None:
_compiled_create_block_mask = torch.compile(
create_block_mask, dynamic=False, fullgraph=True
)
compiled_create_block_mask = _compiled_create_block_mask
def _rewrite_mask_mod(
mask_mod: _mask_mod_signature,

View File

@ -0,0 +1,399 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
"""
Context Parallelism sharding rules for scaled_dot_product attention operators.
The sharding rules for CP cannot be embedded by default because Shard(2) is not
a valid sharding for SDPA without CP enabled. This module provides utilities to
dynamically install Shard(2) sharding rules when CP is activated.
"""
from contextlib import contextmanager
import torch
from torch.distributed.tensor._op_schema import (
OpSchema,
OpStrategy,
PlacementList,
RuntimeSchemaInfo,
)
from torch.distributed.tensor._ops.utils import (
expand_to_full_mesh_op_strategy,
register_op_strategy,
)
from torch.distributed.tensor.placement_types import Replicate, Shard
aten = torch.ops.aten
SEQ_DIM = 2
@contextmanager
def _op_strategy_context(op_overload, strategy_func, schema_info=None):
"""
Context manager for setting and clearing op strategies for Context Parallelism.
Args:
op_overload: The operator overload to set or clear the strategy for.
strategy_func: The strategy function to set for the operator overload.
schema_info: Optional schema information for the operator overload.
Yields:
None
"""
from torch.distributed.tensor import DTensor
propagator = DTensor._op_dispatcher.sharding_propagator
_origin_op_strategy_funcs = None
_origin_op_strategy_schema = None
try:
# Save original strategy if exists
if op_overload in propagator.op_strategy_funcs:
_origin_op_strategy_funcs = propagator.op_strategy_funcs[op_overload]
if op_overload in propagator.op_to_schema_info:
_origin_op_strategy_schema = propagator.op_to_schema_info[op_overload]
# Register the new op strategy
register_op_strategy(op_overload, schema_info=schema_info)(strategy_func)
yield (_origin_op_strategy_funcs, _origin_op_strategy_schema)
finally:
# Restore original strategy
if _origin_op_strategy_funcs is None:
if op_overload in propagator.op_strategy_funcs:
del propagator.op_strategy_funcs[op_overload]
else:
propagator.op_strategy_funcs[op_overload] = _origin_op_strategy_funcs
if _origin_op_strategy_schema is None:
if op_overload in propagator.op_to_schema_info:
del propagator.op_to_schema_info[op_overload]
else:
propagator.op_to_schema_info[op_overload] = _origin_op_strategy_schema
# Clear cache
propagator.propagate_op_sharding.cache.cache_clear()
# ==================== Flash Attention Strategies ====================
def _scaled_dot_product_flash_attention_cp_strategy(op_schema: OpSchema) -> OpStrategy:
"""
Strategy for flash attention forward with Context Parallelism support.
This includes the base strategies plus CP-specific sequence dimension sharding.
"""
# Import here to avoid circular dependency
from torch.distributed.tensor._ops._matrix_ops import (
_scaled_dot_product_flash_attention_base_strategies,
)
# Get the base strategies (without CP modifications)
mesh = op_schema.get_mesh_from_args()
single_mesh_dim_strategies = _scaled_dot_product_flash_attention_base_strategies(
op_schema
)
# Add Context Parallelism strategy: shards on the sequence dim
return_debug_mask = len(op_schema.args_schema) >= 6 and op_schema.args_schema[5]
debug_attn_mask_sharding = Shard(SEQ_DIM) if return_debug_mask else Replicate()
cp_strategy: PlacementList = [
Shard(SEQ_DIM), # output
Shard(SEQ_DIM), # logsumexp
None, # cum_seq_q
None, # cum_seq_k
None, # max_q
None, # max_k
Replicate(), # rng_state
None, # unused
debug_attn_mask_sharding, # debugattn
Shard(SEQ_DIM), # q
Shard(SEQ_DIM), # k
Shard(SEQ_DIM), # v
]
single_mesh_dim_strategies.append(cp_strategy)
return expand_to_full_mesh_op_strategy(
mesh, op_schema, single_mesh_dim_strategies, input_index=9
)
def _scaled_dot_product_flash_attention_backward_cp_strategy(
op_schema: OpSchema,
) -> OpStrategy:
"""
Strategy for flash attention backward with Context Parallelism support.
"""
from torch.distributed.tensor._ops._matrix_ops import (
_scaled_dot_product_flash_attention_backward_base_strategies,
)
mesh = op_schema.get_mesh_from_args(validate=False)
single_mesh_dim_strategies = (
_scaled_dot_product_flash_attention_backward_base_strategies(op_schema)
)
tensor_input_indices = [
i
for i, arg_spec in enumerate(op_schema.args_schema)
if isinstance(arg_spec, OpStrategy)
]
num_tensor_inputs = len(tensor_input_indices)
# Context Parallelism: shards on the sequence dim
cp_strategy: PlacementList = [
Shard(SEQ_DIM), # grad_q
Shard(SEQ_DIM), # grad_k
Shard(SEQ_DIM), # grad_v
Shard(SEQ_DIM), # grad_output
Shard(SEQ_DIM), # q
Shard(SEQ_DIM), # k
Shard(SEQ_DIM), # v
Shard(SEQ_DIM), # output
Shard(SEQ_DIM), # logsumexp
]
cp_strategy.extend([Replicate()] * (num_tensor_inputs - 6))
single_mesh_dim_strategies.append(cp_strategy)
return expand_to_full_mesh_op_strategy(
mesh, op_schema, single_mesh_dim_strategies, input_index=3
)
# ==================== Efficient Attention Strategies ====================
def _scaled_dot_product_efficient_attention_cp_strategy(
op_schema: OpSchema,
) -> OpStrategy:
"""
Strategy for efficient attention forward with Context Parallelism support.
"""
from torch.distributed.tensor._ops._matrix_ops import (
_scaled_dot_product_efficient_attention_base_strategies,
)
mesh = op_schema.get_mesh_from_args()
single_mesh_dim_strategies = (
_scaled_dot_product_efficient_attention_base_strategies(op_schema)
)
# Add Context Parallelism strategy
has_attn_bias = op_schema.args_schema[3] is not None
cp_strategy: PlacementList = [
Shard(SEQ_DIM), # output
Shard(SEQ_DIM), # logsumexp
None, # philox_seed
None, # philox_offset
Shard(SEQ_DIM), # q
Shard(SEQ_DIM), # k
Shard(SEQ_DIM), # v
]
if has_attn_bias:
cp_strategy.append(Replicate()) # attn bias - not sharded for CP
single_mesh_dim_strategies.append(cp_strategy)
return expand_to_full_mesh_op_strategy(
mesh, op_schema, single_mesh_dim_strategies, input_index=4
)
def _scaled_dot_product_efficient_attention_backward_cp_strategy(
op_schema: OpSchema,
) -> OpStrategy:
"""
Strategy for efficient attention backward with Context Parallelism support.
"""
from torch.distributed.tensor._ops._matrix_ops import (
_scaled_dot_product_efficient_attention_backward_base_strategies,
)
mesh = op_schema.get_mesh_from_args(validate=False)
single_mesh_dim_strategies = (
_scaled_dot_product_efficient_attention_backward_base_strategies(op_schema)
)
has_attn_bias = op_schema.args_schema[4] is not None
# Context Parallelism: shards on the sequence dim
cp_strategy: PlacementList = [
Shard(SEQ_DIM), # grad_q
Shard(SEQ_DIM), # grad_k
Shard(SEQ_DIM), # grad_v
Shard(1) if has_attn_bias else None, # grad_bias
Shard(SEQ_DIM), # grad_output
Shard(SEQ_DIM), # q
Shard(SEQ_DIM), # k
Shard(SEQ_DIM), # v
Shard(SEQ_DIM), # output
Shard(SEQ_DIM), # logsumexp
]
if has_attn_bias:
cp_strategy.insert(8, Shard(1)) # attn_bias input
cp_strategy.extend([Replicate(), Replicate()])
single_mesh_dim_strategies.append(cp_strategy)
return expand_to_full_mesh_op_strategy(
mesh, op_schema, single_mesh_dim_strategies, input_index=4
)
# ==================== cuDNN Attention Strategies ====================
def _scaled_dot_product_cudnn_attention_cp_strategy(op_schema: OpSchema) -> OpStrategy:
"""
Strategy for cudnn attention forward with Context Parallelism support.
"""
from torch.distributed.tensor._ops._matrix_ops import (
_scaled_dot_product_cudnn_attention_base_strategies,
)
mesh = op_schema.get_mesh_from_args()
single_mesh_dim_strategies = _scaled_dot_product_cudnn_attention_base_strategies(
op_schema
)
(
query_strategy,
_,
_,
attn_bias_strategy,
compute_log_sumexp,
*rest_args,
) = op_schema.args_schema
return_debug_mask = len(op_schema.args_schema) >= 8 and rest_args[2]
has_attn_bias = attn_bias_strategy is not None
# Context Parallelism: shards on the sequence dim
logsumexp_sharding = Shard(SEQ_DIM) if compute_log_sumexp else Replicate()
debug_attn_mask_sharding = Shard(SEQ_DIM) if return_debug_mask else None
cp_strategy: PlacementList = [
Shard(SEQ_DIM), # output
logsumexp_sharding, # logsumexp
None, # cum_seq_q
None, # cum_seq_k
None, # max_q
None, # max_k
None, # philox_seed
None, # philox_offset
debug_attn_mask_sharding, # debug_attn_mask
Shard(SEQ_DIM), # q
Shard(SEQ_DIM), # k
Shard(SEQ_DIM), # v
]
if has_attn_bias:
cp_strategy.append(Replicate()) # attn_bias - not sharded for CP
single_mesh_dim_strategies.append(cp_strategy)
return expand_to_full_mesh_op_strategy(
mesh, op_schema, single_mesh_dim_strategies, input_index=9
)
def _scaled_dot_product_cudnn_attention_backward_cp_strategy(
op_schema: OpSchema,
) -> OpStrategy:
"""
Strategy for cudnn attention backward with Context Parallelism support.
"""
from torch.distributed.tensor._ops._matrix_ops import (
_scaled_dot_product_cudnn_attention_backward_base_strategies,
)
mesh = op_schema.get_mesh_from_args(validate=False)
single_mesh_dim_strategies = (
_scaled_dot_product_cudnn_attention_backward_base_strategies(op_schema)
)
has_attn_bias = op_schema.args_schema[8] is not None
has_scale = len(op_schema.args_schema) >= 16 and False
# Context Parallelism: shards on the sequence dim
cp_sharding_gout: PlacementList = [Shard(SEQ_DIM)] * 3 # grad_q, grad_k, grad_v
cp_sharding_ginp: PlacementList = [
Shard(SEQ_DIM)
] * 6 # grad_output, q, k, v, output, logsumexp
cp_sharding_ginp += [Replicate()] * 2 # philox_seed, philox_offset
cp_sharding_ginp += [Shard(SEQ_DIM) if has_attn_bias else None] # attn_bias
cp_sharding_ginp += [
None
] * 6 # cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal
if has_scale:
cp_sharding_ginp.append(None)
cp_sharding = cp_sharding_gout + cp_sharding_ginp
single_mesh_dim_strategies.append(cp_sharding)
return expand_to_full_mesh_op_strategy(
mesh, op_schema, single_mesh_dim_strategies, input_index=3
)
# Store context managers and original strategies
_cp_strategy_contexts = {}
_original_strategies = {}
def register_cp_sharding_rules():
"""Register Context Parallelism sharding rules for all scaled_dot_product ops."""
global _cp_strategy_contexts, _original_strategies
# If already registered, don't register again
if _cp_strategy_contexts:
return
# Define ops and their corresponding CP strategy functions
cp_strategies = [
(
aten._scaled_dot_product_flash_attention.default,
_scaled_dot_product_flash_attention_cp_strategy,
RuntimeSchemaInfo(5),
),
(
aten._scaled_dot_product_flash_attention_backward.default,
_scaled_dot_product_flash_attention_backward_cp_strategy,
None,
),
(
aten._scaled_dot_product_efficient_attention.default,
_scaled_dot_product_efficient_attention_cp_strategy,
RuntimeSchemaInfo(4),
),
(
aten._scaled_dot_product_efficient_attention_backward.default,
_scaled_dot_product_efficient_attention_backward_cp_strategy,
None,
),
(
aten._scaled_dot_product_cudnn_attention.default,
_scaled_dot_product_cudnn_attention_cp_strategy,
RuntimeSchemaInfo(4),
),
(
aten._scaled_dot_product_cudnn_attention_backward.default,
_scaled_dot_product_cudnn_attention_backward_cp_strategy,
None,
),
]
# Register each strategy
for op_overload, strategy_func, schema_info in cp_strategies:
ctx = _op_strategy_context(op_overload, strategy_func, schema_info)
orig_funcs, orig_schema = ctx.__enter__()
_cp_strategy_contexts[op_overload] = ctx
_original_strategies[op_overload] = (orig_funcs, orig_schema)
def unregister_cp_sharding_rules():
"""Unregister Context Parallelism sharding rules and restore original strategies."""
global _cp_strategy_contexts, _original_strategies
# Exit all context managers
for ctx in _cp_strategy_contexts.values():
ctx.__exit__(None, None, None)
_cp_strategy_contexts = {}
_original_strategies = {}

View File

@ -529,14 +529,11 @@ RE_EXTERN_SHARED = re.compile(r"extern\s+([\w\(\)]+)?\s*__shared__\s+([\w:<>\s]+
def replace_extern_shared(input_string):
"""
Match 'extern __shared__ type foo[];' syntax and use HIP_DYNAMIC_SHARED() MACRO instead.
See: https://github.com/ROCm/hip/blob/master/docs/markdown/hip_kernel_language.md#__shared__
Examples:
"extern __shared__ char smemChar[];"
=> "HIP_DYNAMIC_SHARED( char, smemChar)"
"extern __shared__ unsigned char smem[];"
=> "HIP_DYNAMIC_SHARED( unsigned char, my_smem)"
"""Match extern __shared__ type foo[]; syntax and use HIP_DYNAMIC_SHARED() MACRO instead.
https://github.com/ROCm/hip/blob/master/docs/markdown/hip_kernel_language.md#__shared__
Example:
"extern __shared__ char smemChar[];" => "HIP_DYNAMIC_SHARED( char, smemChar)"
"extern __shared__ unsigned char smem[];" => "HIP_DYNAMIC_SHARED( unsigned char, my_smem)"
"""
output_string = input_string
output_string = RE_EXTERN_SHARED.sub(
@ -1046,17 +1043,14 @@ RE_INCLUDE = re.compile(r"#include .*\n")
def extract_arguments(start, string):
"""
Return the list of arguments in the upcoming function parameter closure.
Example:
""" Return the list of arguments in the upcoming function parameter closure.
Example:
string (input): '(blocks, threads, 0, THCState_getCurrentStream(state))'
arguments (output):
[
{'start': 1, 'end': 7},
{'start': 8, 'end': 16},
{'start': 17, 'end': 19},
{'start': 20, 'end': 53}
]
'[{'start': 1, 'end': 7},
{'start': 8, 'end': 16},
{'start': 17, 'end': 19},
{'start': 20, 'end': 53}]'
"""
arguments = []