mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Preferring dash over underscore in command-line options. Add `--command-arg-name` to the argument parser. The old arguments with underscores `--command_arg_name` are kept for backward compatibility.
Both dashes and underscores are used in the PyTorch codebase. Some argument parsers only have dashes or only have underscores in arguments. For example, the `torchrun` utility for distributed training only accepts underscore arguments (e.g., `--master_port`). The dashes are more common in other command-line tools. And it looks to be the default choice in the Python standard library:
`argparse.BooleanOptionalAction`: 4a9dff0e5a/Lib/argparse.py (L893-L895)
```python
class BooleanOptionalAction(Action):
def __init__(...):
if option_string.startswith('--'):
option_string = '--no-' + option_string[2:]
_option_strings.append(option_string)
```
It adds `--no-argname`, not `--no_argname`. Also typing `_` need to press the shift or the caps-lock key than `-`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94505
Approved by: https://github.com/ezyang, https://github.com/seemethere
784 lines
27 KiB
Python
784 lines
27 KiB
Python
import argparse
|
|
import os
|
|
import pathlib
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass
|
|
from typing import Callable, Dict, List, Optional, Sequence, TextIO, Tuple, Union
|
|
|
|
import yaml
|
|
|
|
# Parse native_functions.yaml into a sequence of NativeFunctions and Backend Indices.
|
|
from torchgen import dest
|
|
from torchgen.api import cpp as aten_cpp
|
|
from torchgen.api.types import CppSignature, CppSignatureGroup, CType, NamedCType
|
|
from torchgen.context import method_with_native_function, with_native_function_and_index
|
|
from torchgen.executorch.api import et_cpp
|
|
from torchgen.executorch.api.custom_ops import (
|
|
ComputeNativeFunctionStub,
|
|
gen_custom_ops_registration,
|
|
)
|
|
from torchgen.executorch.api.types import ExecutorchCppSignature
|
|
from torchgen.executorch.api.unboxing import Unboxing
|
|
from torchgen.gen import (
|
|
get_custom_build_selector,
|
|
get_native_function_declarations,
|
|
get_native_function_schema_registrations,
|
|
LineLoader,
|
|
parse_native_yaml,
|
|
ParsedYaml,
|
|
)
|
|
from torchgen.model import (
|
|
BackendIndex,
|
|
BackendMetadata,
|
|
DispatchKey,
|
|
is_cuda_dispatch_key,
|
|
Location,
|
|
NativeFunction,
|
|
NativeFunctionsGroup,
|
|
OperatorName,
|
|
Variant,
|
|
)
|
|
from torchgen.selective_build.selector import SelectiveBuilder
|
|
from torchgen.utils import (
|
|
context,
|
|
FileManager,
|
|
make_file_manager,
|
|
mapMaybe,
|
|
NamespaceHelper,
|
|
)
|
|
|
|
|
|
def static_dispatch(
|
|
sig: Union[CppSignature, ExecutorchCppSignature],
|
|
f: NativeFunction,
|
|
backend_indices: List[BackendIndex],
|
|
) -> str:
|
|
"""
|
|
For a given `NativeFunction`, find out the corresponding native function and dispatch to it. If zero or more than one
|
|
native function exists, error out. A simplified version of register_dispatch_key.py
|
|
Arguments:
|
|
sig: A CppSignature for this native function we want to use.
|
|
f: NativeFunction to generate static dispatch.
|
|
backend_indices: All available backends.
|
|
Return:
|
|
C++ code to call backend-specific functions, e.g., "return at::native::add(self, other, scale);"
|
|
"""
|
|
if len(backend_indices) == 0 or f.manual_kernel_registration:
|
|
return ""
|
|
|
|
backends = [b for b in backend_indices if b.has_kernel(f)]
|
|
static_block = None
|
|
if len(backends) == 1:
|
|
backend_metadata = backends[0].get_kernel(f)
|
|
if backend_metadata:
|
|
args = ", ".join(a.name for a in sig.arguments())
|
|
# Here we are assuming there's no difference between CppSignature and NativeSignature for Executorch.
|
|
static_block = f"return ::{backend_metadata.cpp_namespace}::{backend_metadata.kernel}({args});"
|
|
else:
|
|
static_block = f"""
|
|
ET_ASSERT_UNREACHABLE_MSG("The number of native function(s) binding to {f.func.name} is {len(backends)}.");
|
|
"""
|
|
return f"""
|
|
// {f.namespace}::{f.func}
|
|
TORCH_API inline {sig.decl()} {{
|
|
{static_block}
|
|
}}
|
|
"""
|
|
|
|
|
|
# Generates Functions.h, which provides the functional public C++ API,
|
|
# and the scaffolding to call into the dispatcher from these functions.
|
|
@dataclass(frozen=True)
|
|
class ComputeFunction:
|
|
static_dispatch_backend_indices: List[BackendIndex]
|
|
|
|
selector: SelectiveBuilder
|
|
|
|
use_aten_lib: bool
|
|
|
|
is_custom_op: Callable[[NativeFunction], bool]
|
|
|
|
@method_with_native_function
|
|
def __call__(self, f: NativeFunction) -> Optional[str]:
|
|
if not self.selector.is_root_operator(f"{f.namespace}::{f.func.name}"):
|
|
return None
|
|
if Variant.function not in f.variants:
|
|
return None
|
|
sig: Union[CppSignature, ExecutorchCppSignature] = (
|
|
CppSignatureGroup.from_native_function(
|
|
f, method=False, fallback_binding=f.manual_cpp_binding
|
|
).most_faithful_signature()
|
|
if self.use_aten_lib
|
|
else ExecutorchCppSignature.from_native_function(f)
|
|
)
|
|
if self.use_aten_lib and not self.is_custom_op(f):
|
|
comma = ", "
|
|
|
|
return f"""
|
|
// {f.namespace}::{f.func}
|
|
TORCH_API inline {sig.decl()} {{
|
|
return at::{sig.name()}({comma.join(e.name for e in sig.arguments())});
|
|
}}
|
|
"""
|
|
|
|
else:
|
|
return static_dispatch(
|
|
sig,
|
|
f,
|
|
backend_indices=self.static_dispatch_backend_indices,
|
|
)
|
|
|
|
|
|
# Generates RegisterCodegenUnboxedKernels.cpp.
|
|
@dataclass(frozen=True)
|
|
class ComputeCodegenUnboxedKernels:
|
|
selector: SelectiveBuilder
|
|
|
|
use_aten_lib: bool
|
|
|
|
@method_with_native_function
|
|
def __call__(self, f: NativeFunction) -> str:
|
|
if not self.selector.is_root_operator(f"{f.namespace}::{f.func.name}"):
|
|
return ""
|
|
sig: Union[CppSignature, ExecutorchCppSignature]
|
|
argument_type_gen: Callable[..., NamedCType]
|
|
return_type_gen: Callable[..., CType]
|
|
if self.use_aten_lib:
|
|
sig = CppSignatureGroup.from_native_function(
|
|
f, method=False, fallback_binding=f.manual_cpp_binding
|
|
).most_faithful_signature()
|
|
argument_type_gen = aten_cpp.argumenttype_type
|
|
return_type_gen = aten_cpp.returns_type
|
|
else:
|
|
sig = ExecutorchCppSignature.from_native_function(f)
|
|
argument_type_gen = et_cpp.argumenttype_type
|
|
return_type_gen = et_cpp.returns_type
|
|
# parse arguments into C++ code
|
|
binding_list, code_list = Unboxing(
|
|
argument_type_gen=argument_type_gen
|
|
).convert_arguments(sig.arguments())
|
|
|
|
# for each C++ argument, generate the conversion code
|
|
code_connector = "\n\t"
|
|
arg_connector = ", "
|
|
|
|
args_str = f"{arg_connector.join(e.name for e in binding_list)}"
|
|
|
|
if len(f.func.returns) == 0:
|
|
if len(f.func.arguments.out) == 0:
|
|
raise Exception(
|
|
f"Can't handle native function {f.func} with no returns and no out yet."
|
|
)
|
|
out = f.func.arguments.out[0]
|
|
return_assignment = f"""stack[{len(binding_list)}] = &{out.name};"""
|
|
ret_prefix = ""
|
|
else:
|
|
if len(f.func.arguments.out) == 0:
|
|
return_assignment = (
|
|
f"""*stack[{len(binding_list)}] = EValue(result_);"""
|
|
)
|
|
ret_prefix = return_type_gen(f.func.returns).cpp_type() + " result_ = "
|
|
else:
|
|
return_assignment = ""
|
|
ret_prefix = ""
|
|
|
|
return f"""
|
|
Operator(
|
|
"{f.namespace}::{f.func.name}",
|
|
[](EValue** stack) {{
|
|
{code_connector.join(code_list)}
|
|
|
|
EXECUTORCH_SCOPE_PROF("native_call_{f.func.name}");
|
|
{ret_prefix}torch::executor::{f.namespace}::{sig.name()}({args_str});
|
|
|
|
{return_assignment}
|
|
}}
|
|
),
|
|
"""
|
|
|
|
|
|
def gen_unboxing(
|
|
*,
|
|
native_functions: Sequence[NativeFunction],
|
|
cpu_fm: FileManager,
|
|
selector: SelectiveBuilder,
|
|
use_aten_lib: bool,
|
|
) -> None:
|
|
def key_func(fn: Union[NativeFunction, NativeFunctionsGroup]) -> str:
|
|
return fn.root_name
|
|
|
|
cpu_fm.write_sharded(
|
|
"RegisterCodegenUnboxedKernels.cpp",
|
|
native_functions,
|
|
key_fn=key_func,
|
|
env_callable=lambda fn: {
|
|
"unboxed_ops": [ComputeCodegenUnboxedKernels(selector, use_aten_lib)(fn)],
|
|
},
|
|
num_shards=1,
|
|
sharded_keys={"unboxed_ops"},
|
|
)
|
|
|
|
|
|
@with_native_function_and_index
|
|
def compute_native_function_declaration(
|
|
g: Union[NativeFunctionsGroup, NativeFunction], backend_index: BackendIndex
|
|
) -> List[str]:
|
|
assert isinstance(g, NativeFunction)
|
|
sig = ExecutorchCppSignature.from_native_function(f=g)
|
|
metadata = backend_index.get_kernel(g)
|
|
if metadata is None:
|
|
return []
|
|
prefix = "static" if backend_index.external else "TORCH_API"
|
|
return [f"{prefix} {sig.decl(name=metadata.kernel)};"]
|
|
|
|
|
|
def gen_functions_declarations(
|
|
*,
|
|
native_functions: Sequence[NativeFunction],
|
|
static_dispatch_idx: List[BackendIndex],
|
|
selector: SelectiveBuilder,
|
|
use_aten_lib: bool,
|
|
custom_ops_native_functions: Optional[Sequence[NativeFunction]] = None,
|
|
) -> str:
|
|
"""
|
|
Generates namespace separated C++ function API inline declaration/definitions.
|
|
Native functions are grouped by namespaces and the generated code is wrapped inside
|
|
namespace blocks.
|
|
|
|
E.g., for `custom_1::foo.out` in yaml file we will generate a C++ API as a symbol
|
|
in `torch::executor::custom_1::foo_out`. This way we avoid symbol conflict when
|
|
the other `custom_2::foo.out` is available.
|
|
"""
|
|
ns_grouped_functions = defaultdict(list)
|
|
for native_function in native_functions:
|
|
ns_grouped_functions[native_function.namespace].append(native_function)
|
|
functions_declarations = ""
|
|
newline = "\n"
|
|
for namespace in ns_grouped_functions:
|
|
ns_helper = NamespaceHelper(
|
|
namespace_str=namespace,
|
|
entity_name="",
|
|
max_level=3,
|
|
)
|
|
declarations = list(
|
|
mapMaybe(
|
|
ComputeFunction(
|
|
static_dispatch_backend_indices=static_dispatch_idx,
|
|
selector=selector,
|
|
use_aten_lib=use_aten_lib,
|
|
is_custom_op=lambda f: custom_ops_native_functions is not None
|
|
and f in custom_ops_native_functions,
|
|
),
|
|
ns_grouped_functions[namespace],
|
|
)
|
|
)
|
|
functions_declarations += f"""
|
|
{ns_helper.prologue}
|
|
{newline.join(declarations)}
|
|
{ns_helper.epilogue}
|
|
"""
|
|
return functions_declarations
|
|
|
|
|
|
def gen_headers(
|
|
*,
|
|
native_functions: Sequence[NativeFunction],
|
|
custom_ops_native_functions: Sequence[NativeFunction],
|
|
static_dispatch_idx: List[BackendIndex],
|
|
selector: SelectiveBuilder,
|
|
backend_indices: Dict[DispatchKey, BackendIndex],
|
|
cpu_fm: FileManager,
|
|
use_aten_lib: bool,
|
|
) -> None:
|
|
aten_headers = ["#include <ATen/Functions.h>"]
|
|
if custom_ops_native_functions:
|
|
cpu_fm.write_with_template(
|
|
"CustomOpsNativeFunctions.h",
|
|
"NativeFunctions.h",
|
|
lambda: {
|
|
"nativeFunctions_declarations": get_native_function_declarations(
|
|
grouped_native_functions=custom_ops_native_functions,
|
|
backend_indices=backend_indices,
|
|
native_function_decl_gen=dest.compute_native_function_declaration,
|
|
),
|
|
},
|
|
)
|
|
aten_headers.append('#include "CustomOpsNativeFunctions.h"')
|
|
cpu_fm.write(
|
|
"Functions.h",
|
|
lambda: {
|
|
"static_dispatch_extra_headers": aten_headers
|
|
if use_aten_lib
|
|
else ['#include "NativeFunctions.h"'],
|
|
"Functions_declarations": gen_functions_declarations(
|
|
native_functions=native_functions,
|
|
static_dispatch_idx=static_dispatch_idx,
|
|
selector=selector,
|
|
use_aten_lib=use_aten_lib,
|
|
custom_ops_native_functions=custom_ops_native_functions,
|
|
),
|
|
},
|
|
)
|
|
|
|
cpu_fm.write(
|
|
"NativeFunctions.h",
|
|
lambda: {
|
|
"nativeFunctions_declarations": get_native_function_declarations(
|
|
grouped_native_functions=native_functions,
|
|
backend_indices=backend_indices,
|
|
native_function_decl_gen=dest.compute_native_function_declaration
|
|
if use_aten_lib
|
|
else compute_native_function_declaration,
|
|
),
|
|
},
|
|
)
|
|
|
|
|
|
def gen_custom_ops(
|
|
*,
|
|
native_functions: Sequence[NativeFunction],
|
|
selector: SelectiveBuilder,
|
|
backend_indices: Dict[DispatchKey, BackendIndex],
|
|
cpu_fm: FileManager,
|
|
rocm: bool,
|
|
) -> None:
|
|
dispatch_key = DispatchKey.CPU
|
|
backend_index = backend_indices[dispatch_key]
|
|
(
|
|
anonymous_definition,
|
|
static_init_dispatch_registrations,
|
|
) = gen_custom_ops_registration(
|
|
native_functions=native_functions,
|
|
selector=selector,
|
|
backend_index=backend_index,
|
|
rocm=rocm,
|
|
)
|
|
cpu_fm.write_with_template(
|
|
f"Register{dispatch_key}CustomOps.cpp",
|
|
"RegisterDispatchKeyCustomOps.cpp",
|
|
lambda: {
|
|
"ops_headers": '#include "CustomOpsNativeFunctions.h"',
|
|
"DispatchKey": dispatch_key,
|
|
"dispatch_namespace": dispatch_key.lower(),
|
|
"dispatch_namespaced_definitions": "",
|
|
"dispatch_anonymous_definitions": anonymous_definition,
|
|
"static_init_dispatch_registrations": static_init_dispatch_registrations,
|
|
},
|
|
)
|
|
cpu_fm.write_with_template(
|
|
f"Register{dispatch_key}Stub.cpp",
|
|
"RegisterDispatchKeyCustomOps.cpp",
|
|
lambda: {
|
|
"ops_headers": "",
|
|
"DispatchKey": dispatch_key,
|
|
"dispatch_namespace": dispatch_key.lower(),
|
|
"dispatch_namespaced_definitions": "",
|
|
"dispatch_anonymous_definitions": list(
|
|
mapMaybe(ComputeNativeFunctionStub(), native_functions)
|
|
),
|
|
"static_init_dispatch_registrations": static_init_dispatch_registrations,
|
|
},
|
|
)
|
|
|
|
(
|
|
aten_schema_registrations,
|
|
schema_registrations,
|
|
) = get_native_function_schema_registrations(
|
|
native_functions=native_functions,
|
|
schema_selector=selector,
|
|
)
|
|
cpu_fm.write(
|
|
"RegisterSchema.cpp",
|
|
lambda: {
|
|
"schema_registrations": schema_registrations,
|
|
"aten_schema_registrations": aten_schema_registrations,
|
|
},
|
|
)
|
|
|
|
|
|
def translate_native_yaml(
|
|
tags_yaml_path: str,
|
|
aten_yaml_path: str,
|
|
native_yaml_path: Optional[str],
|
|
use_aten_lib: bool,
|
|
out_file: TextIO,
|
|
) -> None:
|
|
"""Translates Executorch DSL dialect to use the same syntax as
|
|
native_functions.yaml. The major difference is that Executorch DSL dialect
|
|
supports "op" key, where it refers to the operator name in native_functions.yaml.
|
|
|
|
For example, a functions.yaml may have the following entry:
|
|
|
|
- op: add.out
|
|
...
|
|
|
|
It needs to be translated to the following:
|
|
|
|
- func: add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
|
|
...
|
|
|
|
We go in aten_yaml_path and find the operator schema for "add.out" and add it
|
|
to the original functions.yaml. We also add required field "variants", where for
|
|
Executorch it will always be "function".
|
|
|
|
For ATen mode we don't have to do the translation because native_yaml_path is
|
|
the same as native_functions.yaml.
|
|
|
|
Args:
|
|
tags_yaml_path: Path to a tags.yaml file to satisfy codegen parsing.
|
|
It is not optional.
|
|
aten_yaml_path: Path to ATen operator yaml file native_functions.yaml.
|
|
native_yaml_path: Path to a functions.yaml file to parse.
|
|
If the path does not exist in the filesystem, it is treated as an
|
|
empty file. If `custom_ops_yaml_path` exists, the contents of that
|
|
file are appended to the yaml input to be parsed.
|
|
use_aten_lib: We use this flag to determine if we want to generate native
|
|
functions. In ATen mode we should generate out= variants.
|
|
out_file: The IO object that we are writing into.
|
|
Returns:
|
|
None
|
|
"""
|
|
if use_aten_lib:
|
|
with open(aten_yaml_path, "r") as aten_yaml:
|
|
out_file.writelines(aten_yaml.readlines())
|
|
return
|
|
aten_parsed_yaml = parse_native_yaml(
|
|
aten_yaml_path,
|
|
tags_yaml_path,
|
|
None,
|
|
skip_native_fns_gen=False,
|
|
)
|
|
aten_native_functions = aten_parsed_yaml.native_functions
|
|
schema_dict = {
|
|
f"{f.namespace}::{f.func.name}": str(f.func) for f in aten_native_functions
|
|
}
|
|
if (
|
|
not native_yaml_path
|
|
or not os.path.exists(native_yaml_path)
|
|
or os.stat(native_yaml_path).st_size == 0
|
|
):
|
|
return
|
|
with open(native_yaml_path, "r") as native_yaml:
|
|
native_es = yaml.load(native_yaml, Loader=LineLoader)
|
|
if not native_es:
|
|
return
|
|
for e in native_es:
|
|
assert isinstance(e.get("__line__"), int), e
|
|
loc = Location(native_yaml_path, e.pop("__line__"))
|
|
with context(lambda: f"in {loc}:\n "):
|
|
if "variants" not in e:
|
|
e["variants"] = "function"
|
|
if "func" in e:
|
|
continue
|
|
assert isinstance(e.get("op"), str), e
|
|
opname = e.pop("op")
|
|
if "::" not in opname:
|
|
opname = "aten::" + opname
|
|
assert opname in schema_dict
|
|
e["func"] = schema_dict.get(opname)
|
|
yaml.dump(native_es, out_file, width=1000)
|
|
|
|
|
|
def convert_backend_indices(
|
|
bs: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]]
|
|
) -> Dict[DispatchKey, BackendIndex]:
|
|
indices: Dict[DispatchKey, BackendIndex] = defaultdict(
|
|
lambda: BackendIndex(
|
|
dispatch_key=DispatchKey.Undefined,
|
|
use_out_as_primary=True,
|
|
external=False,
|
|
device_guard=False,
|
|
index={},
|
|
)
|
|
)
|
|
for k, v in bs.items():
|
|
indices[k] = BackendIndex(
|
|
dispatch_key=k,
|
|
use_out_as_primary=True,
|
|
external=False,
|
|
# Only cuda-like devices in tree require device guards
|
|
device_guard=is_cuda_dispatch_key(k),
|
|
index=v,
|
|
)
|
|
return indices
|
|
|
|
|
|
def parse_yaml(
|
|
path: Optional[str],
|
|
tags_yaml_path: str,
|
|
function_filter: Callable[[NativeFunction], bool],
|
|
skip_native_fns_gen: bool = False,
|
|
) -> Tuple[
|
|
List[NativeFunction], Dict[DispatchKey, Dict[OperatorName, BackendMetadata]]
|
|
]:
|
|
if path and os.path.exists(path) and os.stat(path).st_size > 0:
|
|
parsed_yaml = parse_native_yaml(
|
|
path,
|
|
tags_yaml_path,
|
|
None,
|
|
skip_native_fns_gen=skip_native_fns_gen,
|
|
)
|
|
native_functions = list(filter(function_filter, parsed_yaml.native_functions))
|
|
op_names = [f.func.name for f in native_functions]
|
|
|
|
def map_index(
|
|
m: Dict[OperatorName, BackendMetadata]
|
|
) -> Dict[OperatorName, BackendMetadata]:
|
|
return {op: m[op] for op in m if op in op_names}
|
|
|
|
backend_indices = dict(
|
|
(
|
|
k,
|
|
map_index(b.index),
|
|
)
|
|
for (k, b) in parsed_yaml.backend_indices.items()
|
|
)
|
|
return native_functions, backend_indices
|
|
else:
|
|
return [], {}
|
|
|
|
|
|
def parse_yaml_files(
|
|
tags_yaml_path: str,
|
|
aten_yaml_path: str,
|
|
native_yaml_path: Optional[str],
|
|
custom_ops_yaml_path: Optional[str],
|
|
selector: SelectiveBuilder,
|
|
use_aten_lib: bool,
|
|
) -> Tuple[ParsedYaml, Optional[ParsedYaml]]:
|
|
"""Parses functions.yaml and custom_ops.yaml files.
|
|
|
|
Args:
|
|
tags_yaml_path: Path to a tags.yaml file to satisfy codegen parsing.
|
|
It is not optional.
|
|
aten_yaml_path: Path to ATen operator yaml file native_functions.yaml.
|
|
native_yaml_path: Path to a functions.yaml file to parse.
|
|
If the path does not exist in the filesystem, it is treated as an
|
|
empty file. If `custom_ops_yaml_path` exists, the contents of that
|
|
file are appended to the yaml input to be parsed.
|
|
custom_ops_yaml_path: Path to a custom_ops.yaml file to parse. If
|
|
the path does not exist in the filesystem, it is ignored.
|
|
selector: For selective build.
|
|
use_aten_lib: We use this flag to determine if we want to generate native
|
|
functions. In ATen mode we should generate out= variants.
|
|
Returns:
|
|
A tuple with two elements:
|
|
[0]: The parsed results of concatenating the contents of
|
|
`native_yaml_path` and `custom_ops_yaml_path`.
|
|
[1]: The parsed results of the contents of `custom_ops_yaml_path`, if
|
|
present. If not present, None.
|
|
"""
|
|
import tempfile
|
|
|
|
# only include selected ops, this is because we want to avoid
|
|
def function_filter(f: NativeFunction) -> bool:
|
|
return selector.is_native_function_selected(f)
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
translated_yaml_path = os.path.join(tmpdirname, "translated.yaml")
|
|
with open(translated_yaml_path, "w") as translated:
|
|
translate_native_yaml(
|
|
tags_yaml_path,
|
|
aten_yaml_path,
|
|
native_yaml_path,
|
|
use_aten_lib,
|
|
translated,
|
|
)
|
|
translated_functions, translated_backend_indices = parse_yaml(
|
|
translated_yaml_path, tags_yaml_path, function_filter, not use_aten_lib
|
|
)
|
|
custom_ops_functions, custom_ops_backend_indices = parse_yaml(
|
|
custom_ops_yaml_path, tags_yaml_path, function_filter, True
|
|
)
|
|
|
|
combined_functions = translated_functions + custom_ops_functions
|
|
combined_backend_indices: Dict[
|
|
DispatchKey, Dict[OperatorName, BackendMetadata]
|
|
] = defaultdict(dict)
|
|
combined_backend_indices.update(translated_backend_indices)
|
|
|
|
for dk in custom_ops_backend_indices:
|
|
if dk not in combined_backend_indices:
|
|
combined_backend_indices.update({dk: custom_ops_backend_indices[dk]})
|
|
else:
|
|
combined_backend_indices[dk] = {
|
|
**combined_backend_indices[dk],
|
|
**custom_ops_backend_indices[dk],
|
|
}
|
|
|
|
combined_yaml = ParsedYaml(
|
|
combined_functions, convert_backend_indices(combined_backend_indices)
|
|
)
|
|
custom_ops_parsed_yaml = ParsedYaml(
|
|
custom_ops_functions, convert_backend_indices(custom_ops_backend_indices)
|
|
)
|
|
return combined_yaml, custom_ops_parsed_yaml
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(description="Generate operator source files")
|
|
# Although we don't refer to --source-path directly, make_file_manager()
|
|
# expects it to point to a directory that contains a templates/ subdirectory
|
|
# containing the file templates.
|
|
parser.add_argument(
|
|
"-s",
|
|
"--source-path",
|
|
help="path to source directory for kernel templates",
|
|
)
|
|
parser.add_argument(
|
|
"--functions-yaml-path",
|
|
"--functions_yaml_path",
|
|
help="path to the functions.yaml file to use. Optional, but at least "
|
|
"one of --functions-yaml-path and --custom-ops-yaml-path must be "
|
|
"specified.",
|
|
)
|
|
parser.add_argument(
|
|
"--custom-ops-yaml-path",
|
|
"--custom_ops_yaml_path",
|
|
help="path to the custom_ops.yaml file to use. Optional, but at least "
|
|
"one of --functions-yaml-path and --custom-ops-yaml-path must be "
|
|
"specified.",
|
|
)
|
|
parser.add_argument(
|
|
"--aten-yaml-path",
|
|
"--aten_yaml_path",
|
|
help="path to native_functions.yaml file.",
|
|
)
|
|
# Note that make_file_manager() also looks at --install-dir.
|
|
parser.add_argument(
|
|
"-d",
|
|
"--install-dir",
|
|
"--install_dir",
|
|
help="output directory",
|
|
default="build/generated",
|
|
)
|
|
parser.add_argument(
|
|
"-o",
|
|
"--output-dependencies",
|
|
help="output a list of dependencies into the given file and exit",
|
|
)
|
|
# Although we don't refer to --dry-run directly, make_file_manager() looks
|
|
# for it.
|
|
parser.add_argument(
|
|
"--dry-run",
|
|
action="store_true",
|
|
help="run without writing any files (still updates outputs)",
|
|
)
|
|
parser.add_argument(
|
|
"--static-dispatch-backend",
|
|
"--static_dispatch_backend",
|
|
nargs="*",
|
|
help="generate static dispatch code for the specific backend (if set)",
|
|
)
|
|
parser.add_argument(
|
|
"--op-registration-whitelist",
|
|
"--op_registration_whitelist",
|
|
nargs="*",
|
|
help="filter op registrations by the whitelist (if set); "
|
|
"each item is `namespace`::`operator name` without overload name; "
|
|
"e.g.: aten::empty aten::conv2d ...",
|
|
)
|
|
parser.add_argument(
|
|
"--op-selection-yaml-path",
|
|
"--op_selection_yaml_path",
|
|
help="Provide a path to the operator selection (for custom build) YAML "
|
|
"that contains the information about the set of selected operators "
|
|
"and their categories (training, ...). Each operator is either a "
|
|
"full operator name with overload or just a bare operator name. "
|
|
"The operator names also contain the namespace prefix (e.g. aten::)",
|
|
)
|
|
parser.add_argument(
|
|
"--tags-path",
|
|
help="Path to tags.yaml. Required by yaml parsing in codegen system.",
|
|
)
|
|
parser.add_argument(
|
|
"--rocm",
|
|
action="store_true",
|
|
help="reinterpret CUDA as ROCm/HIP and adjust filepaths accordingly",
|
|
)
|
|
parser.add_argument(
|
|
"--use-aten-lib",
|
|
"--use_aten_lib",
|
|
action="store_true",
|
|
help="a boolean flag to indicate whether we use ATen kernels or not, in the future this flag will be per "
|
|
"operator",
|
|
)
|
|
parser.add_argument(
|
|
"--generate",
|
|
type=str,
|
|
nargs="*",
|
|
choices=["headers", "sources"],
|
|
default=["headers", "sources"],
|
|
help="Generate only a subset of files",
|
|
)
|
|
options = parser.parse_args()
|
|
assert options.tags_path, "tags.yaml is required by codegen yaml parsing."
|
|
|
|
selector = get_custom_build_selector(
|
|
options.op_registration_whitelist,
|
|
options.op_selection_yaml_path,
|
|
)
|
|
|
|
parsed_yaml, custom_ops_parsed_yaml = parse_yaml_files(
|
|
aten_yaml_path=options.aten_yaml_path,
|
|
tags_yaml_path=options.tags_path,
|
|
native_yaml_path=options.functions_yaml_path,
|
|
custom_ops_yaml_path=options.custom_ops_yaml_path,
|
|
selector=selector,
|
|
use_aten_lib=options.use_aten_lib,
|
|
)
|
|
native_functions, backend_indices = (
|
|
parsed_yaml.native_functions,
|
|
parsed_yaml.backend_indices,
|
|
)
|
|
custom_ops_native_functions = (
|
|
custom_ops_parsed_yaml.native_functions if custom_ops_parsed_yaml else []
|
|
)
|
|
|
|
cpu_fm = make_file_manager(options=options)
|
|
|
|
static_dispatch_idx: List[BackendIndex] = [backend_indices[DispatchKey.CPU]]
|
|
|
|
if "headers" in options.generate:
|
|
gen_headers(
|
|
native_functions=native_functions,
|
|
custom_ops_native_functions=custom_ops_native_functions,
|
|
static_dispatch_idx=static_dispatch_idx,
|
|
selector=selector,
|
|
backend_indices=backend_indices,
|
|
cpu_fm=cpu_fm,
|
|
use_aten_lib=options.use_aten_lib,
|
|
)
|
|
|
|
if "sources" in options.generate:
|
|
gen_unboxing(
|
|
native_functions=native_functions,
|
|
cpu_fm=cpu_fm,
|
|
selector=selector,
|
|
use_aten_lib=options.use_aten_lib,
|
|
)
|
|
if custom_ops_native_functions:
|
|
gen_custom_ops(
|
|
native_functions=custom_ops_native_functions,
|
|
selector=selector,
|
|
backend_indices=backend_indices,
|
|
cpu_fm=cpu_fm,
|
|
rocm=options.rocm,
|
|
)
|
|
|
|
if options.output_dependencies:
|
|
depfile_path = pathlib.Path(options.output_dependencies).resolve()
|
|
depfile_name = depfile_path.name
|
|
depfile_stem = depfile_path.stem
|
|
|
|
for fm, prefix in [
|
|
(cpu_fm, ""),
|
|
]:
|
|
varname = prefix + depfile_stem
|
|
path = depfile_path.parent / (prefix + depfile_name)
|
|
fm.write_outputs(varname, str(path))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|