mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Pytorch Edge] Extend Tracer to Custom Classes (#67004)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/67004 New version because the other one was impossible to rebase Trace custom classes Test Plan: CI. Reviewed By: dhruvbird Differential Revision: D31818978 fbshipit-source-id: daa22ccb153e32685bcca43a303ba9e21042d052
This commit is contained in:
committed by
Facebook GitHub Bot
parent
34ee5b11ff
commit
6c22b96082
@ -52,7 +52,7 @@ def get_selected_kernel_dtypes_code(
|
||||
# dtypes are selected (i.e. both cases).
|
||||
#
|
||||
body = "return true;"
|
||||
if selective_builder.include_all_operators is False and selective_builder.include_all_kernel_dtypes is False:
|
||||
if selective_builder.include_all_operators is False and selective_builder.include_all_non_op_selectives is False:
|
||||
body_parts = []
|
||||
for kernel_tag, dtypes in selective_builder.kernel_metadata.items():
|
||||
conditions = list(map(lambda x: 'scalar_type == at::ScalarType::' + x, dtypes))
|
||||
@ -76,10 +76,16 @@ def write_selected_mobile_ops(
|
||||
selective_builder: SelectiveBuilder,
|
||||
) -> None:
|
||||
root_ops = extract_root_operators(selective_builder)
|
||||
custom_classes = selective_builder.custom_classes
|
||||
with open(output_file_path, "wb") as out_file:
|
||||
body_parts = [selected_mobile_ops_preamble]
|
||||
# This condition checks if we are in selective build.
|
||||
# if these lists are not defined the corresponding selective build macros trivially return the item in question was selected
|
||||
if not selective_builder.include_all_operators:
|
||||
body_parts.append("#define TORCH_OPERATOR_WHITELIST " + (";".join(sorted(root_ops))) + ";\n\n")
|
||||
# This condition checks if we are in tracing based selective build
|
||||
if selective_builder.include_all_non_op_selectives is False:
|
||||
body_parts.append("#define TORCH_CUSTOM_CLASS_ALLOWLIST " + (";".join(sorted(custom_classes))) + ";\n\n")
|
||||
|
||||
body_parts.append(get_selected_kernel_dtypes_code(selective_builder))
|
||||
header_contents = "".join(body_parts)
|
||||
|
||||
Reference in New Issue
Block a user