Revert "Don't introduce new overload for SymInt (#83628)"

This reverts commit 9790d90e4b0288796ab44a6b4979db0a67580ba8.

Reverted https://github.com/pytorch/pytorch/pull/83628 on behalf of https://github.com/malfet due to Breaks internal builds, see D39076487
This commit is contained in:
PyTorch MergeBot
2022-08-27 01:23:17 +00:00
parent 38e5e4a85f
commit c7edcd6968
81 changed files with 729 additions and 766 deletions

View File

@ -238,8 +238,6 @@ def gen(
tags_yaml_path: str,
deprecated_yaml_path: str,
template_path: str,
*,
symint: bool = True,
) -> None:
fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
native_functions = parse_native_yaml(
@ -255,7 +253,6 @@ def gen(
None,
"python_variable_methods.cpp",
method=True,
symint=symint,
)
# NOTE: num_shards here must be synced with gatherTorchFunctions in
@ -269,7 +266,6 @@ def gen(
"python_torch_functions.cpp",
method=False,
num_shards=3,
symint=symint,
)
create_python_bindings(
@ -279,7 +275,6 @@ def gen(
"torch.nn",
"python_nn_functions.cpp",
method=False,
symint=symint,
)
create_python_bindings(
@ -289,7 +284,6 @@ def gen(
"torch.fft",
"python_fft_functions.cpp",
method=False,
symint=symint,
)
create_python_bindings(
@ -299,7 +293,6 @@ def gen(
"torch.linalg",
"python_linalg_functions.cpp",
method=False,
symint=symint,
)
create_python_bindings(
@ -309,7 +302,6 @@ def gen(
"torch.sparse",
"python_sparse_functions.cpp",
method=False,
symint=symint,
)
create_python_bindings(
@ -319,7 +311,6 @@ def gen(
"torch.special",
"python_special_functions.cpp",
method=False,
symint=symint,
)
# Currently, we only use `functions` to generate `return_types` bindings.
@ -363,7 +354,6 @@ def create_python_bindings(
filename: str,
*,
method: bool,
symint: bool = True,
) -> None:
"""Generates Python bindings to ATen functions"""
py_methods: List[str] = []
@ -375,9 +365,7 @@ def create_python_bindings(
for name in sorted(grouped.keys(), key=lambda x: str(x)):
overloads = grouped[name]
py_methods.append(
method_impl(name, module, overloads, method=method, symint=symint)
)
py_methods.append(method_impl(name, module, overloads, method=method))
py_method_defs.append(method_def(name, module, overloads, method=method))
py_forwards.extend(forward_decls(name, overloads, method=method))
ops_headers.append(f"#include <ATen/ops/{name.base}.h>")
@ -440,7 +428,6 @@ def create_python_bindings_sharded(
*,
method: bool,
num_shards: int,
symint: bool = True,
) -> None:
"""Generates Python bindings to ATen functions"""
grouped = group_filter_overloads(pairs, pred)
@ -457,9 +444,7 @@ def create_python_bindings_sharded(
return {
"ops_headers": [f"#include <ATen/ops/{name.base}.h>"],
"py_forwards": list(forward_decls(name, fn_pairs, method=method)),
"py_methods": [
method_impl(name, module, fn_pairs, method=method, symint=symint)
],
"py_methods": [method_impl(name, module, fn_pairs, method=method)],
"py_method_defs": [method_def(name, module, fn_pairs, method=method)],
}
@ -788,7 +773,6 @@ def method_impl(
overloads: Sequence[PythonSignatureNativeFunctionPair],
*,
method: bool,
symint: bool = True,
) -> str:
"""
Generate a python binding for all overloads of an op.
@ -807,18 +791,14 @@ def method_impl(
traceable = "true" if all(should_trace(o.function) for o in overloads) else "false"
grouped_overloads: Sequence[PythonSignatureGroup] = group_overloads(
overloads, symint=symint
)
grouped_overloads: Sequence[PythonSignatureGroup] = group_overloads(overloads)
is_singleton = len(grouped_overloads) == 1
signatures: List[str] = []
dispatch: List[str] = []
for overload_index, overload in enumerate(grouped_overloads):
signature = overload.signature.signature_str(symint=symint)
signature = overload.signature.signature_str()
signatures.append(f"{cpp_string(str(signature))},")
dispatch_body = emit_dispatch_case(
overload, namedtuple_typenames, symint=symint
)
dispatch_body = emit_dispatch_case(overload, namedtuple_typenames)
dispatch.append(
PY_VARIABLE_CASE.substitute(
overload_index=overload_index, body=dispatch_body
@ -902,8 +882,6 @@ if (_r.isNone(${out_idx})) {
def emit_dispatch_case(
overload: PythonSignatureGroup,
namedtuple_typenames: Dict[str, str],
*,
symint: bool = True,
) -> str:
"""
Emit dispatch code for a single parsed signature. This corresponds to either
@ -916,19 +894,18 @@ def emit_dispatch_case(
return PY_VARIABLE_OUT.substitute(
out_idx=overload.signature.output_idx(),
call_dispatch=emit_single_dispatch(
overload.signature, overload.base, namedtuple_typenames, symint=symint
overload.signature, overload.base, namedtuple_typenames
),
call_dispatch_out=emit_single_dispatch(
overload.signature,
overload.outplace,
namedtuple_typenames,
symint=symint,
),
)
else:
# no-output version only
return emit_single_dispatch(
overload.signature, overload.base, namedtuple_typenames, symint=symint
overload.signature, overload.base, namedtuple_typenames
)
@ -1010,14 +987,14 @@ def method_def(
def group_overloads(
overloads: Sequence[PythonSignatureNativeFunctionPair], *, symint: bool = True
overloads: Sequence[PythonSignatureNativeFunctionPair],
) -> Sequence[PythonSignatureGroup]:
bases: Dict[str, PythonSignatureNativeFunctionPair] = {}
outplaces: Dict[str, PythonSignatureNativeFunctionPair] = {}
# first group by signature ignoring out arguments
for overload in overloads:
sig = overload.signature.signature_str(skip_outputs=True, symint=symint)
sig = overload.signature.signature_str(skip_outputs=True)
if overload.function.func.is_out_fn():
if sig in outplaces:
raise RuntimeError(
@ -1044,11 +1021,9 @@ def group_overloads(
and not overload.signature.deprecated
):
candidates.append(
overload.signature.signature_str(
skip_outputs=True, symint=symint
)
overload.signature.signature_str(skip_outputs=True)
)
out_sig = out.signature.signature_str(symint=symint)
out_sig = out.signature.signature_str()
raise RuntimeError(
f"While identifying overloads, we found an out schema {out_sig} without a corresponding non-out variant. "
f"We expected the non-out variant to have schema: \n- {sig}\nPlease check that you spelled the schema "
@ -1063,7 +1038,7 @@ def group_overloads(
)
for sig, base in bases.items()
]
return sort_overloads(grouped, symint=symint)
return sort_overloads(grouped)
# This function declares a partial order on declarations, and sorts them according
@ -1112,7 +1087,7 @@ def group_overloads(
def sort_overloads(
grouped_overloads: Sequence[PythonSignatureGroup], *, symint: bool = True
grouped_overloads: Sequence[PythonSignatureGroup],
) -> Sequence[PythonSignatureGroup]:
# NB: Smaller here means lower priority
@ -1157,7 +1132,7 @@ def sort_overloads(
# First sort by signature
grouped_overloads = sorted(
grouped_overloads, key=lambda x: x.signature.signature_str(symint=symint)
grouped_overloads, key=lambda x: x.signature.signature_str()
)
# Construct the relation graph
@ -1195,11 +1170,7 @@ def sort_overloads(
def emit_single_dispatch(
ps: PythonSignature,
f: NativeFunction,
namedtuple_typenames: Dict[str, str],
*,
symint: bool = True,
ps: PythonSignature, f: NativeFunction, namedtuple_typenames: Dict[str, str]
) -> str:
"""
Emit dispatch code for a single native function.
@ -1218,10 +1189,7 @@ def emit_single_dispatch(
# dispatch lambda signature
name = cpp.name(f.func)
lambda_formals = ", ".join(
map(
lambda a: f"{a.type_str} {a.name}",
dispatch_lambda_args(ps, f, symint=symint),
)
map(lambda a: f"{a.type_str} {a.name}", dispatch_lambda_args(ps, f))
)
lambda_return = dispatch_lambda_return_str(f)
@ -1230,8 +1198,8 @@ def emit_single_dispatch(
dispatch_args = ", ".join(cpp_dispatch_exprs(f, python_signature=ps))
# from arg parser outputs to dispatch lambda arguments
parser_outputs = arg_parser_output_exprs(ps, f, symint=symint)
lambda_arg_exprs = dispatch_lambda_exprs(ps, f, symint=symint)
parser_outputs = arg_parser_output_exprs(ps, f)
lambda_arg_exprs = dispatch_lambda_exprs(ps, f)
inits = "\n".join(lambda_arg_exprs.inits)
lambda_args = ", ".join(lambda_arg_exprs.exprs)