mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Preferring dash over underscore in command-line options. Add `--command-arg-name` to the argument parser. The old arguments with underscores `--command_arg_name` are kept for backward compatibility.
Both dashes and underscores are used in the PyTorch codebase. Some argument parsers only have dashes or only have underscores in arguments. For example, the `torchrun` utility for distributed training only accepts underscore arguments (e.g., `--master_port`). The dashes are more common in other command-line tools. And it looks to be the default choice in the Python standard library:
`argparse.BooleanOptionalAction`: 4a9dff0e5a/Lib/argparse.py (L893-L895)
```python
class BooleanOptionalAction(Action):
def __init__(...):
if option_string.startswith('--'):
option_string = '--no-' + option_string[2:]
_option_strings.append(option_string)
```
It adds `--no-argname`, not `--no_argname`. Also typing `_` need to press the shift or the caps-lock key than `-`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94505
Approved by: https://github.com/ezyang, https://github.com/seemethere
99 lines
3.7 KiB
Python
99 lines
3.7 KiB
Python
#!/usr/bin/env python3
|
|
import argparse
|
|
from torchgen.gen import parse_native_yaml, FileManager
|
|
import torchgen.model as model
|
|
|
|
def num_leading_spaces(line: str) -> int:
|
|
return len(line) - len(line.lstrip())
|
|
def deindent(code: str) -> str:
|
|
lines = code.split('\n')
|
|
min_leading_spaces = min(map(num_leading_spaces, lines))
|
|
lines = [line[min_leading_spaces:] for line in lines]
|
|
return '\n'.join(lines)
|
|
|
|
|
|
def gen_external(native_functions_path, tags_path, external_path):
|
|
native_functions = parse_native_yaml(native_functions_path, tags_path)
|
|
func_decls = []
|
|
func_registrations = []
|
|
for func in native_functions:
|
|
schema = func.func
|
|
name = schema.name.name.base
|
|
args = schema.arguments
|
|
# Only supports extern calls for functions with out variants
|
|
if not schema.is_out_fn():
|
|
continue
|
|
|
|
# Doesn't currently support functions with more than one out parameter
|
|
if len(args.out) > 1:
|
|
continue
|
|
|
|
# Doesn't currently support kwarg arguments
|
|
if len(args.pre_tensor_options_kwarg_only) > 0 or len(args.post_tensor_options_kwarg_only) > 0:
|
|
continue
|
|
self_arg = [args.self_arg.argument] if args.self_arg is not None else []
|
|
args = list(args.pre_self_positional) + self_arg + list(args.post_self_positional)
|
|
tensor_args = [arg for arg in args if isinstance(arg.type, model.BaseType) and arg.type.name == model.BaseTy.Tensor]
|
|
if len(tensor_args) != len(args):
|
|
continue
|
|
|
|
arg_names = [None] * len(args)
|
|
|
|
tensor_decls = []
|
|
for idx, arg in enumerate(tensor_args):
|
|
s = f"const at::Tensor& {arg.name} = tensors[{idx + 1}];"
|
|
tensor_decls.append(s)
|
|
arg_names[idx] = arg.name
|
|
nl = '\n'
|
|
|
|
# print(tensor_decls, name, arg_names)
|
|
func_decl = f"""\
|
|
void nnc_aten_{name}(
|
|
int64_t bufs_num,
|
|
void** buf_data,
|
|
int64_t* buf_ranks,
|
|
int64_t* buf_dims,
|
|
int64_t* buf_strides,
|
|
int8_t* buf_dtypes,
|
|
int64_t args_num,
|
|
int64_t* extra_args) {{
|
|
std::vector<at::Tensor> tensors =
|
|
constructTensors(bufs_num, buf_data, buf_ranks, buf_dims, buf_strides, buf_dtypes);
|
|
at::Tensor& r = tensors[0];
|
|
{nl.join(tensor_decls)}
|
|
try {{
|
|
at::{name}_out({', '.join(['r'] + arg_names)});
|
|
}} catch (...) {{
|
|
}}
|
|
}}"""
|
|
func_registration = f"""\
|
|
const static RegisterNNCExternalFunction nnc_{name}(
|
|
"nnc_aten_{name}",
|
|
nnc_aten_{name});"""
|
|
func_decls.append(func_decl)
|
|
func_registrations.append(func_registration)
|
|
fm = FileManager(install_dir='.', template_dir='.', dry_run=False)
|
|
fm.write_with_template('external_functions_codegen.cpp', external_path,
|
|
lambda: {'external_registrations': func_registrations, 'external_functions': func_decls})
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(
|
|
description='Generate annotated_fn_args script')
|
|
parser.add_argument('--native-functions',
|
|
'--native_functions',
|
|
help='path to native_functions.yaml',
|
|
default='../../../../aten/src/ATen/native/native_functions.yaml')
|
|
parser.add_argument('--tags',
|
|
help='path to tags.yaml',
|
|
default='../../../../aten/src/ATen/native/tags.yaml')
|
|
parser.add_argument('--template-path',
|
|
'--template_path',
|
|
help='path to external_functions_codegen_template.cpp',
|
|
default='../../../../tools/jit/templates/external_functions_codegen_template.cpp')
|
|
args = parser.parse_args()
|
|
gen_external(args.native_functions, args.tags, args.template_path)
|
|
|
|
if __name__ == '__main__':
|
|
main()
|