[BE][Easy] enable postponed annotations in tools (#129375)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129375
Approved by: https://github.com/malfet
This commit is contained in:
Xuehai Pan
2024-06-29 12:48:06 +08:00
committed by PyTorch MergeBot
parent 58f346c874
commit 8a67daf283
123 changed files with 1274 additions and 1053 deletions

View File

@ -1,10 +1,13 @@
# Generates RegisterCodegenUnboxedKernels.cpp, UnboxingFunctions.h and UnboxingFunctions.cpp.
from __future__ import annotations
import argparse
import os
import pathlib
import sys
from dataclasses import dataclass
from typing import List, Literal, Sequence, Union
from pathlib import Path
from typing import Literal, Sequence, TYPE_CHECKING
import yaml
@ -15,10 +18,13 @@ from torchgen.api.unboxing import convert_arguments
from torchgen.context import method_with_native_function
from torchgen.gen import cpp_string, get_custom_build_selector, parse_native_yaml
from torchgen.model import Argument, NativeFunction, NativeFunctionsGroup, Variant
from torchgen.selective_build.selector import SelectiveBuilder
from torchgen.utils import FileManager, make_file_manager, mapMaybe, Target
if TYPE_CHECKING:
from torchgen.selective_build.selector import SelectiveBuilder
# Generates UnboxingFunctions.h & UnboxingFunctions.cpp.
@dataclass(frozen=True)
class ComputeUnboxingFunctions:
@ -156,7 +162,7 @@ def gen_unboxing(
cpu_fm: FileManager,
selector: SelectiveBuilder,
) -> None:
def key_func(fn: Union[NativeFunction, NativeFunctionsGroup]) -> str:
def key_func(fn: NativeFunction | NativeFunctionsGroup) -> str:
return fn.root_name
selected_op_num: int = len(selector.operators)
@ -195,7 +201,7 @@ def gen_unboxing(
)
def main(args: List[str]) -> None:
def main(args: list[str]) -> None:
parser = argparse.ArgumentParser(description="Generate unboxing source files")
parser.add_argument(
"-s",
@ -272,7 +278,7 @@ def main(args: List[str]) -> None:
gen_unboxing(native_functions=native_functions, cpu_fm=cpu_fm, selector=selector)
if options.output_dependencies:
depfile_path = pathlib.Path(options.output_dependencies).resolve()
depfile_path = Path(options.output_dependencies).resolve()
depfile_name = depfile_path.name
depfile_stem = depfile_path.stem