Pyrefly suppressions 6/n (#164877)

Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

Almost there!

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: delete lines in the pyrefly.toml file from the project-excludes field
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199

after:

INFO 0 errors (5,064 ignored)

Only four directories left to enable

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164877
Approved by: https://github.com/oulgen
This commit is contained in:
Maggie Moss
2025-10-08 02:30:53 +00:00
committed by PyTorch MergeBot
parent ad7b2bebc6
commit 086dec3235
123 changed files with 355 additions and 72 deletions

View File

@ -117,6 +117,7 @@ def context_decorator(ctx, func):
@functools.wraps(func)
def decorate_context(*args, **kwargs):
# pyrefly: ignore # bad-context-manager
with ctx_factory():
return func(*args, **kwargs)

View File

@ -41,7 +41,10 @@ if not python_pytree._cxx_pytree_dynamo_traceable:
)
# pyrefly: ignore # import-error
import optree
# pyrefly: ignore # import-error
from optree import PyTreeSpec as TreeSpec # direct import for type annotations
@ -706,6 +709,7 @@ def tree_map_only(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree:
# pyrefly: ignore # no-matching-overload
return tree_map(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
@ -766,6 +770,7 @@ def tree_map_only_(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree:
# pyrefly: ignore # no-matching-overload
return tree_map_(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
@ -1079,6 +1084,7 @@ def key_get(obj: Any, kp: KeyPath) -> Any:
with python_pytree._NODE_REGISTRY_LOCK:
# pyrefly: ignore # bad-assignment
python_pytree._cxx_pytree_imported = True
args, kwargs = (), {} # type: ignore[var-annotated]
for args, kwargs in python_pytree._cxx_pytree_pending_imports:

View File

@ -152,6 +152,7 @@ class DebugMode(TorchDispatchMode):
super().__enter__()
return self
# pyrefly: ignore # bad-override
def __exit__(self, *args):
super().__exit__(*args)
if self.record_torchfunction:

View File

@ -60,6 +60,7 @@ def _device_constructors():
# NB: This is directly called from C++ in torch/csrc/Device.cpp
class DeviceContext(TorchFunctionMode):
def __init__(self, device):
# pyrefly: ignore # read-only
self.device = torch.device(device)
def __enter__(self):

View File

@ -35,10 +35,12 @@ def cache_method(
if not (cache := getattr(self, cache_name, None)):
cache = {}
setattr(self, cache_name, cache)
# pyrefly: ignore # unbound-name
cached_value = cache.get(args, _cache_sentinel)
if cached_value is not _cache_sentinel:
return cached_value
value = f(self, *args, **kwargs)
# pyrefly: ignore # unbound-name
cache[args] = value
return value

View File

@ -158,6 +158,7 @@ class OrderedSet(MutableSet[T], Reversible[T]):
def __and__(self, other: AbstractSet[T_co]) -> OrderedSet[T]:
# MutableSet impl will iterate over other, iter over smaller of two sets
if isinstance(other, OrderedSet) and len(self) < len(other):
# pyrefly: ignore # unsupported-operation, bad-return
return other & self
return cast(OrderedSet[T], super().__and__(other))

View File

@ -708,6 +708,7 @@ class structseq(tuple[_T_co, ...]):
def __new__(
cls: type[Self],
sequence: Iterable[_T_co],
# pyrefly: ignore # bad-function-definition
dict: dict[str, Any] = ...,
) -> Self:
raise NotImplementedError
@ -754,6 +755,7 @@ def _tuple_flatten_with_keys(
d: tuple[T, ...],
) -> tuple[list[tuple[KeyEntry, T]], Context]:
values, context = _tuple_flatten(d)
# pyrefly: ignore # bad-return
return [(SequenceKey(i), v) for i, v in enumerate(values)], context
@ -767,6 +769,7 @@ def _list_flatten(d: list[T]) -> tuple[list[T], Context]:
def _list_flatten_with_keys(d: list[T]) -> tuple[list[tuple[KeyEntry, T]], Context]:
values, context = _list_flatten(d)
# pyrefly: ignore # bad-return
return [(SequenceKey(i), v) for i, v in enumerate(values)], context
@ -782,6 +785,7 @@ def _dict_flatten_with_keys(
d: dict[Any, T],
) -> tuple[list[tuple[KeyEntry, T]], Context]:
values, context = _dict_flatten(d)
# pyrefly: ignore # bad-return
return [(MappingKey(k), v) for k, v in zip(context, values)], context
@ -797,6 +801,7 @@ def _namedtuple_flatten_with_keys(
d: NamedTuple,
) -> tuple[list[tuple[KeyEntry, Any]], Context]:
values, context = _namedtuple_flatten(d)
# pyrefly: ignore # bad-return
return (
[(GetAttrKey(field), v) for field, v in zip(context._fields, values)],
context,
@ -846,6 +851,7 @@ def _ordereddict_flatten_with_keys(
d: OrderedDict[Any, T],
) -> tuple[list[tuple[KeyEntry, T]], Context]:
values, context = _ordereddict_flatten(d)
# pyrefly: ignore # bad-return
return [(MappingKey(k), v) for k, v in zip(context, values)], context
@ -870,6 +876,7 @@ def _defaultdict_flatten_with_keys(
) -> tuple[list[tuple[KeyEntry, T]], Context]:
values, context = _defaultdict_flatten(d)
_, dict_context = context
# pyrefly: ignore # bad-return
return [(MappingKey(k), v) for k, v in zip(dict_context, values)], context
@ -918,6 +925,7 @@ def _deque_flatten_with_keys(
d: deque[T],
) -> tuple[list[tuple[KeyEntry, T]], Context]:
values, context = _deque_flatten(d)
# pyrefly: ignore # bad-return
return [(SequenceKey(i), v) for i, v in enumerate(values)], context
@ -1547,6 +1555,7 @@ def tree_map_only(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree:
# pyrefly: ignore # no-matching-overload
return tree_map(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
@ -1607,6 +1616,7 @@ def tree_map_only_(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree:
# pyrefly: ignore # no-matching-overload
return tree_map_(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
@ -1819,6 +1829,7 @@ def enum_object_hook(obj: dict[str, Any]) -> Union[Enum, dict[str, Any]]:
for attr in classname.split("."):
enum_cls = getattr(enum_cls, attr)
enum_cls = cast(type[Enum], enum_cls)
# pyrefly: ignore # unsupported-operation
return enum_cls[obj["name"]]
return obj

View File

@ -305,6 +305,7 @@ def strobelight(
) -> Callable[_P, Optional[_R]]:
@functools.wraps(work_function)
def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> Optional[_R]:
# pyrefly: ignore # bad-argument-type
return profiler.profile(work_function, *args, **kwargs)
return wrapper_function

View File

@ -105,6 +105,7 @@ def _keep_float(
) -> Callable[[Unpack[_Ts]], Union[_T, sympy.Float]]:
@functools.wraps(f)
def inner(*args: Unpack[_Ts]) -> Union[_T, sympy.Float]:
# pyrefly: ignore # bad-argument-type
r: Union[_T, sympy.Float] = f(*args)
if any(isinstance(a, sympy.Float) for a in args) and not isinstance(
r, sympy.Float
@ -112,6 +113,7 @@ def _keep_float(
r = sympy.Float(float(r))
return r
# pyrefly: ignore # bad-return
return inner
@ -198,10 +200,12 @@ class FloorDiv(sympy.Function):
@property
def base(self) -> sympy.Basic:
# pyrefly: ignore # missing-attribute
return self.args[0]
@property
def divisor(self) -> sympy.Basic:
# pyrefly: ignore # missing-attribute
return self.args[1]
def _sympystr(self, printer: sympy.printing.StrPrinter) -> str:
@ -370,6 +374,7 @@ class ModularIndexing(sympy.Function):
return None
def _eval_is_nonnegative(self) -> Optional[bool]:
# pyrefly: ignore # missing-attribute
p, q = self.args[:2]
return fuzzy_eq(p.is_nonnegative, q.is_nonnegative) # type: ignore[attr-defined]
@ -450,6 +455,7 @@ class PythonMod(sympy.Function):
# - floor(p / q) = 0
# - p % q = p - floor(p / q) * q = p
less = p < q
# pyrefly: ignore # missing-attribute
if less.is_Boolean and bool(less) and r.is_positive:
return p
@ -466,8 +472,11 @@ class PythonMod(sympy.Function):
return True if self.args[1].is_negative else None # type: ignore[attr-defined]
def _ccode(self, printer):
# pyrefly: ignore # missing-attribute
p = printer.parenthesize(self.args[0], PRECEDENCE["Atom"] - 0.5)
# pyrefly: ignore # missing-attribute
q = printer.parenthesize(self.args[1], PRECEDENCE["Atom"] - 0.5)
# pyrefly: ignore # missing-attribute
abs_q = str(q) if self.args[1].is_positive else f"abs({q})"
return f"({p} % {q}) < 0 ? {p} % {q} + {abs_q} : {p} % {q}"
@ -548,6 +557,7 @@ class CeilToInt(sympy.Function):
return sympy.Integer(math.ceil(float(number)))
def _ccode(self, printer):
# pyrefly: ignore # missing-attribute
number = printer.parenthesize(self.args[0], self.args[0].precedence - 0.5)
return f"ceil({number})"
@ -818,6 +828,7 @@ class MinMaxBase(Expr, LatticeOp): # type: ignore[misc]
if not cond:
return ai.func(*[do(i, a) for i in ai.args], evaluate=False)
if isinstance(ai, cls):
# pyrefly: ignore # missing-attribute
return ai.func(*[do(i, a) for i in ai.args if i != a], evaluate=False)
return a
@ -995,6 +1006,7 @@ class Max(MinMaxBase, Application): # type: ignore[misc]
return fuzzy_or(a.is_nonnegative for a in self.args) # type: ignore[attr-defined]
def _eval_is_negative(self): # type:ignore[override]
# pyrefly: ignore # missing-attribute
return fuzzy_and(a.is_negative for a in self.args)
@ -1013,6 +1025,7 @@ class Min(MinMaxBase, Application): # type: ignore[misc]
return fuzzy_and(a.is_nonnegative for a in self.args) # type: ignore[attr-defined]
def _eval_is_negative(self): # type:ignore[override]
# pyrefly: ignore # missing-attribute
return fuzzy_or(a.is_negative for a in self.args)
@ -1150,7 +1163,9 @@ class IntTrueDiv(sympy.Function):
return sympy.Float(int(base) / int(divisor))
def _ccode(self, printer):
# pyrefly: ignore # missing-attribute
base = printer.parenthesize(self.args[0], PRECEDENCE["Atom"] - 0.5)
# pyrefly: ignore # missing-attribute
divisor = printer.parenthesize(self.args[1], PRECEDENCE["Atom"] - 0.5)
return f"((int){base}/(int){divisor})"
@ -1310,9 +1325,11 @@ class Identity(sympy.Function):
precedence = 10
def __repr__(self): # type: ignore[override]
# pyrefly: ignore # missing-attribute
return f"Identity({self.args[0]})"
def _eval_is_real(self):
# pyrefly: ignore # missing-attribute
return self.args[0].is_real
def _eval_is_integer(self):
@ -1320,12 +1337,15 @@ class Identity(sympy.Function):
def _eval_expand_identity(self, **hints):
# Removes the identity op.
# pyrefly: ignore # missing-attribute
return self.args[0]
def __int__(self) -> int:
# pyrefly: ignore # missing-attribute
return int(self.args[0])
def __float__(self) -> float:
# pyrefly: ignore # missing-attribute
return float(self.args[0])

View File

@ -9,6 +9,7 @@ from sympy.core.parameters import global_parameters
from sympy.core.singleton import S, Singleton
# pyrefly: ignore # invalid-inheritance
class IntInfinity(Number, metaclass=Singleton):
r"""Positive integer infinite quantity.
@ -203,6 +204,7 @@ class IntInfinity(Number, metaclass=Singleton):
int_oo = S.IntInfinity
# pyrefly: ignore # invalid-inheritance
class NegativeIntInfinity(Number, metaclass=Singleton):
"""Negative integer infinite quantity.

View File

@ -66,6 +66,7 @@ class ExprPrinter(StrPrinter):
# NB: this pow by natural, you should never have used builtin sympy.pow
# for FloatPow, and a symbolic exponent should be PowByNatural. These
# means exp is guaranteed to be integer.
# pyrefly: ignore # bad-override
def _print_Pow(self, expr: sympy.Expr) -> str:
base, exp = expr.args
assert exp == int(exp), exp

View File

@ -175,6 +175,7 @@ class ReferenceAnalysis:
@staticmethod
def pow(a, b):
# pyrefly: ignore # bad-argument-type
return _keep_float(FloatPow)(a, b)
@staticmethod

View File

@ -123,7 +123,9 @@ AllFn2 = Union[ExprFn2, BoolFn2]
class ValueRanges(Generic[_T]):
if TYPE_CHECKING:
# ruff doesn't understand circular references but mypy does
# pyrefly: ignore # unbound-name
ExprVR = ValueRanges[sympy.Expr] # noqa: F821
# pyrefly: ignore # unbound-name
BoolVR = ValueRanges[SympyBoolean] # noqa: F821
AllVR = Union[ExprVR, BoolVR]
@ -464,6 +466,7 @@ class SymPyValueRangeAnalysis:
@staticmethod
def to_dtype(a, dtype, src_dtype=None):
if dtype == torch.float64:
# pyrefly: ignore # bad-argument-type
return ValueRanges.increasing_map(a, ToFloat)
elif dtype == torch.bool:
return ValueRanges.unknown_bool()
@ -473,6 +476,7 @@ class SymPyValueRangeAnalysis:
@staticmethod
def trunc_to_int(a, dtype):
# pyrefly: ignore # bad-argument-type
return ValueRanges.increasing_map(a, TruncToInt)
@staticmethod
@ -621,7 +625,10 @@ class SymPyValueRangeAnalysis:
return ValueRanges.unknown()
else:
return ValueRanges.coordinatewise_monotone_map(
a, b, _keep_float(IntTrueDiv)
a,
b,
# pyrefly: ignore # bad-argument-type
_keep_float(IntTrueDiv),
)
@staticmethod
@ -634,7 +641,10 @@ class SymPyValueRangeAnalysis:
return ValueRanges.unknown()
else:
return ValueRanges.coordinatewise_monotone_map(
a, b, _keep_float(FloatTrueDiv)
a,
b,
# pyrefly: ignore # bad-argument-type
_keep_float(FloatTrueDiv),
)
@staticmethod
@ -713,6 +723,7 @@ class SymPyValueRangeAnalysis:
# We should know that b >= 0 but we may have forgotten this fact due
# to replacements, so don't assert it, but DO clamp it to prevent
# degenerate problems
# pyrefly: ignore # no-matching-overload
return ValueRanges.coordinatewise_increasing_map(
a, b & ValueRanges(0, int_oo), PowByNatural
)
@ -879,6 +890,7 @@ class SymPyValueRangeAnalysis:
@classmethod
def round_to_int(cls, number, dtype):
# pyrefly: ignore # bad-argument-type
return ValueRanges.increasing_map(number, RoundToInt)
# It's used in some models on symints
@ -992,6 +1004,7 @@ class SymPyValueRangeAnalysis:
@staticmethod
def trunc(x):
# pyrefly: ignore # bad-argument-type
return ValueRanges.increasing_map(x, TruncToFloat)

View File

@ -202,6 +202,7 @@ def _generate_module_methods_for_privateuse1_backend(custom_backend_name: str) -
Args:
device (int, optional): if specified, all parameters will be copied to that device
"""
# pyrefly: ignore # missing-attribute
return self._apply(lambda t: getattr(t, custom_backend_name)(device))
_check_register_once(torch.nn.Module, custom_backend_name)

View File

@ -63,6 +63,7 @@ def generate_coo_data(size, sparse_dim, nnz, dtype, device):
indices = torch.rand(sparse_dim, nnz, device=device)
indices.mul_(torch.tensor(size[:sparse_dim]).unsqueeze(1).to(indices))
indices = indices.to(torch.long)
# pyrefly: ignore # no-matching-overload
values = torch.rand([nnz, ], dtype=dtype, device=device)
return indices, values

View File

@ -15,6 +15,7 @@ _warned_tensor_cores = False
_default_float_32_precision = torch.get_float32_matmul_precision()
try:
from tabulate import tabulate
HAS_TABULATE = True
@ -169,6 +170,7 @@ if HAS_TABULATE:
_disable_tensor_cores()
table.append([
("Training" if optimizer else "Inference"),
# pyrefly: ignore # redundant-condition
backend if backend else "-",
mode if mode is not None else "-",
f"{compilation_time} ms " if compilation_time else "-",
@ -189,4 +191,5 @@ if HAS_TABULATE:
])
# pyrefly: ignore # not-callable
return tabulate(table, headers=field_names, tablefmt="github")

View File

@ -35,6 +35,7 @@ def _get_build_root() -> str:
global _BUILD_ROOT
if _BUILD_ROOT is None:
_BUILD_ROOT = _make_temp_dir(prefix="benchmark_utils_jit_build")
# pyrefly: ignore # missing-argument
atexit.register(shutil.rmtree, _BUILD_ROOT)
return _BUILD_ROOT

View File

@ -91,6 +91,7 @@ class FuzzedSparseTensor(FuzzedTensor):
return x
def _make_tensor(self, params, state):
# pyrefly: ignore # missing-attribute
size, _, _ = self._get_size_and_steps(params)
density = params['density']
nnz = math.ceil(sum(size) * density)
@ -99,8 +100,10 @@ class FuzzedSparseTensor(FuzzedTensor):
is_coalesced = params['coalesced']
sparse_dim = params['sparse_dim'] if self._sparse_dim else len(size)
sparse_dim = min(sparse_dim, len(size))
# pyrefly: ignore # missing-attribute
tensor = self.sparse_tensor_constructor(size, self._dtype, sparse_dim, nnz, is_coalesced)
# pyrefly: ignore # missing-attribute
if self._cuda:
tensor = tensor.cuda()
sparse_dim = tensor.sparse_dim()
@ -116,6 +119,7 @@ class FuzzedSparseTensor(FuzzedTensor):
"sparse_dim": sparse_dim,
"dense_dim": dense_dim,
"is_hybrid": is_hybrid,
# pyrefly: ignore # missing-attribute
"dtype": str(self._dtype),
}
return tensor, properties

View File

@ -233,6 +233,7 @@ class Timer:
setup = textwrap.dedent(setup)
setup = (setup[1:] if setup and setup[0] == "\n" else setup).rstrip()
# pyrefly: ignore # bad-instantiation
self._timer = self._timer_cls(
stmt=stmt,
setup=setup,

View File

@ -448,11 +448,13 @@ class GlobalsBridge:
load_lines = []
for name, wrapped_value in self._globals.items():
if wrapped_value.setup is not None:
# pyrefly: ignore # bad-argument-type
load_lines.append(textwrap.dedent(wrapped_value.setup))
if wrapped_value.serialization == Serialization.PICKLE:
path = os.path.join(self._data_dir, f"{name}.pkl")
load_lines.append(
# pyrefly: ignore # bad-argument-type
f"with open({repr(path)}, 'rb') as f:\n {name} = pickle.load(f)")
with open(path, "wb") as f:
pickle.dump(wrapped_value.value, f)
@ -462,11 +464,13 @@ class GlobalsBridge:
# TODO: Figure out if we can use torch.serialization.add_safe_globals here
# Using weights_only=False after the change in
# https://dev-discuss.pytorch.org/t/bc-breaking-change-torch-load-is-being-flipped-to-use-weights-only-true-by-default-in-the-nightlies-after-137602/2573
# pyrefly: ignore # bad-argument-type
load_lines.append(f"{name} = torch.load({repr(path)}, weights_only=False)")
torch.save(wrapped_value.value, path)
elif wrapped_value.serialization == Serialization.TORCH_JIT:
path = os.path.join(self._data_dir, f"{name}.pt")
# pyrefly: ignore # bad-argument-type
load_lines.append(f"{name} = torch.jit.load({repr(path)})")
with open(path, "wb") as f:
torch.jit.save(wrapped_value.value, f) # type: ignore[no-untyped-call]

View File

@ -222,6 +222,7 @@ def _get_autocast_kwargs(device_type="cuda"):
class CheckpointFunction(torch.autograd.Function):
@staticmethod
# pyrefly: ignore # bad-override
def forward(ctx, run_function, preserve_rng_state, *args):
check_backward_validity(args)
ctx.run_function = run_function
@ -784,6 +785,7 @@ class _Holder:
class _NoopSaveInputs(torch.autograd.Function):
@staticmethod
# pyrefly: ignore # bad-override
def forward(*args):
return torch.empty((0,))
@ -1006,6 +1008,7 @@ def _get_debug_context_and_cb() -> Tuple[Callable[[], Any], Callable[[Checkpoint
def logging_mode():
with LoggingTensorMode(), \
capture_logs(True, python_tb=True, script_tb=True, cpp_tb=cpp_tb) as logs_and_tb:
# pyrefly: ignore # bad-assignment
self.logs, self.tbs = logs_and_tb
yield logs_and_tb
return logging_mode()

View File

@ -787,6 +787,7 @@ class BuildExtension(build_ext):
# Use absolute path for output_dir so that the object file paths
# (`objects`) get generated with absolute paths.
# pyrefly: ignore # no-matching-overload
output_dir = os.path.abspath(output_dir)
# See Note [Absolute include_dirs]
@ -977,6 +978,7 @@ class BuildExtension(build_ext):
is_standalone=False):
if not self.compiler.initialized:
self.compiler.initialize()
# pyrefly: ignore # no-matching-overload
output_dir = os.path.abspath(output_dir)
# Note [Absolute include_dirs]
@ -1528,6 +1530,7 @@ def include_paths(device_type: str = "cpu", torch_include_dirs=True) -> list[str
# Support CUDA_INC_PATH env variable supported by CMake files
if (cuda_inc_path := os.environ.get("CUDA_INC_PATH", None)) and \
cuda_inc_path != '/usr/include':
# pyrefly: ignore # unbound-name
paths.append(cuda_inc_path)
if CUDNN_HOME is not None:
paths.append(os.path.join(CUDNN_HOME, 'include'))
@ -2569,6 +2572,7 @@ def _get_num_workers(verbose: bool) -> Optional[int]:
def _get_vc_env(vc_arch: str) -> dict[str, str]:
try:
from setuptools import distutils # type: ignore[attr-defined]
# pyrefly: ignore # missing-attribute
return distutils._msvccompiler._get_vc_env(vc_arch)
except AttributeError:
try:

View File

@ -204,6 +204,7 @@ def collate(
# check to make sure that the elements in batch have consistent size
it = iter(batch)
elem_size = len(next(it))
# pyrefly: ignore # not-iterable
if not all(len(elem) == elem_size for elem in it):
raise RuntimeError("each element in list of batch should be of equal size")
transposed = list(zip(*batch)) # It may be accessed twice, so we use a list.

View File

@ -70,6 +70,7 @@ def pin_memory(data, device=None):
return clone
else:
return type(data)(
# pyrefly: ignore # bad-argument-count
{k: pin_memory(sample, device) for k, sample in data.items()}
) # type: ignore[call-arg]
except TypeError:

View File

@ -674,6 +674,7 @@ class _BaseDataLoaderIter:
# Set pin memory device based on the current accelerator.
self._pin_memory_device = (
# pyrefly: ignore # unbound-name
acc.type
if self._pin_memory
and (acc := torch.accelerator.current_accelerator()) is not None

View File

@ -265,6 +265,7 @@ class _DataPipeType:
# Default type for DataPipe without annotation
_T_co = TypeVar("_T_co", covariant=True)
# pyrefly: ignore # invalid-annotation
_DEFAULT_TYPE = _DataPipeType(Generic[_T_co])
@ -283,6 +284,7 @@ class _DataPipeMeta(GenericMeta):
return super().__new__(cls, name, bases, namespace, **kwargs) # type: ignore[call-overload]
# TODO: the statements below are not reachable by design as there is a bug and typing is low priority for now.
# pyrefly: ignore # no-access
cls.__origin__ = None
if "type" in namespace:
return super().__new__(cls, name, bases, namespace, **kwargs) # type: ignore[call-overload]

View File

@ -80,6 +80,7 @@ class Capture:
def _ops_str(self):
res = ""
# pyrefly: ignore # not-iterable
for op in self.ctx["operations"]:
if len(res) > 0:
res += "\n"
@ -89,6 +90,7 @@ class Capture:
def __getstate__(self):
# TODO(VitalyFedyunin): Currently can't pickle (why?)
self.ctx["schema_df"] = None
# pyrefly: ignore # not-iterable
for var in self.ctx["variables"]:
var.calculated_value = None
state = {}
@ -112,11 +114,13 @@ class Capture:
return CaptureGetItem(self, key, ctx=self.ctx)
def __setitem__(self, key, value):
# pyrefly: ignore # missing-attribute
self.ctx["operations"].append(CaptureSetItem(self, key, value, ctx=self.ctx))
def __add__(self, add_val):
res = CaptureAdd(self, add_val, ctx=self.ctx)
var = CaptureVariable(res, ctx=self.ctx)
# pyrefly: ignore # missing-attribute
self.ctx["operations"].append(
CaptureVariableAssign(variable=var, value=res, ctx=self.ctx)
)
@ -125,6 +129,7 @@ class Capture:
def __sub__(self, add_val):
res = CaptureSub(self, add_val, ctx=self.ctx)
var = CaptureVariable(res, ctx=self.ctx)
# pyrefly: ignore # missing-attribute
self.ctx["operations"].append(
CaptureVariableAssign(variable=var, value=res, ctx=self.ctx)
)
@ -134,15 +139,19 @@ class Capture:
res = CaptureMul(self, add_val, ctx=self.ctx)
var = CaptureVariable(res, ctx=self.ctx)
t = CaptureVariableAssign(variable=var, value=res, ctx=self.ctx)
# pyrefly: ignore # missing-attribute
self.ctx["operations"].append(t)
return var
def _is_context_empty(self):
# pyrefly: ignore # bad-argument-type
return len(self.ctx["operations"]) == 0 and len(self.ctx["variables"]) == 0
def apply_ops_2(self, dataframe):
# TODO(VitalyFedyunin): Make this calculation thread safe (as currently it updates pointer)
# pyrefly: ignore # unsupported-operation
self.ctx["variables"][0].calculated_value = dataframe
# pyrefly: ignore # not-iterable
for op in self.ctx["operations"]:
op.execute()
@ -175,6 +184,7 @@ class Capture:
res = CaptureCall(self, ctx=self.ctx, args=args, kwargs=kwargs)
var = CaptureVariable(None, ctx=self.ctx)
t = CaptureVariableAssign(ctx=self.ctx, variable=var, value=res)
# pyrefly: ignore # missing-attribute
self.ctx["operations"].append(t)
return var
@ -273,7 +283,9 @@ class CaptureVariable(Capture):
def apply_ops(self, dataframe):
# TODO(VitalyFedyunin): Make this calculation thread safe (as currently it updates pointer)
# pyrefly: ignore # unsupported-operation
self.ctx["variables"][0].calculated_value = dataframe
# pyrefly: ignore # not-iterable
for op in self.ctx["operations"]:
op.execute()
return self.calculated_value
@ -373,6 +385,7 @@ def get_val(capture):
class CaptureInitial(CaptureVariable):
def __init__(self, schema_df=None):
# pyrefly: ignore # bad-assignment
new_ctx: dict[str, list[Any]] = {
"operations": [],
"variables": [],
@ -388,6 +401,7 @@ class CaptureDataFrame(CaptureInitial):
class CaptureDataFrameWithDataPipeOps(CaptureDataFrame):
def as_datapipe(self):
# pyrefly: ignore # unsupported-operation
return DataFrameTracedOps(self.ctx["variables"][0].source_datapipe, self)
def raw_iterator(self):

View File

@ -92,6 +92,7 @@ class FilterDataFramesPipe(DFIterDataPipe):
size = None
all_buffer = []
filter_res = []
# pyrefly: ignore # bad-assignment
for df in self.source_datapipe:
if size is None:
size = len(df.index)

View File

@ -135,6 +135,7 @@ class IterDataPipe(IterableDataset[_T_co], metaclass=_IterDataPipeMeta):
_fast_forward_iterator: Optional[Iterator] = None
def __iter__(self) -> Iterator[_T_co]:
# pyrefly: ignore # bad-return
return self
def __getattr__(self, attribute_name):
@ -379,6 +380,7 @@ class _DataPipeSerializationWrapper:
value = pickle.dumps(self._datapipe)
except Exception:
if HAS_DILL:
# pyrefly: ignore # missing-attribute
value = dill.dumps(self._datapipe)
use_dill = True
else:
@ -388,6 +390,7 @@ class _DataPipeSerializationWrapper:
def __setstate__(self, state):
value, use_dill = state
if use_dill:
# pyrefly: ignore # missing-attribute
self._datapipe = dill.loads(value)
else:
self._datapipe = pickle.loads(value)
@ -404,6 +407,7 @@ class _DataPipeSerializationWrapper:
class _IterDataPipeSerializationWrapper(_DataPipeSerializationWrapper, IterDataPipe):
def __init__(self, datapipe: IterDataPipe[_T_co]):
super().__init__(datapipe)
# pyrefly: ignore # invalid-type-var
self._datapipe_iter: Optional[Iterator[_T_co]] = None
def __iter__(self) -> "_IterDataPipeSerializationWrapper":

View File

@ -118,6 +118,7 @@ class MapperIterDataPipe(IterDataPipe[_T_co]):
for idx in sorted(self.input_col[1:], reverse=True):
del data[idx]
else:
# pyrefly: ignore # unsupported-operation
data[self.input_col] = res
else:
if self.output_col == -1:

View File

@ -42,6 +42,7 @@ class SamplerIterDataPipe(IterDataPipe[_T_co]):
"Sampler class requires input datapipe implemented `__len__`"
)
super().__init__()
# pyrefly: ignore # bad-assignment
self.datapipe = datapipe
self.sampler_args = () if sampler_args is None else sampler_args
self.sampler_kwargs = {} if sampler_kwargs is None else sampler_kwargs

View File

@ -59,6 +59,7 @@ class ConcaterIterDataPipe(IterDataPipe):
def __len__(self) -> int:
if all(isinstance(dp, Sized) for dp in self.datapipes):
# pyrefly: ignore # bad-argument-type
return sum(len(dp) for dp in self.datapipes)
else:
raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
@ -179,6 +180,7 @@ class _ForkerIterDataPipe(IterDataPipe, _ContainerTemplate):
self._child_stop: list[bool] = [True for _ in range(num_instances)]
def __len__(self):
# pyrefly: ignore # bad-argument-type
return len(self.main_datapipe)
def get_next_element_by_instance(self, instance_id: int):
@ -238,6 +240,7 @@ class _ForkerIterDataPipe(IterDataPipe, _ContainerTemplate):
return self.end_ptr is not None and all(self._child_stop)
def get_length_by_instance(self, instance_id: int) -> int:
# pyrefly: ignore # bad-argument-type
return len(self.main_datapipe)
def reset(self) -> None:
@ -323,6 +326,7 @@ class _ChildDataPipe(IterDataPipe):
def __init__(self, main_datapipe: IterDataPipe, instance_id: int):
assert isinstance(main_datapipe, _ContainerTemplate)
# pyrefly: ignore # bad-assignment
self.main_datapipe: IterDataPipe = main_datapipe
self.instance_id = instance_id
@ -449,6 +453,7 @@ class _DemultiplexerIterDataPipe(IterDataPipe, _ContainerTemplate):
drop_none: bool,
buffer_size: int,
):
# pyrefly: ignore # invalid-type-var
self.main_datapipe = datapipe
self._datapipe_iterator: Optional[Iterator[Any]] = None
self.num_instances = num_instances
@ -460,7 +465,9 @@ class _DemultiplexerIterDataPipe(IterDataPipe, _ContainerTemplate):
UserWarning,
)
self.current_buffer_usage = 0
# pyrefly: ignore # invalid-type-var
self.child_buffers: list[deque[_T_co]] = [deque() for _ in range(num_instances)]
# pyrefly: ignore # invalid-type-var
self.classifier_fn = classifier_fn
self.drop_none = drop_none
self.main_datapipe_exhausted = False
@ -698,6 +705,7 @@ class ZipperIterDataPipe(IterDataPipe[tuple[_T_co]]):
def __len__(self) -> int:
if all(isinstance(dp, Sized) for dp in self.datapipes):
# pyrefly: ignore # bad-argument-type
return min(len(dp) for dp in self.datapipes)
else:
raise TypeError(f"{type(self).__name__} instance doesn't have valid length")

View File

@ -203,7 +203,9 @@ class GrouperIterDataPipe(IterDataPipe[DataChunk]):
drop_remaining: bool = False,
):
_check_unpickable_fn(group_key_fn)
# pyrefly: ignore # invalid-type-var
self.datapipe = datapipe
# pyrefly: ignore # invalid-type-var
self.group_key_fn = group_key_fn
self.keep_key = keep_key
@ -214,9 +216,11 @@ class GrouperIterDataPipe(IterDataPipe[DataChunk]):
self.guaranteed_group_size = None
if group_size is not None and buffer_size is not None:
assert 0 < group_size <= buffer_size
# pyrefly: ignore # bad-assignment
self.guaranteed_group_size = group_size
if guaranteed_group_size is not None:
assert group_size is not None and 0 < guaranteed_group_size <= group_size
# pyrefly: ignore # bad-assignment
self.guaranteed_group_size = guaranteed_group_size
self.drop_remaining = drop_remaining
self.wrapper_class = DataChunk

View File

@ -60,6 +60,7 @@ class MapperMapDataPipe(MapDataPipe[_T_co]):
self.fn = fn # type: ignore[assignment]
def __len__(self) -> int:
# pyrefly: ignore # bad-argument-type
return len(self.datapipe)
def __getitem__(self, index) -> _T_co:

View File

@ -64,6 +64,7 @@ class ShufflerIterDataPipe(IterDataPipe[_T_co]):
) -> None:
super().__init__()
self.datapipe = datapipe
# pyrefly: ignore # bad-argument-type
self.indices = list(range(len(datapipe))) if indices is None else indices
self._enabled = True
self._seed = None
@ -95,6 +96,7 @@ class ShufflerIterDataPipe(IterDataPipe[_T_co]):
self._shuffled_indices = self._rng.sample(self.indices, len(self.indices))
def __len__(self) -> int:
# pyrefly: ignore # bad-argument-type
return len(self.datapipe)
def __getstate__(self):

View File

@ -49,13 +49,16 @@ class ConcaterMapDataPipe(MapDataPipe):
def __getitem__(self, index) -> _T_co: # type: ignore[type-var]
offset = 0
for dp in self.datapipes:
# pyrefly: ignore # bad-argument-type
if index - offset < len(dp):
return dp[index - offset]
else:
# pyrefly: ignore # bad-argument-type
offset += len(dp)
raise IndexError(f"Index {index} is out of range.")
def __len__(self) -> int:
# pyrefly: ignore # bad-argument-type
return sum(len(dp) for dp in self.datapipes)
@ -102,4 +105,5 @@ class ZipperMapDataPipe(MapDataPipe[tuple[_T_co, ...]]):
return tuple(res)
def __len__(self) -> int:
# pyrefly: ignore # bad-argument-type
return min(len(dp) for dp in self.datapipes)

View File

@ -196,6 +196,7 @@ def get_file_pathnames_from_root(
if match_masks(fname, masks):
yield path
else:
# pyrefly: ignore # bad-assignment
for path, dirs, files in os.walk(root, onerror=onerror):
if abspath:
path = os.path.abspath(path)

View File

@ -43,6 +43,7 @@ def _simple_graph_snapshot_restoration(
# simple fast-forwarding. Therefore, we need to call `reset` twice, because if `SnapshotState` is `Restored`,
# the first reset will not actually reset.
datapipe.reset() # This ensures `SnapshotState` is `Iterating` by this point, even if it was `Restored`.
# pyrefly: ignore # bad-argument-type
apply_random_seed(datapipe, rng)
remainder = n_iterations

View File

@ -131,6 +131,7 @@ class DistributedSampler(Sampler[_T_co]):
indices = indices[self.rank : self.total_size : self.num_replicas]
assert len(indices) == self.num_samples
# pyrefly: ignore # bad-return
return iter(indices)
def __len__(self) -> int:

View File

@ -72,6 +72,7 @@ def _list_connected_datapipes(
p.dump(scan_obj)
except (pickle.PickleError, AttributeError, TypeError):
if dill_available():
# pyrefly: ignore # missing-attribute
d.dump(scan_obj)
else:
raise

View File

@ -31,6 +31,7 @@ class FileBaton:
True if the file could be created, else False.
"""
try:
# pyrefly: ignore # bad-assignment
self.fd = os.open(self.lock_file_path, os.O_CREAT | os.O_EXCL)
return True
except FileExistsError:

View File

@ -149,6 +149,7 @@ def conv_flop_count(
@register_flop_formula([aten.convolution, aten._convolution, aten.cudnn_convolution, aten._slow_conv2d_forward])
def conv_flop(x_shape, w_shape, _bias, _stride, _padding, _dilation, transposed, *args, out_shape=None, **kwargs) -> int:
"""Count flops for convolution."""
# pyrefly: ignore # bad-argument-type
return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed)
@ -676,7 +677,9 @@ class FlopCounterMode:
if depth is None:
depth = 999999
import tabulate
# pyrefly: ignore # bad-assignment
tabulate.PRESERVE_WHITESPACE = True
header = ["Module", "FLOP", "% Total"]
values = []

View File

@ -48,6 +48,7 @@ MATH_TRANSPILATIONS = collections.OrderedDict(
]
)
# pyrefly: ignore # no-matching-overload
CUDA_TYPE_NAME_MAP = collections.OrderedDict(
[
("CUresult", ("hipError_t", CONV_TYPE, API_DRIVER)),
@ -675,6 +676,7 @@ CUDA_INCLUDE_MAP = collections.OrderedDict(
]
)
# pyrefly: ignore # no-matching-overload
CUDA_IDENTIFIER_MAP = collections.OrderedDict(
[
("__CUDACC__", ("__HIPCC__", CONV_DEF, API_RUNTIME)),

View File

@ -663,6 +663,7 @@ def is_caffe2_gpu_file(rel_filepath):
return True
filename = os.path.basename(rel_filepath)
_, ext = os.path.splitext(filename)
# pyrefly: ignore # unsupported-operation
return ('gpu' in filename or ext in ['.cu', '.cuh']) and ('cudnn' not in filename)
class TrieNode:
@ -1137,6 +1138,7 @@ def hipify(
out_of_place_only=out_of_place_only,
is_pytorch_extension=is_pytorch_extension))
all_files_set = set(all_files)
# pyrefly: ignore # bad-assignment
for f in extra_files:
if not os.path.isabs(f):
f = os.path.join(output_directory, f)

View File

@ -145,6 +145,7 @@ class BackwardHook:
res = out
# pyrefly: ignore # bad-assignment
self.grad_outputs = None
return self._unpack_none(self.input_tensors_index, res)

View File

@ -208,6 +208,7 @@ def get_model_info(
with zipfile.ZipFile(path_or_file) as zf:
path_prefix = None
zip_files = []
# pyrefly: ignore # bad-assignment
for zi in zf.infolist():
prefix = re.sub("/.*", "", zi.filename)
if path_prefix is None:
@ -359,9 +360,12 @@ def get_inline_skeleton():
import importlib.resources
# pyrefly: ignore # bad-argument-type
skeleton = importlib.resources.read_text(__package__, "skeleton.html")
# pyrefly: ignore # bad-argument-type
js_code = importlib.resources.read_text(__package__, "code.js")
for js_module in ["preact", "htm"]:
# pyrefly: ignore # bad-argument-type
js_lib = importlib.resources.read_binary(__package__, f"{js_module}.mjs")
js_url = "data:application/javascript," + urllib.parse.quote(js_lib)
js_code = js_code.replace(f"https://unpkg.com/{js_module}?module", js_url)

View File

@ -31,5 +31,7 @@ def make_np(x: torch.Tensor) -> np.ndarray:
def _prepare_pytorch(x: torch.Tensor) -> np.ndarray:
if x.dtype == torch.bfloat16:
x = x.to(torch.float16)
# pyrefly: ignore # bad-assignment
x = x.detach().cpu().numpy()
# pyrefly: ignore # bad-return
return x

View File

@ -188,6 +188,7 @@ class GraphPy:
for key, node in self.nodes_io.items():
if type(node) == NodeBase:
# pyrefly: ignore # unsupported-operation
self.unique_name_to_scoped_name[key] = node.scope + "/" + node.debugName
if hasattr(node, "input_or_output"):
self.unique_name_to_scoped_name[key] = (
@ -198,6 +199,7 @@ class GraphPy:
self.unique_name_to_scoped_name[key] = node.scope + "/" + node.debugName
if node.scope == "" and self.shallowest_scope_name:
self.unique_name_to_scoped_name[node.debugName] = (
# pyrefly: ignore # unsupported-operation
self.shallowest_scope_name + "/" + node.debugName
)

View File

@ -57,11 +57,14 @@ def _prepare_video(V):
return num != 0 and ((num & (num - 1)) == 0)
# pad to nearest power of 2, all at once
# pyrefly: ignore # index-error
if not is_power2(V.shape[0]):
# pyrefly: ignore # index-error
len_addition = int(2 ** V.shape[0].bit_length() - V.shape[0])
V = np.concatenate((V, np.zeros(shape=(len_addition, t, c, h, w))), axis=0)
n_rows = 2 ** ((b.bit_length() - 1) // 2)
# pyrefly: ignore # index-error
n_cols = V.shape[0] // n_rows
V = np.reshape(V, newshape=(n_rows, n_cols, t, c, h, w))

View File

@ -9,6 +9,7 @@ from typing import Any, Optional
import torch
import numpy as np
# pyrefly: ignore # import-error
from google.protobuf import struct_pb2
from tensorboard.compat.proto.summary_pb2 import (
@ -497,6 +498,7 @@ def make_histogram(values, bins, max_bins=None):
subsampling = num_bins // max_bins
subsampling_remainder = num_bins % subsampling
if subsampling_remainder != 0:
# pyrefly: ignore # no-matching-overload
counts = np.pad(
counts,
pad_width=[[0, subsampling - subsampling_remainder]],
@ -834,17 +836,21 @@ def compute_curve(labels, predictions, num_thresholds=None, weights=None):
weights = 1.0
# Compute bins of true positives and false positives.
# pyrefly: ignore # unsupported-operation
bucket_indices = np.int32(np.floor(predictions * (num_thresholds - 1)))
float_labels = labels.astype(np.float64)
# pyrefly: ignore # unsupported-operation
histogram_range = (0, num_thresholds - 1)
tp_buckets, _ = np.histogram(
bucket_indices,
# pyrefly: ignore # bad-argument-type
bins=num_thresholds,
range=histogram_range,
weights=float_labels * weights,
)
fp_buckets, _ = np.histogram(
bucket_indices,
# pyrefly: ignore # bad-argument-type
bins=num_thresholds,
range=histogram_range,
weights=(1.0 - float_labels) * weights,

View File

@ -254,7 +254,9 @@ class SummaryWriter:
buckets = []
neg_buckets = []
while v < 1e20:
# pyrefly: ignore # bad-argument-type
buckets.append(v)
# pyrefly: ignore # bad-argument-type
neg_buckets.append(-v)
v *= 1.1
self.default_bins = neg_buckets[::-1] + [0] + buckets
@ -262,15 +264,19 @@ class SummaryWriter:
def _get_file_writer(self):
"""Return the default FileWriter instance. Recreates it if closed."""
if self.all_writers is None or self.file_writer is None:
# pyrefly: ignore # bad-assignment
self.file_writer = FileWriter(
self.log_dir, self.max_queue, self.flush_secs, self.filename_suffix
)
# pyrefly: ignore # bad-assignment, missing-attribute
self.all_writers = {self.file_writer.get_logdir(): self.file_writer}
if self.purge_step is not None:
most_recent_step = self.purge_step
# pyrefly: ignore # missing-attribute
self.file_writer.add_event(
Event(step=most_recent_step, file_version="brain.Event:2")
)
# pyrefly: ignore # missing-attribute
self.file_writer.add_event(
Event(
step=most_recent_step,
@ -950,6 +956,7 @@ class SummaryWriter:
)
self._projector_config.embeddings.extend([embedding_info])
# pyrefly: ignore # import-error
from google.protobuf import text_format
config_pbtxt = text_format.MessageToString(self._projector_config)
@ -1199,6 +1206,7 @@ class SummaryWriter:
for writer in self.all_writers.values():
writer.flush()
writer.close()
# pyrefly: ignore # bad-assignment
self.file_writer = self.all_writers = None
def __enter__(self):

View File

@ -461,6 +461,7 @@ def to_html(nodes):
if n.context is None:
continue
s = _listener_template.format(id=str(i + 1), stack=escape(f'{n.label}:\n{n.context}'))
# pyrefly: ignore # bad-argument-type
listeners.append(s)
dot = to_dot(nodes)
return _template.replace('$DOT', repr(dot)).replace('$LISTENERS', '\n'.join(listeners))

View File

@ -292,6 +292,7 @@ class WeakIdKeyDictionary(MutableMapping):
if o is not None:
return o, value
# pyrefly: ignore # bad-override
def pop(self, key, *args):
self._dirty_len = True
return self.data.pop(self.ref_type(key), *args) # CHANGED