mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-14 06:07:55 +08:00
Compare commits
14 Commits
documentat
...
ciflow/ind
| Author | SHA1 | Date | |
|---|---|---|---|
| 8c6c066024 | |||
| 87ae43a0d6 | |||
| 2cff58cc2d | |||
| 46bd412746 | |||
| 5d07795b28 | |||
| 306914e071 | |||
| fb6807bf86 | |||
| f6a79b2a4a | |||
| 2fcf41dd8e | |||
| 31ccd8f13e | |||
| 1209123500 | |||
| 168fc7cfc5 | |||
| 639276822f | |||
| 8cbe13ad31 |
@ -1308,8 +1308,319 @@ coverage_ignore_functions = [
|
|||||||
# torch.onnx.symbolic_opset7
|
# torch.onnx.symbolic_opset7
|
||||||
"max",
|
"max",
|
||||||
"min",
|
"min",
|
||||||
|
# torch.onnx.symbolic_opset8
|
||||||
|
"addmm",
|
||||||
|
"bmm",
|
||||||
|
"empty",
|
||||||
|
"empty_like",
|
||||||
|
"flatten",
|
||||||
|
"full",
|
||||||
|
"full_like",
|
||||||
|
"gt",
|
||||||
|
"lt",
|
||||||
|
"matmul",
|
||||||
|
"mm",
|
||||||
|
"ones",
|
||||||
|
"ones_like",
|
||||||
|
"prelu",
|
||||||
|
"repeat",
|
||||||
|
"zeros",
|
||||||
|
"zeros_like",
|
||||||
|
# torch.onnx.symbolic_opset9
|
||||||
|
"abs",
|
||||||
|
"acos",
|
||||||
|
"adaptive_avg_pool1d",
|
||||||
|
"adaptive_avg_pool2d",
|
||||||
|
"adaptive_avg_pool3d",
|
||||||
|
"adaptive_max_pool1d",
|
||||||
|
"adaptive_max_pool2d",
|
||||||
|
"adaptive_max_pool3d",
|
||||||
|
"add",
|
||||||
|
"addcmul",
|
||||||
|
"addmm",
|
||||||
|
"alias",
|
||||||
|
"amax",
|
||||||
|
"amin",
|
||||||
|
"aminmax",
|
||||||
|
"arange",
|
||||||
|
"argmax",
|
||||||
|
"argmin",
|
||||||
|
"as_strided",
|
||||||
|
"as_tensor",
|
||||||
|
"asin",
|
||||||
|
"atan",
|
||||||
|
"atan2",
|
||||||
|
"avg_pool1d",
|
||||||
|
"avg_pool2d",
|
||||||
|
"avg_pool3d",
|
||||||
|
"baddbmm",
|
||||||
|
"batch_norm",
|
||||||
|
"bernoulli",
|
||||||
|
"bitwise_not",
|
||||||
|
"bitwise_or",
|
||||||
|
"bmm",
|
||||||
|
"broadcast_tensors",
|
||||||
|
"broadcast_to",
|
||||||
|
"bucketize",
|
||||||
|
"cat",
|
||||||
|
"cdist",
|
||||||
|
"ceil",
|
||||||
|
"clamp",
|
||||||
|
"clamp_max",
|
||||||
|
"clamp_min",
|
||||||
|
"clone",
|
||||||
|
"constant_pad_nd",
|
||||||
|
"contiguous",
|
||||||
|
"conv1d",
|
||||||
|
"conv2d",
|
||||||
|
"conv3d",
|
||||||
|
"conv_tbc",
|
||||||
|
"conv_transpose1d",
|
||||||
|
"conv_transpose2d",
|
||||||
|
"conv_transpose3d",
|
||||||
|
"convert_element_type",
|
||||||
|
"convolution",
|
||||||
|
"cos",
|
||||||
|
"cosine_similarity",
|
||||||
|
"cross",
|
||||||
|
"cumsum",
|
||||||
|
"detach",
|
||||||
"dim",
|
"dim",
|
||||||
|
"div",
|
||||||
|
"dot",
|
||||||
|
"dropout",
|
||||||
|
"elu",
|
||||||
|
"embedding",
|
||||||
|
"embedding_bag",
|
||||||
|
"empty",
|
||||||
|
"empty_like",
|
||||||
|
"eq",
|
||||||
|
"erf",
|
||||||
|
"exp",
|
||||||
|
"expand",
|
||||||
|
"expand_as",
|
||||||
|
"eye",
|
||||||
|
"fill",
|
||||||
|
"flatten",
|
||||||
|
"floor",
|
||||||
|
"floor_divide",
|
||||||
|
"floordiv",
|
||||||
|
"frobenius_norm",
|
||||||
|
"full",
|
||||||
|
"full_like",
|
||||||
|
"gather",
|
||||||
|
"ge",
|
||||||
|
"gelu",
|
||||||
|
"get_pool_ceil_padding",
|
||||||
|
"glu",
|
||||||
|
"group_norm",
|
||||||
|
"gru",
|
||||||
|
"gt",
|
||||||
|
"hann_window",
|
||||||
|
"hardshrink",
|
||||||
|
"hardsigmoid",
|
||||||
|
"hardswish",
|
||||||
|
"hardtanh",
|
||||||
|
"index",
|
||||||
|
"index_add",
|
||||||
|
"index_copy",
|
||||||
|
"index_fill",
|
||||||
|
"index_put",
|
||||||
|
"index_select",
|
||||||
|
"instance_norm",
|
||||||
|
"is_floating_point",
|
||||||
|
"is_pinned",
|
||||||
|
"isnan",
|
||||||
|
"item",
|
||||||
|
"kl_div",
|
||||||
|
"layer_norm",
|
||||||
|
"le",
|
||||||
|
"leaky_relu",
|
||||||
|
"lerp",
|
||||||
|
"lift",
|
||||||
|
"linalg_cross",
|
||||||
|
"linalg_matrix_norm",
|
||||||
|
"linalg_norm",
|
||||||
|
"linalg_vector_norm",
|
||||||
|
"linear",
|
||||||
|
"linspace",
|
||||||
|
"log",
|
||||||
|
"log10",
|
||||||
|
"log1p",
|
||||||
|
"log2",
|
||||||
|
"log_sigmoid",
|
||||||
|
"log_softmax",
|
||||||
|
"logical_and",
|
||||||
|
"logical_not",
|
||||||
|
"logical_or",
|
||||||
|
"logical_xor",
|
||||||
|
"logit",
|
||||||
|
"logsumexp",
|
||||||
|
"lstm",
|
||||||
|
"lstm_cell",
|
||||||
|
"lt",
|
||||||
|
"masked_fill",
|
||||||
|
"masked_fill_",
|
||||||
|
"matmul",
|
||||||
|
"max",
|
||||||
|
"max_pool1d",
|
||||||
|
"max_pool1d_with_indices",
|
||||||
|
"max_pool2d",
|
||||||
|
"max_pool2d_with_indices",
|
||||||
|
"max_pool3d",
|
||||||
|
"max_pool3d_with_indices",
|
||||||
|
"maximum",
|
||||||
|
"meshgrid",
|
||||||
|
"min",
|
||||||
|
"minimum",
|
||||||
|
"mish",
|
||||||
|
"mm",
|
||||||
|
"movedim",
|
||||||
|
"mse_loss",
|
||||||
|
"mul",
|
||||||
|
"multinomial",
|
||||||
|
"mv",
|
||||||
|
"narrow",
|
||||||
|
"native_layer_norm",
|
||||||
|
"ne",
|
||||||
|
"neg",
|
||||||
|
"new_empty",
|
||||||
|
"new_full",
|
||||||
|
"new_ones",
|
||||||
|
"new_zeros",
|
||||||
|
"nonzero",
|
||||||
|
"nonzero_numpy",
|
||||||
|
"noop_complex_operators",
|
||||||
|
"norm",
|
||||||
|
"numel",
|
||||||
|
"numpy_T",
|
||||||
|
"one_hot",
|
||||||
|
"ones",
|
||||||
|
"ones_like",
|
||||||
|
"onnx_placeholder",
|
||||||
|
"overload_by_arg_count",
|
||||||
|
"pad",
|
||||||
|
"pairwise_distance",
|
||||||
|
"permute",
|
||||||
|
"pixel_shuffle",
|
||||||
|
"pixel_unshuffle",
|
||||||
|
"pow",
|
||||||
|
"prelu",
|
||||||
|
"prim_constant",
|
||||||
|
"prim_constant_chunk",
|
||||||
|
"prim_constant_split",
|
||||||
|
"prim_data",
|
||||||
|
"prim_device",
|
||||||
|
"prim_dtype",
|
||||||
|
"prim_if",
|
||||||
|
"prim_layout",
|
||||||
|
"prim_list_construct",
|
||||||
|
"prim_list_unpack",
|
||||||
|
"prim_loop",
|
||||||
|
"prim_max",
|
||||||
|
"prim_min",
|
||||||
|
"prim_shape",
|
||||||
|
"prim_tolist",
|
||||||
|
"prim_tuple_construct",
|
||||||
|
"prim_type",
|
||||||
|
"prim_unchecked_cast",
|
||||||
|
"prim_uninitialized",
|
||||||
|
"rand",
|
||||||
|
"rand_like",
|
||||||
|
"randint",
|
||||||
|
"randint_like",
|
||||||
|
"randn",
|
||||||
|
"randn_like",
|
||||||
|
"reciprocal",
|
||||||
|
"reflection_pad",
|
||||||
|
"relu",
|
||||||
|
"relu6",
|
||||||
|
"remainder",
|
||||||
|
"repeat",
|
||||||
|
"repeat_interleave",
|
||||||
|
"replication_pad",
|
||||||
|
"reshape",
|
||||||
|
"reshape_as",
|
||||||
|
"rnn_relu",
|
||||||
|
"rnn_tanh",
|
||||||
|
"roll",
|
||||||
|
"rrelu",
|
||||||
|
"rsqrt",
|
||||||
|
"rsub",
|
||||||
|
"scalar_tensor",
|
||||||
|
"scatter",
|
||||||
|
"scatter_add",
|
||||||
|
"select",
|
||||||
|
"selu",
|
||||||
|
"sigmoid",
|
||||||
|
"sign",
|
||||||
|
"silu",
|
||||||
|
"sin",
|
||||||
|
"size",
|
||||||
|
"slice",
|
||||||
|
"softmax",
|
||||||
|
"softplus",
|
||||||
|
"softshrink",
|
||||||
|
"sort",
|
||||||
|
"split",
|
||||||
|
"split_with_sizes",
|
||||||
|
"sqrt",
|
||||||
|
"square",
|
||||||
|
"squeeze",
|
||||||
|
"stack",
|
||||||
|
"std",
|
||||||
|
"std_mean",
|
||||||
|
"sub",
|
||||||
|
"t",
|
||||||
|
"take",
|
||||||
|
"tan",
|
||||||
|
"tanh",
|
||||||
|
"tanhshrink",
|
||||||
|
"tensor",
|
||||||
|
"threshold",
|
||||||
|
"to",
|
||||||
|
"topk",
|
||||||
|
"transpose",
|
||||||
|
"true_divide",
|
||||||
|
"type_as",
|
||||||
|
"unbind",
|
||||||
|
"unfold",
|
||||||
|
"unsafe_chunk",
|
||||||
|
"unsafe_split",
|
||||||
|
"unsafe_split_with_sizes",
|
||||||
|
"unsqueeze",
|
||||||
|
"unsupported_complex_operators",
|
||||||
|
"unused",
|
||||||
|
"upsample_bilinear2d",
|
||||||
|
"upsample_linear1d",
|
||||||
|
"upsample_nearest1d",
|
||||||
|
"upsample_nearest2d",
|
||||||
|
"upsample_nearest3d",
|
||||||
|
"upsample_trilinear3d",
|
||||||
|
"var",
|
||||||
|
"var_mean",
|
||||||
|
"view",
|
||||||
|
"view_as",
|
||||||
|
"where",
|
||||||
|
"wrap_logical_op_with_cast_to",
|
||||||
|
"wrap_logical_op_with_negation",
|
||||||
|
"zero",
|
||||||
|
"zeros",
|
||||||
|
"zeros_like",
|
||||||
|
# torch.onnx.utils
|
||||||
|
"disable_apex_o2_state_dict_hook",
|
||||||
"export",
|
"export",
|
||||||
|
"export_to_pretty_string",
|
||||||
|
"exporter_context",
|
||||||
|
"is_in_onnx_export",
|
||||||
|
"model_signature",
|
||||||
|
"register_custom_op_symbolic",
|
||||||
|
"select_model_mode_for_export",
|
||||||
|
"setup_onnx_logging",
|
||||||
|
"unconvertible_ops",
|
||||||
|
"unpack_quantized_tensor",
|
||||||
|
"warn_on_static_input_change",
|
||||||
|
# torch.onnx.verification
|
||||||
"check_export_model_diff",
|
"check_export_model_diff",
|
||||||
"verify",
|
"verify",
|
||||||
"verify_aten_graph",
|
"verify_aten_graph",
|
||||||
@ -1400,6 +1711,32 @@ coverage_ignore_functions = [
|
|||||||
"noop_context_fn",
|
"noop_context_fn",
|
||||||
"set_checkpoint_early_stop",
|
"set_checkpoint_early_stop",
|
||||||
"set_device_states",
|
"set_device_states",
|
||||||
|
# torch.utils.collect_env
|
||||||
|
"check_release_file",
|
||||||
|
"get_cachingallocator_config",
|
||||||
|
"get_clang_version",
|
||||||
|
"get_cmake_version",
|
||||||
|
"get_conda_packages",
|
||||||
|
"get_cpu_info",
|
||||||
|
"get_cuda_module_loading_config",
|
||||||
|
"get_cudnn_version",
|
||||||
|
"get_env_info",
|
||||||
|
"get_gcc_version",
|
||||||
|
"get_gpu_info",
|
||||||
|
"get_libc_version",
|
||||||
|
"get_lsb_version",
|
||||||
|
"get_mac_version",
|
||||||
|
"get_nvidia_driver_version",
|
||||||
|
"get_nvidia_smi",
|
||||||
|
"get_os",
|
||||||
|
"get_pip_packages",
|
||||||
|
"get_platform",
|
||||||
|
"get_pretty_env_info",
|
||||||
|
"get_python_platform",
|
||||||
|
"get_running_cuda_version",
|
||||||
|
"get_windows_version",
|
||||||
|
"is_xnnpack_available",
|
||||||
|
"pretty_str",
|
||||||
# torch.utils.cpp_backtrace
|
# torch.utils.cpp_backtrace
|
||||||
"get_cpp_backtrace",
|
"get_cpp_backtrace",
|
||||||
# torch.utils.cpp_extension
|
# torch.utils.cpp_extension
|
||||||
@ -1463,6 +1800,52 @@ coverage_ignore_functions = [
|
|||||||
"apply_shuffle_seed",
|
"apply_shuffle_seed",
|
||||||
"apply_shuffle_settings",
|
"apply_shuffle_settings",
|
||||||
"get_all_graph_pipes",
|
"get_all_graph_pipes",
|
||||||
|
# torch.utils.flop_counter
|
||||||
|
"addmm_flop",
|
||||||
|
"baddbmm_flop",
|
||||||
|
"bmm_flop",
|
||||||
|
"conv_backward_flop",
|
||||||
|
"conv_flop",
|
||||||
|
"conv_flop_count",
|
||||||
|
"convert_num_with_suffix",
|
||||||
|
"get_shape",
|
||||||
|
"get_suffix_str",
|
||||||
|
"mm_flop",
|
||||||
|
"normalize_tuple",
|
||||||
|
"register_flop_formula",
|
||||||
|
"sdpa_backward_flop",
|
||||||
|
"sdpa_backward_flop_count",
|
||||||
|
"sdpa_flop",
|
||||||
|
"sdpa_flop_count",
|
||||||
|
"shape_wrapper",
|
||||||
|
"transpose_shape",
|
||||||
|
# torch.utils.hipify.hipify_python
|
||||||
|
"add_dim3",
|
||||||
|
"compute_stats",
|
||||||
|
"extract_arguments",
|
||||||
|
"file_add_header",
|
||||||
|
"file_specific_replacement",
|
||||||
|
"find_bracket_group",
|
||||||
|
"find_closure_group",
|
||||||
|
"find_parentheses_group",
|
||||||
|
"fix_static_global_kernels",
|
||||||
|
"get_hip_file_path",
|
||||||
|
"hip_header_magic",
|
||||||
|
"hipify",
|
||||||
|
"is_caffe2_gpu_file",
|
||||||
|
"is_cusparse_file",
|
||||||
|
"is_out_of_place",
|
||||||
|
"is_pytorch_file",
|
||||||
|
"is_special_file",
|
||||||
|
"match_extensions",
|
||||||
|
"matched_files_iter",
|
||||||
|
"openf",
|
||||||
|
"preprocess_file_and_save_result",
|
||||||
|
"preprocessor",
|
||||||
|
"processKernelLaunches",
|
||||||
|
"replace_extern_shared",
|
||||||
|
"replace_math_functions",
|
||||||
|
"str2bool",
|
||||||
# torch.utils.hooks
|
# torch.utils.hooks
|
||||||
"unserializable_hook",
|
"unserializable_hook",
|
||||||
"warn_if_has_hooks",
|
"warn_if_has_hooks",
|
||||||
|
|||||||
@ -19,91 +19,6 @@
|
|||||||
swap_tensors
|
swap_tensors
|
||||||
```
|
```
|
||||||
|
|
||||||
# torch.utils.collect_env
|
|
||||||
```{eval-rst}
|
|
||||||
.. automodule:: torch.utils.collect_env
|
|
||||||
```
|
|
||||||
|
|
||||||
```{eval-rst}
|
|
||||||
.. currentmodule:: torch.utils.collect_env
|
|
||||||
```
|
|
||||||
|
|
||||||
```{eval-rst}
|
|
||||||
.. autosummary::
|
|
||||||
:toctree: generated
|
|
||||||
:nosignatures:
|
|
||||||
|
|
||||||
check_release_file
|
|
||||||
is_xnnpack_available
|
|
||||||
pretty_str
|
|
||||||
```
|
|
||||||
|
|
||||||
# torch.utils.flop_counter
|
|
||||||
```{eval-rst}
|
|
||||||
.. automodule:: torch.utils.flop_counter
|
|
||||||
```
|
|
||||||
|
|
||||||
```{eval-rst}
|
|
||||||
.. currentmodule:: torch.utils.flop_counter
|
|
||||||
```
|
|
||||||
|
|
||||||
```{eval-rst}
|
|
||||||
.. autosummary::
|
|
||||||
:toctree: generated
|
|
||||||
:nosignatures:
|
|
||||||
|
|
||||||
baddbmm_flop
|
|
||||||
bmm_flop
|
|
||||||
conv_backward_flop
|
|
||||||
conv_flop
|
|
||||||
conv_flop_count
|
|
||||||
register_flop_formula
|
|
||||||
sdpa_backward_flop
|
|
||||||
sdpa_backward_flop_count
|
|
||||||
sdpa_flop
|
|
||||||
sdpa_flop_count
|
|
||||||
shape_wrapper
|
|
||||||
```
|
|
||||||
|
|
||||||
# torch.utils.hipify.hipify_python
|
|
||||||
```{eval-rst}
|
|
||||||
.. automodule:: torch.utils.hipify.hipify_python
|
|
||||||
```
|
|
||||||
|
|
||||||
```{eval-rst}
|
|
||||||
.. currentmodule:: torch.utils.hipify.hipify_python
|
|
||||||
```
|
|
||||||
|
|
||||||
```{eval-rst}
|
|
||||||
.. autosummary::
|
|
||||||
:toctree: generated
|
|
||||||
:nosignatures:
|
|
||||||
|
|
||||||
compute_stats
|
|
||||||
extract_arguments
|
|
||||||
file_add_header
|
|
||||||
file_specific_replacement
|
|
||||||
find_bracket_group
|
|
||||||
find_closure_group
|
|
||||||
find_parentheses_group
|
|
||||||
fix_static_global_kernels
|
|
||||||
hip_header_magic
|
|
||||||
hipify
|
|
||||||
is_caffe2_gpu_file
|
|
||||||
is_cusparse_file
|
|
||||||
is_out_of_place
|
|
||||||
is_pytorch_file
|
|
||||||
is_special_file
|
|
||||||
openf
|
|
||||||
preprocess_file_and_save_result
|
|
||||||
preprocessor
|
|
||||||
processKernelLaunches
|
|
||||||
replace_extern_shared
|
|
||||||
replace_math_functions
|
|
||||||
str2bool
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
<!-- This module needs to be documented. Adding here in the meantime
|
<!-- This module needs to be documented. Adding here in the meantime
|
||||||
for tracking purposes -->
|
for tracking purposes -->
|
||||||
```{eval-rst}
|
```{eval-rst}
|
||||||
@ -128,6 +43,7 @@ for tracking purposes -->
|
|||||||
.. py:module:: torch.utils.benchmark.utils.valgrind_wrapper.timer_interface
|
.. py:module:: torch.utils.benchmark.utils.valgrind_wrapper.timer_interface
|
||||||
.. py:module:: torch.utils.bundled_inputs
|
.. py:module:: torch.utils.bundled_inputs
|
||||||
.. py:module:: torch.utils.checkpoint
|
.. py:module:: torch.utils.checkpoint
|
||||||
|
.. py:module:: torch.utils.collect_env
|
||||||
.. py:module:: torch.utils.cpp_backtrace
|
.. py:module:: torch.utils.cpp_backtrace
|
||||||
.. py:module:: torch.utils.cpp_extension
|
.. py:module:: torch.utils.cpp_extension
|
||||||
.. py:module:: torch.utils.data.backward_compatibility
|
.. py:module:: torch.utils.data.backward_compatibility
|
||||||
@ -164,8 +80,10 @@ for tracking purposes -->
|
|||||||
.. py:module:: torch.utils.data.sampler
|
.. py:module:: torch.utils.data.sampler
|
||||||
.. py:module:: torch.utils.dlpack
|
.. py:module:: torch.utils.dlpack
|
||||||
.. py:module:: torch.utils.file_baton
|
.. py:module:: torch.utils.file_baton
|
||||||
|
.. py:module:: torch.utils.flop_counter
|
||||||
.. py:module:: torch.utils.hipify.constants
|
.. py:module:: torch.utils.hipify.constants
|
||||||
.. py:module:: torch.utils.hipify.cuda_to_hip_mappings
|
.. py:module:: torch.utils.hipify.cuda_to_hip_mappings
|
||||||
|
.. py:module:: torch.utils.hipify.hipify_python
|
||||||
.. py:module:: torch.utils.hipify.version
|
.. py:module:: torch.utils.hipify.version
|
||||||
.. py:module:: torch.utils.hooks
|
.. py:module:: torch.utils.hooks
|
||||||
.. py:module:: torch.utils.jit.log_extract
|
.. py:module:: torch.utils.jit.log_extract
|
||||||
|
|||||||
@ -260,6 +260,7 @@ select = [
|
|||||||
"TRY401", # verbose-log-message
|
"TRY401", # verbose-log-message
|
||||||
"UP",
|
"UP",
|
||||||
"YTT",
|
"YTT",
|
||||||
|
"S101",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.ruff.lint.pyupgrade]
|
[tool.ruff.lint.pyupgrade]
|
||||||
@ -339,6 +340,39 @@ keep-runtime-typing = true
|
|||||||
"tools/linter/**" = [
|
"tools/linter/**" = [
|
||||||
"LOG015" # please fix
|
"LOG015" # please fix
|
||||||
]
|
]
|
||||||
|
"benchmarks/**" = [
|
||||||
|
"S101"
|
||||||
|
]
|
||||||
|
"test/**" = [
|
||||||
|
"S101"
|
||||||
|
]
|
||||||
|
"torchgen/**" = [
|
||||||
|
"S101"
|
||||||
|
]
|
||||||
|
"torch/**" = [
|
||||||
|
"S101"
|
||||||
|
]
|
||||||
|
"tools/**" = [
|
||||||
|
"S101"
|
||||||
|
]
|
||||||
|
"setup.py" = [
|
||||||
|
"S101"
|
||||||
|
]
|
||||||
|
"functorch/**" = [
|
||||||
|
"S101"
|
||||||
|
]
|
||||||
|
"docs/**" = [
|
||||||
|
"S101"
|
||||||
|
]
|
||||||
|
"android/**" = [
|
||||||
|
"S101"
|
||||||
|
]
|
||||||
|
".github/**" = [
|
||||||
|
"S101"
|
||||||
|
]
|
||||||
|
".ci/**" = [
|
||||||
|
"S101"
|
||||||
|
]
|
||||||
|
|
||||||
[tool.codespell]
|
[tool.codespell]
|
||||||
ignore-words = "tools/linter/dictionary.txt"
|
ignore-words = "tools/linter/dictionary.txt"
|
||||||
|
|||||||
@ -813,6 +813,50 @@ class TestSharding(DTensorTestBase):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@skip_if_lt_x_gpu(2)
|
||||||
|
@with_comms
|
||||||
|
@unittest.skipIf(
|
||||||
|
not PLATFORM_SUPPORTS_FUSED_ATTENTION,
|
||||||
|
"Does not support flash nor efficient attention",
|
||||||
|
)
|
||||||
|
def test_attention_shard_without_cp(self) -> None:
|
||||||
|
"""Test that sharding on sequence dimension without CP enabled is not supported."""
|
||||||
|
from torch.distributed.tensor import distribute_tensor, Replicate, Shard
|
||||||
|
|
||||||
|
B = 2
|
||||||
|
nheads = 4
|
||||||
|
seq_len = 256
|
||||||
|
dim = 32
|
||||||
|
|
||||||
|
device_mesh = init_device_mesh(
|
||||||
|
mesh_shape=(2,), mesh_dim_names=("cp",), device_type=self.device_type
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create q, k, v tensors with shape (B, nheads, seq_len, dim)
|
||||||
|
q = torch.randn(
|
||||||
|
B, nheads, seq_len, dim, device=self.device_type, dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
k = torch.randn(
|
||||||
|
B, nheads, seq_len, dim, device=self.device_type, dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
v = torch.randn(
|
||||||
|
B, nheads, seq_len, dim, device=self.device_type, dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
q_dt = distribute_tensor(q, device_mesh, [Shard(2)])
|
||||||
|
k_dt = distribute_tensor(k, device_mesh, [Shard(2)])
|
||||||
|
v_dt = distribute_tensor(v, device_mesh, [Shard(2)])
|
||||||
|
|
||||||
|
# Run SDPA with sequence-sharded tensors WITHOUT enabling CP
|
||||||
|
# Without CP enabled, DTensor should select a different strategy
|
||||||
|
# (not sequence-sharded) because Shard(2) strategy is only available with CP
|
||||||
|
out = F.scaled_dot_product_attention(q_dt, k_dt, v_dt)
|
||||||
|
|
||||||
|
# Verify the output is NOT sharded on sequence dimension (dim 2)
|
||||||
|
# This proves that CP sharding rules were not used
|
||||||
|
self.assertNotEqual(out.placements[0], Shard(2))
|
||||||
|
# The output should be replicated or sharded on batch head dimensions.
|
||||||
|
self.assertIn(out.placements[0], [Replicate(), Shard(0), Shard(1)])
|
||||||
|
|
||||||
|
|
||||||
RingAttentionTestWithLocalTensor = create_local_tensor_test_class(
|
RingAttentionTestWithLocalTensor = create_local_tensor_test_class(
|
||||||
RingAttentionTest,
|
RingAttentionTest,
|
||||||
|
|||||||
@ -7522,6 +7522,38 @@ class AOTInductorTestsTemplate:
|
|||||||
eager_outputs = model(*example_inputs)
|
eager_outputs = model(*example_inputs)
|
||||||
torch.testing.assert_close(eager_outputs, compiled_outputs)
|
torch.testing.assert_close(eager_outputs, compiled_outputs)
|
||||||
|
|
||||||
|
@requires_gpu
|
||||||
|
def test_mixed_device_1(self):
|
||||||
|
if self.device != GPU_TYPE:
|
||||||
|
raise unittest.SkipTest("Mixed-device test requires GPU")
|
||||||
|
|
||||||
|
class Model(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
# Buffers are on CPU
|
||||||
|
self.register_buffer(
|
||||||
|
"index", torch.tensor([1, 4, 1, 7], device="cpu", dtype=torch.int64)
|
||||||
|
)
|
||||||
|
self.register_buffer(
|
||||||
|
"src", torch.ones(4, device="cpu", dtype=torch.int64)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, matrix, vector):
|
||||||
|
# Inputs are on CUDA
|
||||||
|
# 1. Operation on CPU tensors
|
||||||
|
z = torch.zeros((vector.shape[0],), device="cpu", dtype=torch.int64)
|
||||||
|
scatter_result = z.scatter_add(0, self.index, self.src)
|
||||||
|
|
||||||
|
# 2. Move result to CUDA and continue on CUDA
|
||||||
|
v = vector + scatter_result.to(vector.dtype).to(GPU_TYPE)
|
||||||
|
return torch.matmul(matrix, v)
|
||||||
|
|
||||||
|
example_inputs = (
|
||||||
|
torch.randn(10, 10, device=self.device),
|
||||||
|
torch.randn(10, device=self.device),
|
||||||
|
)
|
||||||
|
self.check_model(Model(), example_inputs, move_model_to_device=False)
|
||||||
|
|
||||||
|
|
||||||
class AOTInductorLoggingTest(LoggingTestCase):
|
class AOTInductorLoggingTest(LoggingTestCase):
|
||||||
@make_logging_test(dynamic=logging.DEBUG)
|
@make_logging_test(dynamic=logging.DEBUG)
|
||||||
|
|||||||
@ -218,6 +218,7 @@ def check_model(
|
|||||||
dynamic_shapes=None,
|
dynamic_shapes=None,
|
||||||
atol=None,
|
atol=None,
|
||||||
rtol=None,
|
rtol=None,
|
||||||
|
move_model_to_device=True,
|
||||||
):
|
):
|
||||||
with (
|
with (
|
||||||
torch.no_grad(),
|
torch.no_grad(),
|
||||||
@ -229,7 +230,7 @@ def check_model(
|
|||||||
),
|
),
|
||||||
):
|
):
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
if not isinstance(model, types.FunctionType):
|
if not isinstance(model, types.FunctionType) and move_model_to_device:
|
||||||
model = model.to(self.device)
|
model = model.to(self.device)
|
||||||
|
|
||||||
# For non mixed device inputs with default "cpu",set the device manually.
|
# For non mixed device inputs with default "cpu",set the device manually.
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
# Owner(s): ["oncall: pt2"]
|
# Owner(s): ["oncall: pt2"]
|
||||||
import functools
|
import functools
|
||||||
|
import re
|
||||||
import sys
|
import sys
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
@ -230,6 +231,33 @@ class PallasTestsMixin:
|
|||||||
self.assertIn("import jax.numpy as jnp", code)
|
self.assertIn("import jax.numpy as jnp", code)
|
||||||
self.assertIn("from jax.experimental import pallas as pl", code)
|
self.assertIn("from jax.experimental import pallas as pl", code)
|
||||||
|
|
||||||
|
def test_jax_jit_wrapper_is_emitted(self):
|
||||||
|
"""Ensure generated Pallas code wraps pl.pallas_call in jax.jit."""
|
||||||
|
|
||||||
|
key = "cuda_backend" if self.DEVICE == "cuda" else "cpu_backend"
|
||||||
|
|
||||||
|
@torch.compile(backend="inductor", options={key: "pallas"})
|
||||||
|
def pallas_fn(a, b):
|
||||||
|
return a + b
|
||||||
|
|
||||||
|
_, (code,) = run_and_get_code(
|
||||||
|
pallas_fn,
|
||||||
|
torch.randn(32, device=self.DEVICE),
|
||||||
|
torch.randn(32, device=self.DEVICE),
|
||||||
|
)
|
||||||
|
|
||||||
|
kernel_match = re.search(r"def (pallas_[A-Za-z0-9_]+)_kernel", code)
|
||||||
|
self.assertIsNotNone(kernel_match)
|
||||||
|
kernel_name = kernel_match.group(1)
|
||||||
|
wrapper_name = f"{kernel_name}_jit_wrapper"
|
||||||
|
self.assertIn(wrapper_name, code)
|
||||||
|
start = code.index(f"def {wrapper_name}")
|
||||||
|
end = code.index(f"def {kernel_name}_main", start)
|
||||||
|
wrapper_block = code[start:end]
|
||||||
|
|
||||||
|
self.assertIn("jax.jit", code)
|
||||||
|
self.assertNotIn("torch.", wrapper_block)
|
||||||
|
|
||||||
def test_2d_tensor(self):
|
def test_2d_tensor(self):
|
||||||
"""Test with 2D tensors (though current implementation flattens)."""
|
"""Test with 2D tensors (though current implementation flattens)."""
|
||||||
|
|
||||||
|
|||||||
@ -221,7 +221,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
self.add_device_include(self.device)
|
for device in V.graph.device_types:
|
||||||
|
if device != "meta":
|
||||||
|
self.add_device_include(device)
|
||||||
|
|
||||||
if V.graph.aot_mode:
|
if V.graph.aot_mode:
|
||||||
if config.aot_inductor.dynamic_linkage:
|
if config.aot_inductor.dynamic_linkage:
|
||||||
@ -1423,11 +1425,13 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||||||
src_is_tensor,
|
src_is_tensor,
|
||||||
reduce,
|
reduce,
|
||||||
kwargs,
|
kwargs,
|
||||||
|
device,
|
||||||
):
|
):
|
||||||
reduce = self._get_scatter_reduce_enum(reduce)
|
reduce = self._get_scatter_reduce_enum(reduce)
|
||||||
|
|
||||||
# call the ABI shim function instead of the ATen one
|
# call the ABI shim function instead of the ATen one
|
||||||
cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name, self.device)
|
self.add_device_include(device)
|
||||||
|
cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name, device)
|
||||||
# TODO: consider remove "_out" and add missing inplace variants to fallback_ops.py
|
# TODO: consider remove "_out" and add missing inplace variants to fallback_ops.py
|
||||||
cpp_kernel_name = cpp_kernel_name.replace("__", "_") + "_out"
|
cpp_kernel_name = cpp_kernel_name.replace("__", "_") + "_out"
|
||||||
inputs_wrapped = [str(x) for x in inputs]
|
inputs_wrapped = [str(x) for x in inputs]
|
||||||
|
|||||||
@ -708,11 +708,14 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
|
|||||||
src_is_tensor,
|
src_is_tensor,
|
||||||
reduce,
|
reduce,
|
||||||
kwargs,
|
kwargs,
|
||||||
|
device,
|
||||||
):
|
):
|
||||||
reduce = self._get_scatter_reduce_enum(reduce)
|
reduce = self._get_scatter_reduce_enum(reduce)
|
||||||
|
|
||||||
# call the ABI shim function instead of the ATen one
|
# call the ABI shim function instead of the ATen one
|
||||||
cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name, self.device)
|
self.add_device_include(device)
|
||||||
|
cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name, device)
|
||||||
|
|
||||||
# TODO: consider remove "_out" and add missing inplace variants to fallback_ops.py
|
# TODO: consider remove "_out" and add missing inplace variants to fallback_ops.py
|
||||||
cpp_kernel_name = cpp_kernel_name.replace("__", "_") + "_out"
|
cpp_kernel_name = cpp_kernel_name.replace("__", "_") + "_out"
|
||||||
self._assert_safe_to_use_borrow_arrayref_tensor_as_tensor()
|
self._assert_safe_to_use_borrow_arrayref_tensor_as_tensor()
|
||||||
|
|||||||
@ -287,6 +287,7 @@ class PallasKernel(SIMDKernel):
|
|||||||
code = IndentedBuffer()
|
code = IndentedBuffer()
|
||||||
code.splice(
|
code.splice(
|
||||||
"""
|
"""
|
||||||
|
import functools
|
||||||
import torch
|
import torch
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
@ -301,6 +302,9 @@ class PallasKernel(SIMDKernel):
|
|||||||
kernel_params = [a.name for a in arg_defs]
|
kernel_params = [a.name for a in arg_defs]
|
||||||
|
|
||||||
kernel_name = name or "<KERNEL_NAME>"
|
kernel_name = name or "<KERNEL_NAME>"
|
||||||
|
interpret_literal = (
|
||||||
|
"True" if V.graph.get_current_device_or_throw().type == "cpu" else "False"
|
||||||
|
)
|
||||||
code.writeline(f"def {kernel_name}_kernel({', '.join(kernel_params)}):")
|
code.writeline(f"def {kernel_name}_kernel({', '.join(kernel_params)}):")
|
||||||
with code.indent():
|
with code.indent():
|
||||||
# Emit compute (CSE) and store lines; they reference *_ptr[...] directly
|
# Emit compute (CSE) and store lines; they reference *_ptr[...] directly
|
||||||
@ -309,16 +313,22 @@ class PallasKernel(SIMDKernel):
|
|||||||
for line in self.stores._lines:
|
for line in self.stores._lines:
|
||||||
code.writeline(str(line))
|
code.writeline(str(line))
|
||||||
|
|
||||||
|
jit_wrapper_name = f"{kernel_name}_jit_wrapper"
|
||||||
|
code.writeline("@functools.partial(jax.jit, static_argnums=(0, 1))")
|
||||||
|
code.writeline(f"def {jit_wrapper_name}(out_shape, out_dtype, *kernel_refs):")
|
||||||
|
with code.indent():
|
||||||
|
code.writeline("out_spec = jax.ShapeDtypeStruct(out_shape, out_dtype)")
|
||||||
|
code.writeline("return pl.pallas_call(")
|
||||||
|
code.writeline(f" {kernel_name}_kernel,")
|
||||||
|
code.writeline(" out_shape=out_spec,")
|
||||||
|
code.writeline(f" interpret={interpret_literal},")
|
||||||
|
code.writeline(" grid=(1,),")
|
||||||
|
code.writeline(")(*kernel_refs)")
|
||||||
|
|
||||||
# Host entry: convert torch tensors <-> jax, call pallas_call and copy back
|
# Host entry: convert torch tensors <-> jax, call pallas_call and copy back
|
||||||
main_name = f"{kernel_name}_main"
|
main_name = f"{kernel_name}_main"
|
||||||
code.writeline(f"def {main_name}({', '.join(kernel_params)}, stream=None):")
|
code.writeline(f"def {main_name}({', '.join(kernel_params)}, stream=None):")
|
||||||
with code.indent():
|
with code.indent():
|
||||||
# Determine interpret statically based on codegen device
|
|
||||||
interpret_literal = (
|
|
||||||
"True"
|
|
||||||
if V.graph.get_current_device_or_throw().type == "cpu"
|
|
||||||
else "False"
|
|
||||||
)
|
|
||||||
# Identify inputs (in_ptr*) and output (out_ptr*)
|
# Identify inputs (in_ptr*) and output (out_ptr*)
|
||||||
input_params = [
|
input_params = [
|
||||||
p for p in kernel_params if p.startswith(("in_ptr", "in_out_ptr"))
|
p for p in kernel_params if p.startswith(("in_ptr", "in_out_ptr"))
|
||||||
@ -337,9 +347,9 @@ class PallasKernel(SIMDKernel):
|
|||||||
for inp in input_params:
|
for inp in input_params:
|
||||||
code.writeline(f"{inp}_jax = jax.dlpack.from_dlpack({inp})")
|
code.writeline(f"{inp}_jax = jax.dlpack.from_dlpack({inp})")
|
||||||
|
|
||||||
# Get output spec from PyTorch tensor
|
# Get output metadata from PyTorch tensor
|
||||||
code.writeline("# Prepare output spec from PyTorch tensor")
|
code.writeline("# Prepare output metadata from PyTorch tensor")
|
||||||
code.writeline("# Map PyTorch dtype to JAX dtype string")
|
code.writeline("# Map PyTorch dtype to JAX dtype")
|
||||||
code.writeline("_torch_dtype_to_jax = {")
|
code.writeline("_torch_dtype_to_jax = {")
|
||||||
code.writeline(
|
code.writeline(
|
||||||
" torch.float32: jnp.float32, torch.float64: jnp.float64, torch.float16: jnp.float16,"
|
" torch.float32: jnp.float32, torch.float64: jnp.float64, torch.float16: jnp.float16,"
|
||||||
@ -349,21 +359,14 @@ class PallasKernel(SIMDKernel):
|
|||||||
)
|
)
|
||||||
code.writeline(" torch.uint8: jnp.uint8, torch.bool: jnp.bool_,")
|
code.writeline(" torch.uint8: jnp.uint8, torch.bool: jnp.bool_,")
|
||||||
code.writeline("}")
|
code.writeline("}")
|
||||||
code.writeline(
|
code.writeline(f"out_shape = tuple({output_param}.shape)")
|
||||||
f"out_spec = jax.ShapeDtypeStruct({output_param}.shape, _torch_dtype_to_jax[{output_param}.dtype])"
|
code.writeline(f"out_dtype = _torch_dtype_to_jax[{output_param}.dtype]")
|
||||||
)
|
|
||||||
|
|
||||||
# Call pallas
|
call_args = ["out_shape", "out_dtype"] + [
|
||||||
# Pass interpret=True on CPU, False otherwise (single call, no duplication)
|
f"{inp}_jax" for inp in input_params
|
||||||
code.writeline("compiled = pl.pallas_call(")
|
]
|
||||||
code.writeline(f" lambda *refs: {kernel_name}_kernel(*refs),")
|
call_arg_str = ", ".join(call_args)
|
||||||
code.writeline(" out_shape=out_spec,")
|
code.writeline(f"res = {jit_wrapper_name}({call_arg_str})")
|
||||||
code.writeline(f" interpret={interpret_literal},")
|
|
||||||
code.writeline(" grid=(1,),")
|
|
||||||
code.writeline(")")
|
|
||||||
|
|
||||||
jax_input_args = ", ".join([f"{inp}_jax" for inp in input_params])
|
|
||||||
code.writeline(f"res = compiled({jax_input_args})")
|
|
||||||
|
|
||||||
# Copy result back
|
# Copy result back
|
||||||
code.writeline("# Copy result back into the provided torch output tensor")
|
code.writeline("# Copy result back into the provided torch output tensor")
|
||||||
|
|||||||
@ -971,6 +971,7 @@ class ScatterFallbackLine(WrapperLine):
|
|||||||
else:
|
else:
|
||||||
(x, index) = (t.codegen_reference() for t in node.inputs)
|
(x, index) = (t.codegen_reference() for t in node.inputs)
|
||||||
src = node.constant_args[1]
|
src = node.constant_args[1]
|
||||||
|
device = d.type if (d := node.get_device()) else V.graph.device_type
|
||||||
self.wrapper._generate_scatter_fallback(
|
self.wrapper._generate_scatter_fallback(
|
||||||
x,
|
x,
|
||||||
[x, node.constant_args[0], index, src],
|
[x, node.constant_args[0], index, src],
|
||||||
@ -979,6 +980,7 @@ class ScatterFallbackLine(WrapperLine):
|
|||||||
node.src_is_tensor,
|
node.src_is_tensor,
|
||||||
node.kwargs["reduce"],
|
node.kwargs["reduce"],
|
||||||
node.codegen_kwargs(),
|
node.codegen_kwargs(),
|
||||||
|
device,
|
||||||
)
|
)
|
||||||
|
|
||||||
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
|
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
|
||||||
@ -1632,6 +1634,7 @@ class PythonWrapperCodegen(CodeGen):
|
|||||||
src_is_tensor,
|
src_is_tensor,
|
||||||
reduce,
|
reduce,
|
||||||
kwargs,
|
kwargs,
|
||||||
|
device,
|
||||||
):
|
):
|
||||||
line = f"{python_kernel_name}({','.join(map(str, inputs))}"
|
line = f"{python_kernel_name}({','.join(map(str, inputs))}"
|
||||||
if python_kernel_name.startswith("aten.scatter_reduce"):
|
if python_kernel_name.startswith("aten.scatter_reduce"):
|
||||||
|
|||||||
@ -267,16 +267,10 @@ def scaled_mm_strategy(op_schema: OpSchema) -> OpStrategy:
|
|||||||
return _scaled_mm_like_strategy("mk,kn->mn", mesh, op_schema)
|
return _scaled_mm_like_strategy("mk,kn->mn", mesh, op_schema)
|
||||||
|
|
||||||
|
|
||||||
@register_op_strategy(
|
def _scaled_dot_product_flash_attention_base_strategies(
|
||||||
aten._scaled_dot_product_flash_attention.default, schema_info=RuntimeSchemaInfo(5)
|
op_schema: OpSchema,
|
||||||
)
|
) -> list[PlacementList]:
|
||||||
def scaled_dot_product_flash_attention_strategy(op_schema: OpSchema) -> OpStrategy:
|
"""Helper that returns list of base placement strategies (without CP)."""
|
||||||
# NOTE: currently we only support some simple strategies to support tensor parallelism
|
|
||||||
# TODO: sdpa might be a good candidate for us to explore decomposed sharding propagation
|
|
||||||
# as it involves: matmul, pointwise, reduction ops together.
|
|
||||||
|
|
||||||
mesh = op_schema.get_mesh_from_args()
|
|
||||||
|
|
||||||
return_debug_mask = len(op_schema.args_schema) >= 6 and op_schema.args_schema[5]
|
return_debug_mask = len(op_schema.args_schema) >= 6 and op_schema.args_schema[5]
|
||||||
q_input_strategy = op_schema.args_schema[0]
|
q_input_strategy = op_schema.args_schema[0]
|
||||||
if not isinstance(q_input_strategy, OpStrategy):
|
if not isinstance(q_input_strategy, OpStrategy):
|
||||||
@ -349,37 +343,30 @@ def scaled_dot_product_flash_attention_strategy(op_schema: OpSchema) -> OpStrate
|
|||||||
Shard(0), # v
|
Shard(0), # v
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
return single_mesh_dim_strategies
|
||||||
|
|
||||||
# Context Parallelism: shards on the sequence dim
|
|
||||||
debug_attn_mask_sharding = Shard(2) if return_debug_mask else Replicate()
|
@register_op_strategy(
|
||||||
single_mesh_dim_strategies.append(
|
aten._scaled_dot_product_flash_attention.default, schema_info=RuntimeSchemaInfo(5)
|
||||||
[
|
)
|
||||||
Shard(2), # output
|
def scaled_dot_product_flash_attention_strategy(op_schema: OpSchema) -> OpStrategy:
|
||||||
Shard(2), # logsumexp
|
# NOTE: currently we only support some simple strategies to support tensor parallelism
|
||||||
None, # cum_seq_q
|
# TODO: sdpa might be a good candidate for us to explore decomposed sharding propagation
|
||||||
None, # cum_seq_k
|
# as it involves: matmul, pointwise, reduction ops together.
|
||||||
None, # max_q
|
|
||||||
None, # max_k
|
mesh = op_schema.get_mesh_from_args()
|
||||||
Replicate(), # rng_state
|
single_mesh_dim_strategies = _scaled_dot_product_flash_attention_base_strategies(
|
||||||
None, # unused
|
op_schema
|
||||||
debug_attn_mask_sharding, # debugattn
|
|
||||||
Shard(2), # q
|
|
||||||
Shard(2), # k
|
|
||||||
Shard(2), # v
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
return expand_to_full_mesh_op_strategy(
|
return expand_to_full_mesh_op_strategy(
|
||||||
mesh, op_schema, single_mesh_dim_strategies, input_index=9
|
mesh, op_schema, single_mesh_dim_strategies, input_index=9
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_op_strategy(aten._scaled_dot_product_flash_attention_backward.default)
|
def _scaled_dot_product_flash_attention_backward_base_strategies(
|
||||||
def scaled_dot_product_flash_attention_backward_strategy(
|
|
||||||
op_schema: OpSchema,
|
op_schema: OpSchema,
|
||||||
) -> OpStrategy:
|
) -> list[PlacementList]:
|
||||||
# backward op does not need to validate the mesh since forward op has already done it
|
"""Helper that returns list of base placement strategies (without CP)."""
|
||||||
mesh = op_schema.get_mesh_from_args(validate=False)
|
|
||||||
|
|
||||||
q_input_strategy = op_schema.args_schema[1]
|
q_input_strategy = op_schema.args_schema[1]
|
||||||
if not isinstance(q_input_strategy, OpStrategy):
|
if not isinstance(q_input_strategy, OpStrategy):
|
||||||
raise AssertionError(f"Expected OpStrategy, got {type(q_input_strategy)}")
|
raise AssertionError(f"Expected OpStrategy, got {type(q_input_strategy)}")
|
||||||
@ -444,24 +431,18 @@ def scaled_dot_product_flash_attention_backward_strategy(
|
|||||||
batch_dim_sharding.extend([Replicate()] * (num_tensor_inputs - 6))
|
batch_dim_sharding.extend([Replicate()] * (num_tensor_inputs - 6))
|
||||||
single_mesh_dim_strategies.append(batch_dim_sharding)
|
single_mesh_dim_strategies.append(batch_dim_sharding)
|
||||||
|
|
||||||
# Context Parallelism: shards on the sequence dim
|
return single_mesh_dim_strategies
|
||||||
seq_dim_sharding: PlacementList = [
|
|
||||||
Shard(2), # grad_q
|
|
||||||
Shard(2), # grad_k
|
|
||||||
Shard(2), # grad_v
|
|
||||||
Shard(2), # grad_output
|
|
||||||
Shard(2), # q
|
|
||||||
Shard(2), # k
|
|
||||||
Shard(2), # v
|
|
||||||
Shard(2), # output
|
|
||||||
Shard(2), # logsumexp
|
|
||||||
]
|
|
||||||
# accept replicate on the rest tensor inputs, potentially
|
|
||||||
# cum_seq_q, cum_seq_k, philox_seed, philox_offset
|
|
||||||
# at indices 6, 7, 12, 13, respectively
|
|
||||||
seq_dim_sharding.extend([Replicate()] * (num_tensor_inputs - 6))
|
|
||||||
single_mesh_dim_strategies.append(seq_dim_sharding)
|
|
||||||
|
|
||||||
|
|
||||||
|
@register_op_strategy(aten._scaled_dot_product_flash_attention_backward.default)
|
||||||
|
def scaled_dot_product_flash_attention_backward_strategy(
|
||||||
|
op_schema: OpSchema,
|
||||||
|
) -> OpStrategy:
|
||||||
|
# backward op does not need to validate the mesh since forward op has already done it
|
||||||
|
mesh = op_schema.get_mesh_from_args(validate=False)
|
||||||
|
single_mesh_dim_strategies = (
|
||||||
|
_scaled_dot_product_flash_attention_backward_base_strategies(op_schema)
|
||||||
|
)
|
||||||
return expand_to_full_mesh_op_strategy(
|
return expand_to_full_mesh_op_strategy(
|
||||||
mesh, op_schema, single_mesh_dim_strategies, input_index=3
|
mesh, op_schema, single_mesh_dim_strategies, input_index=3
|
||||||
)
|
)
|
||||||
@ -486,13 +467,10 @@ def constant_pad_nd_strategy(op_schema: OpSchema) -> OpStrategy:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_op_strategy(
|
def _scaled_dot_product_efficient_attention_base_strategies(
|
||||||
aten._scaled_dot_product_efficient_attention.default,
|
op_schema: OpSchema,
|
||||||
schema_info=RuntimeSchemaInfo(4),
|
) -> list[PlacementList]:
|
||||||
)
|
"""Helper that returns list of base placement strategies (without CP)."""
|
||||||
def scaled_dot_product_efficient_attention_strategy(op_schema: OpSchema) -> OpStrategy:
|
|
||||||
# NOTE: currently we only support some simple strategies to support tensor parallelism
|
|
||||||
mesh = op_schema.get_mesh_from_args()
|
|
||||||
q_input_strategy = op_schema.args_schema[0]
|
q_input_strategy = op_schema.args_schema[0]
|
||||||
if not isinstance(q_input_strategy, OpStrategy):
|
if not isinstance(q_input_strategy, OpStrategy):
|
||||||
raise AssertionError(f"Expected OpStrategy, got {type(q_input_strategy)}")
|
raise AssertionError(f"Expected OpStrategy, got {type(q_input_strategy)}")
|
||||||
@ -518,19 +496,6 @@ def scaled_dot_product_efficient_attention_strategy(op_schema: OpSchema) -> OpSt
|
|||||||
if has_attn_bias:
|
if has_attn_bias:
|
||||||
all_replicate.append(Replicate()) # attn bias
|
all_replicate.append(Replicate()) # attn bias
|
||||||
|
|
||||||
# Context Parallelism: shards on the sequence dim
|
|
||||||
single_mesh_dim_strategies.append(
|
|
||||||
[
|
|
||||||
Shard(2), # output
|
|
||||||
Shard(2), # logsumexp
|
|
||||||
None, # philox_seed
|
|
||||||
None, # philox_offset
|
|
||||||
Shard(2), # q
|
|
||||||
Shard(2), # k
|
|
||||||
Shard(2), # v
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
single_mesh_dim_strategies.append(all_replicate)
|
single_mesh_dim_strategies.append(all_replicate)
|
||||||
|
|
||||||
# second we can accept the sharding pattern of tensor parallelism, which
|
# second we can accept the sharding pattern of tensor parallelism, which
|
||||||
@ -576,6 +541,19 @@ def scaled_dot_product_efficient_attention_strategy(op_schema: OpSchema) -> OpSt
|
|||||||
|
|
||||||
single_mesh_dim_strategies.append(batch_sharding)
|
single_mesh_dim_strategies.append(batch_sharding)
|
||||||
|
|
||||||
|
return single_mesh_dim_strategies
|
||||||
|
|
||||||
|
|
||||||
|
@register_op_strategy(
|
||||||
|
aten._scaled_dot_product_efficient_attention.default,
|
||||||
|
schema_info=RuntimeSchemaInfo(4),
|
||||||
|
)
|
||||||
|
def scaled_dot_product_efficient_attention_strategy(op_schema: OpSchema) -> OpStrategy:
|
||||||
|
# NOTE: currently we only support some simple strategies to support tensor parallelism
|
||||||
|
mesh = op_schema.get_mesh_from_args()
|
||||||
|
single_mesh_dim_strategies = (
|
||||||
|
_scaled_dot_product_efficient_attention_base_strategies(op_schema)
|
||||||
|
)
|
||||||
return expand_to_full_mesh_op_strategy(
|
return expand_to_full_mesh_op_strategy(
|
||||||
mesh,
|
mesh,
|
||||||
op_schema,
|
op_schema,
|
||||||
@ -584,13 +562,10 @@ def scaled_dot_product_efficient_attention_strategy(op_schema: OpSchema) -> OpSt
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_op_strategy(aten._scaled_dot_product_efficient_attention_backward.default)
|
def _scaled_dot_product_efficient_attention_backward_base_strategies(
|
||||||
def scaled_dot_product_efficient_attention_backward_strategy(
|
|
||||||
op_schema: OpSchema,
|
op_schema: OpSchema,
|
||||||
) -> OpStrategy:
|
) -> list[PlacementList]:
|
||||||
# backward op does not need to validate the mesh since forward op has already done it
|
"""Helper that returns list of base placement strategies (without CP)."""
|
||||||
mesh = op_schema.get_mesh_from_args(validate=False)
|
|
||||||
|
|
||||||
q_input_strategy = op_schema.args_schema[1]
|
q_input_strategy = op_schema.args_schema[1]
|
||||||
if not isinstance(q_input_strategy, OpStrategy):
|
if not isinstance(q_input_strategy, OpStrategy):
|
||||||
raise AssertionError(f"Expected OpStrategy, got {type(q_input_strategy)}")
|
raise AssertionError(f"Expected OpStrategy, got {type(q_input_strategy)}")
|
||||||
@ -662,27 +637,18 @@ def scaled_dot_product_efficient_attention_backward_strategy(
|
|||||||
batch_dim_sharding.extend([Replicate(), Replicate()])
|
batch_dim_sharding.extend([Replicate(), Replicate()])
|
||||||
single_mesh_dim_strategies.append(batch_dim_sharding)
|
single_mesh_dim_strategies.append(batch_dim_sharding)
|
||||||
|
|
||||||
# Context Parallelism: shards on the sequence dim
|
return single_mesh_dim_strategies
|
||||||
seq_dim_sharding: PlacementList = [
|
|
||||||
Shard(2), # grad_q
|
|
||||||
Shard(2), # grad_k
|
|
||||||
Shard(2), # grad_v
|
|
||||||
Shard(1) if has_attn_bias else None, # grad_bias
|
|
||||||
Shard(2), # grad_output
|
|
||||||
Shard(2), # q
|
|
||||||
Shard(2), # k
|
|
||||||
Shard(2), # v
|
|
||||||
Shard(2), # output
|
|
||||||
Shard(2), # logsumexp
|
|
||||||
]
|
|
||||||
# accept replicate on the rest tensor inputs, potentially
|
|
||||||
# cum_seq_q, cum_seq_k, philox_seed, philox_offset
|
|
||||||
# at indices 6, 7, 12, 13, respectively
|
|
||||||
if has_attn_bias:
|
|
||||||
num_heads_dim_sharding.insert(8, Shard(1))
|
|
||||||
seq_dim_sharding.extend([Replicate(), Replicate()])
|
|
||||||
single_mesh_dim_strategies.append(seq_dim_sharding)
|
|
||||||
|
|
||||||
|
|
||||||
|
@register_op_strategy(aten._scaled_dot_product_efficient_attention_backward.default)
|
||||||
|
def scaled_dot_product_efficient_attention_backward_strategy(
|
||||||
|
op_schema: OpSchema,
|
||||||
|
) -> OpStrategy:
|
||||||
|
# backward op does not need to validate the mesh since forward op has already done it
|
||||||
|
mesh = op_schema.get_mesh_from_args(validate=False)
|
||||||
|
single_mesh_dim_strategies = (
|
||||||
|
_scaled_dot_product_efficient_attention_backward_base_strategies(op_schema)
|
||||||
|
)
|
||||||
return expand_to_full_mesh_op_strategy(
|
return expand_to_full_mesh_op_strategy(
|
||||||
mesh,
|
mesh,
|
||||||
op_schema,
|
op_schema,
|
||||||
@ -691,13 +657,10 @@ def scaled_dot_product_efficient_attention_backward_strategy(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_op_strategy(
|
def _scaled_dot_product_cudnn_attention_base_strategies(
|
||||||
aten._scaled_dot_product_cudnn_attention.default,
|
op_schema: OpSchema,
|
||||||
schema_info=RuntimeSchemaInfo(4),
|
) -> list[PlacementList]:
|
||||||
)
|
"""Helper that returns list of base placement strategies (without CP)."""
|
||||||
def scaled_dot_product_cudnn_attention_strategy(op_schema: OpSchema) -> OpStrategy:
|
|
||||||
mesh = op_schema.get_mesh_from_args()
|
|
||||||
|
|
||||||
(
|
(
|
||||||
query_strategy, # query
|
query_strategy, # query
|
||||||
_, # key
|
_, # key
|
||||||
@ -785,39 +748,27 @@ def scaled_dot_product_cudnn_attention_strategy(op_schema: OpSchema) -> OpStrate
|
|||||||
]
|
]
|
||||||
single_mesh_dim_strategies.append(batch_dim_sharding)
|
single_mesh_dim_strategies.append(batch_dim_sharding)
|
||||||
|
|
||||||
# Context Parallelism: shards on the sequence dim
|
return single_mesh_dim_strategies
|
||||||
cp_sharding = Shard(2) # seq dim
|
|
||||||
logsumexp_sharding = cp_sharding if compute_log_sumexp else Replicate()
|
|
||||||
debug_attn_mask_sharding = cp_sharding if return_debug_mask else None
|
|
||||||
|
|
||||||
single_mesh_dim_strategies.append(
|
|
||||||
[
|
@register_op_strategy(
|
||||||
cp_sharding, # output
|
aten._scaled_dot_product_cudnn_attention.default,
|
||||||
logsumexp_sharding, # logsumexp
|
schema_info=RuntimeSchemaInfo(4),
|
||||||
None, # cum_seq_q
|
)
|
||||||
None, # cum_seq_k
|
def scaled_dot_product_cudnn_attention_strategy(op_schema: OpSchema) -> OpStrategy:
|
||||||
None, # max_q
|
mesh = op_schema.get_mesh_from_args()
|
||||||
None, # max_k
|
single_mesh_dim_strategies = _scaled_dot_product_cudnn_attention_base_strategies(
|
||||||
None, # philox_seed
|
op_schema
|
||||||
None, # philox_offset
|
|
||||||
debug_attn_mask_sharding, # debug_attn_mask
|
|
||||||
cp_sharding, # q
|
|
||||||
cp_sharding, # k
|
|
||||||
cp_sharding, # v
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
return expand_to_full_mesh_op_strategy(
|
return expand_to_full_mesh_op_strategy(
|
||||||
mesh, op_schema, single_mesh_dim_strategies, input_index=9
|
mesh, op_schema, single_mesh_dim_strategies, input_index=9
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_op_strategy(aten._scaled_dot_product_cudnn_attention_backward.default)
|
def _scaled_dot_product_cudnn_attention_backward_base_strategies(
|
||||||
def scaled_scaled_dot_product_cudnn_attention_backward_strategy(
|
|
||||||
op_schema: OpSchema,
|
op_schema: OpSchema,
|
||||||
) -> OpStrategy:
|
) -> list[PlacementList]:
|
||||||
# backward op does not need to validate the mesh since forward op has already done it
|
"""Helper that returns list of base placement strategies (without CP)."""
|
||||||
mesh = op_schema.get_mesh_from_args(validate=False)
|
|
||||||
|
|
||||||
if len(op_schema.args_schema) < 15:
|
if len(op_schema.args_schema) < 15:
|
||||||
raise AssertionError(
|
raise AssertionError(
|
||||||
f"Expected at least 15 args_schema, got {len(op_schema.args_schema)}"
|
f"Expected at least 15 args_schema, got {len(op_schema.args_schema)}"
|
||||||
@ -892,23 +843,7 @@ def scaled_scaled_dot_product_cudnn_attention_backward_strategy(
|
|||||||
num_heads_dim_sharding = num_heads_dim_sharding_out + num_heads_dim_sharding_inp
|
num_heads_dim_sharding = num_heads_dim_sharding_out + num_heads_dim_sharding_inp
|
||||||
single_mesh_dim_strategies.append(num_heads_dim_sharding)
|
single_mesh_dim_strategies.append(num_heads_dim_sharding)
|
||||||
|
|
||||||
# case 3: Context Parallelism which shards on the sequence dim
|
# case 3: we can accept the sharding pattern of batch parallelism, which
|
||||||
context_parallel_sharding_out: PlacementList = [Shard(2)] * 3
|
|
||||||
context_parallel_sharding_inp: PlacementList = [Shard(2)] * 6
|
|
||||||
context_parallel_sharding_inp += [
|
|
||||||
Replicate()
|
|
||||||
] * 2 # philox_seed, philox_offset is casted to Replicate() in DTensor
|
|
||||||
context_parallel_sharding_inp += [Shard(2) if has_attn_bias else None]
|
|
||||||
context_parallel_sharding_inp += [None] * 6
|
|
||||||
if has_scale:
|
|
||||||
context_parallel_sharding_inp.append(None)
|
|
||||||
|
|
||||||
context_parallel_sharding = (
|
|
||||||
context_parallel_sharding_out + context_parallel_sharding_inp
|
|
||||||
)
|
|
||||||
single_mesh_dim_strategies.append(context_parallel_sharding)
|
|
||||||
|
|
||||||
# case 4: we can accept the sharding pattern of batch parallelism, which
|
|
||||||
# shards on the batch dimension
|
# shards on the batch dimension
|
||||||
qkv_sharding = Shard(0)
|
qkv_sharding = Shard(0)
|
||||||
output_sharding = Shard(0)
|
output_sharding = Shard(0)
|
||||||
@ -929,6 +864,18 @@ def scaled_scaled_dot_product_cudnn_attention_backward_strategy(
|
|||||||
batch_dim_sharding = batch_dim_sharding_out + batch_dim_sharding_inp
|
batch_dim_sharding = batch_dim_sharding_out + batch_dim_sharding_inp
|
||||||
single_mesh_dim_strategies.append(batch_dim_sharding)
|
single_mesh_dim_strategies.append(batch_dim_sharding)
|
||||||
|
|
||||||
|
return single_mesh_dim_strategies
|
||||||
|
|
||||||
|
|
||||||
|
@register_op_strategy(aten._scaled_dot_product_cudnn_attention_backward.default)
|
||||||
|
def scaled_scaled_dot_product_cudnn_attention_backward_strategy(
|
||||||
|
op_schema: OpSchema,
|
||||||
|
) -> OpStrategy:
|
||||||
|
# backward op does not need to validate the mesh since forward op has already done it
|
||||||
|
mesh = op_schema.get_mesh_from_args(validate=False)
|
||||||
|
single_mesh_dim_strategies = (
|
||||||
|
_scaled_dot_product_cudnn_attention_backward_base_strategies(op_schema)
|
||||||
|
)
|
||||||
return expand_to_full_mesh_op_strategy(
|
return expand_to_full_mesh_op_strategy(
|
||||||
mesh, op_schema, single_mesh_dim_strategies, input_index=3
|
mesh, op_schema, single_mesh_dim_strategies, input_index=3
|
||||||
)
|
)
|
||||||
|
|||||||
@ -989,16 +989,31 @@ def _restore_function(fn: Callable, fn_module: types.ModuleType) -> None:
|
|||||||
|
|
||||||
def _enable_cp_dtensor_dispatcher() -> None:
|
def _enable_cp_dtensor_dispatcher() -> None:
|
||||||
"""Enables DTensor dispatcher to dispatch SDPA to CP."""
|
"""Enables DTensor dispatcher to dispatch SDPA to CP."""
|
||||||
|
# Enable custom op handlers for CP
|
||||||
DTensor._op_dispatcher._custom_op_handlers = {
|
DTensor._op_dispatcher._custom_op_handlers = {
|
||||||
**exitsing_custom_ops,
|
**exitsing_custom_ops,
|
||||||
**custom_ops,
|
**custom_ops,
|
||||||
}
|
}
|
||||||
|
# Register CP-specific sharding rules
|
||||||
|
from ._sharding_rules import register_cp_sharding_rules
|
||||||
|
|
||||||
|
register_cp_sharding_rules()
|
||||||
|
|
||||||
|
|
||||||
def _disable_cp_dtensor_dispatcher() -> None:
|
def _disable_cp_dtensor_dispatcher() -> None:
|
||||||
"""Disables DTensor dispatcher to dispatch SDPA to CP."""
|
"""Disables DTensor dispatcher to dispatch SDPA to CP."""
|
||||||
|
# Restore original custom op handlers
|
||||||
DTensor._op_dispatcher._custom_op_handlers = exitsing_custom_ops
|
DTensor._op_dispatcher._custom_op_handlers = exitsing_custom_ops
|
||||||
|
|
||||||
|
# TODO: unregister_cp_sharding_rules() will cause all DTensor sharding
|
||||||
|
# propagation cache being invalidated. It is not easy to achieve
|
||||||
|
# selectively invalidating lru cache without rewriting the sharding
|
||||||
|
# propagation wrapper. Disable unregister_cp_sharding_rules() call
|
||||||
|
# for now.
|
||||||
|
|
||||||
|
# from ._sharding_rules import unregister_cp_sharding_rules
|
||||||
|
# unregister_cp_sharding_rules()
|
||||||
|
|
||||||
|
|
||||||
def _enable_context_parallel_dispatcher_impl(seq_dim: int, mesh: DeviceMesh) -> None:
|
def _enable_context_parallel_dispatcher_impl(seq_dim: int, mesh: DeviceMesh) -> None:
|
||||||
sdpa_cp = _ContextParallel(
|
sdpa_cp = _ContextParallel(
|
||||||
@ -1032,9 +1047,7 @@ def _disable_context_parallel_dispatcher_impl() -> None:
|
|||||||
_disable_cp_dtensor_dispatcher()
|
_disable_cp_dtensor_dispatcher()
|
||||||
|
|
||||||
|
|
||||||
_compiled_create_block_mask = torch.compile(
|
_compiled_create_block_mask = None
|
||||||
create_block_mask, dynamic=False, fullgraph=True
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _context_parallel_buffers(
|
def _context_parallel_buffers(
|
||||||
@ -1187,9 +1200,12 @@ def _create_cp_block_mask(
|
|||||||
f"BLOCK_SIZE {_DEFAULT_SPARSE_BLOCK_SIZE}. This is not supported yet. "
|
f"BLOCK_SIZE {_DEFAULT_SPARSE_BLOCK_SIZE}. This is not supported yet. "
|
||||||
)
|
)
|
||||||
|
|
||||||
compiled_create_block_mask = torch.compile(
|
global _compiled_create_block_mask
|
||||||
create_block_mask, dynamic=False, fullgraph=True
|
if _compiled_create_block_mask is None:
|
||||||
)
|
_compiled_create_block_mask = torch.compile(
|
||||||
|
create_block_mask, dynamic=False, fullgraph=True
|
||||||
|
)
|
||||||
|
compiled_create_block_mask = _compiled_create_block_mask
|
||||||
|
|
||||||
def _rewrite_mask_mod(
|
def _rewrite_mask_mod(
|
||||||
mask_mod: _mask_mod_signature,
|
mask_mod: _mask_mod_signature,
|
||||||
|
|||||||
@ -0,0 +1,399 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||||
|
"""
|
||||||
|
Context Parallelism sharding rules for scaled_dot_product attention operators.
|
||||||
|
|
||||||
|
The sharding rules for CP cannot be embedded by default because Shard(2) is not
|
||||||
|
a valid sharding for SDPA without CP enabled. This module provides utilities to
|
||||||
|
dynamically install Shard(2) sharding rules when CP is activated.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.distributed.tensor._op_schema import (
|
||||||
|
OpSchema,
|
||||||
|
OpStrategy,
|
||||||
|
PlacementList,
|
||||||
|
RuntimeSchemaInfo,
|
||||||
|
)
|
||||||
|
from torch.distributed.tensor._ops.utils import (
|
||||||
|
expand_to_full_mesh_op_strategy,
|
||||||
|
register_op_strategy,
|
||||||
|
)
|
||||||
|
from torch.distributed.tensor.placement_types import Replicate, Shard
|
||||||
|
|
||||||
|
|
||||||
|
aten = torch.ops.aten
|
||||||
|
|
||||||
|
SEQ_DIM = 2
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _op_strategy_context(op_overload, strategy_func, schema_info=None):
|
||||||
|
"""
|
||||||
|
Context manager for setting and clearing op strategies for Context Parallelism.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
op_overload: The operator overload to set or clear the strategy for.
|
||||||
|
strategy_func: The strategy function to set for the operator overload.
|
||||||
|
schema_info: Optional schema information for the operator overload.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
from torch.distributed.tensor import DTensor
|
||||||
|
|
||||||
|
propagator = DTensor._op_dispatcher.sharding_propagator
|
||||||
|
_origin_op_strategy_funcs = None
|
||||||
|
_origin_op_strategy_schema = None
|
||||||
|
try:
|
||||||
|
# Save original strategy if exists
|
||||||
|
if op_overload in propagator.op_strategy_funcs:
|
||||||
|
_origin_op_strategy_funcs = propagator.op_strategy_funcs[op_overload]
|
||||||
|
if op_overload in propagator.op_to_schema_info:
|
||||||
|
_origin_op_strategy_schema = propagator.op_to_schema_info[op_overload]
|
||||||
|
|
||||||
|
# Register the new op strategy
|
||||||
|
register_op_strategy(op_overload, schema_info=schema_info)(strategy_func)
|
||||||
|
yield (_origin_op_strategy_funcs, _origin_op_strategy_schema)
|
||||||
|
finally:
|
||||||
|
# Restore original strategy
|
||||||
|
if _origin_op_strategy_funcs is None:
|
||||||
|
if op_overload in propagator.op_strategy_funcs:
|
||||||
|
del propagator.op_strategy_funcs[op_overload]
|
||||||
|
else:
|
||||||
|
propagator.op_strategy_funcs[op_overload] = _origin_op_strategy_funcs
|
||||||
|
|
||||||
|
if _origin_op_strategy_schema is None:
|
||||||
|
if op_overload in propagator.op_to_schema_info:
|
||||||
|
del propagator.op_to_schema_info[op_overload]
|
||||||
|
else:
|
||||||
|
propagator.op_to_schema_info[op_overload] = _origin_op_strategy_schema
|
||||||
|
|
||||||
|
# Clear cache
|
||||||
|
propagator.propagate_op_sharding.cache.cache_clear()
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== Flash Attention Strategies ====================
|
||||||
|
|
||||||
|
|
||||||
|
def _scaled_dot_product_flash_attention_cp_strategy(op_schema: OpSchema) -> OpStrategy:
|
||||||
|
"""
|
||||||
|
Strategy for flash attention forward with Context Parallelism support.
|
||||||
|
This includes the base strategies plus CP-specific sequence dimension sharding.
|
||||||
|
"""
|
||||||
|
# Import here to avoid circular dependency
|
||||||
|
from torch.distributed.tensor._ops._matrix_ops import (
|
||||||
|
_scaled_dot_product_flash_attention_base_strategies,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the base strategies (without CP modifications)
|
||||||
|
mesh = op_schema.get_mesh_from_args()
|
||||||
|
single_mesh_dim_strategies = _scaled_dot_product_flash_attention_base_strategies(
|
||||||
|
op_schema
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add Context Parallelism strategy: shards on the sequence dim
|
||||||
|
return_debug_mask = len(op_schema.args_schema) >= 6 and op_schema.args_schema[5]
|
||||||
|
debug_attn_mask_sharding = Shard(SEQ_DIM) if return_debug_mask else Replicate()
|
||||||
|
|
||||||
|
cp_strategy: PlacementList = [
|
||||||
|
Shard(SEQ_DIM), # output
|
||||||
|
Shard(SEQ_DIM), # logsumexp
|
||||||
|
None, # cum_seq_q
|
||||||
|
None, # cum_seq_k
|
||||||
|
None, # max_q
|
||||||
|
None, # max_k
|
||||||
|
Replicate(), # rng_state
|
||||||
|
None, # unused
|
||||||
|
debug_attn_mask_sharding, # debugattn
|
||||||
|
Shard(SEQ_DIM), # q
|
||||||
|
Shard(SEQ_DIM), # k
|
||||||
|
Shard(SEQ_DIM), # v
|
||||||
|
]
|
||||||
|
single_mesh_dim_strategies.append(cp_strategy)
|
||||||
|
|
||||||
|
return expand_to_full_mesh_op_strategy(
|
||||||
|
mesh, op_schema, single_mesh_dim_strategies, input_index=9
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _scaled_dot_product_flash_attention_backward_cp_strategy(
|
||||||
|
op_schema: OpSchema,
|
||||||
|
) -> OpStrategy:
|
||||||
|
"""
|
||||||
|
Strategy for flash attention backward with Context Parallelism support.
|
||||||
|
"""
|
||||||
|
from torch.distributed.tensor._ops._matrix_ops import (
|
||||||
|
_scaled_dot_product_flash_attention_backward_base_strategies,
|
||||||
|
)
|
||||||
|
|
||||||
|
mesh = op_schema.get_mesh_from_args(validate=False)
|
||||||
|
single_mesh_dim_strategies = (
|
||||||
|
_scaled_dot_product_flash_attention_backward_base_strategies(op_schema)
|
||||||
|
)
|
||||||
|
|
||||||
|
tensor_input_indices = [
|
||||||
|
i
|
||||||
|
for i, arg_spec in enumerate(op_schema.args_schema)
|
||||||
|
if isinstance(arg_spec, OpStrategy)
|
||||||
|
]
|
||||||
|
num_tensor_inputs = len(tensor_input_indices)
|
||||||
|
|
||||||
|
# Context Parallelism: shards on the sequence dim
|
||||||
|
cp_strategy: PlacementList = [
|
||||||
|
Shard(SEQ_DIM), # grad_q
|
||||||
|
Shard(SEQ_DIM), # grad_k
|
||||||
|
Shard(SEQ_DIM), # grad_v
|
||||||
|
Shard(SEQ_DIM), # grad_output
|
||||||
|
Shard(SEQ_DIM), # q
|
||||||
|
Shard(SEQ_DIM), # k
|
||||||
|
Shard(SEQ_DIM), # v
|
||||||
|
Shard(SEQ_DIM), # output
|
||||||
|
Shard(SEQ_DIM), # logsumexp
|
||||||
|
]
|
||||||
|
cp_strategy.extend([Replicate()] * (num_tensor_inputs - 6))
|
||||||
|
single_mesh_dim_strategies.append(cp_strategy)
|
||||||
|
|
||||||
|
return expand_to_full_mesh_op_strategy(
|
||||||
|
mesh, op_schema, single_mesh_dim_strategies, input_index=3
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== Efficient Attention Strategies ====================
|
||||||
|
|
||||||
|
|
||||||
|
def _scaled_dot_product_efficient_attention_cp_strategy(
|
||||||
|
op_schema: OpSchema,
|
||||||
|
) -> OpStrategy:
|
||||||
|
"""
|
||||||
|
Strategy for efficient attention forward with Context Parallelism support.
|
||||||
|
"""
|
||||||
|
from torch.distributed.tensor._ops._matrix_ops import (
|
||||||
|
_scaled_dot_product_efficient_attention_base_strategies,
|
||||||
|
)
|
||||||
|
|
||||||
|
mesh = op_schema.get_mesh_from_args()
|
||||||
|
single_mesh_dim_strategies = (
|
||||||
|
_scaled_dot_product_efficient_attention_base_strategies(op_schema)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add Context Parallelism strategy
|
||||||
|
has_attn_bias = op_schema.args_schema[3] is not None
|
||||||
|
|
||||||
|
cp_strategy: PlacementList = [
|
||||||
|
Shard(SEQ_DIM), # output
|
||||||
|
Shard(SEQ_DIM), # logsumexp
|
||||||
|
None, # philox_seed
|
||||||
|
None, # philox_offset
|
||||||
|
Shard(SEQ_DIM), # q
|
||||||
|
Shard(SEQ_DIM), # k
|
||||||
|
Shard(SEQ_DIM), # v
|
||||||
|
]
|
||||||
|
if has_attn_bias:
|
||||||
|
cp_strategy.append(Replicate()) # attn bias - not sharded for CP
|
||||||
|
single_mesh_dim_strategies.append(cp_strategy)
|
||||||
|
|
||||||
|
return expand_to_full_mesh_op_strategy(
|
||||||
|
mesh, op_schema, single_mesh_dim_strategies, input_index=4
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _scaled_dot_product_efficient_attention_backward_cp_strategy(
|
||||||
|
op_schema: OpSchema,
|
||||||
|
) -> OpStrategy:
|
||||||
|
"""
|
||||||
|
Strategy for efficient attention backward with Context Parallelism support.
|
||||||
|
"""
|
||||||
|
from torch.distributed.tensor._ops._matrix_ops import (
|
||||||
|
_scaled_dot_product_efficient_attention_backward_base_strategies,
|
||||||
|
)
|
||||||
|
|
||||||
|
mesh = op_schema.get_mesh_from_args(validate=False)
|
||||||
|
single_mesh_dim_strategies = (
|
||||||
|
_scaled_dot_product_efficient_attention_backward_base_strategies(op_schema)
|
||||||
|
)
|
||||||
|
|
||||||
|
has_attn_bias = op_schema.args_schema[4] is not None
|
||||||
|
|
||||||
|
# Context Parallelism: shards on the sequence dim
|
||||||
|
cp_strategy: PlacementList = [
|
||||||
|
Shard(SEQ_DIM), # grad_q
|
||||||
|
Shard(SEQ_DIM), # grad_k
|
||||||
|
Shard(SEQ_DIM), # grad_v
|
||||||
|
Shard(1) if has_attn_bias else None, # grad_bias
|
||||||
|
Shard(SEQ_DIM), # grad_output
|
||||||
|
Shard(SEQ_DIM), # q
|
||||||
|
Shard(SEQ_DIM), # k
|
||||||
|
Shard(SEQ_DIM), # v
|
||||||
|
Shard(SEQ_DIM), # output
|
||||||
|
Shard(SEQ_DIM), # logsumexp
|
||||||
|
]
|
||||||
|
if has_attn_bias:
|
||||||
|
cp_strategy.insert(8, Shard(1)) # attn_bias input
|
||||||
|
cp_strategy.extend([Replicate(), Replicate()])
|
||||||
|
single_mesh_dim_strategies.append(cp_strategy)
|
||||||
|
|
||||||
|
return expand_to_full_mesh_op_strategy(
|
||||||
|
mesh, op_schema, single_mesh_dim_strategies, input_index=4
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== cuDNN Attention Strategies ====================
|
||||||
|
|
||||||
|
|
||||||
|
def _scaled_dot_product_cudnn_attention_cp_strategy(op_schema: OpSchema) -> OpStrategy:
|
||||||
|
"""
|
||||||
|
Strategy for cudnn attention forward with Context Parallelism support.
|
||||||
|
"""
|
||||||
|
from torch.distributed.tensor._ops._matrix_ops import (
|
||||||
|
_scaled_dot_product_cudnn_attention_base_strategies,
|
||||||
|
)
|
||||||
|
|
||||||
|
mesh = op_schema.get_mesh_from_args()
|
||||||
|
single_mesh_dim_strategies = _scaled_dot_product_cudnn_attention_base_strategies(
|
||||||
|
op_schema
|
||||||
|
)
|
||||||
|
|
||||||
|
(
|
||||||
|
query_strategy,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
attn_bias_strategy,
|
||||||
|
compute_log_sumexp,
|
||||||
|
*rest_args,
|
||||||
|
) = op_schema.args_schema
|
||||||
|
return_debug_mask = len(op_schema.args_schema) >= 8 and rest_args[2]
|
||||||
|
has_attn_bias = attn_bias_strategy is not None
|
||||||
|
|
||||||
|
# Context Parallelism: shards on the sequence dim
|
||||||
|
logsumexp_sharding = Shard(SEQ_DIM) if compute_log_sumexp else Replicate()
|
||||||
|
debug_attn_mask_sharding = Shard(SEQ_DIM) if return_debug_mask else None
|
||||||
|
|
||||||
|
cp_strategy: PlacementList = [
|
||||||
|
Shard(SEQ_DIM), # output
|
||||||
|
logsumexp_sharding, # logsumexp
|
||||||
|
None, # cum_seq_q
|
||||||
|
None, # cum_seq_k
|
||||||
|
None, # max_q
|
||||||
|
None, # max_k
|
||||||
|
None, # philox_seed
|
||||||
|
None, # philox_offset
|
||||||
|
debug_attn_mask_sharding, # debug_attn_mask
|
||||||
|
Shard(SEQ_DIM), # q
|
||||||
|
Shard(SEQ_DIM), # k
|
||||||
|
Shard(SEQ_DIM), # v
|
||||||
|
]
|
||||||
|
if has_attn_bias:
|
||||||
|
cp_strategy.append(Replicate()) # attn_bias - not sharded for CP
|
||||||
|
single_mesh_dim_strategies.append(cp_strategy)
|
||||||
|
|
||||||
|
return expand_to_full_mesh_op_strategy(
|
||||||
|
mesh, op_schema, single_mesh_dim_strategies, input_index=9
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _scaled_dot_product_cudnn_attention_backward_cp_strategy(
|
||||||
|
op_schema: OpSchema,
|
||||||
|
) -> OpStrategy:
|
||||||
|
"""
|
||||||
|
Strategy for cudnn attention backward with Context Parallelism support.
|
||||||
|
"""
|
||||||
|
from torch.distributed.tensor._ops._matrix_ops import (
|
||||||
|
_scaled_dot_product_cudnn_attention_backward_base_strategies,
|
||||||
|
)
|
||||||
|
|
||||||
|
mesh = op_schema.get_mesh_from_args(validate=False)
|
||||||
|
single_mesh_dim_strategies = (
|
||||||
|
_scaled_dot_product_cudnn_attention_backward_base_strategies(op_schema)
|
||||||
|
)
|
||||||
|
|
||||||
|
has_attn_bias = op_schema.args_schema[8] is not None
|
||||||
|
has_scale = len(op_schema.args_schema) >= 16 and False
|
||||||
|
|
||||||
|
# Context Parallelism: shards on the sequence dim
|
||||||
|
cp_sharding_gout: PlacementList = [Shard(SEQ_DIM)] * 3 # grad_q, grad_k, grad_v
|
||||||
|
cp_sharding_ginp: PlacementList = [
|
||||||
|
Shard(SEQ_DIM)
|
||||||
|
] * 6 # grad_output, q, k, v, output, logsumexp
|
||||||
|
cp_sharding_ginp += [Replicate()] * 2 # philox_seed, philox_offset
|
||||||
|
cp_sharding_ginp += [Shard(SEQ_DIM) if has_attn_bias else None] # attn_bias
|
||||||
|
cp_sharding_ginp += [
|
||||||
|
None
|
||||||
|
] * 6 # cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal
|
||||||
|
if has_scale:
|
||||||
|
cp_sharding_ginp.append(None)
|
||||||
|
|
||||||
|
cp_sharding = cp_sharding_gout + cp_sharding_ginp
|
||||||
|
single_mesh_dim_strategies.append(cp_sharding)
|
||||||
|
|
||||||
|
return expand_to_full_mesh_op_strategy(
|
||||||
|
mesh, op_schema, single_mesh_dim_strategies, input_index=3
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Store context managers and original strategies
|
||||||
|
_cp_strategy_contexts = {}
|
||||||
|
_original_strategies = {}
|
||||||
|
|
||||||
|
|
||||||
|
def register_cp_sharding_rules():
|
||||||
|
"""Register Context Parallelism sharding rules for all scaled_dot_product ops."""
|
||||||
|
global _cp_strategy_contexts, _original_strategies
|
||||||
|
|
||||||
|
# If already registered, don't register again
|
||||||
|
if _cp_strategy_contexts:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Define ops and their corresponding CP strategy functions
|
||||||
|
cp_strategies = [
|
||||||
|
(
|
||||||
|
aten._scaled_dot_product_flash_attention.default,
|
||||||
|
_scaled_dot_product_flash_attention_cp_strategy,
|
||||||
|
RuntimeSchemaInfo(5),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
aten._scaled_dot_product_flash_attention_backward.default,
|
||||||
|
_scaled_dot_product_flash_attention_backward_cp_strategy,
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
aten._scaled_dot_product_efficient_attention.default,
|
||||||
|
_scaled_dot_product_efficient_attention_cp_strategy,
|
||||||
|
RuntimeSchemaInfo(4),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
aten._scaled_dot_product_efficient_attention_backward.default,
|
||||||
|
_scaled_dot_product_efficient_attention_backward_cp_strategy,
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
aten._scaled_dot_product_cudnn_attention.default,
|
||||||
|
_scaled_dot_product_cudnn_attention_cp_strategy,
|
||||||
|
RuntimeSchemaInfo(4),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
aten._scaled_dot_product_cudnn_attention_backward.default,
|
||||||
|
_scaled_dot_product_cudnn_attention_backward_cp_strategy,
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Register each strategy
|
||||||
|
for op_overload, strategy_func, schema_info in cp_strategies:
|
||||||
|
ctx = _op_strategy_context(op_overload, strategy_func, schema_info)
|
||||||
|
orig_funcs, orig_schema = ctx.__enter__()
|
||||||
|
_cp_strategy_contexts[op_overload] = ctx
|
||||||
|
_original_strategies[op_overload] = (orig_funcs, orig_schema)
|
||||||
|
|
||||||
|
|
||||||
|
def unregister_cp_sharding_rules():
|
||||||
|
"""Unregister Context Parallelism sharding rules and restore original strategies."""
|
||||||
|
global _cp_strategy_contexts, _original_strategies
|
||||||
|
|
||||||
|
# Exit all context managers
|
||||||
|
for ctx in _cp_strategy_contexts.values():
|
||||||
|
ctx.__exit__(None, None, None)
|
||||||
|
|
||||||
|
_cp_strategy_contexts = {}
|
||||||
|
_original_strategies = {}
|
||||||
@ -529,14 +529,11 @@ RE_EXTERN_SHARED = re.compile(r"extern\s+([\w\(\)]+)?\s*__shared__\s+([\w:<>\s]+
|
|||||||
|
|
||||||
|
|
||||||
def replace_extern_shared(input_string):
|
def replace_extern_shared(input_string):
|
||||||
"""
|
"""Match extern __shared__ type foo[]; syntax and use HIP_DYNAMIC_SHARED() MACRO instead.
|
||||||
Match 'extern __shared__ type foo[];' syntax and use HIP_DYNAMIC_SHARED() MACRO instead.
|
https://github.com/ROCm/hip/blob/master/docs/markdown/hip_kernel_language.md#__shared__
|
||||||
See: https://github.com/ROCm/hip/blob/master/docs/markdown/hip_kernel_language.md#__shared__
|
Example:
|
||||||
Examples:
|
"extern __shared__ char smemChar[];" => "HIP_DYNAMIC_SHARED( char, smemChar)"
|
||||||
"extern __shared__ char smemChar[];"
|
"extern __shared__ unsigned char smem[];" => "HIP_DYNAMIC_SHARED( unsigned char, my_smem)"
|
||||||
=> "HIP_DYNAMIC_SHARED( char, smemChar)"
|
|
||||||
"extern __shared__ unsigned char smem[];"
|
|
||||||
=> "HIP_DYNAMIC_SHARED( unsigned char, my_smem)"
|
|
||||||
"""
|
"""
|
||||||
output_string = input_string
|
output_string = input_string
|
||||||
output_string = RE_EXTERN_SHARED.sub(
|
output_string = RE_EXTERN_SHARED.sub(
|
||||||
@ -1046,17 +1043,14 @@ RE_INCLUDE = re.compile(r"#include .*\n")
|
|||||||
|
|
||||||
|
|
||||||
def extract_arguments(start, string):
|
def extract_arguments(start, string):
|
||||||
"""
|
""" Return the list of arguments in the upcoming function parameter closure.
|
||||||
Return the list of arguments in the upcoming function parameter closure.
|
Example:
|
||||||
Example:
|
|
||||||
string (input): '(blocks, threads, 0, THCState_getCurrentStream(state))'
|
string (input): '(blocks, threads, 0, THCState_getCurrentStream(state))'
|
||||||
arguments (output):
|
arguments (output):
|
||||||
[
|
'[{'start': 1, 'end': 7},
|
||||||
{'start': 1, 'end': 7},
|
{'start': 8, 'end': 16},
|
||||||
{'start': 8, 'end': 16},
|
{'start': 17, 'end': 19},
|
||||||
{'start': 17, 'end': 19},
|
{'start': 20, 'end': 53}]'
|
||||||
{'start': 20, 'end': 53}
|
|
||||||
]
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
arguments = []
|
arguments = []
|
||||||
|
|||||||
Reference in New Issue
Block a user