Use global variables to register the return_types namedtuples (#108832)

Fixes #69221. Builds on top of #107000, fixing the buck build issue linked [here](https://github.com/pytorch/pytorch/pull/107000#issuecomment-1708857375).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108832
Approved by: https://github.com/zou3519
This commit is contained in:
Andrei Gheorghe
2023-09-13 17:42:46 +00:00
committed by PyTorch MergeBot
parent 6065e7a97c
commit 00908475e6
15 changed files with 119 additions and 55 deletions

View File

@ -359,6 +359,9 @@ def gen(
create_python_return_type_bindings(
fm, functions, lambda fn: True, "python_return_types.cpp"
)
create_python_return_type_bindings_header(
fm, functions, lambda fn: True, "python_return_types.h"
)
valid_tags = parse_tags_yaml(tags_yaml_path)
@ -436,22 +439,24 @@ def create_python_return_type_bindings(
) -> None:
"""
Generate function to initialize and return named tuple for native functions
which returns named tuple and relevant entry for the map in `python_return_types.cpp`.
which returns named tuple and registration invocations in `python_return_types.cpp`.
"""
py_return_types_definition: List[str] = []
py_return_types_map: List[str] = []
py_return_types_registrations: List[str] = []
grouped = group_filter_overloads(pairs, pred)
for name in sorted(grouped.keys(), key=lambda x: str(x)):
overloads = grouped[name]
definitions, map_entries = generate_return_type_definition_and_map_entry(
definitions, registrations = generate_return_type_definition_and_registrations(
overloads
)
py_return_types_definition.append(
"" if not definitions else "\n".join(definitions)
)
py_return_types_map.append("" if not map_entries else "\n".join(map_entries))
py_return_types_registrations.append(
"" if not registrations else "\n".join(registrations)
)
fm.write_with_template(
filename,
@ -460,7 +465,39 @@ def create_python_return_type_bindings(
"generated_comment": "@"
+ f"generated from {fm.template_dir_for_comments()}/{filename}",
"py_return_types": py_return_types_definition,
"py_return_types_map": py_return_types_map,
"py_return_types_registrations": py_return_types_registrations,
},
)
def create_python_return_type_bindings_header(
fm: FileManager,
pairs: Sequence[PythonSignatureNativeFunctionPair],
pred: Callable[[NativeFunction], bool],
filename: str,
) -> None:
"""
Generate function to initialize and return named tuple for native functions
which returns named tuple and relevant entry for the map in `python_return_types.cpp`.
"""
py_return_types_declarations: List[str] = []
grouped = group_filter_overloads(pairs, pred)
for name in sorted(grouped.keys(), key=lambda x: str(x)):
overloads = grouped[name]
declarations = generate_return_type_declarations(overloads)
py_return_types_declarations.append(
"" if not declarations else "\n".join(declarations)
)
fm.write_with_template(
filename,
filename,
lambda: {
"generated_comment": "@"
+ f"generated from {fm.template_dir_for_comments()}/{filename}",
"py_return_types_declarations": py_return_types_declarations,
},
)
@ -683,27 +720,25 @@ def emit_namedtuple_call(
typenames[tn_key] = typename
typedefs.append(
f"""\
static PyTypeObject* {typename} = get_namedtuple("{name}");"""
static PyTypeObject* {typename} = generated::get_{name}_namedtuple();"""
)
return typedefs, typenames
def generate_return_type_definition_and_map_entry(
def generate_return_type_definition_and_registrations(
overloads: Sequence[PythonSignatureNativeFunctionPair],
) -> Tuple[List[str], List[str]]:
"""
Generate block of function in `python_return_types.cpp` to initialize
and return named tuple for a native function which returns named tuple
and relevant entry for the map in same file.
and registration invocations in same file.
"""
typenames: Dict[
str, str
] = {} # map from unique name + field name lists to typedef name
definitions: List[str] = [] # function defintion to register the typedef
map_entries: List[
str
] = [] # C++ map entry of <function_name, function creates it namedtuple>
definitions: List[str] = [] # function definition to register the typedef
registrations: List[str] = [] # register call for the typedef
for overload in overloads:
fieldnames = namedtuple_fieldnames(overload.function.func.returns)
@ -735,9 +770,42 @@ PyTypeObject* get_{name}_namedtuple() {{
}}
"""
)
map_entries.append(f'{{"{name}", get_{name}_namedtuple()}}, ')
registrations.append(
f'addReturnType(return_types_module, "{name}", generated::get_{name}_namedtuple());'
)
return definitions, map_entries
return definitions, registrations
def generate_return_type_declarations(
overloads: Sequence[PythonSignatureNativeFunctionPair],
) -> List[str]:
"""
Generate block of function declarations in `python_return_types.h` to initialize
and return named tuple for a native function.
"""
typenames: Dict[
str, str
] = {} # map from unique name + field name lists to typedef name
declarations: List[str] = [] # function declaration to register the typedef
for overload in overloads:
fieldnames = namedtuple_fieldnames(overload.function.func.returns)
if not fieldnames:
continue
name = cpp.name(overload.function.func) # use @with_native_function?
tn_key = gen_namedtuple_typename_key(overload.function)
typename = typenames.get(tn_key)
if typename is None:
typename = (
f'{name}NamedTuple{"" if not declarations else len(declarations)}'
)
typenames[tn_key] = typename
declarations.append(f"PyTypeObject* get_{name}_namedtuple();")
return declarations
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #