mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[BE] typing for decorators - utils/flop_counter (#131580)"
This reverts commit 81c26ba5ae1edf95da8f6956ae4b5ad23c9833c6. Reverted https://github.com/pytorch/pytorch/pull/131580 on behalf of https://github.com/clee2000 due to breaking lint internally D60265575 ([comment](https://github.com/pytorch/pytorch/pull/131572#issuecomment-2254328359))
This commit is contained in:
@ -1,19 +1,18 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# mypy: allow-untyped-decorators
|
||||
import torch
|
||||
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
|
||||
from .module_tracker import ModuleTracker
|
||||
from typing import List, Any, Dict, Optional, Union, Tuple, Iterator, TypeVar, Callable
|
||||
from typing_extensions import ParamSpec
|
||||
from typing import List, Any, Dict, Optional, Union, Tuple, Iterator
|
||||
from collections import defaultdict
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
from math import prod
|
||||
from functools import wraps
|
||||
import warnings
|
||||
|
||||
__all__ = ["FlopCounterMode", "register_flop_formula"]
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
__all__ = ["FlopCounterMode", "register_flop_formula"]
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
@ -31,8 +30,8 @@ def shape_wrapper(f):
|
||||
return f(*args, out_shape=out_shape, **kwargs)
|
||||
return nf
|
||||
|
||||
def register_flop_formula(targets, get_raw=False) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
|
||||
def register_fun(flop_formula: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||
def register_flop_formula(targets, get_raw=False):
|
||||
def register_fun(flop_formula):
|
||||
if not get_raw:
|
||||
flop_formula = shape_wrapper(flop_formula)
|
||||
|
||||
|
Reference in New Issue
Block a user