mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Use global variables to register the return_types namedtuples (#108832)
Fixes #69221. Builds on top of #107000, fixing the buck build issue linked [here](https://github.com/pytorch/pytorch/pull/107000#issuecomment-1708857375). Pull Request resolved: https://github.com/pytorch/pytorch/pull/108832 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
6065e7a97c
commit
00908475e6
@ -359,6 +359,9 @@ def gen(
|
||||
create_python_return_type_bindings(
|
||||
fm, functions, lambda fn: True, "python_return_types.cpp"
|
||||
)
|
||||
create_python_return_type_bindings_header(
|
||||
fm, functions, lambda fn: True, "python_return_types.h"
|
||||
)
|
||||
|
||||
valid_tags = parse_tags_yaml(tags_yaml_path)
|
||||
|
||||
@ -436,22 +439,24 @@ def create_python_return_type_bindings(
|
||||
) -> None:
|
||||
"""
|
||||
Generate function to initialize and return named tuple for native functions
|
||||
which returns named tuple and relevant entry for the map in `python_return_types.cpp`.
|
||||
which returns named tuple and registration invocations in `python_return_types.cpp`.
|
||||
"""
|
||||
py_return_types_definition: List[str] = []
|
||||
py_return_types_map: List[str] = []
|
||||
py_return_types_registrations: List[str] = []
|
||||
|
||||
grouped = group_filter_overloads(pairs, pred)
|
||||
|
||||
for name in sorted(grouped.keys(), key=lambda x: str(x)):
|
||||
overloads = grouped[name]
|
||||
definitions, map_entries = generate_return_type_definition_and_map_entry(
|
||||
definitions, registrations = generate_return_type_definition_and_registrations(
|
||||
overloads
|
||||
)
|
||||
py_return_types_definition.append(
|
||||
"" if not definitions else "\n".join(definitions)
|
||||
)
|
||||
py_return_types_map.append("" if not map_entries else "\n".join(map_entries))
|
||||
py_return_types_registrations.append(
|
||||
"" if not registrations else "\n".join(registrations)
|
||||
)
|
||||
|
||||
fm.write_with_template(
|
||||
filename,
|
||||
@ -460,7 +465,39 @@ def create_python_return_type_bindings(
|
||||
"generated_comment": "@"
|
||||
+ f"generated from {fm.template_dir_for_comments()}/{filename}",
|
||||
"py_return_types": py_return_types_definition,
|
||||
"py_return_types_map": py_return_types_map,
|
||||
"py_return_types_registrations": py_return_types_registrations,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def create_python_return_type_bindings_header(
|
||||
fm: FileManager,
|
||||
pairs: Sequence[PythonSignatureNativeFunctionPair],
|
||||
pred: Callable[[NativeFunction], bool],
|
||||
filename: str,
|
||||
) -> None:
|
||||
"""
|
||||
Generate function to initialize and return named tuple for native functions
|
||||
which returns named tuple and relevant entry for the map in `python_return_types.cpp`.
|
||||
"""
|
||||
py_return_types_declarations: List[str] = []
|
||||
|
||||
grouped = group_filter_overloads(pairs, pred)
|
||||
|
||||
for name in sorted(grouped.keys(), key=lambda x: str(x)):
|
||||
overloads = grouped[name]
|
||||
declarations = generate_return_type_declarations(overloads)
|
||||
py_return_types_declarations.append(
|
||||
"" if not declarations else "\n".join(declarations)
|
||||
)
|
||||
|
||||
fm.write_with_template(
|
||||
filename,
|
||||
filename,
|
||||
lambda: {
|
||||
"generated_comment": "@"
|
||||
+ f"generated from {fm.template_dir_for_comments()}/{filename}",
|
||||
"py_return_types_declarations": py_return_types_declarations,
|
||||
},
|
||||
)
|
||||
|
||||
@ -683,27 +720,25 @@ def emit_namedtuple_call(
|
||||
typenames[tn_key] = typename
|
||||
typedefs.append(
|
||||
f"""\
|
||||
static PyTypeObject* {typename} = get_namedtuple("{name}");"""
|
||||
static PyTypeObject* {typename} = generated::get_{name}_namedtuple();"""
|
||||
)
|
||||
|
||||
return typedefs, typenames
|
||||
|
||||
|
||||
def generate_return_type_definition_and_map_entry(
|
||||
def generate_return_type_definition_and_registrations(
|
||||
overloads: Sequence[PythonSignatureNativeFunctionPair],
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
"""
|
||||
Generate block of function in `python_return_types.cpp` to initialize
|
||||
and return named tuple for a native function which returns named tuple
|
||||
and relevant entry for the map in same file.
|
||||
and registration invocations in same file.
|
||||
"""
|
||||
typenames: Dict[
|
||||
str, str
|
||||
] = {} # map from unique name + field name lists to typedef name
|
||||
definitions: List[str] = [] # function defintion to register the typedef
|
||||
map_entries: List[
|
||||
str
|
||||
] = [] # C++ map entry of <function_name, function creates it namedtuple>
|
||||
definitions: List[str] = [] # function definition to register the typedef
|
||||
registrations: List[str] = [] # register call for the typedef
|
||||
|
||||
for overload in overloads:
|
||||
fieldnames = namedtuple_fieldnames(overload.function.func.returns)
|
||||
@ -735,9 +770,42 @@ PyTypeObject* get_{name}_namedtuple() {{
|
||||
}}
|
||||
"""
|
||||
)
|
||||
map_entries.append(f'{{"{name}", get_{name}_namedtuple()}}, ')
|
||||
registrations.append(
|
||||
f'addReturnType(return_types_module, "{name}", generated::get_{name}_namedtuple());'
|
||||
)
|
||||
|
||||
return definitions, map_entries
|
||||
return definitions, registrations
|
||||
|
||||
|
||||
def generate_return_type_declarations(
|
||||
overloads: Sequence[PythonSignatureNativeFunctionPair],
|
||||
) -> List[str]:
|
||||
"""
|
||||
Generate block of function declarations in `python_return_types.h` to initialize
|
||||
and return named tuple for a native function.
|
||||
"""
|
||||
typenames: Dict[
|
||||
str, str
|
||||
] = {} # map from unique name + field name lists to typedef name
|
||||
declarations: List[str] = [] # function declaration to register the typedef
|
||||
|
||||
for overload in overloads:
|
||||
fieldnames = namedtuple_fieldnames(overload.function.func.returns)
|
||||
if not fieldnames:
|
||||
continue
|
||||
|
||||
name = cpp.name(overload.function.func) # use @with_native_function?
|
||||
tn_key = gen_namedtuple_typename_key(overload.function)
|
||||
typename = typenames.get(tn_key)
|
||||
|
||||
if typename is None:
|
||||
typename = (
|
||||
f'{name}NamedTuple{"" if not declarations else len(declarations)}'
|
||||
)
|
||||
typenames[tn_key] = typename
|
||||
declarations.append(f"PyTypeObject* get_{name}_namedtuple();")
|
||||
|
||||
return declarations
|
||||
|
||||
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
||||
|
Reference in New Issue
Block a user