mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-12 14:54:55 +08:00
Compare commits
14 Commits
documentat
...
ciflow/ind
| Author | SHA1 | Date | |
|---|---|---|---|
| 8c6c066024 | |||
| 87ae43a0d6 | |||
| 2cff58cc2d | |||
| 46bd412746 | |||
| 5d07795b28 | |||
| 306914e071 | |||
| fb6807bf86 | |||
| f6a79b2a4a | |||
| 2fcf41dd8e | |||
| 31ccd8f13e | |||
| 1209123500 | |||
| 168fc7cfc5 | |||
| 639276822f | |||
| 8cbe13ad31 |
@ -1308,8 +1308,319 @@ coverage_ignore_functions = [
|
||||
# torch.onnx.symbolic_opset7
|
||||
"max",
|
||||
"min",
|
||||
# torch.onnx.symbolic_opset8
|
||||
"addmm",
|
||||
"bmm",
|
||||
"empty",
|
||||
"empty_like",
|
||||
"flatten",
|
||||
"full",
|
||||
"full_like",
|
||||
"gt",
|
||||
"lt",
|
||||
"matmul",
|
||||
"mm",
|
||||
"ones",
|
||||
"ones_like",
|
||||
"prelu",
|
||||
"repeat",
|
||||
"zeros",
|
||||
"zeros_like",
|
||||
# torch.onnx.symbolic_opset9
|
||||
"abs",
|
||||
"acos",
|
||||
"adaptive_avg_pool1d",
|
||||
"adaptive_avg_pool2d",
|
||||
"adaptive_avg_pool3d",
|
||||
"adaptive_max_pool1d",
|
||||
"adaptive_max_pool2d",
|
||||
"adaptive_max_pool3d",
|
||||
"add",
|
||||
"addcmul",
|
||||
"addmm",
|
||||
"alias",
|
||||
"amax",
|
||||
"amin",
|
||||
"aminmax",
|
||||
"arange",
|
||||
"argmax",
|
||||
"argmin",
|
||||
"as_strided",
|
||||
"as_tensor",
|
||||
"asin",
|
||||
"atan",
|
||||
"atan2",
|
||||
"avg_pool1d",
|
||||
"avg_pool2d",
|
||||
"avg_pool3d",
|
||||
"baddbmm",
|
||||
"batch_norm",
|
||||
"bernoulli",
|
||||
"bitwise_not",
|
||||
"bitwise_or",
|
||||
"bmm",
|
||||
"broadcast_tensors",
|
||||
"broadcast_to",
|
||||
"bucketize",
|
||||
"cat",
|
||||
"cdist",
|
||||
"ceil",
|
||||
"clamp",
|
||||
"clamp_max",
|
||||
"clamp_min",
|
||||
"clone",
|
||||
"constant_pad_nd",
|
||||
"contiguous",
|
||||
"conv1d",
|
||||
"conv2d",
|
||||
"conv3d",
|
||||
"conv_tbc",
|
||||
"conv_transpose1d",
|
||||
"conv_transpose2d",
|
||||
"conv_transpose3d",
|
||||
"convert_element_type",
|
||||
"convolution",
|
||||
"cos",
|
||||
"cosine_similarity",
|
||||
"cross",
|
||||
"cumsum",
|
||||
"detach",
|
||||
"dim",
|
||||
"div",
|
||||
"dot",
|
||||
"dropout",
|
||||
"elu",
|
||||
"embedding",
|
||||
"embedding_bag",
|
||||
"empty",
|
||||
"empty_like",
|
||||
"eq",
|
||||
"erf",
|
||||
"exp",
|
||||
"expand",
|
||||
"expand_as",
|
||||
"eye",
|
||||
"fill",
|
||||
"flatten",
|
||||
"floor",
|
||||
"floor_divide",
|
||||
"floordiv",
|
||||
"frobenius_norm",
|
||||
"full",
|
||||
"full_like",
|
||||
"gather",
|
||||
"ge",
|
||||
"gelu",
|
||||
"get_pool_ceil_padding",
|
||||
"glu",
|
||||
"group_norm",
|
||||
"gru",
|
||||
"gt",
|
||||
"hann_window",
|
||||
"hardshrink",
|
||||
"hardsigmoid",
|
||||
"hardswish",
|
||||
"hardtanh",
|
||||
"index",
|
||||
"index_add",
|
||||
"index_copy",
|
||||
"index_fill",
|
||||
"index_put",
|
||||
"index_select",
|
||||
"instance_norm",
|
||||
"is_floating_point",
|
||||
"is_pinned",
|
||||
"isnan",
|
||||
"item",
|
||||
"kl_div",
|
||||
"layer_norm",
|
||||
"le",
|
||||
"leaky_relu",
|
||||
"lerp",
|
||||
"lift",
|
||||
"linalg_cross",
|
||||
"linalg_matrix_norm",
|
||||
"linalg_norm",
|
||||
"linalg_vector_norm",
|
||||
"linear",
|
||||
"linspace",
|
||||
"log",
|
||||
"log10",
|
||||
"log1p",
|
||||
"log2",
|
||||
"log_sigmoid",
|
||||
"log_softmax",
|
||||
"logical_and",
|
||||
"logical_not",
|
||||
"logical_or",
|
||||
"logical_xor",
|
||||
"logit",
|
||||
"logsumexp",
|
||||
"lstm",
|
||||
"lstm_cell",
|
||||
"lt",
|
||||
"masked_fill",
|
||||
"masked_fill_",
|
||||
"matmul",
|
||||
"max",
|
||||
"max_pool1d",
|
||||
"max_pool1d_with_indices",
|
||||
"max_pool2d",
|
||||
"max_pool2d_with_indices",
|
||||
"max_pool3d",
|
||||
"max_pool3d_with_indices",
|
||||
"maximum",
|
||||
"meshgrid",
|
||||
"min",
|
||||
"minimum",
|
||||
"mish",
|
||||
"mm",
|
||||
"movedim",
|
||||
"mse_loss",
|
||||
"mul",
|
||||
"multinomial",
|
||||
"mv",
|
||||
"narrow",
|
||||
"native_layer_norm",
|
||||
"ne",
|
||||
"neg",
|
||||
"new_empty",
|
||||
"new_full",
|
||||
"new_ones",
|
||||
"new_zeros",
|
||||
"nonzero",
|
||||
"nonzero_numpy",
|
||||
"noop_complex_operators",
|
||||
"norm",
|
||||
"numel",
|
||||
"numpy_T",
|
||||
"one_hot",
|
||||
"ones",
|
||||
"ones_like",
|
||||
"onnx_placeholder",
|
||||
"overload_by_arg_count",
|
||||
"pad",
|
||||
"pairwise_distance",
|
||||
"permute",
|
||||
"pixel_shuffle",
|
||||
"pixel_unshuffle",
|
||||
"pow",
|
||||
"prelu",
|
||||
"prim_constant",
|
||||
"prim_constant_chunk",
|
||||
"prim_constant_split",
|
||||
"prim_data",
|
||||
"prim_device",
|
||||
"prim_dtype",
|
||||
"prim_if",
|
||||
"prim_layout",
|
||||
"prim_list_construct",
|
||||
"prim_list_unpack",
|
||||
"prim_loop",
|
||||
"prim_max",
|
||||
"prim_min",
|
||||
"prim_shape",
|
||||
"prim_tolist",
|
||||
"prim_tuple_construct",
|
||||
"prim_type",
|
||||
"prim_unchecked_cast",
|
||||
"prim_uninitialized",
|
||||
"rand",
|
||||
"rand_like",
|
||||
"randint",
|
||||
"randint_like",
|
||||
"randn",
|
||||
"randn_like",
|
||||
"reciprocal",
|
||||
"reflection_pad",
|
||||
"relu",
|
||||
"relu6",
|
||||
"remainder",
|
||||
"repeat",
|
||||
"repeat_interleave",
|
||||
"replication_pad",
|
||||
"reshape",
|
||||
"reshape_as",
|
||||
"rnn_relu",
|
||||
"rnn_tanh",
|
||||
"roll",
|
||||
"rrelu",
|
||||
"rsqrt",
|
||||
"rsub",
|
||||
"scalar_tensor",
|
||||
"scatter",
|
||||
"scatter_add",
|
||||
"select",
|
||||
"selu",
|
||||
"sigmoid",
|
||||
"sign",
|
||||
"silu",
|
||||
"sin",
|
||||
"size",
|
||||
"slice",
|
||||
"softmax",
|
||||
"softplus",
|
||||
"softshrink",
|
||||
"sort",
|
||||
"split",
|
||||
"split_with_sizes",
|
||||
"sqrt",
|
||||
"square",
|
||||
"squeeze",
|
||||
"stack",
|
||||
"std",
|
||||
"std_mean",
|
||||
"sub",
|
||||
"t",
|
||||
"take",
|
||||
"tan",
|
||||
"tanh",
|
||||
"tanhshrink",
|
||||
"tensor",
|
||||
"threshold",
|
||||
"to",
|
||||
"topk",
|
||||
"transpose",
|
||||
"true_divide",
|
||||
"type_as",
|
||||
"unbind",
|
||||
"unfold",
|
||||
"unsafe_chunk",
|
||||
"unsafe_split",
|
||||
"unsafe_split_with_sizes",
|
||||
"unsqueeze",
|
||||
"unsupported_complex_operators",
|
||||
"unused",
|
||||
"upsample_bilinear2d",
|
||||
"upsample_linear1d",
|
||||
"upsample_nearest1d",
|
||||
"upsample_nearest2d",
|
||||
"upsample_nearest3d",
|
||||
"upsample_trilinear3d",
|
||||
"var",
|
||||
"var_mean",
|
||||
"view",
|
||||
"view_as",
|
||||
"where",
|
||||
"wrap_logical_op_with_cast_to",
|
||||
"wrap_logical_op_with_negation",
|
||||
"zero",
|
||||
"zeros",
|
||||
"zeros_like",
|
||||
# torch.onnx.utils
|
||||
"disable_apex_o2_state_dict_hook",
|
||||
"export",
|
||||
"export_to_pretty_string",
|
||||
"exporter_context",
|
||||
"is_in_onnx_export",
|
||||
"model_signature",
|
||||
"register_custom_op_symbolic",
|
||||
"select_model_mode_for_export",
|
||||
"setup_onnx_logging",
|
||||
"unconvertible_ops",
|
||||
"unpack_quantized_tensor",
|
||||
"warn_on_static_input_change",
|
||||
# torch.onnx.verification
|
||||
"check_export_model_diff",
|
||||
"verify",
|
||||
"verify_aten_graph",
|
||||
@ -1400,6 +1711,32 @@ coverage_ignore_functions = [
|
||||
"noop_context_fn",
|
||||
"set_checkpoint_early_stop",
|
||||
"set_device_states",
|
||||
# torch.utils.collect_env
|
||||
"check_release_file",
|
||||
"get_cachingallocator_config",
|
||||
"get_clang_version",
|
||||
"get_cmake_version",
|
||||
"get_conda_packages",
|
||||
"get_cpu_info",
|
||||
"get_cuda_module_loading_config",
|
||||
"get_cudnn_version",
|
||||
"get_env_info",
|
||||
"get_gcc_version",
|
||||
"get_gpu_info",
|
||||
"get_libc_version",
|
||||
"get_lsb_version",
|
||||
"get_mac_version",
|
||||
"get_nvidia_driver_version",
|
||||
"get_nvidia_smi",
|
||||
"get_os",
|
||||
"get_pip_packages",
|
||||
"get_platform",
|
||||
"get_pretty_env_info",
|
||||
"get_python_platform",
|
||||
"get_running_cuda_version",
|
||||
"get_windows_version",
|
||||
"is_xnnpack_available",
|
||||
"pretty_str",
|
||||
# torch.utils.cpp_backtrace
|
||||
"get_cpp_backtrace",
|
||||
# torch.utils.cpp_extension
|
||||
@ -1463,6 +1800,52 @@ coverage_ignore_functions = [
|
||||
"apply_shuffle_seed",
|
||||
"apply_shuffle_settings",
|
||||
"get_all_graph_pipes",
|
||||
# torch.utils.flop_counter
|
||||
"addmm_flop",
|
||||
"baddbmm_flop",
|
||||
"bmm_flop",
|
||||
"conv_backward_flop",
|
||||
"conv_flop",
|
||||
"conv_flop_count",
|
||||
"convert_num_with_suffix",
|
||||
"get_shape",
|
||||
"get_suffix_str",
|
||||
"mm_flop",
|
||||
"normalize_tuple",
|
||||
"register_flop_formula",
|
||||
"sdpa_backward_flop",
|
||||
"sdpa_backward_flop_count",
|
||||
"sdpa_flop",
|
||||
"sdpa_flop_count",
|
||||
"shape_wrapper",
|
||||
"transpose_shape",
|
||||
# torch.utils.hipify.hipify_python
|
||||
"add_dim3",
|
||||
"compute_stats",
|
||||
"extract_arguments",
|
||||
"file_add_header",
|
||||
"file_specific_replacement",
|
||||
"find_bracket_group",
|
||||
"find_closure_group",
|
||||
"find_parentheses_group",
|
||||
"fix_static_global_kernels",
|
||||
"get_hip_file_path",
|
||||
"hip_header_magic",
|
||||
"hipify",
|
||||
"is_caffe2_gpu_file",
|
||||
"is_cusparse_file",
|
||||
"is_out_of_place",
|
||||
"is_pytorch_file",
|
||||
"is_special_file",
|
||||
"match_extensions",
|
||||
"matched_files_iter",
|
||||
"openf",
|
||||
"preprocess_file_and_save_result",
|
||||
"preprocessor",
|
||||
"processKernelLaunches",
|
||||
"replace_extern_shared",
|
||||
"replace_math_functions",
|
||||
"str2bool",
|
||||
# torch.utils.hooks
|
||||
"unserializable_hook",
|
||||
"warn_if_has_hooks",
|
||||
|
||||
@ -19,91 +19,6 @@
|
||||
swap_tensors
|
||||
```
|
||||
|
||||
# torch.utils.collect_env
|
||||
```{eval-rst}
|
||||
.. automodule:: torch.utils.collect_env
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.utils.collect_env
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
check_release_file
|
||||
is_xnnpack_available
|
||||
pretty_str
|
||||
```
|
||||
|
||||
# torch.utils.flop_counter
|
||||
```{eval-rst}
|
||||
.. automodule:: torch.utils.flop_counter
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.utils.flop_counter
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
baddbmm_flop
|
||||
bmm_flop
|
||||
conv_backward_flop
|
||||
conv_flop
|
||||
conv_flop_count
|
||||
register_flop_formula
|
||||
sdpa_backward_flop
|
||||
sdpa_backward_flop_count
|
||||
sdpa_flop
|
||||
sdpa_flop_count
|
||||
shape_wrapper
|
||||
```
|
||||
|
||||
# torch.utils.hipify.hipify_python
|
||||
```{eval-rst}
|
||||
.. automodule:: torch.utils.hipify.hipify_python
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.utils.hipify.hipify_python
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
compute_stats
|
||||
extract_arguments
|
||||
file_add_header
|
||||
file_specific_replacement
|
||||
find_bracket_group
|
||||
find_closure_group
|
||||
find_parentheses_group
|
||||
fix_static_global_kernels
|
||||
hip_header_magic
|
||||
hipify
|
||||
is_caffe2_gpu_file
|
||||
is_cusparse_file
|
||||
is_out_of_place
|
||||
is_pytorch_file
|
||||
is_special_file
|
||||
openf
|
||||
preprocess_file_and_save_result
|
||||
preprocessor
|
||||
processKernelLaunches
|
||||
replace_extern_shared
|
||||
replace_math_functions
|
||||
str2bool
|
||||
```
|
||||
|
||||
|
||||
<!-- This module needs to be documented. Adding here in the meantime
|
||||
for tracking purposes -->
|
||||
```{eval-rst}
|
||||
@ -128,6 +43,7 @@ for tracking purposes -->
|
||||
.. py:module:: torch.utils.benchmark.utils.valgrind_wrapper.timer_interface
|
||||
.. py:module:: torch.utils.bundled_inputs
|
||||
.. py:module:: torch.utils.checkpoint
|
||||
.. py:module:: torch.utils.collect_env
|
||||
.. py:module:: torch.utils.cpp_backtrace
|
||||
.. py:module:: torch.utils.cpp_extension
|
||||
.. py:module:: torch.utils.data.backward_compatibility
|
||||
@ -164,8 +80,10 @@ for tracking purposes -->
|
||||
.. py:module:: torch.utils.data.sampler
|
||||
.. py:module:: torch.utils.dlpack
|
||||
.. py:module:: torch.utils.file_baton
|
||||
.. py:module:: torch.utils.flop_counter
|
||||
.. py:module:: torch.utils.hipify.constants
|
||||
.. py:module:: torch.utils.hipify.cuda_to_hip_mappings
|
||||
.. py:module:: torch.utils.hipify.hipify_python
|
||||
.. py:module:: torch.utils.hipify.version
|
||||
.. py:module:: torch.utils.hooks
|
||||
.. py:module:: torch.utils.jit.log_extract
|
||||
|
||||
@ -260,6 +260,7 @@ select = [
|
||||
"TRY401", # verbose-log-message
|
||||
"UP",
|
||||
"YTT",
|
||||
"S101",
|
||||
]
|
||||
|
||||
[tool.ruff.lint.pyupgrade]
|
||||
@ -339,6 +340,39 @@ keep-runtime-typing = true
|
||||
"tools/linter/**" = [
|
||||
"LOG015" # please fix
|
||||
]
|
||||
"benchmarks/**" = [
|
||||
"S101"
|
||||
]
|
||||
"test/**" = [
|
||||
"S101"
|
||||
]
|
||||
"torchgen/**" = [
|
||||
"S101"
|
||||
]
|
||||
"torch/**" = [
|
||||
"S101"
|
||||
]
|
||||
"tools/**" = [
|
||||
"S101"
|
||||
]
|
||||
"setup.py" = [
|
||||
"S101"
|
||||
]
|
||||
"functorch/**" = [
|
||||
"S101"
|
||||
]
|
||||
"docs/**" = [
|
||||
"S101"
|
||||
]
|
||||
"android/**" = [
|
||||
"S101"
|
||||
]
|
||||
".github/**" = [
|
||||
"S101"
|
||||
]
|
||||
".ci/**" = [
|
||||
"S101"
|
||||
]
|
||||
|
||||
[tool.codespell]
|
||||
ignore-words = "tools/linter/dictionary.txt"
|
||||
|
||||
@ -813,6 +813,50 @@ class TestSharding(DTensorTestBase):
|
||||
),
|
||||
)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@with_comms
|
||||
@unittest.skipIf(
|
||||
not PLATFORM_SUPPORTS_FUSED_ATTENTION,
|
||||
"Does not support flash nor efficient attention",
|
||||
)
|
||||
def test_attention_shard_without_cp(self) -> None:
|
||||
"""Test that sharding on sequence dimension without CP enabled is not supported."""
|
||||
from torch.distributed.tensor import distribute_tensor, Replicate, Shard
|
||||
|
||||
B = 2
|
||||
nheads = 4
|
||||
seq_len = 256
|
||||
dim = 32
|
||||
|
||||
device_mesh = init_device_mesh(
|
||||
mesh_shape=(2,), mesh_dim_names=("cp",), device_type=self.device_type
|
||||
)
|
||||
|
||||
# Create q, k, v tensors with shape (B, nheads, seq_len, dim)
|
||||
q = torch.randn(
|
||||
B, nheads, seq_len, dim, device=self.device_type, dtype=torch.bfloat16
|
||||
)
|
||||
k = torch.randn(
|
||||
B, nheads, seq_len, dim, device=self.device_type, dtype=torch.bfloat16
|
||||
)
|
||||
v = torch.randn(
|
||||
B, nheads, seq_len, dim, device=self.device_type, dtype=torch.bfloat16
|
||||
)
|
||||
q_dt = distribute_tensor(q, device_mesh, [Shard(2)])
|
||||
k_dt = distribute_tensor(k, device_mesh, [Shard(2)])
|
||||
v_dt = distribute_tensor(v, device_mesh, [Shard(2)])
|
||||
|
||||
# Run SDPA with sequence-sharded tensors WITHOUT enabling CP
|
||||
# Without CP enabled, DTensor should select a different strategy
|
||||
# (not sequence-sharded) because Shard(2) strategy is only available with CP
|
||||
out = F.scaled_dot_product_attention(q_dt, k_dt, v_dt)
|
||||
|
||||
# Verify the output is NOT sharded on sequence dimension (dim 2)
|
||||
# This proves that CP sharding rules were not used
|
||||
self.assertNotEqual(out.placements[0], Shard(2))
|
||||
# The output should be replicated or sharded on batch head dimensions.
|
||||
self.assertIn(out.placements[0], [Replicate(), Shard(0), Shard(1)])
|
||||
|
||||
|
||||
RingAttentionTestWithLocalTensor = create_local_tensor_test_class(
|
||||
RingAttentionTest,
|
||||
|
||||
@ -7522,6 +7522,38 @@ class AOTInductorTestsTemplate:
|
||||
eager_outputs = model(*example_inputs)
|
||||
torch.testing.assert_close(eager_outputs, compiled_outputs)
|
||||
|
||||
@requires_gpu
|
||||
def test_mixed_device_1(self):
|
||||
if self.device != GPU_TYPE:
|
||||
raise unittest.SkipTest("Mixed-device test requires GPU")
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Buffers are on CPU
|
||||
self.register_buffer(
|
||||
"index", torch.tensor([1, 4, 1, 7], device="cpu", dtype=torch.int64)
|
||||
)
|
||||
self.register_buffer(
|
||||
"src", torch.ones(4, device="cpu", dtype=torch.int64)
|
||||
)
|
||||
|
||||
def forward(self, matrix, vector):
|
||||
# Inputs are on CUDA
|
||||
# 1. Operation on CPU tensors
|
||||
z = torch.zeros((vector.shape[0],), device="cpu", dtype=torch.int64)
|
||||
scatter_result = z.scatter_add(0, self.index, self.src)
|
||||
|
||||
# 2. Move result to CUDA and continue on CUDA
|
||||
v = vector + scatter_result.to(vector.dtype).to(GPU_TYPE)
|
||||
return torch.matmul(matrix, v)
|
||||
|
||||
example_inputs = (
|
||||
torch.randn(10, 10, device=self.device),
|
||||
torch.randn(10, device=self.device),
|
||||
)
|
||||
self.check_model(Model(), example_inputs, move_model_to_device=False)
|
||||
|
||||
|
||||
class AOTInductorLoggingTest(LoggingTestCase):
|
||||
@make_logging_test(dynamic=logging.DEBUG)
|
||||
|
||||
@ -218,6 +218,7 @@ def check_model(
|
||||
dynamic_shapes=None,
|
||||
atol=None,
|
||||
rtol=None,
|
||||
move_model_to_device=True,
|
||||
):
|
||||
with (
|
||||
torch.no_grad(),
|
||||
@ -229,7 +230,7 @@ def check_model(
|
||||
),
|
||||
):
|
||||
torch.manual_seed(0)
|
||||
if not isinstance(model, types.FunctionType):
|
||||
if not isinstance(model, types.FunctionType) and move_model_to_device:
|
||||
model = model.to(self.device)
|
||||
|
||||
# For non mixed device inputs with default "cpu",set the device manually.
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
# Owner(s): ["oncall: pt2"]
|
||||
import functools
|
||||
import re
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
@ -230,6 +231,33 @@ class PallasTestsMixin:
|
||||
self.assertIn("import jax.numpy as jnp", code)
|
||||
self.assertIn("from jax.experimental import pallas as pl", code)
|
||||
|
||||
def test_jax_jit_wrapper_is_emitted(self):
|
||||
"""Ensure generated Pallas code wraps pl.pallas_call in jax.jit."""
|
||||
|
||||
key = "cuda_backend" if self.DEVICE == "cuda" else "cpu_backend"
|
||||
|
||||
@torch.compile(backend="inductor", options={key: "pallas"})
|
||||
def pallas_fn(a, b):
|
||||
return a + b
|
||||
|
||||
_, (code,) = run_and_get_code(
|
||||
pallas_fn,
|
||||
torch.randn(32, device=self.DEVICE),
|
||||
torch.randn(32, device=self.DEVICE),
|
||||
)
|
||||
|
||||
kernel_match = re.search(r"def (pallas_[A-Za-z0-9_]+)_kernel", code)
|
||||
self.assertIsNotNone(kernel_match)
|
||||
kernel_name = kernel_match.group(1)
|
||||
wrapper_name = f"{kernel_name}_jit_wrapper"
|
||||
self.assertIn(wrapper_name, code)
|
||||
start = code.index(f"def {wrapper_name}")
|
||||
end = code.index(f"def {kernel_name}_main", start)
|
||||
wrapper_block = code[start:end]
|
||||
|
||||
self.assertIn("jax.jit", code)
|
||||
self.assertNotIn("torch.", wrapper_block)
|
||||
|
||||
def test_2d_tensor(self):
|
||||
"""Test with 2D tensors (though current implementation flattens)."""
|
||||
|
||||
|
||||
@ -221,7 +221,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||
"""
|
||||
)
|
||||
|
||||
self.add_device_include(self.device)
|
||||
for device in V.graph.device_types:
|
||||
if device != "meta":
|
||||
self.add_device_include(device)
|
||||
|
||||
if V.graph.aot_mode:
|
||||
if config.aot_inductor.dynamic_linkage:
|
||||
@ -1423,11 +1425,13 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||
src_is_tensor,
|
||||
reduce,
|
||||
kwargs,
|
||||
device,
|
||||
):
|
||||
reduce = self._get_scatter_reduce_enum(reduce)
|
||||
|
||||
# call the ABI shim function instead of the ATen one
|
||||
cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name, self.device)
|
||||
self.add_device_include(device)
|
||||
cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name, device)
|
||||
# TODO: consider remove "_out" and add missing inplace variants to fallback_ops.py
|
||||
cpp_kernel_name = cpp_kernel_name.replace("__", "_") + "_out"
|
||||
inputs_wrapped = [str(x) for x in inputs]
|
||||
|
||||
@ -708,11 +708,14 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
|
||||
src_is_tensor,
|
||||
reduce,
|
||||
kwargs,
|
||||
device,
|
||||
):
|
||||
reduce = self._get_scatter_reduce_enum(reduce)
|
||||
|
||||
# call the ABI shim function instead of the ATen one
|
||||
cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name, self.device)
|
||||
self.add_device_include(device)
|
||||
cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name, device)
|
||||
|
||||
# TODO: consider remove "_out" and add missing inplace variants to fallback_ops.py
|
||||
cpp_kernel_name = cpp_kernel_name.replace("__", "_") + "_out"
|
||||
self._assert_safe_to_use_borrow_arrayref_tensor_as_tensor()
|
||||
|
||||
@ -287,6 +287,7 @@ class PallasKernel(SIMDKernel):
|
||||
code = IndentedBuffer()
|
||||
code.splice(
|
||||
"""
|
||||
import functools
|
||||
import torch
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
@ -301,6 +302,9 @@ class PallasKernel(SIMDKernel):
|
||||
kernel_params = [a.name for a in arg_defs]
|
||||
|
||||
kernel_name = name or "<KERNEL_NAME>"
|
||||
interpret_literal = (
|
||||
"True" if V.graph.get_current_device_or_throw().type == "cpu" else "False"
|
||||
)
|
||||
code.writeline(f"def {kernel_name}_kernel({', '.join(kernel_params)}):")
|
||||
with code.indent():
|
||||
# Emit compute (CSE) and store lines; they reference *_ptr[...] directly
|
||||
@ -309,16 +313,22 @@ class PallasKernel(SIMDKernel):
|
||||
for line in self.stores._lines:
|
||||
code.writeline(str(line))
|
||||
|
||||
jit_wrapper_name = f"{kernel_name}_jit_wrapper"
|
||||
code.writeline("@functools.partial(jax.jit, static_argnums=(0, 1))")
|
||||
code.writeline(f"def {jit_wrapper_name}(out_shape, out_dtype, *kernel_refs):")
|
||||
with code.indent():
|
||||
code.writeline("out_spec = jax.ShapeDtypeStruct(out_shape, out_dtype)")
|
||||
code.writeline("return pl.pallas_call(")
|
||||
code.writeline(f" {kernel_name}_kernel,")
|
||||
code.writeline(" out_shape=out_spec,")
|
||||
code.writeline(f" interpret={interpret_literal},")
|
||||
code.writeline(" grid=(1,),")
|
||||
code.writeline(")(*kernel_refs)")
|
||||
|
||||
# Host entry: convert torch tensors <-> jax, call pallas_call and copy back
|
||||
main_name = f"{kernel_name}_main"
|
||||
code.writeline(f"def {main_name}({', '.join(kernel_params)}, stream=None):")
|
||||
with code.indent():
|
||||
# Determine interpret statically based on codegen device
|
||||
interpret_literal = (
|
||||
"True"
|
||||
if V.graph.get_current_device_or_throw().type == "cpu"
|
||||
else "False"
|
||||
)
|
||||
# Identify inputs (in_ptr*) and output (out_ptr*)
|
||||
input_params = [
|
||||
p for p in kernel_params if p.startswith(("in_ptr", "in_out_ptr"))
|
||||
@ -337,9 +347,9 @@ class PallasKernel(SIMDKernel):
|
||||
for inp in input_params:
|
||||
code.writeline(f"{inp}_jax = jax.dlpack.from_dlpack({inp})")
|
||||
|
||||
# Get output spec from PyTorch tensor
|
||||
code.writeline("# Prepare output spec from PyTorch tensor")
|
||||
code.writeline("# Map PyTorch dtype to JAX dtype string")
|
||||
# Get output metadata from PyTorch tensor
|
||||
code.writeline("# Prepare output metadata from PyTorch tensor")
|
||||
code.writeline("# Map PyTorch dtype to JAX dtype")
|
||||
code.writeline("_torch_dtype_to_jax = {")
|
||||
code.writeline(
|
||||
" torch.float32: jnp.float32, torch.float64: jnp.float64, torch.float16: jnp.float16,"
|
||||
@ -349,21 +359,14 @@ class PallasKernel(SIMDKernel):
|
||||
)
|
||||
code.writeline(" torch.uint8: jnp.uint8, torch.bool: jnp.bool_,")
|
||||
code.writeline("}")
|
||||
code.writeline(
|
||||
f"out_spec = jax.ShapeDtypeStruct({output_param}.shape, _torch_dtype_to_jax[{output_param}.dtype])"
|
||||
)
|
||||
code.writeline(f"out_shape = tuple({output_param}.shape)")
|
||||
code.writeline(f"out_dtype = _torch_dtype_to_jax[{output_param}.dtype]")
|
||||
|
||||
# Call pallas
|
||||
# Pass interpret=True on CPU, False otherwise (single call, no duplication)
|
||||
code.writeline("compiled = pl.pallas_call(")
|
||||
code.writeline(f" lambda *refs: {kernel_name}_kernel(*refs),")
|
||||
code.writeline(" out_shape=out_spec,")
|
||||
code.writeline(f" interpret={interpret_literal},")
|
||||
code.writeline(" grid=(1,),")
|
||||
code.writeline(")")
|
||||
|
||||
jax_input_args = ", ".join([f"{inp}_jax" for inp in input_params])
|
||||
code.writeline(f"res = compiled({jax_input_args})")
|
||||
call_args = ["out_shape", "out_dtype"] + [
|
||||
f"{inp}_jax" for inp in input_params
|
||||
]
|
||||
call_arg_str = ", ".join(call_args)
|
||||
code.writeline(f"res = {jit_wrapper_name}({call_arg_str})")
|
||||
|
||||
# Copy result back
|
||||
code.writeline("# Copy result back into the provided torch output tensor")
|
||||
|
||||
@ -971,6 +971,7 @@ class ScatterFallbackLine(WrapperLine):
|
||||
else:
|
||||
(x, index) = (t.codegen_reference() for t in node.inputs)
|
||||
src = node.constant_args[1]
|
||||
device = d.type if (d := node.get_device()) else V.graph.device_type
|
||||
self.wrapper._generate_scatter_fallback(
|
||||
x,
|
||||
[x, node.constant_args[0], index, src],
|
||||
@ -979,6 +980,7 @@ class ScatterFallbackLine(WrapperLine):
|
||||
node.src_is_tensor,
|
||||
node.kwargs["reduce"],
|
||||
node.codegen_kwargs(),
|
||||
device,
|
||||
)
|
||||
|
||||
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
|
||||
@ -1632,6 +1634,7 @@ class PythonWrapperCodegen(CodeGen):
|
||||
src_is_tensor,
|
||||
reduce,
|
||||
kwargs,
|
||||
device,
|
||||
):
|
||||
line = f"{python_kernel_name}({','.join(map(str, inputs))}"
|
||||
if python_kernel_name.startswith("aten.scatter_reduce"):
|
||||
|
||||
@ -267,16 +267,10 @@ def scaled_mm_strategy(op_schema: OpSchema) -> OpStrategy:
|
||||
return _scaled_mm_like_strategy("mk,kn->mn", mesh, op_schema)
|
||||
|
||||
|
||||
@register_op_strategy(
|
||||
aten._scaled_dot_product_flash_attention.default, schema_info=RuntimeSchemaInfo(5)
|
||||
)
|
||||
def scaled_dot_product_flash_attention_strategy(op_schema: OpSchema) -> OpStrategy:
|
||||
# NOTE: currently we only support some simple strategies to support tensor parallelism
|
||||
# TODO: sdpa might be a good candidate for us to explore decomposed sharding propagation
|
||||
# as it involves: matmul, pointwise, reduction ops together.
|
||||
|
||||
mesh = op_schema.get_mesh_from_args()
|
||||
|
||||
def _scaled_dot_product_flash_attention_base_strategies(
|
||||
op_schema: OpSchema,
|
||||
) -> list[PlacementList]:
|
||||
"""Helper that returns list of base placement strategies (without CP)."""
|
||||
return_debug_mask = len(op_schema.args_schema) >= 6 and op_schema.args_schema[5]
|
||||
q_input_strategy = op_schema.args_schema[0]
|
||||
if not isinstance(q_input_strategy, OpStrategy):
|
||||
@ -349,37 +343,30 @@ def scaled_dot_product_flash_attention_strategy(op_schema: OpSchema) -> OpStrate
|
||||
Shard(0), # v
|
||||
]
|
||||
)
|
||||
return single_mesh_dim_strategies
|
||||
|
||||
# Context Parallelism: shards on the sequence dim
|
||||
debug_attn_mask_sharding = Shard(2) if return_debug_mask else Replicate()
|
||||
single_mesh_dim_strategies.append(
|
||||
[
|
||||
Shard(2), # output
|
||||
Shard(2), # logsumexp
|
||||
None, # cum_seq_q
|
||||
None, # cum_seq_k
|
||||
None, # max_q
|
||||
None, # max_k
|
||||
Replicate(), # rng_state
|
||||
None, # unused
|
||||
debug_attn_mask_sharding, # debugattn
|
||||
Shard(2), # q
|
||||
Shard(2), # k
|
||||
Shard(2), # v
|
||||
]
|
||||
|
||||
@register_op_strategy(
|
||||
aten._scaled_dot_product_flash_attention.default, schema_info=RuntimeSchemaInfo(5)
|
||||
)
|
||||
def scaled_dot_product_flash_attention_strategy(op_schema: OpSchema) -> OpStrategy:
|
||||
# NOTE: currently we only support some simple strategies to support tensor parallelism
|
||||
# TODO: sdpa might be a good candidate for us to explore decomposed sharding propagation
|
||||
# as it involves: matmul, pointwise, reduction ops together.
|
||||
|
||||
mesh = op_schema.get_mesh_from_args()
|
||||
single_mesh_dim_strategies = _scaled_dot_product_flash_attention_base_strategies(
|
||||
op_schema
|
||||
)
|
||||
return expand_to_full_mesh_op_strategy(
|
||||
mesh, op_schema, single_mesh_dim_strategies, input_index=9
|
||||
)
|
||||
|
||||
|
||||
@register_op_strategy(aten._scaled_dot_product_flash_attention_backward.default)
|
||||
def scaled_dot_product_flash_attention_backward_strategy(
|
||||
def _scaled_dot_product_flash_attention_backward_base_strategies(
|
||||
op_schema: OpSchema,
|
||||
) -> OpStrategy:
|
||||
# backward op does not need to validate the mesh since forward op has already done it
|
||||
mesh = op_schema.get_mesh_from_args(validate=False)
|
||||
|
||||
) -> list[PlacementList]:
|
||||
"""Helper that returns list of base placement strategies (without CP)."""
|
||||
q_input_strategy = op_schema.args_schema[1]
|
||||
if not isinstance(q_input_strategy, OpStrategy):
|
||||
raise AssertionError(f"Expected OpStrategy, got {type(q_input_strategy)}")
|
||||
@ -444,24 +431,18 @@ def scaled_dot_product_flash_attention_backward_strategy(
|
||||
batch_dim_sharding.extend([Replicate()] * (num_tensor_inputs - 6))
|
||||
single_mesh_dim_strategies.append(batch_dim_sharding)
|
||||
|
||||
# Context Parallelism: shards on the sequence dim
|
||||
seq_dim_sharding: PlacementList = [
|
||||
Shard(2), # grad_q
|
||||
Shard(2), # grad_k
|
||||
Shard(2), # grad_v
|
||||
Shard(2), # grad_output
|
||||
Shard(2), # q
|
||||
Shard(2), # k
|
||||
Shard(2), # v
|
||||
Shard(2), # output
|
||||
Shard(2), # logsumexp
|
||||
]
|
||||
# accept replicate on the rest tensor inputs, potentially
|
||||
# cum_seq_q, cum_seq_k, philox_seed, philox_offset
|
||||
# at indices 6, 7, 12, 13, respectively
|
||||
seq_dim_sharding.extend([Replicate()] * (num_tensor_inputs - 6))
|
||||
single_mesh_dim_strategies.append(seq_dim_sharding)
|
||||
return single_mesh_dim_strategies
|
||||
|
||||
|
||||
@register_op_strategy(aten._scaled_dot_product_flash_attention_backward.default)
|
||||
def scaled_dot_product_flash_attention_backward_strategy(
|
||||
op_schema: OpSchema,
|
||||
) -> OpStrategy:
|
||||
# backward op does not need to validate the mesh since forward op has already done it
|
||||
mesh = op_schema.get_mesh_from_args(validate=False)
|
||||
single_mesh_dim_strategies = (
|
||||
_scaled_dot_product_flash_attention_backward_base_strategies(op_schema)
|
||||
)
|
||||
return expand_to_full_mesh_op_strategy(
|
||||
mesh, op_schema, single_mesh_dim_strategies, input_index=3
|
||||
)
|
||||
@ -486,13 +467,10 @@ def constant_pad_nd_strategy(op_schema: OpSchema) -> OpStrategy:
|
||||
)
|
||||
|
||||
|
||||
@register_op_strategy(
|
||||
aten._scaled_dot_product_efficient_attention.default,
|
||||
schema_info=RuntimeSchemaInfo(4),
|
||||
)
|
||||
def scaled_dot_product_efficient_attention_strategy(op_schema: OpSchema) -> OpStrategy:
|
||||
# NOTE: currently we only support some simple strategies to support tensor parallelism
|
||||
mesh = op_schema.get_mesh_from_args()
|
||||
def _scaled_dot_product_efficient_attention_base_strategies(
|
||||
op_schema: OpSchema,
|
||||
) -> list[PlacementList]:
|
||||
"""Helper that returns list of base placement strategies (without CP)."""
|
||||
q_input_strategy = op_schema.args_schema[0]
|
||||
if not isinstance(q_input_strategy, OpStrategy):
|
||||
raise AssertionError(f"Expected OpStrategy, got {type(q_input_strategy)}")
|
||||
@ -518,19 +496,6 @@ def scaled_dot_product_efficient_attention_strategy(op_schema: OpSchema) -> OpSt
|
||||
if has_attn_bias:
|
||||
all_replicate.append(Replicate()) # attn bias
|
||||
|
||||
# Context Parallelism: shards on the sequence dim
|
||||
single_mesh_dim_strategies.append(
|
||||
[
|
||||
Shard(2), # output
|
||||
Shard(2), # logsumexp
|
||||
None, # philox_seed
|
||||
None, # philox_offset
|
||||
Shard(2), # q
|
||||
Shard(2), # k
|
||||
Shard(2), # v
|
||||
]
|
||||
)
|
||||
|
||||
single_mesh_dim_strategies.append(all_replicate)
|
||||
|
||||
# second we can accept the sharding pattern of tensor parallelism, which
|
||||
@ -576,6 +541,19 @@ def scaled_dot_product_efficient_attention_strategy(op_schema: OpSchema) -> OpSt
|
||||
|
||||
single_mesh_dim_strategies.append(batch_sharding)
|
||||
|
||||
return single_mesh_dim_strategies
|
||||
|
||||
|
||||
@register_op_strategy(
|
||||
aten._scaled_dot_product_efficient_attention.default,
|
||||
schema_info=RuntimeSchemaInfo(4),
|
||||
)
|
||||
def scaled_dot_product_efficient_attention_strategy(op_schema: OpSchema) -> OpStrategy:
|
||||
# NOTE: currently we only support some simple strategies to support tensor parallelism
|
||||
mesh = op_schema.get_mesh_from_args()
|
||||
single_mesh_dim_strategies = (
|
||||
_scaled_dot_product_efficient_attention_base_strategies(op_schema)
|
||||
)
|
||||
return expand_to_full_mesh_op_strategy(
|
||||
mesh,
|
||||
op_schema,
|
||||
@ -584,13 +562,10 @@ def scaled_dot_product_efficient_attention_strategy(op_schema: OpSchema) -> OpSt
|
||||
)
|
||||
|
||||
|
||||
@register_op_strategy(aten._scaled_dot_product_efficient_attention_backward.default)
|
||||
def scaled_dot_product_efficient_attention_backward_strategy(
|
||||
def _scaled_dot_product_efficient_attention_backward_base_strategies(
|
||||
op_schema: OpSchema,
|
||||
) -> OpStrategy:
|
||||
# backward op does not need to validate the mesh since forward op has already done it
|
||||
mesh = op_schema.get_mesh_from_args(validate=False)
|
||||
|
||||
) -> list[PlacementList]:
|
||||
"""Helper that returns list of base placement strategies (without CP)."""
|
||||
q_input_strategy = op_schema.args_schema[1]
|
||||
if not isinstance(q_input_strategy, OpStrategy):
|
||||
raise AssertionError(f"Expected OpStrategy, got {type(q_input_strategy)}")
|
||||
@ -662,27 +637,18 @@ def scaled_dot_product_efficient_attention_backward_strategy(
|
||||
batch_dim_sharding.extend([Replicate(), Replicate()])
|
||||
single_mesh_dim_strategies.append(batch_dim_sharding)
|
||||
|
||||
# Context Parallelism: shards on the sequence dim
|
||||
seq_dim_sharding: PlacementList = [
|
||||
Shard(2), # grad_q
|
||||
Shard(2), # grad_k
|
||||
Shard(2), # grad_v
|
||||
Shard(1) if has_attn_bias else None, # grad_bias
|
||||
Shard(2), # grad_output
|
||||
Shard(2), # q
|
||||
Shard(2), # k
|
||||
Shard(2), # v
|
||||
Shard(2), # output
|
||||
Shard(2), # logsumexp
|
||||
]
|
||||
# accept replicate on the rest tensor inputs, potentially
|
||||
# cum_seq_q, cum_seq_k, philox_seed, philox_offset
|
||||
# at indices 6, 7, 12, 13, respectively
|
||||
if has_attn_bias:
|
||||
num_heads_dim_sharding.insert(8, Shard(1))
|
||||
seq_dim_sharding.extend([Replicate(), Replicate()])
|
||||
single_mesh_dim_strategies.append(seq_dim_sharding)
|
||||
return single_mesh_dim_strategies
|
||||
|
||||
|
||||
@register_op_strategy(aten._scaled_dot_product_efficient_attention_backward.default)
|
||||
def scaled_dot_product_efficient_attention_backward_strategy(
|
||||
op_schema: OpSchema,
|
||||
) -> OpStrategy:
|
||||
# backward op does not need to validate the mesh since forward op has already done it
|
||||
mesh = op_schema.get_mesh_from_args(validate=False)
|
||||
single_mesh_dim_strategies = (
|
||||
_scaled_dot_product_efficient_attention_backward_base_strategies(op_schema)
|
||||
)
|
||||
return expand_to_full_mesh_op_strategy(
|
||||
mesh,
|
||||
op_schema,
|
||||
@ -691,13 +657,10 @@ def scaled_dot_product_efficient_attention_backward_strategy(
|
||||
)
|
||||
|
||||
|
||||
@register_op_strategy(
|
||||
aten._scaled_dot_product_cudnn_attention.default,
|
||||
schema_info=RuntimeSchemaInfo(4),
|
||||
)
|
||||
def scaled_dot_product_cudnn_attention_strategy(op_schema: OpSchema) -> OpStrategy:
|
||||
mesh = op_schema.get_mesh_from_args()
|
||||
|
||||
def _scaled_dot_product_cudnn_attention_base_strategies(
|
||||
op_schema: OpSchema,
|
||||
) -> list[PlacementList]:
|
||||
"""Helper that returns list of base placement strategies (without CP)."""
|
||||
(
|
||||
query_strategy, # query
|
||||
_, # key
|
||||
@ -785,39 +748,27 @@ def scaled_dot_product_cudnn_attention_strategy(op_schema: OpSchema) -> OpStrate
|
||||
]
|
||||
single_mesh_dim_strategies.append(batch_dim_sharding)
|
||||
|
||||
# Context Parallelism: shards on the sequence dim
|
||||
cp_sharding = Shard(2) # seq dim
|
||||
logsumexp_sharding = cp_sharding if compute_log_sumexp else Replicate()
|
||||
debug_attn_mask_sharding = cp_sharding if return_debug_mask else None
|
||||
return single_mesh_dim_strategies
|
||||
|
||||
single_mesh_dim_strategies.append(
|
||||
[
|
||||
cp_sharding, # output
|
||||
logsumexp_sharding, # logsumexp
|
||||
None, # cum_seq_q
|
||||
None, # cum_seq_k
|
||||
None, # max_q
|
||||
None, # max_k
|
||||
None, # philox_seed
|
||||
None, # philox_offset
|
||||
debug_attn_mask_sharding, # debug_attn_mask
|
||||
cp_sharding, # q
|
||||
cp_sharding, # k
|
||||
cp_sharding, # v
|
||||
]
|
||||
|
||||
@register_op_strategy(
|
||||
aten._scaled_dot_product_cudnn_attention.default,
|
||||
schema_info=RuntimeSchemaInfo(4),
|
||||
)
|
||||
def scaled_dot_product_cudnn_attention_strategy(op_schema: OpSchema) -> OpStrategy:
|
||||
mesh = op_schema.get_mesh_from_args()
|
||||
single_mesh_dim_strategies = _scaled_dot_product_cudnn_attention_base_strategies(
|
||||
op_schema
|
||||
)
|
||||
return expand_to_full_mesh_op_strategy(
|
||||
mesh, op_schema, single_mesh_dim_strategies, input_index=9
|
||||
)
|
||||
|
||||
|
||||
@register_op_strategy(aten._scaled_dot_product_cudnn_attention_backward.default)
|
||||
def scaled_scaled_dot_product_cudnn_attention_backward_strategy(
|
||||
def _scaled_dot_product_cudnn_attention_backward_base_strategies(
|
||||
op_schema: OpSchema,
|
||||
) -> OpStrategy:
|
||||
# backward op does not need to validate the mesh since forward op has already done it
|
||||
mesh = op_schema.get_mesh_from_args(validate=False)
|
||||
|
||||
) -> list[PlacementList]:
|
||||
"""Helper that returns list of base placement strategies (without CP)."""
|
||||
if len(op_schema.args_schema) < 15:
|
||||
raise AssertionError(
|
||||
f"Expected at least 15 args_schema, got {len(op_schema.args_schema)}"
|
||||
@ -892,23 +843,7 @@ def scaled_scaled_dot_product_cudnn_attention_backward_strategy(
|
||||
num_heads_dim_sharding = num_heads_dim_sharding_out + num_heads_dim_sharding_inp
|
||||
single_mesh_dim_strategies.append(num_heads_dim_sharding)
|
||||
|
||||
# case 3: Context Parallelism which shards on the sequence dim
|
||||
context_parallel_sharding_out: PlacementList = [Shard(2)] * 3
|
||||
context_parallel_sharding_inp: PlacementList = [Shard(2)] * 6
|
||||
context_parallel_sharding_inp += [
|
||||
Replicate()
|
||||
] * 2 # philox_seed, philox_offset is casted to Replicate() in DTensor
|
||||
context_parallel_sharding_inp += [Shard(2) if has_attn_bias else None]
|
||||
context_parallel_sharding_inp += [None] * 6
|
||||
if has_scale:
|
||||
context_parallel_sharding_inp.append(None)
|
||||
|
||||
context_parallel_sharding = (
|
||||
context_parallel_sharding_out + context_parallel_sharding_inp
|
||||
)
|
||||
single_mesh_dim_strategies.append(context_parallel_sharding)
|
||||
|
||||
# case 4: we can accept the sharding pattern of batch parallelism, which
|
||||
# case 3: we can accept the sharding pattern of batch parallelism, which
|
||||
# shards on the batch dimension
|
||||
qkv_sharding = Shard(0)
|
||||
output_sharding = Shard(0)
|
||||
@ -929,6 +864,18 @@ def scaled_scaled_dot_product_cudnn_attention_backward_strategy(
|
||||
batch_dim_sharding = batch_dim_sharding_out + batch_dim_sharding_inp
|
||||
single_mesh_dim_strategies.append(batch_dim_sharding)
|
||||
|
||||
return single_mesh_dim_strategies
|
||||
|
||||
|
||||
@register_op_strategy(aten._scaled_dot_product_cudnn_attention_backward.default)
|
||||
def scaled_scaled_dot_product_cudnn_attention_backward_strategy(
|
||||
op_schema: OpSchema,
|
||||
) -> OpStrategy:
|
||||
# backward op does not need to validate the mesh since forward op has already done it
|
||||
mesh = op_schema.get_mesh_from_args(validate=False)
|
||||
single_mesh_dim_strategies = (
|
||||
_scaled_dot_product_cudnn_attention_backward_base_strategies(op_schema)
|
||||
)
|
||||
return expand_to_full_mesh_op_strategy(
|
||||
mesh, op_schema, single_mesh_dim_strategies, input_index=3
|
||||
)
|
||||
|
||||
@ -989,16 +989,31 @@ def _restore_function(fn: Callable, fn_module: types.ModuleType) -> None:
|
||||
|
||||
def _enable_cp_dtensor_dispatcher() -> None:
|
||||
"""Enables DTensor dispatcher to dispatch SDPA to CP."""
|
||||
# Enable custom op handlers for CP
|
||||
DTensor._op_dispatcher._custom_op_handlers = {
|
||||
**exitsing_custom_ops,
|
||||
**custom_ops,
|
||||
}
|
||||
# Register CP-specific sharding rules
|
||||
from ._sharding_rules import register_cp_sharding_rules
|
||||
|
||||
register_cp_sharding_rules()
|
||||
|
||||
|
||||
def _disable_cp_dtensor_dispatcher() -> None:
|
||||
"""Disables DTensor dispatcher to dispatch SDPA to CP."""
|
||||
# Restore original custom op handlers
|
||||
DTensor._op_dispatcher._custom_op_handlers = exitsing_custom_ops
|
||||
|
||||
# TODO: unregister_cp_sharding_rules() will cause all DTensor sharding
|
||||
# propagation cache being invalidated. It is not easy to achieve
|
||||
# selectively invalidating lru cache without rewriting the sharding
|
||||
# propagation wrapper. Disable unregister_cp_sharding_rules() call
|
||||
# for now.
|
||||
|
||||
# from ._sharding_rules import unregister_cp_sharding_rules
|
||||
# unregister_cp_sharding_rules()
|
||||
|
||||
|
||||
def _enable_context_parallel_dispatcher_impl(seq_dim: int, mesh: DeviceMesh) -> None:
|
||||
sdpa_cp = _ContextParallel(
|
||||
@ -1032,9 +1047,7 @@ def _disable_context_parallel_dispatcher_impl() -> None:
|
||||
_disable_cp_dtensor_dispatcher()
|
||||
|
||||
|
||||
_compiled_create_block_mask = torch.compile(
|
||||
create_block_mask, dynamic=False, fullgraph=True
|
||||
)
|
||||
_compiled_create_block_mask = None
|
||||
|
||||
|
||||
def _context_parallel_buffers(
|
||||
@ -1187,9 +1200,12 @@ def _create_cp_block_mask(
|
||||
f"BLOCK_SIZE {_DEFAULT_SPARSE_BLOCK_SIZE}. This is not supported yet. "
|
||||
)
|
||||
|
||||
compiled_create_block_mask = torch.compile(
|
||||
create_block_mask, dynamic=False, fullgraph=True
|
||||
)
|
||||
global _compiled_create_block_mask
|
||||
if _compiled_create_block_mask is None:
|
||||
_compiled_create_block_mask = torch.compile(
|
||||
create_block_mask, dynamic=False, fullgraph=True
|
||||
)
|
||||
compiled_create_block_mask = _compiled_create_block_mask
|
||||
|
||||
def _rewrite_mask_mod(
|
||||
mask_mod: _mask_mod_signature,
|
||||
|
||||
@ -0,0 +1,399 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
"""
|
||||
Context Parallelism sharding rules for scaled_dot_product attention operators.
|
||||
|
||||
The sharding rules for CP cannot be embedded by default because Shard(2) is not
|
||||
a valid sharding for SDPA without CP enabled. This module provides utilities to
|
||||
dynamically install Shard(2) sharding rules when CP is activated.
|
||||
"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
from torch.distributed.tensor._op_schema import (
|
||||
OpSchema,
|
||||
OpStrategy,
|
||||
PlacementList,
|
||||
RuntimeSchemaInfo,
|
||||
)
|
||||
from torch.distributed.tensor._ops.utils import (
|
||||
expand_to_full_mesh_op_strategy,
|
||||
register_op_strategy,
|
||||
)
|
||||
from torch.distributed.tensor.placement_types import Replicate, Shard
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
SEQ_DIM = 2
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _op_strategy_context(op_overload, strategy_func, schema_info=None):
|
||||
"""
|
||||
Context manager for setting and clearing op strategies for Context Parallelism.
|
||||
|
||||
Args:
|
||||
op_overload: The operator overload to set or clear the strategy for.
|
||||
strategy_func: The strategy function to set for the operator overload.
|
||||
schema_info: Optional schema information for the operator overload.
|
||||
|
||||
Yields:
|
||||
None
|
||||
"""
|
||||
from torch.distributed.tensor import DTensor
|
||||
|
||||
propagator = DTensor._op_dispatcher.sharding_propagator
|
||||
_origin_op_strategy_funcs = None
|
||||
_origin_op_strategy_schema = None
|
||||
try:
|
||||
# Save original strategy if exists
|
||||
if op_overload in propagator.op_strategy_funcs:
|
||||
_origin_op_strategy_funcs = propagator.op_strategy_funcs[op_overload]
|
||||
if op_overload in propagator.op_to_schema_info:
|
||||
_origin_op_strategy_schema = propagator.op_to_schema_info[op_overload]
|
||||
|
||||
# Register the new op strategy
|
||||
register_op_strategy(op_overload, schema_info=schema_info)(strategy_func)
|
||||
yield (_origin_op_strategy_funcs, _origin_op_strategy_schema)
|
||||
finally:
|
||||
# Restore original strategy
|
||||
if _origin_op_strategy_funcs is None:
|
||||
if op_overload in propagator.op_strategy_funcs:
|
||||
del propagator.op_strategy_funcs[op_overload]
|
||||
else:
|
||||
propagator.op_strategy_funcs[op_overload] = _origin_op_strategy_funcs
|
||||
|
||||
if _origin_op_strategy_schema is None:
|
||||
if op_overload in propagator.op_to_schema_info:
|
||||
del propagator.op_to_schema_info[op_overload]
|
||||
else:
|
||||
propagator.op_to_schema_info[op_overload] = _origin_op_strategy_schema
|
||||
|
||||
# Clear cache
|
||||
propagator.propagate_op_sharding.cache.cache_clear()
|
||||
|
||||
|
||||
# ==================== Flash Attention Strategies ====================
|
||||
|
||||
|
||||
def _scaled_dot_product_flash_attention_cp_strategy(op_schema: OpSchema) -> OpStrategy:
|
||||
"""
|
||||
Strategy for flash attention forward with Context Parallelism support.
|
||||
This includes the base strategies plus CP-specific sequence dimension sharding.
|
||||
"""
|
||||
# Import here to avoid circular dependency
|
||||
from torch.distributed.tensor._ops._matrix_ops import (
|
||||
_scaled_dot_product_flash_attention_base_strategies,
|
||||
)
|
||||
|
||||
# Get the base strategies (without CP modifications)
|
||||
mesh = op_schema.get_mesh_from_args()
|
||||
single_mesh_dim_strategies = _scaled_dot_product_flash_attention_base_strategies(
|
||||
op_schema
|
||||
)
|
||||
|
||||
# Add Context Parallelism strategy: shards on the sequence dim
|
||||
return_debug_mask = len(op_schema.args_schema) >= 6 and op_schema.args_schema[5]
|
||||
debug_attn_mask_sharding = Shard(SEQ_DIM) if return_debug_mask else Replicate()
|
||||
|
||||
cp_strategy: PlacementList = [
|
||||
Shard(SEQ_DIM), # output
|
||||
Shard(SEQ_DIM), # logsumexp
|
||||
None, # cum_seq_q
|
||||
None, # cum_seq_k
|
||||
None, # max_q
|
||||
None, # max_k
|
||||
Replicate(), # rng_state
|
||||
None, # unused
|
||||
debug_attn_mask_sharding, # debugattn
|
||||
Shard(SEQ_DIM), # q
|
||||
Shard(SEQ_DIM), # k
|
||||
Shard(SEQ_DIM), # v
|
||||
]
|
||||
single_mesh_dim_strategies.append(cp_strategy)
|
||||
|
||||
return expand_to_full_mesh_op_strategy(
|
||||
mesh, op_schema, single_mesh_dim_strategies, input_index=9
|
||||
)
|
||||
|
||||
|
||||
def _scaled_dot_product_flash_attention_backward_cp_strategy(
|
||||
op_schema: OpSchema,
|
||||
) -> OpStrategy:
|
||||
"""
|
||||
Strategy for flash attention backward with Context Parallelism support.
|
||||
"""
|
||||
from torch.distributed.tensor._ops._matrix_ops import (
|
||||
_scaled_dot_product_flash_attention_backward_base_strategies,
|
||||
)
|
||||
|
||||
mesh = op_schema.get_mesh_from_args(validate=False)
|
||||
single_mesh_dim_strategies = (
|
||||
_scaled_dot_product_flash_attention_backward_base_strategies(op_schema)
|
||||
)
|
||||
|
||||
tensor_input_indices = [
|
||||
i
|
||||
for i, arg_spec in enumerate(op_schema.args_schema)
|
||||
if isinstance(arg_spec, OpStrategy)
|
||||
]
|
||||
num_tensor_inputs = len(tensor_input_indices)
|
||||
|
||||
# Context Parallelism: shards on the sequence dim
|
||||
cp_strategy: PlacementList = [
|
||||
Shard(SEQ_DIM), # grad_q
|
||||
Shard(SEQ_DIM), # grad_k
|
||||
Shard(SEQ_DIM), # grad_v
|
||||
Shard(SEQ_DIM), # grad_output
|
||||
Shard(SEQ_DIM), # q
|
||||
Shard(SEQ_DIM), # k
|
||||
Shard(SEQ_DIM), # v
|
||||
Shard(SEQ_DIM), # output
|
||||
Shard(SEQ_DIM), # logsumexp
|
||||
]
|
||||
cp_strategy.extend([Replicate()] * (num_tensor_inputs - 6))
|
||||
single_mesh_dim_strategies.append(cp_strategy)
|
||||
|
||||
return expand_to_full_mesh_op_strategy(
|
||||
mesh, op_schema, single_mesh_dim_strategies, input_index=3
|
||||
)
|
||||
|
||||
|
||||
# ==================== Efficient Attention Strategies ====================
|
||||
|
||||
|
||||
def _scaled_dot_product_efficient_attention_cp_strategy(
|
||||
op_schema: OpSchema,
|
||||
) -> OpStrategy:
|
||||
"""
|
||||
Strategy for efficient attention forward with Context Parallelism support.
|
||||
"""
|
||||
from torch.distributed.tensor._ops._matrix_ops import (
|
||||
_scaled_dot_product_efficient_attention_base_strategies,
|
||||
)
|
||||
|
||||
mesh = op_schema.get_mesh_from_args()
|
||||
single_mesh_dim_strategies = (
|
||||
_scaled_dot_product_efficient_attention_base_strategies(op_schema)
|
||||
)
|
||||
|
||||
# Add Context Parallelism strategy
|
||||
has_attn_bias = op_schema.args_schema[3] is not None
|
||||
|
||||
cp_strategy: PlacementList = [
|
||||
Shard(SEQ_DIM), # output
|
||||
Shard(SEQ_DIM), # logsumexp
|
||||
None, # philox_seed
|
||||
None, # philox_offset
|
||||
Shard(SEQ_DIM), # q
|
||||
Shard(SEQ_DIM), # k
|
||||
Shard(SEQ_DIM), # v
|
||||
]
|
||||
if has_attn_bias:
|
||||
cp_strategy.append(Replicate()) # attn bias - not sharded for CP
|
||||
single_mesh_dim_strategies.append(cp_strategy)
|
||||
|
||||
return expand_to_full_mesh_op_strategy(
|
||||
mesh, op_schema, single_mesh_dim_strategies, input_index=4
|
||||
)
|
||||
|
||||
|
||||
def _scaled_dot_product_efficient_attention_backward_cp_strategy(
|
||||
op_schema: OpSchema,
|
||||
) -> OpStrategy:
|
||||
"""
|
||||
Strategy for efficient attention backward with Context Parallelism support.
|
||||
"""
|
||||
from torch.distributed.tensor._ops._matrix_ops import (
|
||||
_scaled_dot_product_efficient_attention_backward_base_strategies,
|
||||
)
|
||||
|
||||
mesh = op_schema.get_mesh_from_args(validate=False)
|
||||
single_mesh_dim_strategies = (
|
||||
_scaled_dot_product_efficient_attention_backward_base_strategies(op_schema)
|
||||
)
|
||||
|
||||
has_attn_bias = op_schema.args_schema[4] is not None
|
||||
|
||||
# Context Parallelism: shards on the sequence dim
|
||||
cp_strategy: PlacementList = [
|
||||
Shard(SEQ_DIM), # grad_q
|
||||
Shard(SEQ_DIM), # grad_k
|
||||
Shard(SEQ_DIM), # grad_v
|
||||
Shard(1) if has_attn_bias else None, # grad_bias
|
||||
Shard(SEQ_DIM), # grad_output
|
||||
Shard(SEQ_DIM), # q
|
||||
Shard(SEQ_DIM), # k
|
||||
Shard(SEQ_DIM), # v
|
||||
Shard(SEQ_DIM), # output
|
||||
Shard(SEQ_DIM), # logsumexp
|
||||
]
|
||||
if has_attn_bias:
|
||||
cp_strategy.insert(8, Shard(1)) # attn_bias input
|
||||
cp_strategy.extend([Replicate(), Replicate()])
|
||||
single_mesh_dim_strategies.append(cp_strategy)
|
||||
|
||||
return expand_to_full_mesh_op_strategy(
|
||||
mesh, op_schema, single_mesh_dim_strategies, input_index=4
|
||||
)
|
||||
|
||||
|
||||
# ==================== cuDNN Attention Strategies ====================
|
||||
|
||||
|
||||
def _scaled_dot_product_cudnn_attention_cp_strategy(op_schema: OpSchema) -> OpStrategy:
|
||||
"""
|
||||
Strategy for cudnn attention forward with Context Parallelism support.
|
||||
"""
|
||||
from torch.distributed.tensor._ops._matrix_ops import (
|
||||
_scaled_dot_product_cudnn_attention_base_strategies,
|
||||
)
|
||||
|
||||
mesh = op_schema.get_mesh_from_args()
|
||||
single_mesh_dim_strategies = _scaled_dot_product_cudnn_attention_base_strategies(
|
||||
op_schema
|
||||
)
|
||||
|
||||
(
|
||||
query_strategy,
|
||||
_,
|
||||
_,
|
||||
attn_bias_strategy,
|
||||
compute_log_sumexp,
|
||||
*rest_args,
|
||||
) = op_schema.args_schema
|
||||
return_debug_mask = len(op_schema.args_schema) >= 8 and rest_args[2]
|
||||
has_attn_bias = attn_bias_strategy is not None
|
||||
|
||||
# Context Parallelism: shards on the sequence dim
|
||||
logsumexp_sharding = Shard(SEQ_DIM) if compute_log_sumexp else Replicate()
|
||||
debug_attn_mask_sharding = Shard(SEQ_DIM) if return_debug_mask else None
|
||||
|
||||
cp_strategy: PlacementList = [
|
||||
Shard(SEQ_DIM), # output
|
||||
logsumexp_sharding, # logsumexp
|
||||
None, # cum_seq_q
|
||||
None, # cum_seq_k
|
||||
None, # max_q
|
||||
None, # max_k
|
||||
None, # philox_seed
|
||||
None, # philox_offset
|
||||
debug_attn_mask_sharding, # debug_attn_mask
|
||||
Shard(SEQ_DIM), # q
|
||||
Shard(SEQ_DIM), # k
|
||||
Shard(SEQ_DIM), # v
|
||||
]
|
||||
if has_attn_bias:
|
||||
cp_strategy.append(Replicate()) # attn_bias - not sharded for CP
|
||||
single_mesh_dim_strategies.append(cp_strategy)
|
||||
|
||||
return expand_to_full_mesh_op_strategy(
|
||||
mesh, op_schema, single_mesh_dim_strategies, input_index=9
|
||||
)
|
||||
|
||||
|
||||
def _scaled_dot_product_cudnn_attention_backward_cp_strategy(
|
||||
op_schema: OpSchema,
|
||||
) -> OpStrategy:
|
||||
"""
|
||||
Strategy for cudnn attention backward with Context Parallelism support.
|
||||
"""
|
||||
from torch.distributed.tensor._ops._matrix_ops import (
|
||||
_scaled_dot_product_cudnn_attention_backward_base_strategies,
|
||||
)
|
||||
|
||||
mesh = op_schema.get_mesh_from_args(validate=False)
|
||||
single_mesh_dim_strategies = (
|
||||
_scaled_dot_product_cudnn_attention_backward_base_strategies(op_schema)
|
||||
)
|
||||
|
||||
has_attn_bias = op_schema.args_schema[8] is not None
|
||||
has_scale = len(op_schema.args_schema) >= 16 and False
|
||||
|
||||
# Context Parallelism: shards on the sequence dim
|
||||
cp_sharding_gout: PlacementList = [Shard(SEQ_DIM)] * 3 # grad_q, grad_k, grad_v
|
||||
cp_sharding_ginp: PlacementList = [
|
||||
Shard(SEQ_DIM)
|
||||
] * 6 # grad_output, q, k, v, output, logsumexp
|
||||
cp_sharding_ginp += [Replicate()] * 2 # philox_seed, philox_offset
|
||||
cp_sharding_ginp += [Shard(SEQ_DIM) if has_attn_bias else None] # attn_bias
|
||||
cp_sharding_ginp += [
|
||||
None
|
||||
] * 6 # cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal
|
||||
if has_scale:
|
||||
cp_sharding_ginp.append(None)
|
||||
|
||||
cp_sharding = cp_sharding_gout + cp_sharding_ginp
|
||||
single_mesh_dim_strategies.append(cp_sharding)
|
||||
|
||||
return expand_to_full_mesh_op_strategy(
|
||||
mesh, op_schema, single_mesh_dim_strategies, input_index=3
|
||||
)
|
||||
|
||||
|
||||
# Store context managers and original strategies
|
||||
_cp_strategy_contexts = {}
|
||||
_original_strategies = {}
|
||||
|
||||
|
||||
def register_cp_sharding_rules():
|
||||
"""Register Context Parallelism sharding rules for all scaled_dot_product ops."""
|
||||
global _cp_strategy_contexts, _original_strategies
|
||||
|
||||
# If already registered, don't register again
|
||||
if _cp_strategy_contexts:
|
||||
return
|
||||
|
||||
# Define ops and their corresponding CP strategy functions
|
||||
cp_strategies = [
|
||||
(
|
||||
aten._scaled_dot_product_flash_attention.default,
|
||||
_scaled_dot_product_flash_attention_cp_strategy,
|
||||
RuntimeSchemaInfo(5),
|
||||
),
|
||||
(
|
||||
aten._scaled_dot_product_flash_attention_backward.default,
|
||||
_scaled_dot_product_flash_attention_backward_cp_strategy,
|
||||
None,
|
||||
),
|
||||
(
|
||||
aten._scaled_dot_product_efficient_attention.default,
|
||||
_scaled_dot_product_efficient_attention_cp_strategy,
|
||||
RuntimeSchemaInfo(4),
|
||||
),
|
||||
(
|
||||
aten._scaled_dot_product_efficient_attention_backward.default,
|
||||
_scaled_dot_product_efficient_attention_backward_cp_strategy,
|
||||
None,
|
||||
),
|
||||
(
|
||||
aten._scaled_dot_product_cudnn_attention.default,
|
||||
_scaled_dot_product_cudnn_attention_cp_strategy,
|
||||
RuntimeSchemaInfo(4),
|
||||
),
|
||||
(
|
||||
aten._scaled_dot_product_cudnn_attention_backward.default,
|
||||
_scaled_dot_product_cudnn_attention_backward_cp_strategy,
|
||||
None,
|
||||
),
|
||||
]
|
||||
|
||||
# Register each strategy
|
||||
for op_overload, strategy_func, schema_info in cp_strategies:
|
||||
ctx = _op_strategy_context(op_overload, strategy_func, schema_info)
|
||||
orig_funcs, orig_schema = ctx.__enter__()
|
||||
_cp_strategy_contexts[op_overload] = ctx
|
||||
_original_strategies[op_overload] = (orig_funcs, orig_schema)
|
||||
|
||||
|
||||
def unregister_cp_sharding_rules():
|
||||
"""Unregister Context Parallelism sharding rules and restore original strategies."""
|
||||
global _cp_strategy_contexts, _original_strategies
|
||||
|
||||
# Exit all context managers
|
||||
for ctx in _cp_strategy_contexts.values():
|
||||
ctx.__exit__(None, None, None)
|
||||
|
||||
_cp_strategy_contexts = {}
|
||||
_original_strategies = {}
|
||||
@ -529,14 +529,11 @@ RE_EXTERN_SHARED = re.compile(r"extern\s+([\w\(\)]+)?\s*__shared__\s+([\w:<>\s]+
|
||||
|
||||
|
||||
def replace_extern_shared(input_string):
|
||||
"""
|
||||
Match 'extern __shared__ type foo[];' syntax and use HIP_DYNAMIC_SHARED() MACRO instead.
|
||||
See: https://github.com/ROCm/hip/blob/master/docs/markdown/hip_kernel_language.md#__shared__
|
||||
Examples:
|
||||
"extern __shared__ char smemChar[];"
|
||||
=> "HIP_DYNAMIC_SHARED( char, smemChar)"
|
||||
"extern __shared__ unsigned char smem[];"
|
||||
=> "HIP_DYNAMIC_SHARED( unsigned char, my_smem)"
|
||||
"""Match extern __shared__ type foo[]; syntax and use HIP_DYNAMIC_SHARED() MACRO instead.
|
||||
https://github.com/ROCm/hip/blob/master/docs/markdown/hip_kernel_language.md#__shared__
|
||||
Example:
|
||||
"extern __shared__ char smemChar[];" => "HIP_DYNAMIC_SHARED( char, smemChar)"
|
||||
"extern __shared__ unsigned char smem[];" => "HIP_DYNAMIC_SHARED( unsigned char, my_smem)"
|
||||
"""
|
||||
output_string = input_string
|
||||
output_string = RE_EXTERN_SHARED.sub(
|
||||
@ -1046,17 +1043,14 @@ RE_INCLUDE = re.compile(r"#include .*\n")
|
||||
|
||||
|
||||
def extract_arguments(start, string):
|
||||
"""
|
||||
Return the list of arguments in the upcoming function parameter closure.
|
||||
Example:
|
||||
""" Return the list of arguments in the upcoming function parameter closure.
|
||||
Example:
|
||||
string (input): '(blocks, threads, 0, THCState_getCurrentStream(state))'
|
||||
arguments (output):
|
||||
[
|
||||
{'start': 1, 'end': 7},
|
||||
{'start': 8, 'end': 16},
|
||||
{'start': 17, 'end': 19},
|
||||
{'start': 20, 'end': 53}
|
||||
]
|
||||
'[{'start': 1, 'end': 7},
|
||||
{'start': 8, 'end': 16},
|
||||
{'start': 17, 'end': 19},
|
||||
{'start': 20, 'end': 53}]'
|
||||
"""
|
||||
|
||||
arguments = []
|
||||
|
||||
Reference in New Issue
Block a user