mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Make forced eager fallback optional in codegen (#75274)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/75274 - default to generating forced fallback for TS backend (where it is used for tests/debugging, but false otherwise Test Plan: Imported from OSS Reviewed By: bdhirsh Differential Revision: D35411211 Pulled By: wconstab fbshipit-source-id: ccff2f65aa5d8e1aa670d210ce51805985df55ce (cherry picked from commit 55b48cc02497686f4e25ed3c6dcf9b6b77d49136)
This commit is contained in:
committed by
PyTorch MergeBot
parent
5c964e38b0
commit
a8e45b5969
@ -204,6 +204,7 @@ class GenLazyNativeFuncDefinition:
|
||||
class_method_name: str
|
||||
backend_index: BackendIndex
|
||||
tensor_class: str
|
||||
gen_forced_fallback_code: bool
|
||||
|
||||
@method_with_native_function
|
||||
def __call__(self, func: NativeFunction) -> List[str]:
|
||||
@ -215,7 +216,9 @@ class GenLazyNativeFuncDefinition:
|
||||
value_args = schema.filtered_args(values=True, scalars=False)
|
||||
returns_length = len(schema.returns)
|
||||
|
||||
fallback_str = gen_fallback_code(schema, overload_name=func.func.name.overload_name)
|
||||
fallback_str = ""
|
||||
if self.gen_forced_fallback_code:
|
||||
fallback_str = gen_fallback_code(schema, overload_name=func.func.name.overload_name)
|
||||
|
||||
value_types_names = [f"{a.name}" for a in value_args if not a.is_wrapped_scalar]
|
||||
assert len(value_types_names) > 0, "Code below assumes there is at least one tensor arg"
|
||||
|
@ -190,7 +190,8 @@ def run_gen_lazy_tensor(aten_path: str, source_yaml: str, output_dir: str,
|
||||
# per_operator_headers changes whether ATen/Functions.h or individual operator headers are used
|
||||
# it must match how ATen was built
|
||||
per_operator_headers: bool = False,
|
||||
backend_name: str = default_args.backend_name) -> None:
|
||||
backend_name: str = default_args.backend_name,
|
||||
gen_forced_fallback_code: bool = False) -> None:
|
||||
|
||||
template_dir = os.path.join(aten_path, "templates")
|
||||
|
||||
@ -319,8 +320,7 @@ def run_gen_lazy_tensor(aten_path: str, source_yaml: str, output_dir: str,
|
||||
"torch/csrc/lazy/core/shape.h",
|
||||
f"{output_dir}/{backend_key}NativeFunctions.h",
|
||||
f"{output_dir}/LazyIr.h",
|
||||
"torch/csrc/lazy/ts_backend/ts_eager_fallback.h",
|
||||
]],
|
||||
] + (["torch/csrc/lazy/ts_backend/ts_eager_fallback.h"] if gen_forced_fallback_code else [])],
|
||||
'native_functions_include': '',
|
||||
'namespace_prologue': ns_helper.prologue,
|
||||
'namespace_epilogue': ns_helper.epilogue,
|
||||
@ -328,7 +328,8 @@ def run_gen_lazy_tensor(aten_path: str, source_yaml: str, output_dir: str,
|
||||
list(concat_map_codegen(
|
||||
dest.GenLazyNativeFuncDefinition(f'{backend_key}NativeFunctions',
|
||||
backend_indices[backend_key],
|
||||
tensor_class),
|
||||
tensor_class,
|
||||
gen_forced_fallback_code),
|
||||
grouped_native_functions,
|
||||
codegenInplaceVariant=True
|
||||
)),
|
||||
|
@ -208,7 +208,8 @@ def main() -> None:
|
||||
node_base_hdr=ts_node_base,
|
||||
build_in_tree=True,
|
||||
lazy_ir_cls=TSLazyIR,
|
||||
per_operator_headers=options.per_operator_headers)
|
||||
per_operator_headers=options.per_operator_headers,
|
||||
gen_forced_fallback_code=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Reference in New Issue
Block a user