default argument handling for mobile unboxing codegen

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78353

Differential Revision: [D36702695](https://our.internmc.facebook.com/intern/diff/D36702695/)

Approved by: https://github.com/priyaramani, https://github.com/larryliu0820
This commit is contained in:
Brian Hirsh
2022-05-26 13:29:20 -07:00
committed by PyTorch MergeBot
parent 85f308275e
commit 8ad305f375

View File

@ -3,13 +3,14 @@ import argparse
import os
import pathlib
from dataclasses import dataclass
from torchgen.api import cpp
from torchgen.api import unboxing
from torchgen.api.translate import translate
from torchgen.api.types import CppSignatureGroup
from torchgen.api.unboxing import convert_arguments
from torchgen.context import method_with_native_function
from torchgen.gen import parse_native_yaml, cpp_string, get_custom_build_selector
from torchgen.model import NativeFunction, NativeFunctionsGroup, Variant
from torchgen.model import NativeFunction, NativeFunctionsGroup, Variant, Argument
from torchgen.selective_build.selector import SelectiveBuilder
from torchgen.utils import Target, FileManager, mapMaybe, make_file_manager
from typing import Union, Sequence
@ -103,12 +104,20 @@ class ComputeCodegenUnboxedKernels:
connector = ",\n\t\t"
args_code = []
for arg in args:
if not arg.default:
# Using method=False faithful C++ API, so we should not see SelfArgument/TensorOptionsArgument
assert isinstance(arg.argument, Argument)
if not arg.argument.default:
arg_cpp = "c10::IValue(c10::nullopt)"
elif arg.default.startswith("{"):
arg_cpp = f"c10::IntArrayRef({arg.default})"
else:
arg_cpp = f"c10::IValue({arg.default})"
# The unboxing code uses the faithful C++ API to avoid the overhead
# from wrapping/unwrapping TensorOptios.
# However, we would look to include default args for schema parsing.
# Default args only show up in the nonfaithful C++ API,
arg_default = cpp.default_expr(arg.argument.default, arg.argument.type)
if arg_default.startswith("{"):
arg_cpp = f"c10::IntArrayRef({arg_default})"
else:
arg_cpp = f"c10::IValue({arg_default})"
args_code.append(
f"""c10::Argument("{arg.name}", nullptr, c10::nullopt, {arg_cpp})"""
)