mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Shard RegisterDispatchKey (#144364)
Should fix https://github.com/pytorch/pytorch/issues/143952 . Testing: built PyTorch on Raspberry Pi 5; this seemed to alleviate high peak memory requirement. (I did increase shard counts for other generated files along the way, but I need to go back and figure out how much of that was strictly necessary vs. needing to use -j1 or -j2.) Differential Revision: [D67925496](https://our.internmc.facebook.com/intern/diff/D67925496/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/144364 Approved by: https://github.com/Skylion007, https://github.com/bdhirsh ghstack dependencies: #144363
This commit is contained in:
committed by
PyTorch MergeBot
parent
4143312e67
commit
b46d00c1b7
43
BUILD.bazel
43
BUILD.bazel
@ -38,26 +38,29 @@ aten_generation_srcs = ["aten/src/ATen/native/native_functions.yaml"] + ["aten/s
|
||||
|
||||
generated_cpu_cpp = [
|
||||
"aten/src/ATen/RegisterBackendSelect.cpp",
|
||||
"aten/src/ATen/RegisterCPU.cpp",
|
||||
"aten/src/ATen/RegisterCPU_0.cpp",
|
||||
"aten/src/ATen/RegisterCPU_1.cpp",
|
||||
"aten/src/ATen/RegisterCPU_2.cpp",
|
||||
"aten/src/ATen/RegisterCPU_3.cpp",
|
||||
"aten/src/ATen/RegisterFunctionalization_0.cpp",
|
||||
"aten/src/ATen/RegisterFunctionalization_1.cpp",
|
||||
"aten/src/ATen/RegisterFunctionalization_2.cpp",
|
||||
"aten/src/ATen/RegisterFunctionalization_3.cpp",
|
||||
# "aten/src/ATen/RegisterFunctionalizationEverything.cpp",
|
||||
"aten/src/ATen/RegisterMkldnnCPU.cpp",
|
||||
"aten/src/ATen/RegisterNestedTensorCPU.cpp",
|
||||
"aten/src/ATen/RegisterQuantizedCPU.cpp",
|
||||
"aten/src/ATen/RegisterSparseCPU.cpp",
|
||||
"aten/src/ATen/RegisterSparseCsrCPU.cpp",
|
||||
"aten/src/ATen/RegisterZeroTensor.cpp",
|
||||
"aten/src/ATen/RegisterCompositeImplicitAutograd.cpp",
|
||||
"aten/src/ATen/RegisterCompositeImplicitAutogradNestedTensor.cpp",
|
||||
"aten/src/ATen/RegisterCompositeExplicitAutograd.cpp",
|
||||
"aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp",
|
||||
"aten/src/ATen/RegisterMeta.cpp",
|
||||
"aten/src/ATen/RegisterSparseMeta.cpp",
|
||||
"aten/src/ATen/RegisterQuantizedMeta.cpp",
|
||||
"aten/src/ATen/RegisterNestedTensorMeta.cpp",
|
||||
"aten/src/ATen/RegisterMkldnnCPU_0.cpp",
|
||||
"aten/src/ATen/RegisterNestedTensorCPU_0.cpp",
|
||||
"aten/src/ATen/RegisterQuantizedCPU_0.cpp",
|
||||
"aten/src/ATen/RegisterSparseCPU_0.cpp",
|
||||
"aten/src/ATen/RegisterSparseCsrCPU_0.cpp",
|
||||
"aten/src/ATen/RegisterZeroTensor_0.cpp",
|
||||
"aten/src/ATen/RegisterCompositeImplicitAutograd_0.cpp",
|
||||
"aten/src/ATen/RegisterCompositeImplicitAutogradNestedTensor_0.cpp",
|
||||
"aten/src/ATen/RegisterCompositeExplicitAutograd_0.cpp",
|
||||
"aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp",
|
||||
"aten/src/ATen/RegisterMeta_0.cpp",
|
||||
"aten/src/ATen/RegisterSparseMeta_0.cpp",
|
||||
"aten/src/ATen/RegisterQuantizedMeta_0.cpp",
|
||||
"aten/src/ATen/RegisterNestedTensorMeta_0.cpp",
|
||||
"aten/src/ATen/RegisterSchema.cpp",
|
||||
"aten/src/ATen/CPUFunctions.h",
|
||||
"aten/src/ATen/CPUFunctions_inl.h",
|
||||
@ -97,11 +100,11 @@ generated_cpu_cpp = [
|
||||
generated_cuda_cpp = [
|
||||
"aten/src/ATen/CUDAFunctions.h",
|
||||
"aten/src/ATen/CUDAFunctions_inl.h",
|
||||
"aten/src/ATen/RegisterCUDA.cpp",
|
||||
"aten/src/ATen/RegisterNestedTensorCUDA.cpp",
|
||||
"aten/src/ATen/RegisterQuantizedCUDA.cpp",
|
||||
"aten/src/ATen/RegisterSparseCUDA.cpp",
|
||||
"aten/src/ATen/RegisterSparseCsrCUDA.cpp",
|
||||
"aten/src/ATen/RegisterCUDA_0.cpp",
|
||||
"aten/src/ATen/RegisterNestedTensorCUDA_0.cpp",
|
||||
"aten/src/ATen/RegisterQuantizedCUDA_0.cpp",
|
||||
"aten/src/ATen/RegisterSparseCUDA_0.cpp",
|
||||
"aten/src/ATen/RegisterSparseCsrCUDA_0.cpp",
|
||||
]
|
||||
|
||||
generate_aten(
|
||||
|
@ -353,10 +353,10 @@ def get_aten_generated_files(enabled_backends):
|
||||
# and is intentionally omitted from here
|
||||
src_files = [
|
||||
"RegisterBackendSelect.cpp",
|
||||
"RegisterCompositeImplicitAutograd.cpp",
|
||||
"RegisterCompositeImplicitAutogradNestedTensor.cpp",
|
||||
"RegisterCompositeExplicitAutograd.cpp",
|
||||
"RegisterCompositeExplicitAutogradNonFunctional.cpp",
|
||||
"RegisterCompositeImplicitAutograd_0.cpp",
|
||||
"RegisterCompositeImplicitAutogradNestedTensor_0.cpp",
|
||||
"RegisterCompositeExplicitAutograd_0.cpp",
|
||||
"RegisterCompositeExplicitAutogradNonFunctional_0.cpp",
|
||||
"CompositeViewCopyKernels.cpp",
|
||||
"RegisterSchema.cpp",
|
||||
"Declarations.yaml",
|
||||
@ -409,20 +409,22 @@ def get_aten_generated_files(enabled_backends):
|
||||
|
||||
def get_aten_derived_type_src_rules(aten_rule_name, enabled_backends):
|
||||
return [
|
||||
":{}[{}]".format(aten_rule_name, "Register" + backend + ".cpp")
|
||||
for backend in enabled_backends
|
||||
]
|
||||
":{}[{}]".format(aten_rule_name, "Register" + backend + "_0.cpp")
|
||||
for backend in enabled_backends if backend != "CPU"
|
||||
] + ([
|
||||
":{}[RegisterCPU_{}.cpp]".format(aten_rule_name, x) for x in range(4)
|
||||
] if "CPU" in enabled_backends else [])
|
||||
|
||||
def get_aten_selective_cpp_rules(aten_rule_name, enabled_backends):
|
||||
return [
|
||||
":{}[{}]".format(aten_rule_name, f)
|
||||
for f in ["RegisterCompositeImplicitAutograd.cpp", "RegisterCompositeImplicitAutogradNestedTensor.cpp", "RegisterCompositeExplicitAutograd.cpp", "RegisterCompositeExplicitAutogradNonFunctional.cpp", "RegisterSchema.cpp", "RegisterBackendSelect.cpp", "CompositeViewCopyKernels.cpp"]
|
||||
for f in ["RegisterCompositeImplicitAutograd_0.cpp", "RegisterCompositeImplicitAutogradNestedTensor_0.cpp", "RegisterCompositeExplicitAutograd_0.cpp", "RegisterCompositeExplicitAutogradNonFunctional_0.cpp", "RegisterSchema.cpp", "RegisterBackendSelect.cpp", "CompositeViewCopyKernels.cpp"]
|
||||
] + get_aten_derived_type_src_rules(aten_rule_name, enabled_backends)
|
||||
|
||||
def get_aten_derived_type_srcs(enabled_backends):
|
||||
return [
|
||||
"Register" + derived_type + ".cpp"
|
||||
for derived_type in enabled_backends
|
||||
"Register" + derived_type + "_0.cpp"
|
||||
for derived_type in enabled_backends if derived_type != "CPU"
|
||||
] + [
|
||||
derived_type + "Functions.h"
|
||||
for derived_type in enabled_backends
|
||||
@ -431,7 +433,9 @@ def get_aten_derived_type_srcs(enabled_backends):
|
||||
derived_type + "Functions_inl.h"
|
||||
for derived_type in enabled_backends
|
||||
if derived_type in PT_BACKEND_HEADERS or derived_type in get_static_dispatch_backend()
|
||||
]
|
||||
] + ([
|
||||
"RegisterCPU_{}.cpp".format(x) for x in range(4)
|
||||
] if "CPU" in enabled_backends else [])
|
||||
|
||||
def gen_aten_files(
|
||||
name,
|
||||
|
43
build.bzl
43
build.bzl
@ -202,31 +202,34 @@ GENERATED_H_CUDA = [
|
||||
]
|
||||
|
||||
GENERATED_CPP_CUDA = [
|
||||
"RegisterCUDA.cpp",
|
||||
"RegisterNestedTensorCUDA.cpp",
|
||||
"RegisterSparseCUDA.cpp",
|
||||
"RegisterSparseCsrCUDA.cpp",
|
||||
"RegisterQuantizedCUDA.cpp",
|
||||
"RegisterCUDA_0.cpp",
|
||||
"RegisterNestedTensorCUDA_0.cpp",
|
||||
"RegisterSparseCUDA_0.cpp",
|
||||
"RegisterSparseCsrCUDA_0.cpp",
|
||||
"RegisterQuantizedCUDA_0.cpp",
|
||||
]
|
||||
|
||||
GENERATED_CPP = [
|
||||
"Functions.cpp",
|
||||
"RegisterBackendSelect.cpp",
|
||||
"RegisterCPU.cpp",
|
||||
"RegisterQuantizedCPU.cpp",
|
||||
"RegisterNestedTensorCPU.cpp",
|
||||
"RegisterSparseCPU.cpp",
|
||||
"RegisterSparseCsrCPU.cpp",
|
||||
"RegisterMkldnnCPU.cpp",
|
||||
"RegisterCompositeImplicitAutograd.cpp",
|
||||
"RegisterCompositeImplicitAutogradNestedTensor.cpp",
|
||||
"RegisterZeroTensor.cpp",
|
||||
"RegisterMeta.cpp",
|
||||
"RegisterQuantizedMeta.cpp",
|
||||
"RegisterNestedTensorMeta.cpp",
|
||||
"RegisterSparseMeta.cpp",
|
||||
"RegisterCompositeExplicitAutograd.cpp",
|
||||
"RegisterCompositeExplicitAutogradNonFunctional.cpp",
|
||||
"RegisterCPU_0.cpp",
|
||||
"RegisterCPU_1.cpp",
|
||||
"RegisterCPU_2.cpp",
|
||||
"RegisterCPU_3.cpp",
|
||||
"RegisterQuantizedCPU_0.cpp",
|
||||
"RegisterNestedTensorCPU_0.cpp",
|
||||
"RegisterSparseCPU_0.cpp",
|
||||
"RegisterSparseCsrCPU_0.cpp",
|
||||
"RegisterMkldnnCPU_0.cpp",
|
||||
"RegisterCompositeImplicitAutograd_0.cpp",
|
||||
"RegisterCompositeImplicitAutogradNestedTensor_0.cpp",
|
||||
"RegisterZeroTensor_0.cpp",
|
||||
"RegisterMeta_0.cpp",
|
||||
"RegisterQuantizedMeta_0.cpp",
|
||||
"RegisterNestedTensorMeta_0.cpp",
|
||||
"RegisterSparseMeta_0.cpp",
|
||||
"RegisterCompositeExplicitAutograd_0.cpp",
|
||||
"RegisterCompositeExplicitAutogradNonFunctional_0.cpp",
|
||||
"CompositeViewCopyKernels.cpp",
|
||||
"RegisterSchema.cpp",
|
||||
"RegisterFunctionalization_0.cpp",
|
||||
|
@ -8,7 +8,7 @@ import os
|
||||
from collections import defaultdict, namedtuple, OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Literal, TYPE_CHECKING, TypeVar
|
||||
from typing import Any, Callable, Dict, Literal, TYPE_CHECKING, TypeVar
|
||||
|
||||
import yaml
|
||||
|
||||
@ -2305,9 +2305,30 @@ def gen_source_files(
|
||||
dispatch_key != DispatchKey.CompositeImplicitAutogradNestedTensor
|
||||
)
|
||||
|
||||
dispatch_definitions = get_native_function_definitions(
|
||||
fm=fm,
|
||||
grouped_native_functions=grouped_native_functions,
|
||||
register_dispatch_key_base_env = {
|
||||
"extra_cuda_headers": extra_cuda_headers
|
||||
if is_cuda_dispatch_key(dispatch_key)
|
||||
else "",
|
||||
"external_backend_headers": "",
|
||||
"dispatch_headers": dest.gen_registration_headers(
|
||||
backend_index, per_operator_headers, rocm
|
||||
),
|
||||
# ops_headers *could* be sharded, but doesn't seem necessary?
|
||||
"ops_headers": operator_headers(),
|
||||
"dispatch_helpers": (
|
||||
dest.gen_registration_helpers(backend_index)
|
||||
if gen_dispatch_helpers
|
||||
else []
|
||||
),
|
||||
}
|
||||
|
||||
def register_dispatch_key_env_callable(
|
||||
gnf: NativeFunction | NativeFunctionsGroup,
|
||||
) -> Dict[str, list[str]]:
|
||||
return {
|
||||
"dispatch_definitions": get_native_function_definitions(
|
||||
fm=fm, # noqa: F821
|
||||
grouped_native_functions=[gnf],
|
||||
dispatch_key=dispatch_key,
|
||||
backend_idx=backend_index,
|
||||
selector=selector,
|
||||
@ -2316,23 +2337,17 @@ def gen_source_files(
|
||||
skip_dispatcher_op_registration=skip_dispatcher_op_registration,
|
||||
gen_dispatch_helpers=gen_dispatch_helpers,
|
||||
)
|
||||
fm.write_with_template(
|
||||
}
|
||||
|
||||
fm.write_sharded_with_template(
|
||||
f"Register{dispatch_key}.cpp",
|
||||
"RegisterDispatchKey.cpp",
|
||||
lambda: {
|
||||
"extra_cuda_headers": extra_cuda_headers
|
||||
if is_cuda_dispatch_key(dispatch_key)
|
||||
else "",
|
||||
"external_backend_headers": "",
|
||||
"dispatch_headers": dest.gen_registration_headers(
|
||||
backend_index, per_operator_headers, rocm
|
||||
),
|
||||
"ops_headers": operator_headers(),
|
||||
"dispatch_helpers": dest.gen_registration_helpers(backend_index)
|
||||
if gen_dispatch_helpers
|
||||
else [],
|
||||
"dispatch_definitions": dispatch_definitions,
|
||||
},
|
||||
grouped_native_functions,
|
||||
key_fn=lambda x: x.root_name,
|
||||
env_callable=register_dispatch_key_env_callable,
|
||||
num_shards=4 if dispatch_key == DispatchKey.CPU else 1,
|
||||
base_env=register_dispatch_key_base_env,
|
||||
sharded_keys={"dispatch_definitions"},
|
||||
)
|
||||
|
||||
for g in structured_native_functions:
|
||||
|
@ -209,6 +209,29 @@ class FileManager:
|
||||
num_shards: int,
|
||||
base_env: dict[str, Any] | None = None,
|
||||
sharded_keys: set[str],
|
||||
) -> None:
|
||||
self.write_sharded_with_template(
|
||||
filename,
|
||||
filename,
|
||||
items,
|
||||
key_fn=key_fn,
|
||||
env_callable=env_callable,
|
||||
num_shards=num_shards,
|
||||
base_env=base_env,
|
||||
sharded_keys=sharded_keys,
|
||||
)
|
||||
|
||||
def write_sharded_with_template(
|
||||
self,
|
||||
filename: str,
|
||||
template_fn: str,
|
||||
items: Iterable[T],
|
||||
*,
|
||||
key_fn: Callable[[T], str],
|
||||
env_callable: Callable[[T], dict[str, list[str]]],
|
||||
num_shards: int,
|
||||
base_env: dict[str, Any] | None = None,
|
||||
sharded_keys: set[str],
|
||||
) -> None:
|
||||
everything: dict[str, Any] = {"shard_id": "Everything"}
|
||||
shards: list[dict[str, Any]] = [
|
||||
@ -256,7 +279,9 @@ class FileManager:
|
||||
for shard in all_shards:
|
||||
shard_id = shard["shard_id"]
|
||||
self.write_with_template(
|
||||
f"{base_filename}{shard_id}{extension}", filename, lambda: shard
|
||||
f"{base_filename}{shard_id}{extension}",
|
||||
template_fn,
|
||||
lambda: shard,
|
||||
)
|
||||
|
||||
# filenames is used to track compiled files, but FooEverything.cpp isn't meant to be compiled
|
||||
|
Reference in New Issue
Block a user