Revert "[BE] typing for decorators - utils/flop_counter (#131580)"

This reverts commit 81c26ba5ae1edf95da8f6956ae4b5ad23c9833c6.

Reverted https://github.com/pytorch/pytorch/pull/131580 on behalf of https://github.com/clee2000 due to breaking lint internally D60265575 ([comment](https://github.com/pytorch/pytorch/pull/131572#issuecomment-2254328359))
This commit is contained in:
PyTorch MergeBot
2024-07-28 03:29:31 +00:00
parent 2c4023d65f
commit 5ced63a005

View File

@ -1,19 +1,18 @@
# mypy: allow-untyped-defs
# mypy: allow-untyped-decorators
import torch
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
from .module_tracker import ModuleTracker
from typing import List, Any, Dict, Optional, Union, Tuple, Iterator, TypeVar, Callable
from typing_extensions import ParamSpec
from typing import List, Any, Dict, Optional, Union, Tuple, Iterator
from collections import defaultdict
from torch.utils._python_dispatch import TorchDispatchMode
from math import prod
from functools import wraps
import warnings
__all__ = ["FlopCounterMode", "register_flop_formula"]
_T = TypeVar("_T")
_P = ParamSpec("_P")
__all__ = ["FlopCounterMode", "register_flop_formula"]
aten = torch.ops.aten
@ -31,8 +30,8 @@ def shape_wrapper(f):
return f(*args, out_shape=out_shape, **kwargs)
return nf
def register_flop_formula(targets, get_raw=False) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
def register_fun(flop_formula: Callable[_P, _T]) -> Callable[_P, _T]:
def register_flop_formula(targets, get_raw=False):
def register_fun(flop_formula):
if not get_raw:
flop_formula = shape_wrapper(flop_formula)