Shard RegisterDispatchKey (#144364)

Should fix https://github.com/pytorch/pytorch/issues/143952 .

Testing: built PyTorch on Raspberry Pi 5; this seemed to alleviate high peak memory requirement. (I did increase shard counts for other generated files along the way, but I need to go back and figure out how much of that was strictly necessary vs. needing to use -j1 or -j2.)

Differential Revision: [D67925496](https://our.internmc.facebook.com/intern/diff/D67925496/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144364
Approved by: https://github.com/Skylion007, https://github.com/bdhirsh
ghstack dependencies: #144363
This commit is contained in:
Scott Wolchok
2025-01-09 15:00:21 -08:00
committed by PyTorch MergeBot
parent 4143312e67
commit b46d00c1b7
5 changed files with 129 additions and 79 deletions

View File

@ -38,26 +38,29 @@ aten_generation_srcs = ["aten/src/ATen/native/native_functions.yaml"] + ["aten/s
generated_cpu_cpp = [
"aten/src/ATen/RegisterBackendSelect.cpp",
"aten/src/ATen/RegisterCPU.cpp",
"aten/src/ATen/RegisterCPU_0.cpp",
"aten/src/ATen/RegisterCPU_1.cpp",
"aten/src/ATen/RegisterCPU_2.cpp",
"aten/src/ATen/RegisterCPU_3.cpp",
"aten/src/ATen/RegisterFunctionalization_0.cpp",
"aten/src/ATen/RegisterFunctionalization_1.cpp",
"aten/src/ATen/RegisterFunctionalization_2.cpp",
"aten/src/ATen/RegisterFunctionalization_3.cpp",
# "aten/src/ATen/RegisterFunctionalizationEverything.cpp",
"aten/src/ATen/RegisterMkldnnCPU.cpp",
"aten/src/ATen/RegisterNestedTensorCPU.cpp",
"aten/src/ATen/RegisterQuantizedCPU.cpp",
"aten/src/ATen/RegisterSparseCPU.cpp",
"aten/src/ATen/RegisterSparseCsrCPU.cpp",
"aten/src/ATen/RegisterZeroTensor.cpp",
"aten/src/ATen/RegisterCompositeImplicitAutograd.cpp",
"aten/src/ATen/RegisterCompositeImplicitAutogradNestedTensor.cpp",
"aten/src/ATen/RegisterCompositeExplicitAutograd.cpp",
"aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp",
"aten/src/ATen/RegisterMeta.cpp",
"aten/src/ATen/RegisterSparseMeta.cpp",
"aten/src/ATen/RegisterQuantizedMeta.cpp",
"aten/src/ATen/RegisterNestedTensorMeta.cpp",
"aten/src/ATen/RegisterMkldnnCPU_0.cpp",
"aten/src/ATen/RegisterNestedTensorCPU_0.cpp",
"aten/src/ATen/RegisterQuantizedCPU_0.cpp",
"aten/src/ATen/RegisterSparseCPU_0.cpp",
"aten/src/ATen/RegisterSparseCsrCPU_0.cpp",
"aten/src/ATen/RegisterZeroTensor_0.cpp",
"aten/src/ATen/RegisterCompositeImplicitAutograd_0.cpp",
"aten/src/ATen/RegisterCompositeImplicitAutogradNestedTensor_0.cpp",
"aten/src/ATen/RegisterCompositeExplicitAutograd_0.cpp",
"aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp",
"aten/src/ATen/RegisterMeta_0.cpp",
"aten/src/ATen/RegisterSparseMeta_0.cpp",
"aten/src/ATen/RegisterQuantizedMeta_0.cpp",
"aten/src/ATen/RegisterNestedTensorMeta_0.cpp",
"aten/src/ATen/RegisterSchema.cpp",
"aten/src/ATen/CPUFunctions.h",
"aten/src/ATen/CPUFunctions_inl.h",
@ -97,11 +100,11 @@ generated_cpu_cpp = [
generated_cuda_cpp = [
"aten/src/ATen/CUDAFunctions.h",
"aten/src/ATen/CUDAFunctions_inl.h",
"aten/src/ATen/RegisterCUDA.cpp",
"aten/src/ATen/RegisterNestedTensorCUDA.cpp",
"aten/src/ATen/RegisterQuantizedCUDA.cpp",
"aten/src/ATen/RegisterSparseCUDA.cpp",
"aten/src/ATen/RegisterSparseCsrCUDA.cpp",
"aten/src/ATen/RegisterCUDA_0.cpp",
"aten/src/ATen/RegisterNestedTensorCUDA_0.cpp",
"aten/src/ATen/RegisterQuantizedCUDA_0.cpp",
"aten/src/ATen/RegisterSparseCUDA_0.cpp",
"aten/src/ATen/RegisterSparseCsrCUDA_0.cpp",
]
generate_aten(

View File

@ -353,10 +353,10 @@ def get_aten_generated_files(enabled_backends):
# and is intentionally omitted from here
src_files = [
"RegisterBackendSelect.cpp",
"RegisterCompositeImplicitAutograd.cpp",
"RegisterCompositeImplicitAutogradNestedTensor.cpp",
"RegisterCompositeExplicitAutograd.cpp",
"RegisterCompositeExplicitAutogradNonFunctional.cpp",
"RegisterCompositeImplicitAutograd_0.cpp",
"RegisterCompositeImplicitAutogradNestedTensor_0.cpp",
"RegisterCompositeExplicitAutograd_0.cpp",
"RegisterCompositeExplicitAutogradNonFunctional_0.cpp",
"CompositeViewCopyKernels.cpp",
"RegisterSchema.cpp",
"Declarations.yaml",
@ -409,20 +409,22 @@ def get_aten_generated_files(enabled_backends):
def get_aten_derived_type_src_rules(aten_rule_name, enabled_backends):
return [
":{}[{}]".format(aten_rule_name, "Register" + backend + ".cpp")
for backend in enabled_backends
]
":{}[{}]".format(aten_rule_name, "Register" + backend + "_0.cpp")
for backend in enabled_backends if backend != "CPU"
] + ([
":{}[RegisterCPU_{}.cpp]".format(aten_rule_name, x) for x in range(4)
] if "CPU" in enabled_backends else [])
def get_aten_selective_cpp_rules(aten_rule_name, enabled_backends):
return [
":{}[{}]".format(aten_rule_name, f)
for f in ["RegisterCompositeImplicitAutograd.cpp", "RegisterCompositeImplicitAutogradNestedTensor.cpp", "RegisterCompositeExplicitAutograd.cpp", "RegisterCompositeExplicitAutogradNonFunctional.cpp", "RegisterSchema.cpp", "RegisterBackendSelect.cpp", "CompositeViewCopyKernels.cpp"]
for f in ["RegisterCompositeImplicitAutograd_0.cpp", "RegisterCompositeImplicitAutogradNestedTensor_0.cpp", "RegisterCompositeExplicitAutograd_0.cpp", "RegisterCompositeExplicitAutogradNonFunctional_0.cpp", "RegisterSchema.cpp", "RegisterBackendSelect.cpp", "CompositeViewCopyKernels.cpp"]
] + get_aten_derived_type_src_rules(aten_rule_name, enabled_backends)
def get_aten_derived_type_srcs(enabled_backends):
return [
"Register" + derived_type + ".cpp"
for derived_type in enabled_backends
"Register" + derived_type + "_0.cpp"
for derived_type in enabled_backends if derived_type != "CPU"
] + [
derived_type + "Functions.h"
for derived_type in enabled_backends
@ -431,7 +433,9 @@ def get_aten_derived_type_srcs(enabled_backends):
derived_type + "Functions_inl.h"
for derived_type in enabled_backends
if derived_type in PT_BACKEND_HEADERS or derived_type in get_static_dispatch_backend()
]
] + ([
"RegisterCPU_{}.cpp".format(x) for x in range(4)
] if "CPU" in enabled_backends else [])
def gen_aten_files(
name,

View File

@ -202,31 +202,34 @@ GENERATED_H_CUDA = [
]
GENERATED_CPP_CUDA = [
"RegisterCUDA.cpp",
"RegisterNestedTensorCUDA.cpp",
"RegisterSparseCUDA.cpp",
"RegisterSparseCsrCUDA.cpp",
"RegisterQuantizedCUDA.cpp",
"RegisterCUDA_0.cpp",
"RegisterNestedTensorCUDA_0.cpp",
"RegisterSparseCUDA_0.cpp",
"RegisterSparseCsrCUDA_0.cpp",
"RegisterQuantizedCUDA_0.cpp",
]
GENERATED_CPP = [
"Functions.cpp",
"RegisterBackendSelect.cpp",
"RegisterCPU.cpp",
"RegisterQuantizedCPU.cpp",
"RegisterNestedTensorCPU.cpp",
"RegisterSparseCPU.cpp",
"RegisterSparseCsrCPU.cpp",
"RegisterMkldnnCPU.cpp",
"RegisterCompositeImplicitAutograd.cpp",
"RegisterCompositeImplicitAutogradNestedTensor.cpp",
"RegisterZeroTensor.cpp",
"RegisterMeta.cpp",
"RegisterQuantizedMeta.cpp",
"RegisterNestedTensorMeta.cpp",
"RegisterSparseMeta.cpp",
"RegisterCompositeExplicitAutograd.cpp",
"RegisterCompositeExplicitAutogradNonFunctional.cpp",
"RegisterCPU_0.cpp",
"RegisterCPU_1.cpp",
"RegisterCPU_2.cpp",
"RegisterCPU_3.cpp",
"RegisterQuantizedCPU_0.cpp",
"RegisterNestedTensorCPU_0.cpp",
"RegisterSparseCPU_0.cpp",
"RegisterSparseCsrCPU_0.cpp",
"RegisterMkldnnCPU_0.cpp",
"RegisterCompositeImplicitAutograd_0.cpp",
"RegisterCompositeImplicitAutogradNestedTensor_0.cpp",
"RegisterZeroTensor_0.cpp",
"RegisterMeta_0.cpp",
"RegisterQuantizedMeta_0.cpp",
"RegisterNestedTensorMeta_0.cpp",
"RegisterSparseMeta_0.cpp",
"RegisterCompositeExplicitAutograd_0.cpp",
"RegisterCompositeExplicitAutogradNonFunctional_0.cpp",
"CompositeViewCopyKernels.cpp",
"RegisterSchema.cpp",
"RegisterFunctionalization_0.cpp",

View File

@ -8,7 +8,7 @@ import os
from collections import defaultdict, namedtuple, OrderedDict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable, Literal, TYPE_CHECKING, TypeVar
from typing import Any, Callable, Dict, Literal, TYPE_CHECKING, TypeVar
import yaml
@ -2305,34 +2305,49 @@ def gen_source_files(
dispatch_key != DispatchKey.CompositeImplicitAutogradNestedTensor
)
dispatch_definitions = get_native_function_definitions(
fm=fm,
grouped_native_functions=grouped_native_functions,
dispatch_key=dispatch_key,
backend_idx=backend_index,
selector=selector,
rocm=rocm,
symint=True,
skip_dispatcher_op_registration=skip_dispatcher_op_registration,
gen_dispatch_helpers=gen_dispatch_helpers,
)
fm.write_with_template(
register_dispatch_key_base_env = {
"extra_cuda_headers": extra_cuda_headers
if is_cuda_dispatch_key(dispatch_key)
else "",
"external_backend_headers": "",
"dispatch_headers": dest.gen_registration_headers(
backend_index, per_operator_headers, rocm
),
# ops_headers *could* be sharded, but doesn't seem necessary?
"ops_headers": operator_headers(),
"dispatch_helpers": (
dest.gen_registration_helpers(backend_index)
if gen_dispatch_helpers
else []
),
}
def register_dispatch_key_env_callable(
gnf: NativeFunction | NativeFunctionsGroup,
) -> Dict[str, list[str]]:
return {
"dispatch_definitions": get_native_function_definitions(
fm=fm, # noqa: F821
grouped_native_functions=[gnf],
dispatch_key=dispatch_key,
backend_idx=backend_index,
selector=selector,
rocm=rocm,
symint=True,
skip_dispatcher_op_registration=skip_dispatcher_op_registration,
gen_dispatch_helpers=gen_dispatch_helpers,
)
}
fm.write_sharded_with_template(
f"Register{dispatch_key}.cpp",
"RegisterDispatchKey.cpp",
lambda: {
"extra_cuda_headers": extra_cuda_headers
if is_cuda_dispatch_key(dispatch_key)
else "",
"external_backend_headers": "",
"dispatch_headers": dest.gen_registration_headers(
backend_index, per_operator_headers, rocm
),
"ops_headers": operator_headers(),
"dispatch_helpers": dest.gen_registration_helpers(backend_index)
if gen_dispatch_helpers
else [],
"dispatch_definitions": dispatch_definitions,
},
grouped_native_functions,
key_fn=lambda x: x.root_name,
env_callable=register_dispatch_key_env_callable,
num_shards=4 if dispatch_key == DispatchKey.CPU else 1,
base_env=register_dispatch_key_base_env,
sharded_keys={"dispatch_definitions"},
)
for g in structured_native_functions:

View File

@ -209,6 +209,29 @@ class FileManager:
num_shards: int,
base_env: dict[str, Any] | None = None,
sharded_keys: set[str],
) -> None:
self.write_sharded_with_template(
filename,
filename,
items,
key_fn=key_fn,
env_callable=env_callable,
num_shards=num_shards,
base_env=base_env,
sharded_keys=sharded_keys,
)
def write_sharded_with_template(
self,
filename: str,
template_fn: str,
items: Iterable[T],
*,
key_fn: Callable[[T], str],
env_callable: Callable[[T], dict[str, list[str]]],
num_shards: int,
base_env: dict[str, Any] | None = None,
sharded_keys: set[str],
) -> None:
everything: dict[str, Any] = {"shard_id": "Everything"}
shards: list[dict[str, Any]] = [
@ -256,7 +279,9 @@ class FileManager:
for shard in all_shards:
shard_id = shard["shard_id"]
self.write_with_template(
f"{base_filename}{shard_id}{extension}", filename, lambda: shard
f"{base_filename}{shard_id}{extension}",
template_fn,
lambda: shard,
)
# filenames is used to track compiled files, but FooEverything.cpp isn't meant to be compiled