Lintrunner: Run mypy-strict on torchgen (#82576)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82576
Approved by: https://github.com/ezyang
This commit is contained in:
Peter Bell
2022-08-01 17:19:18 +01:00
committed by PyTorch MergeBot
parent 9f0fccadd7
commit afafd16671
4 changed files with 16 additions and 14 deletions

View File

@ -1,25 +1,28 @@
import argparse
import itertools
import os
from typing import Sequence, Union
from typing import Sequence, TypeVar, Union
from libfb.py.log import set_simple_logging
from libfb.py.log import set_simple_logging # type: ignore[import]
from torchgen import gen
from torchgen.context import native_function_manager
from torchgen.model import DispatchKey, NativeFunctionsGroup, NativeFunctionsViewGroup
from torchgen.static_runtime import generator
from torchgen.static_runtime import config, generator
# Given a list of `grouped_native_functions` sorted by their op names, return a list of
# lists each of which groups ops that share the base name. For example, `mean` and
# `mean.dim` are grouped together by this function.
NativeGroupT = TypeVar(
"NativeGroupT",
bound=Union[NativeFunctionsGroup, NativeFunctionsViewGroup],
)
def group_functions_by_op_name(
grouped_native_functions: Sequence[
Union[NativeFunctionsGroup, NativeFunctionsViewGroup]
]
) -> Sequence[Sequence[Union[NativeFunctionsGroup, NativeFunctionsViewGroup]]]:
grouped_native_functions: Sequence[NativeGroupT],
) -> Sequence[Sequence[NativeGroupT]]:
if not grouped_native_functions:
return []
groups = []
@ -34,9 +37,7 @@ def group_functions_by_op_name(
for k, group in (
itertools.groupby(
eligible_ops,
key=lambda g: g.functional.func.name.name.base
if isinstance(g, NativeFunctionsGroup)
else g.view.root_name,
key=lambda g: config.func_name_base_str(g),
)
)
]