mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Lintrunner: Run mypy-strict on torchgen (#82576)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82576 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
9f0fccadd7
commit
afafd16671
@ -1,25 +1,28 @@
|
||||
import argparse
|
||||
import itertools
|
||||
import os
|
||||
from typing import Sequence, Union
|
||||
from typing import Sequence, TypeVar, Union
|
||||
|
||||
from libfb.py.log import set_simple_logging
|
||||
from libfb.py.log import set_simple_logging # type: ignore[import]
|
||||
|
||||
from torchgen import gen
|
||||
from torchgen.context import native_function_manager
|
||||
from torchgen.model import DispatchKey, NativeFunctionsGroup, NativeFunctionsViewGroup
|
||||
from torchgen.static_runtime import generator
|
||||
from torchgen.static_runtime import config, generator
|
||||
|
||||
# Given a list of `grouped_native_functions` sorted by their op names, return a list of
|
||||
# lists each of which groups ops that share the base name. For example, `mean` and
|
||||
# `mean.dim` are grouped together by this function.
|
||||
|
||||
NativeGroupT = TypeVar(
|
||||
"NativeGroupT",
|
||||
bound=Union[NativeFunctionsGroup, NativeFunctionsViewGroup],
|
||||
)
|
||||
|
||||
|
||||
def group_functions_by_op_name(
|
||||
grouped_native_functions: Sequence[
|
||||
Union[NativeFunctionsGroup, NativeFunctionsViewGroup]
|
||||
]
|
||||
) -> Sequence[Sequence[Union[NativeFunctionsGroup, NativeFunctionsViewGroup]]]:
|
||||
grouped_native_functions: Sequence[NativeGroupT],
|
||||
) -> Sequence[Sequence[NativeGroupT]]:
|
||||
if not grouped_native_functions:
|
||||
return []
|
||||
groups = []
|
||||
@ -34,9 +37,7 @@ def group_functions_by_op_name(
|
||||
for k, group in (
|
||||
itertools.groupby(
|
||||
eligible_ops,
|
||||
key=lambda g: g.functional.func.name.name.base
|
||||
if isinstance(g, NativeFunctionsGroup)
|
||||
else g.view.root_name,
|
||||
key=lambda g: config.func_name_base_str(g),
|
||||
)
|
||||
)
|
||||
]
|
||||
|
Reference in New Issue
Block a user