mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
default argument handling for mobile unboxing codegen
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78353 Differential Revision: [D36702695](https://our.internmc.facebook.com/intern/diff/D36702695/) Approved by: https://github.com/priyaramani, https://github.com/larryliu0820
This commit is contained in:
committed by
PyTorch MergeBot
parent
85f308275e
commit
8ad305f375
@ -3,13 +3,14 @@ import argparse
|
||||
import os
|
||||
import pathlib
|
||||
from dataclasses import dataclass
|
||||
from torchgen.api import cpp
|
||||
from torchgen.api import unboxing
|
||||
from torchgen.api.translate import translate
|
||||
from torchgen.api.types import CppSignatureGroup
|
||||
from torchgen.api.unboxing import convert_arguments
|
||||
from torchgen.context import method_with_native_function
|
||||
from torchgen.gen import parse_native_yaml, cpp_string, get_custom_build_selector
|
||||
from torchgen.model import NativeFunction, NativeFunctionsGroup, Variant
|
||||
from torchgen.model import NativeFunction, NativeFunctionsGroup, Variant, Argument
|
||||
from torchgen.selective_build.selector import SelectiveBuilder
|
||||
from torchgen.utils import Target, FileManager, mapMaybe, make_file_manager
|
||||
from typing import Union, Sequence
|
||||
@ -103,12 +104,20 @@ class ComputeCodegenUnboxedKernels:
|
||||
connector = ",\n\t\t"
|
||||
args_code = []
|
||||
for arg in args:
|
||||
if not arg.default:
|
||||
# Using method=False faithful C++ API, so we should not see SelfArgument/TensorOptionsArgument
|
||||
assert isinstance(arg.argument, Argument)
|
||||
if not arg.argument.default:
|
||||
arg_cpp = "c10::IValue(c10::nullopt)"
|
||||
elif arg.default.startswith("{"):
|
||||
arg_cpp = f"c10::IntArrayRef({arg.default})"
|
||||
else:
|
||||
arg_cpp = f"c10::IValue({arg.default})"
|
||||
# The unboxing code uses the faithful C++ API to avoid the overhead
|
||||
# from wrapping/unwrapping TensorOptios.
|
||||
# However, we would look to include default args for schema parsing.
|
||||
# Default args only show up in the nonfaithful C++ API,
|
||||
arg_default = cpp.default_expr(arg.argument.default, arg.argument.type)
|
||||
if arg_default.startswith("{"):
|
||||
arg_cpp = f"c10::IntArrayRef({arg_default})"
|
||||
else:
|
||||
arg_cpp = f"c10::IValue({arg_default})"
|
||||
args_code.append(
|
||||
f"""c10::Argument("{arg.name}", nullptr, c10::nullopt, {arg_cpp})"""
|
||||
)
|
||||
|
Reference in New Issue
Block a user