mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 00:14:54 +08:00
Revert "Don't introduce new overload for SymInt (#83628)"
This reverts commit 9790d90e4b0288796ab44a6b4979db0a67580ba8. Reverted https://github.com/pytorch/pytorch/pull/83628 on behalf of https://github.com/malfet due to Breaks internal builds, see D39076487
This commit is contained in:
@ -238,8 +238,6 @@ def gen(
|
||||
tags_yaml_path: str,
|
||||
deprecated_yaml_path: str,
|
||||
template_path: str,
|
||||
*,
|
||||
symint: bool = True,
|
||||
) -> None:
|
||||
fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
|
||||
native_functions = parse_native_yaml(
|
||||
@ -255,7 +253,6 @@ def gen(
|
||||
None,
|
||||
"python_variable_methods.cpp",
|
||||
method=True,
|
||||
symint=symint,
|
||||
)
|
||||
|
||||
# NOTE: num_shards here must be synced with gatherTorchFunctions in
|
||||
@ -269,7 +266,6 @@ def gen(
|
||||
"python_torch_functions.cpp",
|
||||
method=False,
|
||||
num_shards=3,
|
||||
symint=symint,
|
||||
)
|
||||
|
||||
create_python_bindings(
|
||||
@ -279,7 +275,6 @@ def gen(
|
||||
"torch.nn",
|
||||
"python_nn_functions.cpp",
|
||||
method=False,
|
||||
symint=symint,
|
||||
)
|
||||
|
||||
create_python_bindings(
|
||||
@ -289,7 +284,6 @@ def gen(
|
||||
"torch.fft",
|
||||
"python_fft_functions.cpp",
|
||||
method=False,
|
||||
symint=symint,
|
||||
)
|
||||
|
||||
create_python_bindings(
|
||||
@ -299,7 +293,6 @@ def gen(
|
||||
"torch.linalg",
|
||||
"python_linalg_functions.cpp",
|
||||
method=False,
|
||||
symint=symint,
|
||||
)
|
||||
|
||||
create_python_bindings(
|
||||
@ -309,7 +302,6 @@ def gen(
|
||||
"torch.sparse",
|
||||
"python_sparse_functions.cpp",
|
||||
method=False,
|
||||
symint=symint,
|
||||
)
|
||||
|
||||
create_python_bindings(
|
||||
@ -319,7 +311,6 @@ def gen(
|
||||
"torch.special",
|
||||
"python_special_functions.cpp",
|
||||
method=False,
|
||||
symint=symint,
|
||||
)
|
||||
|
||||
# Currently, we only use `functions` to generate `return_types` bindings.
|
||||
@ -363,7 +354,6 @@ def create_python_bindings(
|
||||
filename: str,
|
||||
*,
|
||||
method: bool,
|
||||
symint: bool = True,
|
||||
) -> None:
|
||||
"""Generates Python bindings to ATen functions"""
|
||||
py_methods: List[str] = []
|
||||
@ -375,9 +365,7 @@ def create_python_bindings(
|
||||
|
||||
for name in sorted(grouped.keys(), key=lambda x: str(x)):
|
||||
overloads = grouped[name]
|
||||
py_methods.append(
|
||||
method_impl(name, module, overloads, method=method, symint=symint)
|
||||
)
|
||||
py_methods.append(method_impl(name, module, overloads, method=method))
|
||||
py_method_defs.append(method_def(name, module, overloads, method=method))
|
||||
py_forwards.extend(forward_decls(name, overloads, method=method))
|
||||
ops_headers.append(f"#include <ATen/ops/{name.base}.h>")
|
||||
@ -440,7 +428,6 @@ def create_python_bindings_sharded(
|
||||
*,
|
||||
method: bool,
|
||||
num_shards: int,
|
||||
symint: bool = True,
|
||||
) -> None:
|
||||
"""Generates Python bindings to ATen functions"""
|
||||
grouped = group_filter_overloads(pairs, pred)
|
||||
@ -457,9 +444,7 @@ def create_python_bindings_sharded(
|
||||
return {
|
||||
"ops_headers": [f"#include <ATen/ops/{name.base}.h>"],
|
||||
"py_forwards": list(forward_decls(name, fn_pairs, method=method)),
|
||||
"py_methods": [
|
||||
method_impl(name, module, fn_pairs, method=method, symint=symint)
|
||||
],
|
||||
"py_methods": [method_impl(name, module, fn_pairs, method=method)],
|
||||
"py_method_defs": [method_def(name, module, fn_pairs, method=method)],
|
||||
}
|
||||
|
||||
@ -788,7 +773,6 @@ def method_impl(
|
||||
overloads: Sequence[PythonSignatureNativeFunctionPair],
|
||||
*,
|
||||
method: bool,
|
||||
symint: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a python binding for all overloads of an op.
|
||||
@ -807,18 +791,14 @@ def method_impl(
|
||||
|
||||
traceable = "true" if all(should_trace(o.function) for o in overloads) else "false"
|
||||
|
||||
grouped_overloads: Sequence[PythonSignatureGroup] = group_overloads(
|
||||
overloads, symint=symint
|
||||
)
|
||||
grouped_overloads: Sequence[PythonSignatureGroup] = group_overloads(overloads)
|
||||
is_singleton = len(grouped_overloads) == 1
|
||||
signatures: List[str] = []
|
||||
dispatch: List[str] = []
|
||||
for overload_index, overload in enumerate(grouped_overloads):
|
||||
signature = overload.signature.signature_str(symint=symint)
|
||||
signature = overload.signature.signature_str()
|
||||
signatures.append(f"{cpp_string(str(signature))},")
|
||||
dispatch_body = emit_dispatch_case(
|
||||
overload, namedtuple_typenames, symint=symint
|
||||
)
|
||||
dispatch_body = emit_dispatch_case(overload, namedtuple_typenames)
|
||||
dispatch.append(
|
||||
PY_VARIABLE_CASE.substitute(
|
||||
overload_index=overload_index, body=dispatch_body
|
||||
@ -902,8 +882,6 @@ if (_r.isNone(${out_idx})) {
|
||||
def emit_dispatch_case(
|
||||
overload: PythonSignatureGroup,
|
||||
namedtuple_typenames: Dict[str, str],
|
||||
*,
|
||||
symint: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
Emit dispatch code for a single parsed signature. This corresponds to either
|
||||
@ -916,19 +894,18 @@ def emit_dispatch_case(
|
||||
return PY_VARIABLE_OUT.substitute(
|
||||
out_idx=overload.signature.output_idx(),
|
||||
call_dispatch=emit_single_dispatch(
|
||||
overload.signature, overload.base, namedtuple_typenames, symint=symint
|
||||
overload.signature, overload.base, namedtuple_typenames
|
||||
),
|
||||
call_dispatch_out=emit_single_dispatch(
|
||||
overload.signature,
|
||||
overload.outplace,
|
||||
namedtuple_typenames,
|
||||
symint=symint,
|
||||
),
|
||||
)
|
||||
else:
|
||||
# no-output version only
|
||||
return emit_single_dispatch(
|
||||
overload.signature, overload.base, namedtuple_typenames, symint=symint
|
||||
overload.signature, overload.base, namedtuple_typenames
|
||||
)
|
||||
|
||||
|
||||
@ -1010,14 +987,14 @@ def method_def(
|
||||
|
||||
|
||||
def group_overloads(
|
||||
overloads: Sequence[PythonSignatureNativeFunctionPair], *, symint: bool = True
|
||||
overloads: Sequence[PythonSignatureNativeFunctionPair],
|
||||
) -> Sequence[PythonSignatureGroup]:
|
||||
bases: Dict[str, PythonSignatureNativeFunctionPair] = {}
|
||||
outplaces: Dict[str, PythonSignatureNativeFunctionPair] = {}
|
||||
|
||||
# first group by signature ignoring out arguments
|
||||
for overload in overloads:
|
||||
sig = overload.signature.signature_str(skip_outputs=True, symint=symint)
|
||||
sig = overload.signature.signature_str(skip_outputs=True)
|
||||
if overload.function.func.is_out_fn():
|
||||
if sig in outplaces:
|
||||
raise RuntimeError(
|
||||
@ -1044,11 +1021,9 @@ def group_overloads(
|
||||
and not overload.signature.deprecated
|
||||
):
|
||||
candidates.append(
|
||||
overload.signature.signature_str(
|
||||
skip_outputs=True, symint=symint
|
||||
)
|
||||
overload.signature.signature_str(skip_outputs=True)
|
||||
)
|
||||
out_sig = out.signature.signature_str(symint=symint)
|
||||
out_sig = out.signature.signature_str()
|
||||
raise RuntimeError(
|
||||
f"While identifying overloads, we found an out schema {out_sig} without a corresponding non-out variant. "
|
||||
f"We expected the non-out variant to have schema: \n- {sig}\nPlease check that you spelled the schema "
|
||||
@ -1063,7 +1038,7 @@ def group_overloads(
|
||||
)
|
||||
for sig, base in bases.items()
|
||||
]
|
||||
return sort_overloads(grouped, symint=symint)
|
||||
return sort_overloads(grouped)
|
||||
|
||||
|
||||
# This function declares a partial order on declarations, and sorts them according
|
||||
@ -1112,7 +1087,7 @@ def group_overloads(
|
||||
|
||||
|
||||
def sort_overloads(
|
||||
grouped_overloads: Sequence[PythonSignatureGroup], *, symint: bool = True
|
||||
grouped_overloads: Sequence[PythonSignatureGroup],
|
||||
) -> Sequence[PythonSignatureGroup]:
|
||||
# NB: Smaller here means lower priority
|
||||
|
||||
@ -1157,7 +1132,7 @@ def sort_overloads(
|
||||
|
||||
# First sort by signature
|
||||
grouped_overloads = sorted(
|
||||
grouped_overloads, key=lambda x: x.signature.signature_str(symint=symint)
|
||||
grouped_overloads, key=lambda x: x.signature.signature_str()
|
||||
)
|
||||
|
||||
# Construct the relation graph
|
||||
@ -1195,11 +1170,7 @@ def sort_overloads(
|
||||
|
||||
|
||||
def emit_single_dispatch(
|
||||
ps: PythonSignature,
|
||||
f: NativeFunction,
|
||||
namedtuple_typenames: Dict[str, str],
|
||||
*,
|
||||
symint: bool = True,
|
||||
ps: PythonSignature, f: NativeFunction, namedtuple_typenames: Dict[str, str]
|
||||
) -> str:
|
||||
"""
|
||||
Emit dispatch code for a single native function.
|
||||
@ -1218,10 +1189,7 @@ def emit_single_dispatch(
|
||||
# dispatch lambda signature
|
||||
name = cpp.name(f.func)
|
||||
lambda_formals = ", ".join(
|
||||
map(
|
||||
lambda a: f"{a.type_str} {a.name}",
|
||||
dispatch_lambda_args(ps, f, symint=symint),
|
||||
)
|
||||
map(lambda a: f"{a.type_str} {a.name}", dispatch_lambda_args(ps, f))
|
||||
)
|
||||
lambda_return = dispatch_lambda_return_str(f)
|
||||
|
||||
@ -1230,8 +1198,8 @@ def emit_single_dispatch(
|
||||
dispatch_args = ", ".join(cpp_dispatch_exprs(f, python_signature=ps))
|
||||
|
||||
# from arg parser outputs to dispatch lambda arguments
|
||||
parser_outputs = arg_parser_output_exprs(ps, f, symint=symint)
|
||||
lambda_arg_exprs = dispatch_lambda_exprs(ps, f, symint=symint)
|
||||
parser_outputs = arg_parser_output_exprs(ps, f)
|
||||
lambda_arg_exprs = dispatch_lambda_exprs(ps, f)
|
||||
inits = "\n".join(lambda_arg_exprs.inits)
|
||||
lambda_args = ", ".join(lambda_arg_exprs.exprs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user