mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pyrefly suppressions 6/n (#164877)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283 Almost there! Test plan: dmypy restart && python3 scripts/lintrunner.py -a pyrefly check step 1: delete lines in the pyrefly.toml file from the project-excludes field step 2: run pyrefly check step 3: add suppressions, clean up unused suppressions before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199 after: INFO 0 errors (5,064 ignored) Only four directories left to enable Pull Request resolved: https://github.com/pytorch/pytorch/pull/164877 Approved by: https://github.com/oulgen
This commit is contained in:
committed by
PyTorch MergeBot
parent
ad7b2bebc6
commit
086dec3235
@ -117,6 +117,7 @@ def context_decorator(ctx, func):
|
||||
|
||||
@functools.wraps(func)
|
||||
def decorate_context(*args, **kwargs):
|
||||
# pyrefly: ignore # bad-context-manager
|
||||
with ctx_factory():
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
@ -41,7 +41,10 @@ if not python_pytree._cxx_pytree_dynamo_traceable:
|
||||
)
|
||||
|
||||
|
||||
# pyrefly: ignore # import-error
|
||||
import optree
|
||||
|
||||
# pyrefly: ignore # import-error
|
||||
from optree import PyTreeSpec as TreeSpec # direct import for type annotations
|
||||
|
||||
|
||||
@ -706,6 +709,7 @@ def tree_map_only(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> PyTree:
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
return tree_map(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
|
||||
|
||||
|
||||
@ -766,6 +770,7 @@ def tree_map_only_(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> PyTree:
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
return tree_map_(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
|
||||
|
||||
|
||||
@ -1079,6 +1084,7 @@ def key_get(obj: Any, kp: KeyPath) -> Any:
|
||||
|
||||
|
||||
with python_pytree._NODE_REGISTRY_LOCK:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
python_pytree._cxx_pytree_imported = True
|
||||
args, kwargs = (), {} # type: ignore[var-annotated]
|
||||
for args, kwargs in python_pytree._cxx_pytree_pending_imports:
|
||||
|
@ -152,6 +152,7 @@ class DebugMode(TorchDispatchMode):
|
||||
super().__enter__()
|
||||
return self
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
def __exit__(self, *args):
|
||||
super().__exit__(*args)
|
||||
if self.record_torchfunction:
|
||||
|
@ -60,6 +60,7 @@ def _device_constructors():
|
||||
# NB: This is directly called from C++ in torch/csrc/Device.cpp
|
||||
class DeviceContext(TorchFunctionMode):
|
||||
def __init__(self, device):
|
||||
# pyrefly: ignore # read-only
|
||||
self.device = torch.device(device)
|
||||
|
||||
def __enter__(self):
|
||||
|
@ -35,10 +35,12 @@ def cache_method(
|
||||
if not (cache := getattr(self, cache_name, None)):
|
||||
cache = {}
|
||||
setattr(self, cache_name, cache)
|
||||
# pyrefly: ignore # unbound-name
|
||||
cached_value = cache.get(args, _cache_sentinel)
|
||||
if cached_value is not _cache_sentinel:
|
||||
return cached_value
|
||||
value = f(self, *args, **kwargs)
|
||||
# pyrefly: ignore # unbound-name
|
||||
cache[args] = value
|
||||
return value
|
||||
|
||||
|
@ -158,6 +158,7 @@ class OrderedSet(MutableSet[T], Reversible[T]):
|
||||
def __and__(self, other: AbstractSet[T_co]) -> OrderedSet[T]:
|
||||
# MutableSet impl will iterate over other, iter over smaller of two sets
|
||||
if isinstance(other, OrderedSet) and len(self) < len(other):
|
||||
# pyrefly: ignore # unsupported-operation, bad-return
|
||||
return other & self
|
||||
return cast(OrderedSet[T], super().__and__(other))
|
||||
|
||||
|
@ -708,6 +708,7 @@ class structseq(tuple[_T_co, ...]):
|
||||
def __new__(
|
||||
cls: type[Self],
|
||||
sequence: Iterable[_T_co],
|
||||
# pyrefly: ignore # bad-function-definition
|
||||
dict: dict[str, Any] = ...,
|
||||
) -> Self:
|
||||
raise NotImplementedError
|
||||
@ -754,6 +755,7 @@ def _tuple_flatten_with_keys(
|
||||
d: tuple[T, ...],
|
||||
) -> tuple[list[tuple[KeyEntry, T]], Context]:
|
||||
values, context = _tuple_flatten(d)
|
||||
# pyrefly: ignore # bad-return
|
||||
return [(SequenceKey(i), v) for i, v in enumerate(values)], context
|
||||
|
||||
|
||||
@ -767,6 +769,7 @@ def _list_flatten(d: list[T]) -> tuple[list[T], Context]:
|
||||
|
||||
def _list_flatten_with_keys(d: list[T]) -> tuple[list[tuple[KeyEntry, T]], Context]:
|
||||
values, context = _list_flatten(d)
|
||||
# pyrefly: ignore # bad-return
|
||||
return [(SequenceKey(i), v) for i, v in enumerate(values)], context
|
||||
|
||||
|
||||
@ -782,6 +785,7 @@ def _dict_flatten_with_keys(
|
||||
d: dict[Any, T],
|
||||
) -> tuple[list[tuple[KeyEntry, T]], Context]:
|
||||
values, context = _dict_flatten(d)
|
||||
# pyrefly: ignore # bad-return
|
||||
return [(MappingKey(k), v) for k, v in zip(context, values)], context
|
||||
|
||||
|
||||
@ -797,6 +801,7 @@ def _namedtuple_flatten_with_keys(
|
||||
d: NamedTuple,
|
||||
) -> tuple[list[tuple[KeyEntry, Any]], Context]:
|
||||
values, context = _namedtuple_flatten(d)
|
||||
# pyrefly: ignore # bad-return
|
||||
return (
|
||||
[(GetAttrKey(field), v) for field, v in zip(context._fields, values)],
|
||||
context,
|
||||
@ -846,6 +851,7 @@ def _ordereddict_flatten_with_keys(
|
||||
d: OrderedDict[Any, T],
|
||||
) -> tuple[list[tuple[KeyEntry, T]], Context]:
|
||||
values, context = _ordereddict_flatten(d)
|
||||
# pyrefly: ignore # bad-return
|
||||
return [(MappingKey(k), v) for k, v in zip(context, values)], context
|
||||
|
||||
|
||||
@ -870,6 +876,7 @@ def _defaultdict_flatten_with_keys(
|
||||
) -> tuple[list[tuple[KeyEntry, T]], Context]:
|
||||
values, context = _defaultdict_flatten(d)
|
||||
_, dict_context = context
|
||||
# pyrefly: ignore # bad-return
|
||||
return [(MappingKey(k), v) for k, v in zip(dict_context, values)], context
|
||||
|
||||
|
||||
@ -918,6 +925,7 @@ def _deque_flatten_with_keys(
|
||||
d: deque[T],
|
||||
) -> tuple[list[tuple[KeyEntry, T]], Context]:
|
||||
values, context = _deque_flatten(d)
|
||||
# pyrefly: ignore # bad-return
|
||||
return [(SequenceKey(i), v) for i, v in enumerate(values)], context
|
||||
|
||||
|
||||
@ -1547,6 +1555,7 @@ def tree_map_only(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> PyTree:
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
return tree_map(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
|
||||
|
||||
|
||||
@ -1607,6 +1616,7 @@ def tree_map_only_(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> PyTree:
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
return tree_map_(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
|
||||
|
||||
|
||||
@ -1819,6 +1829,7 @@ def enum_object_hook(obj: dict[str, Any]) -> Union[Enum, dict[str, Any]]:
|
||||
for attr in classname.split("."):
|
||||
enum_cls = getattr(enum_cls, attr)
|
||||
enum_cls = cast(type[Enum], enum_cls)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
return enum_cls[obj["name"]]
|
||||
return obj
|
||||
|
||||
|
@ -305,6 +305,7 @@ def strobelight(
|
||||
) -> Callable[_P, Optional[_R]]:
|
||||
@functools.wraps(work_function)
|
||||
def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> Optional[_R]:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return profiler.profile(work_function, *args, **kwargs)
|
||||
|
||||
return wrapper_function
|
||||
|
@ -105,6 +105,7 @@ def _keep_float(
|
||||
) -> Callable[[Unpack[_Ts]], Union[_T, sympy.Float]]:
|
||||
@functools.wraps(f)
|
||||
def inner(*args: Unpack[_Ts]) -> Union[_T, sympy.Float]:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
r: Union[_T, sympy.Float] = f(*args)
|
||||
if any(isinstance(a, sympy.Float) for a in args) and not isinstance(
|
||||
r, sympy.Float
|
||||
@ -112,6 +113,7 @@ def _keep_float(
|
||||
r = sympy.Float(float(r))
|
||||
return r
|
||||
|
||||
# pyrefly: ignore # bad-return
|
||||
return inner
|
||||
|
||||
|
||||
@ -198,10 +200,12 @@ class FloorDiv(sympy.Function):
|
||||
|
||||
@property
|
||||
def base(self) -> sympy.Basic:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return self.args[0]
|
||||
|
||||
@property
|
||||
def divisor(self) -> sympy.Basic:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return self.args[1]
|
||||
|
||||
def _sympystr(self, printer: sympy.printing.StrPrinter) -> str:
|
||||
@ -370,6 +374,7 @@ class ModularIndexing(sympy.Function):
|
||||
return None
|
||||
|
||||
def _eval_is_nonnegative(self) -> Optional[bool]:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
p, q = self.args[:2]
|
||||
return fuzzy_eq(p.is_nonnegative, q.is_nonnegative) # type: ignore[attr-defined]
|
||||
|
||||
@ -450,6 +455,7 @@ class PythonMod(sympy.Function):
|
||||
# - floor(p / q) = 0
|
||||
# - p % q = p - floor(p / q) * q = p
|
||||
less = p < q
|
||||
# pyrefly: ignore # missing-attribute
|
||||
if less.is_Boolean and bool(less) and r.is_positive:
|
||||
return p
|
||||
|
||||
@ -466,8 +472,11 @@ class PythonMod(sympy.Function):
|
||||
return True if self.args[1].is_negative else None # type: ignore[attr-defined]
|
||||
|
||||
def _ccode(self, printer):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
p = printer.parenthesize(self.args[0], PRECEDENCE["Atom"] - 0.5)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
q = printer.parenthesize(self.args[1], PRECEDENCE["Atom"] - 0.5)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
abs_q = str(q) if self.args[1].is_positive else f"abs({q})"
|
||||
return f"({p} % {q}) < 0 ? {p} % {q} + {abs_q} : {p} % {q}"
|
||||
|
||||
@ -548,6 +557,7 @@ class CeilToInt(sympy.Function):
|
||||
return sympy.Integer(math.ceil(float(number)))
|
||||
|
||||
def _ccode(self, printer):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
number = printer.parenthesize(self.args[0], self.args[0].precedence - 0.5)
|
||||
return f"ceil({number})"
|
||||
|
||||
@ -818,6 +828,7 @@ class MinMaxBase(Expr, LatticeOp): # type: ignore[misc]
|
||||
if not cond:
|
||||
return ai.func(*[do(i, a) for i in ai.args], evaluate=False)
|
||||
if isinstance(ai, cls):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return ai.func(*[do(i, a) for i in ai.args if i != a], evaluate=False)
|
||||
return a
|
||||
|
||||
@ -995,6 +1006,7 @@ class Max(MinMaxBase, Application): # type: ignore[misc]
|
||||
return fuzzy_or(a.is_nonnegative for a in self.args) # type: ignore[attr-defined]
|
||||
|
||||
def _eval_is_negative(self): # type:ignore[override]
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return fuzzy_and(a.is_negative for a in self.args)
|
||||
|
||||
|
||||
@ -1013,6 +1025,7 @@ class Min(MinMaxBase, Application): # type: ignore[misc]
|
||||
return fuzzy_and(a.is_nonnegative for a in self.args) # type: ignore[attr-defined]
|
||||
|
||||
def _eval_is_negative(self): # type:ignore[override]
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return fuzzy_or(a.is_negative for a in self.args)
|
||||
|
||||
|
||||
@ -1150,7 +1163,9 @@ class IntTrueDiv(sympy.Function):
|
||||
return sympy.Float(int(base) / int(divisor))
|
||||
|
||||
def _ccode(self, printer):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
base = printer.parenthesize(self.args[0], PRECEDENCE["Atom"] - 0.5)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
divisor = printer.parenthesize(self.args[1], PRECEDENCE["Atom"] - 0.5)
|
||||
return f"((int){base}/(int){divisor})"
|
||||
|
||||
@ -1310,9 +1325,11 @@ class Identity(sympy.Function):
|
||||
precedence = 10
|
||||
|
||||
def __repr__(self): # type: ignore[override]
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return f"Identity({self.args[0]})"
|
||||
|
||||
def _eval_is_real(self):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return self.args[0].is_real
|
||||
|
||||
def _eval_is_integer(self):
|
||||
@ -1320,12 +1337,15 @@ class Identity(sympy.Function):
|
||||
|
||||
def _eval_expand_identity(self, **hints):
|
||||
# Removes the identity op.
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return self.args[0]
|
||||
|
||||
def __int__(self) -> int:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return int(self.args[0])
|
||||
|
||||
def __float__(self) -> float:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return float(self.args[0])
|
||||
|
||||
|
||||
|
@ -9,6 +9,7 @@ from sympy.core.parameters import global_parameters
|
||||
from sympy.core.singleton import S, Singleton
|
||||
|
||||
|
||||
# pyrefly: ignore # invalid-inheritance
|
||||
class IntInfinity(Number, metaclass=Singleton):
|
||||
r"""Positive integer infinite quantity.
|
||||
|
||||
@ -203,6 +204,7 @@ class IntInfinity(Number, metaclass=Singleton):
|
||||
int_oo = S.IntInfinity
|
||||
|
||||
|
||||
# pyrefly: ignore # invalid-inheritance
|
||||
class NegativeIntInfinity(Number, metaclass=Singleton):
|
||||
"""Negative integer infinite quantity.
|
||||
|
||||
|
@ -66,6 +66,7 @@ class ExprPrinter(StrPrinter):
|
||||
# NB: this pow by natural, you should never have used builtin sympy.pow
|
||||
# for FloatPow, and a symbolic exponent should be PowByNatural. These
|
||||
# means exp is guaranteed to be integer.
|
||||
# pyrefly: ignore # bad-override
|
||||
def _print_Pow(self, expr: sympy.Expr) -> str:
|
||||
base, exp = expr.args
|
||||
assert exp == int(exp), exp
|
||||
|
@ -175,6 +175,7 @@ class ReferenceAnalysis:
|
||||
|
||||
@staticmethod
|
||||
def pow(a, b):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return _keep_float(FloatPow)(a, b)
|
||||
|
||||
@staticmethod
|
||||
|
@ -123,7 +123,9 @@ AllFn2 = Union[ExprFn2, BoolFn2]
|
||||
class ValueRanges(Generic[_T]):
|
||||
if TYPE_CHECKING:
|
||||
# ruff doesn't understand circular references but mypy does
|
||||
# pyrefly: ignore # unbound-name
|
||||
ExprVR = ValueRanges[sympy.Expr] # noqa: F821
|
||||
# pyrefly: ignore # unbound-name
|
||||
BoolVR = ValueRanges[SympyBoolean] # noqa: F821
|
||||
AllVR = Union[ExprVR, BoolVR]
|
||||
|
||||
@ -464,6 +466,7 @@ class SymPyValueRangeAnalysis:
|
||||
@staticmethod
|
||||
def to_dtype(a, dtype, src_dtype=None):
|
||||
if dtype == torch.float64:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return ValueRanges.increasing_map(a, ToFloat)
|
||||
elif dtype == torch.bool:
|
||||
return ValueRanges.unknown_bool()
|
||||
@ -473,6 +476,7 @@ class SymPyValueRangeAnalysis:
|
||||
|
||||
@staticmethod
|
||||
def trunc_to_int(a, dtype):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return ValueRanges.increasing_map(a, TruncToInt)
|
||||
|
||||
@staticmethod
|
||||
@ -621,7 +625,10 @@ class SymPyValueRangeAnalysis:
|
||||
return ValueRanges.unknown()
|
||||
else:
|
||||
return ValueRanges.coordinatewise_monotone_map(
|
||||
a, b, _keep_float(IntTrueDiv)
|
||||
a,
|
||||
b,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
_keep_float(IntTrueDiv),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -634,7 +641,10 @@ class SymPyValueRangeAnalysis:
|
||||
return ValueRanges.unknown()
|
||||
else:
|
||||
return ValueRanges.coordinatewise_monotone_map(
|
||||
a, b, _keep_float(FloatTrueDiv)
|
||||
a,
|
||||
b,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
_keep_float(FloatTrueDiv),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -713,6 +723,7 @@ class SymPyValueRangeAnalysis:
|
||||
# We should know that b >= 0 but we may have forgotten this fact due
|
||||
# to replacements, so don't assert it, but DO clamp it to prevent
|
||||
# degenerate problems
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
return ValueRanges.coordinatewise_increasing_map(
|
||||
a, b & ValueRanges(0, int_oo), PowByNatural
|
||||
)
|
||||
@ -879,6 +890,7 @@ class SymPyValueRangeAnalysis:
|
||||
|
||||
@classmethod
|
||||
def round_to_int(cls, number, dtype):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return ValueRanges.increasing_map(number, RoundToInt)
|
||||
|
||||
# It's used in some models on symints
|
||||
@ -992,6 +1004,7 @@ class SymPyValueRangeAnalysis:
|
||||
|
||||
@staticmethod
|
||||
def trunc(x):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return ValueRanges.increasing_map(x, TruncToFloat)
|
||||
|
||||
|
||||
|
@ -202,6 +202,7 @@ def _generate_module_methods_for_privateuse1_backend(custom_backend_name: str) -
|
||||
Args:
|
||||
device (int, optional): if specified, all parameters will be copied to that device
|
||||
"""
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return self._apply(lambda t: getattr(t, custom_backend_name)(device))
|
||||
|
||||
_check_register_once(torch.nn.Module, custom_backend_name)
|
||||
|
@ -63,6 +63,7 @@ def generate_coo_data(size, sparse_dim, nnz, dtype, device):
|
||||
indices = torch.rand(sparse_dim, nnz, device=device)
|
||||
indices.mul_(torch.tensor(size[:sparse_dim]).unsqueeze(1).to(indices))
|
||||
indices = indices.to(torch.long)
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
values = torch.rand([nnz, ], dtype=dtype, device=device)
|
||||
return indices, values
|
||||
|
||||
|
@ -15,6 +15,7 @@ _warned_tensor_cores = False
|
||||
_default_float_32_precision = torch.get_float32_matmul_precision()
|
||||
|
||||
try:
|
||||
|
||||
from tabulate import tabulate
|
||||
|
||||
HAS_TABULATE = True
|
||||
@ -169,6 +170,7 @@ if HAS_TABULATE:
|
||||
_disable_tensor_cores()
|
||||
table.append([
|
||||
("Training" if optimizer else "Inference"),
|
||||
# pyrefly: ignore # redundant-condition
|
||||
backend if backend else "-",
|
||||
mode if mode is not None else "-",
|
||||
f"{compilation_time} ms " if compilation_time else "-",
|
||||
@ -189,4 +191,5 @@ if HAS_TABULATE:
|
||||
])
|
||||
|
||||
|
||||
# pyrefly: ignore # not-callable
|
||||
return tabulate(table, headers=field_names, tablefmt="github")
|
||||
|
@ -35,6 +35,7 @@ def _get_build_root() -> str:
|
||||
global _BUILD_ROOT
|
||||
if _BUILD_ROOT is None:
|
||||
_BUILD_ROOT = _make_temp_dir(prefix="benchmark_utils_jit_build")
|
||||
# pyrefly: ignore # missing-argument
|
||||
atexit.register(shutil.rmtree, _BUILD_ROOT)
|
||||
return _BUILD_ROOT
|
||||
|
||||
|
@ -91,6 +91,7 @@ class FuzzedSparseTensor(FuzzedTensor):
|
||||
return x
|
||||
|
||||
def _make_tensor(self, params, state):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
size, _, _ = self._get_size_and_steps(params)
|
||||
density = params['density']
|
||||
nnz = math.ceil(sum(size) * density)
|
||||
@ -99,8 +100,10 @@ class FuzzedSparseTensor(FuzzedTensor):
|
||||
is_coalesced = params['coalesced']
|
||||
sparse_dim = params['sparse_dim'] if self._sparse_dim else len(size)
|
||||
sparse_dim = min(sparse_dim, len(size))
|
||||
# pyrefly: ignore # missing-attribute
|
||||
tensor = self.sparse_tensor_constructor(size, self._dtype, sparse_dim, nnz, is_coalesced)
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
if self._cuda:
|
||||
tensor = tensor.cuda()
|
||||
sparse_dim = tensor.sparse_dim()
|
||||
@ -116,6 +119,7 @@ class FuzzedSparseTensor(FuzzedTensor):
|
||||
"sparse_dim": sparse_dim,
|
||||
"dense_dim": dense_dim,
|
||||
"is_hybrid": is_hybrid,
|
||||
# pyrefly: ignore # missing-attribute
|
||||
"dtype": str(self._dtype),
|
||||
}
|
||||
return tensor, properties
|
||||
|
@ -233,6 +233,7 @@ class Timer:
|
||||
setup = textwrap.dedent(setup)
|
||||
setup = (setup[1:] if setup and setup[0] == "\n" else setup).rstrip()
|
||||
|
||||
# pyrefly: ignore # bad-instantiation
|
||||
self._timer = self._timer_cls(
|
||||
stmt=stmt,
|
||||
setup=setup,
|
||||
|
@ -448,11 +448,13 @@ class GlobalsBridge:
|
||||
load_lines = []
|
||||
for name, wrapped_value in self._globals.items():
|
||||
if wrapped_value.setup is not None:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
load_lines.append(textwrap.dedent(wrapped_value.setup))
|
||||
|
||||
if wrapped_value.serialization == Serialization.PICKLE:
|
||||
path = os.path.join(self._data_dir, f"{name}.pkl")
|
||||
load_lines.append(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
f"with open({repr(path)}, 'rb') as f:\n {name} = pickle.load(f)")
|
||||
with open(path, "wb") as f:
|
||||
pickle.dump(wrapped_value.value, f)
|
||||
@ -462,11 +464,13 @@ class GlobalsBridge:
|
||||
# TODO: Figure out if we can use torch.serialization.add_safe_globals here
|
||||
# Using weights_only=False after the change in
|
||||
# https://dev-discuss.pytorch.org/t/bc-breaking-change-torch-load-is-being-flipped-to-use-weights-only-true-by-default-in-the-nightlies-after-137602/2573
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
load_lines.append(f"{name} = torch.load({repr(path)}, weights_only=False)")
|
||||
torch.save(wrapped_value.value, path)
|
||||
|
||||
elif wrapped_value.serialization == Serialization.TORCH_JIT:
|
||||
path = os.path.join(self._data_dir, f"{name}.pt")
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
load_lines.append(f"{name} = torch.jit.load({repr(path)})")
|
||||
with open(path, "wb") as f:
|
||||
torch.jit.save(wrapped_value.value, f) # type: ignore[no-untyped-call]
|
||||
|
@ -222,6 +222,7 @@ def _get_autocast_kwargs(device_type="cuda"):
|
||||
|
||||
class CheckpointFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(ctx, run_function, preserve_rng_state, *args):
|
||||
check_backward_validity(args)
|
||||
ctx.run_function = run_function
|
||||
@ -784,6 +785,7 @@ class _Holder:
|
||||
|
||||
class _NoopSaveInputs(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(*args):
|
||||
return torch.empty((0,))
|
||||
|
||||
@ -1006,6 +1008,7 @@ def _get_debug_context_and_cb() -> Tuple[Callable[[], Any], Callable[[Checkpoint
|
||||
def logging_mode():
|
||||
with LoggingTensorMode(), \
|
||||
capture_logs(True, python_tb=True, script_tb=True, cpp_tb=cpp_tb) as logs_and_tb:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self.logs, self.tbs = logs_and_tb
|
||||
yield logs_and_tb
|
||||
return logging_mode()
|
||||
|
@ -787,6 +787,7 @@ class BuildExtension(build_ext):
|
||||
|
||||
# Use absolute path for output_dir so that the object file paths
|
||||
# (`objects`) get generated with absolute paths.
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
output_dir = os.path.abspath(output_dir)
|
||||
|
||||
# See Note [Absolute include_dirs]
|
||||
@ -977,6 +978,7 @@ class BuildExtension(build_ext):
|
||||
is_standalone=False):
|
||||
if not self.compiler.initialized:
|
||||
self.compiler.initialize()
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
output_dir = os.path.abspath(output_dir)
|
||||
|
||||
# Note [Absolute include_dirs]
|
||||
@ -1528,6 +1530,7 @@ def include_paths(device_type: str = "cpu", torch_include_dirs=True) -> list[str
|
||||
# Support CUDA_INC_PATH env variable supported by CMake files
|
||||
if (cuda_inc_path := os.environ.get("CUDA_INC_PATH", None)) and \
|
||||
cuda_inc_path != '/usr/include':
|
||||
# pyrefly: ignore # unbound-name
|
||||
paths.append(cuda_inc_path)
|
||||
if CUDNN_HOME is not None:
|
||||
paths.append(os.path.join(CUDNN_HOME, 'include'))
|
||||
@ -2569,6 +2572,7 @@ def _get_num_workers(verbose: bool) -> Optional[int]:
|
||||
def _get_vc_env(vc_arch: str) -> dict[str, str]:
|
||||
try:
|
||||
from setuptools import distutils # type: ignore[attr-defined]
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return distutils._msvccompiler._get_vc_env(vc_arch)
|
||||
except AttributeError:
|
||||
try:
|
||||
|
@ -204,6 +204,7 @@ def collate(
|
||||
# check to make sure that the elements in batch have consistent size
|
||||
it = iter(batch)
|
||||
elem_size = len(next(it))
|
||||
# pyrefly: ignore # not-iterable
|
||||
if not all(len(elem) == elem_size for elem in it):
|
||||
raise RuntimeError("each element in list of batch should be of equal size")
|
||||
transposed = list(zip(*batch)) # It may be accessed twice, so we use a list.
|
||||
|
@ -70,6 +70,7 @@ def pin_memory(data, device=None):
|
||||
return clone
|
||||
else:
|
||||
return type(data)(
|
||||
# pyrefly: ignore # bad-argument-count
|
||||
{k: pin_memory(sample, device) for k, sample in data.items()}
|
||||
) # type: ignore[call-arg]
|
||||
except TypeError:
|
||||
|
@ -674,6 +674,7 @@ class _BaseDataLoaderIter:
|
||||
|
||||
# Set pin memory device based on the current accelerator.
|
||||
self._pin_memory_device = (
|
||||
# pyrefly: ignore # unbound-name
|
||||
acc.type
|
||||
if self._pin_memory
|
||||
and (acc := torch.accelerator.current_accelerator()) is not None
|
||||
|
@ -265,6 +265,7 @@ class _DataPipeType:
|
||||
|
||||
# Default type for DataPipe without annotation
|
||||
_T_co = TypeVar("_T_co", covariant=True)
|
||||
# pyrefly: ignore # invalid-annotation
|
||||
_DEFAULT_TYPE = _DataPipeType(Generic[_T_co])
|
||||
|
||||
|
||||
@ -283,6 +284,7 @@ class _DataPipeMeta(GenericMeta):
|
||||
return super().__new__(cls, name, bases, namespace, **kwargs) # type: ignore[call-overload]
|
||||
|
||||
# TODO: the statements below are not reachable by design as there is a bug and typing is low priority for now.
|
||||
# pyrefly: ignore # no-access
|
||||
cls.__origin__ = None
|
||||
if "type" in namespace:
|
||||
return super().__new__(cls, name, bases, namespace, **kwargs) # type: ignore[call-overload]
|
||||
|
@ -80,6 +80,7 @@ class Capture:
|
||||
|
||||
def _ops_str(self):
|
||||
res = ""
|
||||
# pyrefly: ignore # not-iterable
|
||||
for op in self.ctx["operations"]:
|
||||
if len(res) > 0:
|
||||
res += "\n"
|
||||
@ -89,6 +90,7 @@ class Capture:
|
||||
def __getstate__(self):
|
||||
# TODO(VitalyFedyunin): Currently can't pickle (why?)
|
||||
self.ctx["schema_df"] = None
|
||||
# pyrefly: ignore # not-iterable
|
||||
for var in self.ctx["variables"]:
|
||||
var.calculated_value = None
|
||||
state = {}
|
||||
@ -112,11 +114,13 @@ class Capture:
|
||||
return CaptureGetItem(self, key, ctx=self.ctx)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self.ctx["operations"].append(CaptureSetItem(self, key, value, ctx=self.ctx))
|
||||
|
||||
def __add__(self, add_val):
|
||||
res = CaptureAdd(self, add_val, ctx=self.ctx)
|
||||
var = CaptureVariable(res, ctx=self.ctx)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self.ctx["operations"].append(
|
||||
CaptureVariableAssign(variable=var, value=res, ctx=self.ctx)
|
||||
)
|
||||
@ -125,6 +129,7 @@ class Capture:
|
||||
def __sub__(self, add_val):
|
||||
res = CaptureSub(self, add_val, ctx=self.ctx)
|
||||
var = CaptureVariable(res, ctx=self.ctx)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self.ctx["operations"].append(
|
||||
CaptureVariableAssign(variable=var, value=res, ctx=self.ctx)
|
||||
)
|
||||
@ -134,15 +139,19 @@ class Capture:
|
||||
res = CaptureMul(self, add_val, ctx=self.ctx)
|
||||
var = CaptureVariable(res, ctx=self.ctx)
|
||||
t = CaptureVariableAssign(variable=var, value=res, ctx=self.ctx)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self.ctx["operations"].append(t)
|
||||
return var
|
||||
|
||||
def _is_context_empty(self):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return len(self.ctx["operations"]) == 0 and len(self.ctx["variables"]) == 0
|
||||
|
||||
def apply_ops_2(self, dataframe):
|
||||
# TODO(VitalyFedyunin): Make this calculation thread safe (as currently it updates pointer)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
self.ctx["variables"][0].calculated_value = dataframe
|
||||
# pyrefly: ignore # not-iterable
|
||||
for op in self.ctx["operations"]:
|
||||
op.execute()
|
||||
|
||||
@ -175,6 +184,7 @@ class Capture:
|
||||
res = CaptureCall(self, ctx=self.ctx, args=args, kwargs=kwargs)
|
||||
var = CaptureVariable(None, ctx=self.ctx)
|
||||
t = CaptureVariableAssign(ctx=self.ctx, variable=var, value=res)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self.ctx["operations"].append(t)
|
||||
return var
|
||||
|
||||
@ -273,7 +283,9 @@ class CaptureVariable(Capture):
|
||||
|
||||
def apply_ops(self, dataframe):
|
||||
# TODO(VitalyFedyunin): Make this calculation thread safe (as currently it updates pointer)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
self.ctx["variables"][0].calculated_value = dataframe
|
||||
# pyrefly: ignore # not-iterable
|
||||
for op in self.ctx["operations"]:
|
||||
op.execute()
|
||||
return self.calculated_value
|
||||
@ -373,6 +385,7 @@ def get_val(capture):
|
||||
|
||||
class CaptureInitial(CaptureVariable):
|
||||
def __init__(self, schema_df=None):
|
||||
# pyrefly: ignore # bad-assignment
|
||||
new_ctx: dict[str, list[Any]] = {
|
||||
"operations": [],
|
||||
"variables": [],
|
||||
@ -388,6 +401,7 @@ class CaptureDataFrame(CaptureInitial):
|
||||
|
||||
class CaptureDataFrameWithDataPipeOps(CaptureDataFrame):
|
||||
def as_datapipe(self):
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
return DataFrameTracedOps(self.ctx["variables"][0].source_datapipe, self)
|
||||
|
||||
def raw_iterator(self):
|
||||
|
@ -92,6 +92,7 @@ class FilterDataFramesPipe(DFIterDataPipe):
|
||||
size = None
|
||||
all_buffer = []
|
||||
filter_res = []
|
||||
# pyrefly: ignore # bad-assignment
|
||||
for df in self.source_datapipe:
|
||||
if size is None:
|
||||
size = len(df.index)
|
||||
|
@ -135,6 +135,7 @@ class IterDataPipe(IterableDataset[_T_co], metaclass=_IterDataPipeMeta):
|
||||
_fast_forward_iterator: Optional[Iterator] = None
|
||||
|
||||
def __iter__(self) -> Iterator[_T_co]:
|
||||
# pyrefly: ignore # bad-return
|
||||
return self
|
||||
|
||||
def __getattr__(self, attribute_name):
|
||||
@ -379,6 +380,7 @@ class _DataPipeSerializationWrapper:
|
||||
value = pickle.dumps(self._datapipe)
|
||||
except Exception:
|
||||
if HAS_DILL:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
value = dill.dumps(self._datapipe)
|
||||
use_dill = True
|
||||
else:
|
||||
@ -388,6 +390,7 @@ class _DataPipeSerializationWrapper:
|
||||
def __setstate__(self, state):
|
||||
value, use_dill = state
|
||||
if use_dill:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self._datapipe = dill.loads(value)
|
||||
else:
|
||||
self._datapipe = pickle.loads(value)
|
||||
@ -404,6 +407,7 @@ class _DataPipeSerializationWrapper:
|
||||
class _IterDataPipeSerializationWrapper(_DataPipeSerializationWrapper, IterDataPipe):
|
||||
def __init__(self, datapipe: IterDataPipe[_T_co]):
|
||||
super().__init__(datapipe)
|
||||
# pyrefly: ignore # invalid-type-var
|
||||
self._datapipe_iter: Optional[Iterator[_T_co]] = None
|
||||
|
||||
def __iter__(self) -> "_IterDataPipeSerializationWrapper":
|
||||
|
@ -118,6 +118,7 @@ class MapperIterDataPipe(IterDataPipe[_T_co]):
|
||||
for idx in sorted(self.input_col[1:], reverse=True):
|
||||
del data[idx]
|
||||
else:
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
data[self.input_col] = res
|
||||
else:
|
||||
if self.output_col == -1:
|
||||
|
@ -42,6 +42,7 @@ class SamplerIterDataPipe(IterDataPipe[_T_co]):
|
||||
"Sampler class requires input datapipe implemented `__len__`"
|
||||
)
|
||||
super().__init__()
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self.datapipe = datapipe
|
||||
self.sampler_args = () if sampler_args is None else sampler_args
|
||||
self.sampler_kwargs = {} if sampler_kwargs is None else sampler_kwargs
|
||||
|
@ -59,6 +59,7 @@ class ConcaterIterDataPipe(IterDataPipe):
|
||||
|
||||
def __len__(self) -> int:
|
||||
if all(isinstance(dp, Sized) for dp in self.datapipes):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return sum(len(dp) for dp in self.datapipes)
|
||||
else:
|
||||
raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
|
||||
@ -179,6 +180,7 @@ class _ForkerIterDataPipe(IterDataPipe, _ContainerTemplate):
|
||||
self._child_stop: list[bool] = [True for _ in range(num_instances)]
|
||||
|
||||
def __len__(self):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return len(self.main_datapipe)
|
||||
|
||||
def get_next_element_by_instance(self, instance_id: int):
|
||||
@ -238,6 +240,7 @@ class _ForkerIterDataPipe(IterDataPipe, _ContainerTemplate):
|
||||
return self.end_ptr is not None and all(self._child_stop)
|
||||
|
||||
def get_length_by_instance(self, instance_id: int) -> int:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return len(self.main_datapipe)
|
||||
|
||||
def reset(self) -> None:
|
||||
@ -323,6 +326,7 @@ class _ChildDataPipe(IterDataPipe):
|
||||
def __init__(self, main_datapipe: IterDataPipe, instance_id: int):
|
||||
assert isinstance(main_datapipe, _ContainerTemplate)
|
||||
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self.main_datapipe: IterDataPipe = main_datapipe
|
||||
self.instance_id = instance_id
|
||||
|
||||
@ -449,6 +453,7 @@ class _DemultiplexerIterDataPipe(IterDataPipe, _ContainerTemplate):
|
||||
drop_none: bool,
|
||||
buffer_size: int,
|
||||
):
|
||||
# pyrefly: ignore # invalid-type-var
|
||||
self.main_datapipe = datapipe
|
||||
self._datapipe_iterator: Optional[Iterator[Any]] = None
|
||||
self.num_instances = num_instances
|
||||
@ -460,7 +465,9 @@ class _DemultiplexerIterDataPipe(IterDataPipe, _ContainerTemplate):
|
||||
UserWarning,
|
||||
)
|
||||
self.current_buffer_usage = 0
|
||||
# pyrefly: ignore # invalid-type-var
|
||||
self.child_buffers: list[deque[_T_co]] = [deque() for _ in range(num_instances)]
|
||||
# pyrefly: ignore # invalid-type-var
|
||||
self.classifier_fn = classifier_fn
|
||||
self.drop_none = drop_none
|
||||
self.main_datapipe_exhausted = False
|
||||
@ -698,6 +705,7 @@ class ZipperIterDataPipe(IterDataPipe[tuple[_T_co]]):
|
||||
|
||||
def __len__(self) -> int:
|
||||
if all(isinstance(dp, Sized) for dp in self.datapipes):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return min(len(dp) for dp in self.datapipes)
|
||||
else:
|
||||
raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
|
||||
|
@ -203,7 +203,9 @@ class GrouperIterDataPipe(IterDataPipe[DataChunk]):
|
||||
drop_remaining: bool = False,
|
||||
):
|
||||
_check_unpickable_fn(group_key_fn)
|
||||
# pyrefly: ignore # invalid-type-var
|
||||
self.datapipe = datapipe
|
||||
# pyrefly: ignore # invalid-type-var
|
||||
self.group_key_fn = group_key_fn
|
||||
|
||||
self.keep_key = keep_key
|
||||
@ -214,9 +216,11 @@ class GrouperIterDataPipe(IterDataPipe[DataChunk]):
|
||||
self.guaranteed_group_size = None
|
||||
if group_size is not None and buffer_size is not None:
|
||||
assert 0 < group_size <= buffer_size
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self.guaranteed_group_size = group_size
|
||||
if guaranteed_group_size is not None:
|
||||
assert group_size is not None and 0 < guaranteed_group_size <= group_size
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self.guaranteed_group_size = guaranteed_group_size
|
||||
self.drop_remaining = drop_remaining
|
||||
self.wrapper_class = DataChunk
|
||||
|
@ -60,6 +60,7 @@ class MapperMapDataPipe(MapDataPipe[_T_co]):
|
||||
self.fn = fn # type: ignore[assignment]
|
||||
|
||||
def __len__(self) -> int:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return len(self.datapipe)
|
||||
|
||||
def __getitem__(self, index) -> _T_co:
|
||||
|
@ -64,6 +64,7 @@ class ShufflerIterDataPipe(IterDataPipe[_T_co]):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.datapipe = datapipe
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
self.indices = list(range(len(datapipe))) if indices is None else indices
|
||||
self._enabled = True
|
||||
self._seed = None
|
||||
@ -95,6 +96,7 @@ class ShufflerIterDataPipe(IterDataPipe[_T_co]):
|
||||
self._shuffled_indices = self._rng.sample(self.indices, len(self.indices))
|
||||
|
||||
def __len__(self) -> int:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return len(self.datapipe)
|
||||
|
||||
def __getstate__(self):
|
||||
|
@ -49,13 +49,16 @@ class ConcaterMapDataPipe(MapDataPipe):
|
||||
def __getitem__(self, index) -> _T_co: # type: ignore[type-var]
|
||||
offset = 0
|
||||
for dp in self.datapipes:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
if index - offset < len(dp):
|
||||
return dp[index - offset]
|
||||
else:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
offset += len(dp)
|
||||
raise IndexError(f"Index {index} is out of range.")
|
||||
|
||||
def __len__(self) -> int:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return sum(len(dp) for dp in self.datapipes)
|
||||
|
||||
|
||||
@ -102,4 +105,5 @@ class ZipperMapDataPipe(MapDataPipe[tuple[_T_co, ...]]):
|
||||
return tuple(res)
|
||||
|
||||
def __len__(self) -> int:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return min(len(dp) for dp in self.datapipes)
|
||||
|
@ -196,6 +196,7 @@ def get_file_pathnames_from_root(
|
||||
if match_masks(fname, masks):
|
||||
yield path
|
||||
else:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
for path, dirs, files in os.walk(root, onerror=onerror):
|
||||
if abspath:
|
||||
path = os.path.abspath(path)
|
||||
|
@ -43,6 +43,7 @@ def _simple_graph_snapshot_restoration(
|
||||
# simple fast-forwarding. Therefore, we need to call `reset` twice, because if `SnapshotState` is `Restored`,
|
||||
# the first reset will not actually reset.
|
||||
datapipe.reset() # This ensures `SnapshotState` is `Iterating` by this point, even if it was `Restored`.
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
apply_random_seed(datapipe, rng)
|
||||
|
||||
remainder = n_iterations
|
||||
|
@ -131,6 +131,7 @@ class DistributedSampler(Sampler[_T_co]):
|
||||
indices = indices[self.rank : self.total_size : self.num_replicas]
|
||||
assert len(indices) == self.num_samples
|
||||
|
||||
# pyrefly: ignore # bad-return
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self) -> int:
|
||||
|
@ -72,6 +72,7 @@ def _list_connected_datapipes(
|
||||
p.dump(scan_obj)
|
||||
except (pickle.PickleError, AttributeError, TypeError):
|
||||
if dill_available():
|
||||
# pyrefly: ignore # missing-attribute
|
||||
d.dump(scan_obj)
|
||||
else:
|
||||
raise
|
||||
|
@ -31,6 +31,7 @@ class FileBaton:
|
||||
True if the file could be created, else False.
|
||||
"""
|
||||
try:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self.fd = os.open(self.lock_file_path, os.O_CREAT | os.O_EXCL)
|
||||
return True
|
||||
except FileExistsError:
|
||||
|
@ -149,6 +149,7 @@ def conv_flop_count(
|
||||
@register_flop_formula([aten.convolution, aten._convolution, aten.cudnn_convolution, aten._slow_conv2d_forward])
|
||||
def conv_flop(x_shape, w_shape, _bias, _stride, _padding, _dilation, transposed, *args, out_shape=None, **kwargs) -> int:
|
||||
"""Count flops for convolution."""
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed)
|
||||
|
||||
|
||||
@ -676,7 +677,9 @@ class FlopCounterMode:
|
||||
if depth is None:
|
||||
depth = 999999
|
||||
|
||||
|
||||
import tabulate
|
||||
# pyrefly: ignore # bad-assignment
|
||||
tabulate.PRESERVE_WHITESPACE = True
|
||||
header = ["Module", "FLOP", "% Total"]
|
||||
values = []
|
||||
|
@ -48,6 +48,7 @@ MATH_TRANSPILATIONS = collections.OrderedDict(
|
||||
]
|
||||
)
|
||||
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
CUDA_TYPE_NAME_MAP = collections.OrderedDict(
|
||||
[
|
||||
("CUresult", ("hipError_t", CONV_TYPE, API_DRIVER)),
|
||||
@ -675,6 +676,7 @@ CUDA_INCLUDE_MAP = collections.OrderedDict(
|
||||
]
|
||||
)
|
||||
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
CUDA_IDENTIFIER_MAP = collections.OrderedDict(
|
||||
[
|
||||
("__CUDACC__", ("__HIPCC__", CONV_DEF, API_RUNTIME)),
|
||||
|
@ -663,6 +663,7 @@ def is_caffe2_gpu_file(rel_filepath):
|
||||
return True
|
||||
filename = os.path.basename(rel_filepath)
|
||||
_, ext = os.path.splitext(filename)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
return ('gpu' in filename or ext in ['.cu', '.cuh']) and ('cudnn' not in filename)
|
||||
|
||||
class TrieNode:
|
||||
@ -1137,6 +1138,7 @@ def hipify(
|
||||
out_of_place_only=out_of_place_only,
|
||||
is_pytorch_extension=is_pytorch_extension))
|
||||
all_files_set = set(all_files)
|
||||
# pyrefly: ignore # bad-assignment
|
||||
for f in extra_files:
|
||||
if not os.path.isabs(f):
|
||||
f = os.path.join(output_directory, f)
|
||||
|
@ -145,6 +145,7 @@ class BackwardHook:
|
||||
|
||||
res = out
|
||||
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self.grad_outputs = None
|
||||
|
||||
return self._unpack_none(self.input_tensors_index, res)
|
||||
|
@ -208,6 +208,7 @@ def get_model_info(
|
||||
with zipfile.ZipFile(path_or_file) as zf:
|
||||
path_prefix = None
|
||||
zip_files = []
|
||||
# pyrefly: ignore # bad-assignment
|
||||
for zi in zf.infolist():
|
||||
prefix = re.sub("/.*", "", zi.filename)
|
||||
if path_prefix is None:
|
||||
@ -359,9 +360,12 @@ def get_inline_skeleton():
|
||||
|
||||
import importlib.resources
|
||||
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
skeleton = importlib.resources.read_text(__package__, "skeleton.html")
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
js_code = importlib.resources.read_text(__package__, "code.js")
|
||||
for js_module in ["preact", "htm"]:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
js_lib = importlib.resources.read_binary(__package__, f"{js_module}.mjs")
|
||||
js_url = "data:application/javascript," + urllib.parse.quote(js_lib)
|
||||
js_code = js_code.replace(f"https://unpkg.com/{js_module}?module", js_url)
|
||||
|
@ -31,5 +31,7 @@ def make_np(x: torch.Tensor) -> np.ndarray:
|
||||
def _prepare_pytorch(x: torch.Tensor) -> np.ndarray:
|
||||
if x.dtype == torch.bfloat16:
|
||||
x = x.to(torch.float16)
|
||||
# pyrefly: ignore # bad-assignment
|
||||
x = x.detach().cpu().numpy()
|
||||
# pyrefly: ignore # bad-return
|
||||
return x
|
||||
|
@ -188,6 +188,7 @@ class GraphPy:
|
||||
|
||||
for key, node in self.nodes_io.items():
|
||||
if type(node) == NodeBase:
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
self.unique_name_to_scoped_name[key] = node.scope + "/" + node.debugName
|
||||
if hasattr(node, "input_or_output"):
|
||||
self.unique_name_to_scoped_name[key] = (
|
||||
@ -198,6 +199,7 @@ class GraphPy:
|
||||
self.unique_name_to_scoped_name[key] = node.scope + "/" + node.debugName
|
||||
if node.scope == "" and self.shallowest_scope_name:
|
||||
self.unique_name_to_scoped_name[node.debugName] = (
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
self.shallowest_scope_name + "/" + node.debugName
|
||||
)
|
||||
|
||||
|
@ -57,11 +57,14 @@ def _prepare_video(V):
|
||||
return num != 0 and ((num & (num - 1)) == 0)
|
||||
|
||||
# pad to nearest power of 2, all at once
|
||||
# pyrefly: ignore # index-error
|
||||
if not is_power2(V.shape[0]):
|
||||
# pyrefly: ignore # index-error
|
||||
len_addition = int(2 ** V.shape[0].bit_length() - V.shape[0])
|
||||
V = np.concatenate((V, np.zeros(shape=(len_addition, t, c, h, w))), axis=0)
|
||||
|
||||
n_rows = 2 ** ((b.bit_length() - 1) // 2)
|
||||
# pyrefly: ignore # index-error
|
||||
n_cols = V.shape[0] // n_rows
|
||||
|
||||
V = np.reshape(V, newshape=(n_rows, n_cols, t, c, h, w))
|
||||
|
@ -9,6 +9,7 @@ from typing import Any, Optional
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
# pyrefly: ignore # import-error
|
||||
from google.protobuf import struct_pb2
|
||||
|
||||
from tensorboard.compat.proto.summary_pb2 import (
|
||||
@ -497,6 +498,7 @@ def make_histogram(values, bins, max_bins=None):
|
||||
subsampling = num_bins // max_bins
|
||||
subsampling_remainder = num_bins % subsampling
|
||||
if subsampling_remainder != 0:
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
counts = np.pad(
|
||||
counts,
|
||||
pad_width=[[0, subsampling - subsampling_remainder]],
|
||||
@ -834,17 +836,21 @@ def compute_curve(labels, predictions, num_thresholds=None, weights=None):
|
||||
weights = 1.0
|
||||
|
||||
# Compute bins of true positives and false positives.
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
bucket_indices = np.int32(np.floor(predictions * (num_thresholds - 1)))
|
||||
float_labels = labels.astype(np.float64)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
histogram_range = (0, num_thresholds - 1)
|
||||
tp_buckets, _ = np.histogram(
|
||||
bucket_indices,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
bins=num_thresholds,
|
||||
range=histogram_range,
|
||||
weights=float_labels * weights,
|
||||
)
|
||||
fp_buckets, _ = np.histogram(
|
||||
bucket_indices,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
bins=num_thresholds,
|
||||
range=histogram_range,
|
||||
weights=(1.0 - float_labels) * weights,
|
||||
|
@ -254,7 +254,9 @@ class SummaryWriter:
|
||||
buckets = []
|
||||
neg_buckets = []
|
||||
while v < 1e20:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
buckets.append(v)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
neg_buckets.append(-v)
|
||||
v *= 1.1
|
||||
self.default_bins = neg_buckets[::-1] + [0] + buckets
|
||||
@ -262,15 +264,19 @@ class SummaryWriter:
|
||||
def _get_file_writer(self):
|
||||
"""Return the default FileWriter instance. Recreates it if closed."""
|
||||
if self.all_writers is None or self.file_writer is None:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self.file_writer = FileWriter(
|
||||
self.log_dir, self.max_queue, self.flush_secs, self.filename_suffix
|
||||
)
|
||||
# pyrefly: ignore # bad-assignment, missing-attribute
|
||||
self.all_writers = {self.file_writer.get_logdir(): self.file_writer}
|
||||
if self.purge_step is not None:
|
||||
most_recent_step = self.purge_step
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self.file_writer.add_event(
|
||||
Event(step=most_recent_step, file_version="brain.Event:2")
|
||||
)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self.file_writer.add_event(
|
||||
Event(
|
||||
step=most_recent_step,
|
||||
@ -950,6 +956,7 @@ class SummaryWriter:
|
||||
)
|
||||
self._projector_config.embeddings.extend([embedding_info])
|
||||
|
||||
# pyrefly: ignore # import-error
|
||||
from google.protobuf import text_format
|
||||
|
||||
config_pbtxt = text_format.MessageToString(self._projector_config)
|
||||
@ -1199,6 +1206,7 @@ class SummaryWriter:
|
||||
for writer in self.all_writers.values():
|
||||
writer.flush()
|
||||
writer.close()
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self.file_writer = self.all_writers = None
|
||||
|
||||
def __enter__(self):
|
||||
|
@ -461,6 +461,7 @@ def to_html(nodes):
|
||||
if n.context is None:
|
||||
continue
|
||||
s = _listener_template.format(id=str(i + 1), stack=escape(f'{n.label}:\n{n.context}'))
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
listeners.append(s)
|
||||
dot = to_dot(nodes)
|
||||
return _template.replace('$DOT', repr(dot)).replace('$LISTENERS', '\n'.join(listeners))
|
||||
|
@ -292,6 +292,7 @@ class WeakIdKeyDictionary(MutableMapping):
|
||||
if o is not None:
|
||||
return o, value
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
def pop(self, key, *args):
|
||||
self._dirty_len = True
|
||||
return self.data.pop(self.ref_type(key), *args) # CHANGED
|
||||
|
Reference in New Issue
Block a user