From 5f18f240de43fc24481ead4d740dda64f174fa86 Mon Sep 17 00:00:00 2001 From: Maggie Moss Date: Thu, 2 Oct 2025 20:57:37 +0000 Subject: [PATCH] Add initial suppressions for pyrefly (#164177) Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283 Test plan: `python3 scripts/lintrunner.py` `pyrefly check` --- Pyrefly check before: https://gist.github.com/maggiemoss/3a0aa0b6cdda0e449cd5743d5fce2c60 After: ``` INFO Checking project configured at `/Users/maggiemoss/python_projects/pytorch/pyrefly.toml` INFO 0 errors (1,063 ignored) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/164177 Approved by: https://github.com/Lucaskabela --- pyrefly.toml | 6 +++++- test/test_bundled_inputs.py | 8 ++++++- test/test_complex.py | 4 ++-- test/test_type_hints.py | 1 + test/test_type_info.py | 4 ++-- torch/__init__.py | 17 +++++++++------ torch/_custom_op/impl.py | 2 +- torch/_dispatch/python.py | 2 +- torch/_jit_internal.py | 14 +++++++++---- torch/_lazy/closure.py | 6 ++++-- torch/_lobpcg.py | 8 +++++-- torch/_ops.py | 11 +++++----- torch/_strobelight/cli_function_profiler.py | 1 + torch/_strobelight/compile_time_profiler.py | 2 +- torch/_tensor.py | 23 +++++++++++++++------ torch/_utils.py | 23 ++++++++++++++------- torch/_utils_internal.py | 11 ++++++++-- torch/amp/autocast_mode.py | 14 ++++++++----- torch/cpu/amp/__init__.py | 1 + torch/functional.py | 6 ++++-- torch/hub.py | 2 +- torch/library.py | 3 +++ torch/linalg/__init__.py | 4 ++-- torch/mtia/__init__.py | 2 +- torch/multiprocessing/spawn.py | 3 ++- torch/nested/_internal/nested_tensor.py | 2 +- torch/nested/_internal/ops.py | 8 +++++-- torch/numa/binding.py | 6 ++++-- torch/package/_package_pickler.py | 1 + torch/package/importer.py | 4 +++- torch/package/package_exporter.py | 3 ++- torch/package/package_importer.py | 1 + torch/quantization/qconfig.py | 2 +- torch/serialization.py | 15 ++++++++++---- torch/signal/windows/windows.py | 15 +++++++++++--- torch/storage.py | 2 +- torch/xpu/__init__.py | 6 ++++-- torch/xpu/streams.py | 2 +- 38 files changed, 170 insertions(+), 75 deletions(-) diff --git a/pyrefly.toml b/pyrefly.toml index 43de7694184e..b619ed3a860b 100644 --- a/pyrefly.toml +++ b/pyrefly.toml @@ -43,13 +43,17 @@ project-excludes = [ "torch/profiler/**", "torch/_prims_common/**", "torch/backends/**", - "torch/testing/**", + # "torch/testing/**", "torch/_C/**", "torch/sparse/**", "torch/_library/**", "torch/_prims/**", "torch/_decomp/**", "torch/_meta_registrations.py", + # formatting issues + "torch/linalg/__init__.py", + "torch/package/importer.py", + "torch/package/_package_pickler.py", # ==== "benchmarks/instruction_counts/main.py", "benchmarks/instruction_counts/definitions/setup.py", diff --git a/test/test_bundled_inputs.py b/test/test_bundled_inputs.py index 0ff5373993cf..200658f6e47a 100644 --- a/test/test_bundled_inputs.py +++ b/test/test_bundled_inputs.py @@ -58,17 +58,20 @@ class TestBundledInputs(TestCase): # Make sure the model only grew a little bit, # despite having nominally large bundled inputs. augmented_size = model_size(sm) + # pyrefly: ignore # missing-attribute self.assertLess(augmented_size, original_size + (1 << 12)) loaded = save_and_load(sm) inflated = loaded.get_all_bundled_inputs() self.assertEqual(loaded.get_num_bundled_inputs(), len(samples)) self.assertEqual(len(inflated), len(samples)) + # pyrefly: ignore # missing-attribute self.assertTrue(loaded(*inflated[0]) is inflated[0][0]) for idx, inp in enumerate(inflated): - self.assertIsInstance(inp, tuple) + self.assertIsInstance(inp, tuple) # pyrefly: ignore # missing-attribute self.assertEqual(len(inp), 1) + # pyrefly: ignore # missing-attribute self.assertIsInstance(inp[0], torch.Tensor) if idx != 5: # Strides might be important for benchmarking. @@ -136,6 +139,7 @@ class TestBundledInputs(TestCase): loaded = save_and_load(sm) inflated = loaded.get_all_bundled_inputs() self.assertEqual(inflated, samples) + # pyrefly: ignore # missing-attribute self.assertTrue(loaded(*inflated[0]) == "first 1") def test_multiple_methods_with_inputs(self): @@ -182,6 +186,7 @@ class TestBundledInputs(TestCase): self.assertEqual(inflated, loaded.get_all_bundled_inputs_for_foo()) # Check running and size helpers + # pyrefly: ignore # missing-attribute self.assertTrue(loaded(*inflated[0]) is inflated[0][0]) self.assertEqual(loaded.get_num_bundled_inputs(), len(samples)) @@ -414,6 +419,7 @@ class TestBundledInputs(TestCase): ) augmented_size = model_size(sm) # assert the size has not increased more than 8KB + # pyrefly: ignore # missing-attribute self.assertLess(augmented_size, original_size + (1 << 13)) loaded = save_and_load(sm) diff --git a/test/test_complex.py b/test/test_complex.py index 159f3e18aaee..5646952612d6 100644 --- a/test/test_complex.py +++ b/test/test_complex.py @@ -48,7 +48,7 @@ class TestComplexTensor(TestCase): def test_all(self, device, dtype): # issue: https://github.com/pytorch/pytorch/issues/120875 x = torch.tensor([1 + 2j, 3 - 4j, 5j, 6], device=device, dtype=dtype) - self.assertTrue(torch.all(x)) + self.assertTrue(torch.all(x)) # pyrefly: ignore # missing-attribute @dtypes(*complex_types()) def test_any(self, device, dtype): @@ -56,7 +56,7 @@ class TestComplexTensor(TestCase): x = torch.tensor( [0, 0j, -0 + 0j, -0 - 0j, 0 + 0j, 0 - 0j], device=device, dtype=dtype ) - self.assertFalse(torch.any(x)) + self.assertFalse(torch.any(x)) # pyrefly: ignore # missing-attribute @onlyCPU @dtypes(*complex_types()) diff --git a/test/test_type_hints.py b/test/test_type_hints.py index 0aae54be9b63..c982ae19b6df 100644 --- a/test/test_type_hints.py +++ b/test/test_type_hints.py @@ -142,6 +142,7 @@ class TestTypeHints(TestCase): ] ) if result != 0: + # pyrefly: ignore # missing-attribute self.fail(f"mypy failed:\n{stderr}\n{stdout}") diff --git a/test/test_type_info.py b/test/test_type_info.py index 80a21bc5e9dd..38bace6ff2fd 100644 --- a/test/test_type_info.py +++ b/test/test_type_info.py @@ -125,7 +125,7 @@ class TestDTypeInfo(TestCase): # Regression test for https://github.com/pytorch/pytorch/issues/124868 # If reference count is leaked this would be a set of 10 elements ref_cnt = {sys.getrefcount(torch.float32.to_complex()) for _ in range(10)} - self.assertLess(len(ref_cnt), 3) + self.assertLess(len(ref_cnt), 3) # pyrefly: ignore # missing-attribute self.assertEqual(torch.float64.to_complex(), torch.complex128) self.assertEqual(torch.float32.to_complex(), torch.complex64) @@ -135,7 +135,7 @@ class TestDTypeInfo(TestCase): # Regression test for https://github.com/pytorch/pytorch/issues/124868 # If reference count is leaked this would be a set of 10 elements ref_cnt = {sys.getrefcount(torch.cfloat.to_real()) for _ in range(10)} - self.assertLess(len(ref_cnt), 3) + self.assertLess(len(ref_cnt), 3) # pyrefly: ignore # missing-attribute self.assertEqual(torch.complex128.to_real(), torch.double) self.assertEqual(torch.complex64.to_real(), torch.float32) diff --git a/torch/__init__.py b/torch/__init__.py index 9be88b832fa8..6dc1dfc18ce8 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -1699,7 +1699,7 @@ def _check(cond, message=None): # noqa: F811 an object that has a ``__str__()`` method to be used as the error message. Default: ``None`` """ - _check_with(RuntimeError, cond, message) + _check_with(RuntimeError, cond, message) # pyrefly: ignore # bad-argument-type def _check_is_size(i, message=None, *, max=None): @@ -1748,7 +1748,7 @@ def _check_index(cond, message=None): # noqa: F811 an object that has a ``__str__()`` method to be used as the error message. Default: ``None`` """ - _check_with(IndexError, cond, message) + _check_with(IndexError, cond, message) # pyrefly: ignore # bad-argument-type def _check_value(cond, message=None): # noqa: F811 @@ -1766,7 +1766,7 @@ def _check_value(cond, message=None): # noqa: F811 an object that has a ``__str__()`` method to be used as the error message. Default: ``None`` """ - _check_with(ValueError, cond, message) + _check_with(ValueError, cond, message) # pyrefly: ignore # bad-argument-type def _check_type(cond, message=None): # noqa: F811 @@ -1784,7 +1784,7 @@ def _check_type(cond, message=None): # noqa: F811 an object that has a ``__str__()`` method to be used as the error message. Default: ``None`` """ - _check_with(TypeError, cond, message) + _check_with(TypeError, cond, message) # pyrefly: ignore # bad-argument-type def _check_not_implemented(cond, message=None): # noqa: F811 @@ -1802,7 +1802,12 @@ def _check_not_implemented(cond, message=None): # noqa: F811 an object that has a ``__str__()`` method to be used as the error message. Default: ``None`` """ - _check_with(NotImplementedError, cond, message) + _check_with( + NotImplementedError, + cond, + # pyrefly: ignore # bad-argument-type + message, + ) def _check_tensor_all_with(error_type, cond, message=None): # noqa: F811 @@ -2612,7 +2617,7 @@ def compile( def fn(model: _Callable[_InputT, _RetT]) -> _Callable[_InputT, _RetT]: if model is None: raise RuntimeError("Model can't be None") - return compile( + return compile( # pyrefly: ignore # no-matching-overload model, fullgraph=fullgraph, dynamic=dynamic, diff --git a/torch/_custom_op/impl.py b/torch/_custom_op/impl.py index 208c18e392a4..b445907b5d1a 100644 --- a/torch/_custom_op/impl.py +++ b/torch/_custom_op/impl.py @@ -101,7 +101,7 @@ def custom_op( lib, ns, function_schema, name, ophandle, _private_access=True ) - result.__name__ = func.__name__ + result.__name__ = func.__name__ # pyrefly: ignore # bad-assignment result.__module__ = func.__module__ result.__doc__ = func.__doc__ diff --git a/torch/_dispatch/python.py b/torch/_dispatch/python.py index 7b790fe18ea7..4cf1d1b5cffc 100644 --- a/torch/_dispatch/python.py +++ b/torch/_dispatch/python.py @@ -154,7 +154,7 @@ def make_crossref_functionalize( maybe_detach, (f_args, f_kwargs) ) with fake_mode: - f_r = op(*f_args, **f_kwargs) + f_r = op(*f_args, **f_kwargs) # pyrefly: ignore # invalid-param-spec r = op._op_dk(final_key, *args, **kwargs) def desc(): diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py index 8d42421d8b90..501169cddc2b 100644 --- a/torch/_jit_internal.py +++ b/torch/_jit_internal.py @@ -147,7 +147,7 @@ def _qualified_name(obj, mangle_name=True) -> str: # If the module is actually a torchbind module, then we should short circuit if module_name == "torch._classes": - return obj.qualified_name + return obj.qualified_name # pyrefly: ignore # missing-attribute # The Python docs are very clear that `__module__` can be None, but I can't # figure out when it actually would be. @@ -759,7 +759,7 @@ def unused(fn: Callable[_P, _R]) -> Callable[_P, _R]: prop.fset, "_torchscript_modifier", FunctionModifiers.UNUSED ) - return prop + return prop # pyrefly: ignore # bad-return fn._torchscript_modifier = FunctionModifiers.UNUSED # type: ignore[attr-defined] return fn @@ -844,6 +844,7 @@ def ignore(drop=False, **kwargs): # @torch.jit.ignore # def fn(...): fn = drop + # pyrefly: ignore # missing-attribute fn._torchscript_modifier = FunctionModifiers.IGNORE return fn @@ -1250,7 +1251,10 @@ def _get_named_tuple_properties( obj_annotations = inspect.get_annotations(obj) if len(obj_annotations) == 0 and hasattr(obj, "__base__"): - obj_annotations = inspect.get_annotations(obj.__base__) + obj_annotations = inspect.get_annotations( + # pyrefly: ignore # bad-argument-type + obj.__base__ + ) annotations = [] for field in obj._fields: @@ -1439,7 +1443,9 @@ def container_checker(obj, target_type) -> bool: return False return True elif origin_type is Union or issubclass( - origin_type, BuiltinUnionType + # pyrefly: ignore # bad-argument-type + origin_type, + BuiltinUnionType, ): # also handles Optional if obj is None: # check before recursion because None is always fine return True diff --git a/torch/_lazy/closure.py b/torch/_lazy/closure.py index dce2a58a5d88..864591f84b56 100644 --- a/torch/_lazy/closure.py +++ b/torch/_lazy/closure.py @@ -63,8 +63,10 @@ class AsyncClosureHandler(ClosureHandler): self._closure_exception.put(e) return - self._closure_event_loop = threading.Thread(target=event_loop) - self._closure_event_loop.start() + self._closure_event_loop = threading.Thread( + target=event_loop + ) # pyrefly: ignore # bad-assignment + self._closure_event_loop.start() # pyrefly: ignore # missing-attribute def run(self, closure): with self._closure_lock: diff --git a/torch/_lobpcg.py b/torch/_lobpcg.py index a3f57411b8f5..a35116f8c62b 100644 --- a/torch/_lobpcg.py +++ b/torch/_lobpcg.py @@ -301,7 +301,7 @@ class LOBPCGAutogradFunction(torch.autograd.Function): return D, U @staticmethod - def backward(ctx, D_grad, U_grad): + def backward(ctx, D_grad, U_grad): # pyrefly: ignore # bad-override A_grad = B_grad = None grads = [None] * 14 @@ -1048,7 +1048,11 @@ class LOBPCG: else: E[(torch.where(E < t))[0]] = t - return torch.matmul(U * d_col.mT, Z * E**-0.5) + return torch.matmul( + U * d_col.mT, + # pyrefly: ignore # unsupported-operation + Z * E**-0.5, + ) def _get_ortho(self, U, V): """Return B-orthonormal U with columns are B-orthogonal to V. diff --git a/torch/_ops.py b/torch/_ops.py index 8f91e072c23a..e568536b0869 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -803,7 +803,7 @@ class OpOverload(OperatorBase, Generic[_P, _T]): # Logic replicated from aten/src/ATen/native/MathBitsFallback.h is_write = None - for a in self._schema.arguments: + for a in self._schema.arguments: # pyrefly: ignore # bad-assignment if a.alias_info is None: continue if is_write is None: @@ -885,7 +885,7 @@ class OpOverload(OperatorBase, Generic[_P, _T]): elif torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), dk): return self._op_dk(dk, *args, **kwargs) else: - return NotImplemented + return NotImplemented # pyrefly: ignore # bad-return # Remove a dispatch key from the dispatch cache. This will force it to get # recomputed the next time. Does nothing @@ -990,9 +990,9 @@ class OpOverload(OperatorBase, Generic[_P, _T]): r = self.py_kernels.get(final_key, final_key) if cache_result: - self._dispatch_cache[key] = r + self._dispatch_cache[key] = r # pyrefly: ignore # unsupported-operation add_cached_op(self) - return r + return r # pyrefly: ignore # bad-return def name(self): return self._name @@ -1122,7 +1122,7 @@ class TorchBindOpOverload(OpOverload[_P, _T]): ) assert isinstance(handler, Callable) # type: ignore[arg-type] - return handler(*args, **kwargs) + return handler(*args, **kwargs) # pyrefly: ignore # bad-return def _must_dispatch_in_python(args, kwargs): @@ -1251,6 +1251,7 @@ class OpOverloadPacket(Generic[_P, _T]): # the schema and cause an error for torchbind op when inputs consist of FakeScriptObject so we # intercept it here and call TorchBindOpverload instead. if self._has_torchbind_op_overload and _must_dispatch_in_python(args, kwargs): + # pyrefly: ignore # bad-argument-type return _call_overload_packet_from_python(self, *args, **kwargs) return self._op(*args, **kwargs) diff --git a/torch/_strobelight/cli_function_profiler.py b/torch/_strobelight/cli_function_profiler.py index 0cc7db12fe28..8f901b0b264f 100644 --- a/torch/_strobelight/cli_function_profiler.py +++ b/torch/_strobelight/cli_function_profiler.py @@ -314,6 +314,7 @@ def strobelight( ) -> Callable[_P, Optional[_R]]: @functools.wraps(work_function) def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> Optional[_R]: + # pyrefly: ignore # bad-argument-type return profiler.profile(work_function, *args, **kwargs) return wrapper_function diff --git a/torch/_strobelight/compile_time_profiler.py b/torch/_strobelight/compile_time_profiler.py index 436f9a2c8b59..89b44632e278 100644 --- a/torch/_strobelight/compile_time_profiler.py +++ b/torch/_strobelight/compile_time_profiler.py @@ -145,7 +145,7 @@ class StrobelightCompileTimeProfiler: async_stack_max_len=cls.max_stack_length, run_user_name="pt2-profiler/" + os.environ.get("USER", os.environ.get("USERNAME", "")), - sample_tags={cls.identifier}, + sample_tags={cls.identifier}, # pyrefly: ignore # bad-argument-type ) @classmethod diff --git a/torch/_tensor.py b/torch/_tensor.py index f91539b7533d..a07fc65aee0a 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -756,7 +756,10 @@ class Tensor(torch._C.TensorBase): "post accumulate grad hooks cannot be registered on non-leaf tensors" ) if self._post_accumulate_grad_hooks is None: - self._post_accumulate_grad_hooks: dict[Any, Any] = OrderedDict() + self._post_accumulate_grad_hooks: dict[Any, Any] = ( + # pyrefly: ignore # bad-assignment + OrderedDict() + ) from torch.utils.hooks import RemovableHandle @@ -1056,7 +1059,12 @@ class Tensor(torch._C.TensorBase): if isinstance(split_size, (int, torch.SymInt)): return torch._VF.split(self, split_size, dim) # type: ignore[attr-defined] else: - return torch._VF.split_with_sizes(self, split_size, dim) + return torch._VF.split_with_sizes( + self, + # pyrefly: ignore # bad-argument-type + split_size, + dim, + ) def unique(self, sorted=True, return_inverse=False, return_counts=False, dim=None): r"""Returns the unique elements of the input tensor. @@ -1101,6 +1109,7 @@ class Tensor(torch._C.TensorBase): @_handle_torch_function_and_wrap_type_error_to_not_implemented def __rsub__(self, other: Union["Tensor", int, float, bool, complex]) -> "Tensor": + # pyrefly: ignore # no-matching-overload return _C._VariableFunctions.rsub(self, other) @_handle_torch_function_and_wrap_type_error_to_not_implemented @@ -1126,7 +1135,7 @@ class Tensor(torch._C.TensorBase): @_handle_torch_function_and_wrap_type_error_to_not_implemented def __rmod__(self, other: Union["Tensor", int, float, bool, complex]) -> "Tensor": - return torch.remainder(other, self) + return torch.remainder(other, self) # pyrefly: ignore # no-matching-overload def __format__(self, format_spec): if has_torch_function_unary(self): @@ -1139,7 +1148,7 @@ class Tensor(torch._C.TensorBase): @_handle_torch_function_and_wrap_type_error_to_not_implemented def __rpow__(self, other: Union["Tensor", int, float, bool, complex]) -> "Tensor": - return torch.pow(other, self) + return torch.pow(other, self) # pyrefly: ignore # no-matching-overload @_handle_torch_function_and_wrap_type_error_to_not_implemented def __floordiv__(self, other: Union["Tensor", int, float, bool]) -> "Tensor": # type: ignore[override] @@ -1155,12 +1164,14 @@ class Tensor(torch._C.TensorBase): def __rlshift__( self, other: Union["Tensor", int, float, bool, complex] ) -> "Tensor": + # pyrefly: ignore # no-matching-overload return torch.bitwise_left_shift(other, self) @_handle_torch_function_and_wrap_type_error_to_not_implemented def __rrshift__( self, other: Union["Tensor", int, float, bool, complex] ) -> "Tensor": + # pyrefly: ignore # no-matching-overload return torch.bitwise_right_shift(other, self) @_handle_torch_function_and_wrap_type_error_to_not_implemented @@ -1335,7 +1346,7 @@ class Tensor(torch._C.TensorBase): return self._typed_storage()._get_legacy_storage_class() - def refine_names(self, *names): + def refine_names(self, *names): # pyrefly: ignore # bad-override r"""Refines the dimension names of :attr:`self` according to :attr:`names`. Refining is a special case of renaming that "lifts" unnamed dimensions. @@ -1379,7 +1390,7 @@ class Tensor(torch._C.TensorBase): names = resolve_ellipsis(names, self.names, "refine_names") return super().refine_names(names) - def align_to(self, *names): + def align_to(self, *names): # pyrefly: ignore # bad-override r"""Permutes the dimensions of the :attr:`self` tensor to match the order specified in :attr:`names`, adding size-one dims for any new names. diff --git a/torch/_utils.py b/torch/_utils.py index 68d395a90c9a..bf431553abc8 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -686,8 +686,8 @@ def _take_tensors(tensors, size_limit): if buf_and_size[1] + size > size_limit and buf_and_size[1] > 0: yield buf_and_size[0] buf_and_size = buf_dict[t] = [[], 0] - buf_and_size[0].append(tensor) - buf_and_size[1] += size + buf_and_size[0].append(tensor) # pyrefly: ignore # missing-attribute + buf_and_size[1] += size # pyrefly: ignore # unsupported-operation for buf, _ in buf_dict.values(): if len(buf) > 0: yield buf @@ -744,14 +744,17 @@ class ExceptionWrapper: if exc_info is None: exc_info = sys.exc_info() self.exc_type = exc_info[0] - self.exc_msg = "".join(traceback.format_exception(*exc_info)) + self.exc_msg = "".join( + # pyrefly: ignore # no-matching-overload + traceback.format_exception(*exc_info) + ) self.where = where def reraise(self): r"""Reraises the wrapped exception in the current thread""" # Format a message such as: "Caught ValueError in DataLoader worker # process 2. Original Traceback:", followed by the traceback. - msg = f"Caught {self.exc_type.__name__} {self.where}.\nOriginal {self.exc_msg}" + msg = f"Caught {self.exc_type.__name__} {self.where}.\nOriginal {self.exc_msg}" # pyrefly: ignore # missing-attribute if self.exc_type == KeyError: # KeyError calls repr() on its argument (usually a dict key). This # makes stack traces unreadable. It will not be changed in Python @@ -760,9 +763,13 @@ class ExceptionWrapper: elif getattr(self.exc_type, "message", None): # Some exceptions have first argument as non-str but explicitly # have message field - raise self.exc_type(message=msg) + # pyrefly: ignore # not-callable + raise self.exc_type( + # pyrefly: ignore # unexpected-keyword + message=msg + ) try: - exception = self.exc_type(msg) + exception = self.exc_type(msg) # pyrefly: ignore # not-callable except Exception: # If the exception takes multiple arguments or otherwise can't # be constructed, don't try to instantiate since we don't know how to @@ -1014,12 +1021,12 @@ class _LazySeedTracker: self.call_order = [] def queue_seed_all(self, cb, traceback): - self.manual_seed_all_cb = (cb, traceback) + self.manual_seed_all_cb = (cb, traceback) # pyrefly: ignore # bad-assignment # update seed_all to be latest self.call_order = [self.manual_seed_cb, self.manual_seed_all_cb] def queue_seed(self, cb, traceback): - self.manual_seed_cb = (cb, traceback) + self.manual_seed_cb = (cb, traceback) # pyrefly: ignore # bad-assignment # update seed to be latest self.call_order = [self.manual_seed_all_cb, self.manual_seed_cb] diff --git a/torch/_utils_internal.py b/torch/_utils_internal.py index 0d56facc7ca4..10c0bf23f85b 100644 --- a/torch/_utils_internal.py +++ b/torch/_utils_internal.py @@ -84,7 +84,11 @@ def compile_time_strobelight_meta( ) -> Callable[_P, _T]: @functools.wraps(function) def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> _T: - if "skip" in kwargs and isinstance(skip := kwargs["skip"], int): + if "skip" in kwargs and isinstance( + # pyrefly: ignore # unsupported-operation + skip := kwargs["skip"], + int, + ): kwargs["skip"] = skip + 1 # This is not needed but we have it here to avoid having profile_compile_time @@ -327,7 +331,10 @@ def deprecated(): # public deprecated alias alias = typing_extensions.deprecated( - warning_msg, category=UserWarning, stacklevel=1 + # pyrefly: ignore # bad-argument-type + warning_msg, + category=UserWarning, + stacklevel=1, )(func) alias.__name__ = public_name diff --git a/torch/amp/autocast_mode.py b/torch/amp/autocast_mode.py index 81903d37e5de..9196cb2de695 100644 --- a/torch/amp/autocast_mode.py +++ b/torch/amp/autocast_mode.py @@ -464,7 +464,11 @@ def _cast(value, device_type: str, dtype: _dtype): return value.to(dtype) if is_eligible else value elif isinstance(value, (str, bytes)): return value - elif HAS_NUMPY and isinstance(value, np.ndarray): + elif HAS_NUMPY and isinstance( + value, + # pyrefly: ignore # missing-attribute + np.ndarray, + ): return value elif isinstance(value, collections.abc.Mapping): return { @@ -521,18 +525,18 @@ def custom_fwd( args[0]._dtype = torch.get_autocast_dtype(device_type) if cast_inputs is None: args[0]._fwd_used_autocast = torch.is_autocast_enabled(device_type) - return fwd(*args, **kwargs) + return fwd(*args, **kwargs) # pyrefly: ignore # not-callable else: autocast_context = torch.is_autocast_enabled(device_type) args[0]._fwd_used_autocast = False if autocast_context: with autocast(device_type=device_type, enabled=False): - return fwd( + return fwd( # pyrefly: ignore # not-callable *_cast(args, device_type, cast_inputs), **_cast(kwargs, device_type, cast_inputs), ) else: - return fwd(*args, **kwargs) + return fwd(*args, **kwargs) # pyrefly: ignore # not-callable return decorate_fwd @@ -567,6 +571,6 @@ def custom_bwd(bwd=None, *, device_type: str): enabled=args[0]._fwd_used_autocast, dtype=args[0]._dtype, ): - return bwd(*args, **kwargs) + return bwd(*args, **kwargs) # pyrefly: ignore # not-callable return decorate_bwd diff --git a/torch/cpu/amp/__init__.py b/torch/cpu/amp/__init__.py index e72eb3b92a7f..147d39b4a20a 100644 --- a/torch/cpu/amp/__init__.py +++ b/torch/cpu/amp/__init__.py @@ -1,2 +1,3 @@ +# pyrefly: ignore # deprecated from .autocast_mode import autocast from .grad_scaler import GradScaler diff --git a/torch/functional.py b/torch/functional.py index b5fcf8240c83..802f178d2043 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -1784,7 +1784,9 @@ def norm( # noqa: F811 if isinstance(p, str): if p == "fro" and ( - dim is None or isinstance(dim, (int, torch.SymInt)) or len(dim) <= 2 + dim is None + or isinstance(dim, (int, torch.SymInt)) + or len(dim) <= 2 # pyrefly: ignore # bad-argument-type ): if out is None: return torch.linalg.vector_norm( @@ -1950,7 +1952,7 @@ def _unravel_index(indices: Tensor, shape: Union[int, Sequence[int]]) -> Tensor: ) if isinstance(shape, (int, torch.SymInt)): - shape = torch.Size([shape]) + shape = torch.Size([shape]) # pyrefly: ignore # bad-argument-type else: for dim in shape: torch._check_type( diff --git a/torch/hub.py b/torch/hub.py index 65324a044bbe..4b68e997162a 100644 --- a/torch/hub.py +++ b/torch/hub.py @@ -421,7 +421,7 @@ def set_dir(d: Union[str, os.PathLike]) -> None: d (str): path to a local folder to save downloaded models & weights. """ global _hub_dir - _hub_dir = os.path.expanduser(d) + _hub_dir = os.path.expanduser(d) # pyrefly: ignore # no-matching-overload def list( diff --git a/torch/library.py b/torch/library.py index 0ac29cfde3f7..d962c08c3905 100644 --- a/torch/library.py +++ b/torch/library.py @@ -242,6 +242,7 @@ class Library: if dispatch_key == "": dispatch_key = self.dispatch_key + # pyrefly: ignore # bad-argument-type assert torch.DispatchKeySet(dispatch_key).has(torch._C.DispatchKey.Dense) if isinstance(op_name, str): @@ -643,6 +644,7 @@ def impl( >>> y2 = torch.sin(x) + 1 >>> assert torch.allclose(y1, y2) """ + # pyrefly: ignore # no-matching-overload return _impl(qualname, types, func, lib=lib, disable_dynamo=False) @@ -829,6 +831,7 @@ def register_kernel( if device_types is None: device_types = "CompositeExplicitAutograd" + # pyrefly: ignore # no-matching-overload return _impl(op, device_types, func, lib=lib, disable_dynamo=True) diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index 355ad00d491a..336798f2a3cf 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -1,7 +1,7 @@ -from torch._C import ( # type: ignore[attr-defined] +from torch._C import ( # type: ignore[attr-defined] # pyrefly: ignore # missing-module-attribute _add_docstr, _linalg, - _LinAlgError as LinAlgError, + _LinAlgError as LinAlgError, # pyrefly: ignore # missing-module-attribute ) diff --git a/torch/mtia/__init__.py b/torch/mtia/__init__.py index 55bcadf0c2a1..092e9f2cc5cb 100644 --- a/torch/mtia/__init__.py +++ b/torch/mtia/__init__.py @@ -303,7 +303,7 @@ class StreamContext: self.idx = _get_device_index(None, True) if not torch.jit.is_scripting(): if self.idx is None: - self.idx = -1 + self.idx = -1 # pyrefly: ignore # bad-assignment self.src_prev_stream = ( None if not torch.jit.is_scripting() else torch.mtia.default_stream(None) diff --git a/torch/multiprocessing/spawn.py b/torch/multiprocessing/spawn.py index b11e5714fc2e..d4652ab32ff7 100644 --- a/torch/multiprocessing/spawn.py +++ b/torch/multiprocessing/spawn.py @@ -119,6 +119,7 @@ class ProcessContext: """Attempt to join all processes with a shared timeout.""" end = time.monotonic() + timeout for process in self.processes: + # pyrefly: ignore # no-matching-overload time_to_wait = max(0, end - time.monotonic()) process.join(time_to_wait) @@ -274,7 +275,7 @@ def start_processes( tf.close() os.unlink(tf.name) - process = mp.Process( + process = mp.Process( # pyrefly: ignore # missing-attribute target=_wrap, args=(fn, i, args, tf.name), daemon=daemon, diff --git a/torch/nested/_internal/nested_tensor.py b/torch/nested/_internal/nested_tensor.py index d3c4ba8c9166..8d446a7bd518 100644 --- a/torch/nested/_internal/nested_tensor.py +++ b/torch/nested/_internal/nested_tensor.py @@ -406,7 +406,7 @@ class ViewBufferFromNested(torch.autograd.Function): # Not actually a view! class ViewNestedFromBuffer(torch.autograd.Function): @staticmethod - def forward( + def forward( # pyrefly: ignore # bad-override ctx, values: torch.Tensor, offsets: torch.Tensor, diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py index 5d32c6ace9ad..9ac4f53b60eb 100644 --- a/torch/nested/_internal/ops.py +++ b/torch/nested/_internal/ops.py @@ -46,11 +46,14 @@ def _outer_to_inner_dim(ndim, dim, ragged_dim, canonicalize=False): if canonicalize: dim = canonicalize_dims(ndim, dim) - assert dim >= 0 and dim < ndim + assert dim >= 0 and dim < ndim # pyrefly: ignore # unsupported-operation # Map dim=0 (AKA batch dim) -> packed dim i.e. outer ragged dim - 1. # For other dims, subtract 1 to convert to inner space. - return ragged_dim - 1 if dim == 0 else dim - 1 + return ( + # pyrefly: ignore # unsupported-operation + ragged_dim - 1 if dim == 0 else dim - 1 + ) def _wrap_jagged_dim( @@ -2005,6 +2008,7 @@ def index_put_(func, *args, **kwargs): else: lengths = inp.lengths() torch._assert_async( + # pyrefly: ignore # no-matching-overload torch.all(indices[inp._ragged_idx] < lengths), "Some indices in the ragged dimension are out of bounds!", ) diff --git a/torch/numa/binding.py b/torch/numa/binding.py index 34a61e2b9c56..140457845fde 100644 --- a/torch/numa/binding.py +++ b/torch/numa/binding.py @@ -134,7 +134,8 @@ def _raise_if_logical_cpu_indices_invalid(*, logical_cpu_indices: set[int]) -> N def _bind_current_thread_to_logical_cpus(*, logical_cpu_indices: set[int]) -> None: # 0 represents the current thread - os.sched_setaffinity(0, logical_cpu_indices) + # pyrefly: ignore # missing-attribute + os.sched_setaffinity(0, logical_cpu_indices) # type: ignore[attr-defined] def _get_logical_cpus_to_bind_to( @@ -544,4 +545,5 @@ def _get_numa_node_indices_for_socket_index(*, socket_index: int) -> set[int]: def _get_allowed_cpu_indices_for_current_thread() -> set[int]: # 0 denotes current thread - return os.sched_getaffinity(0) + # pyrefly: ignore # missing-attribute + return os.sched_getaffinity(0) # type:ignore[attr-defined] diff --git a/torch/package/_package_pickler.py b/torch/package/_package_pickler.py index 8384a3ce2c16..31898c96f1b0 100644 --- a/torch/package/_package_pickler.py +++ b/torch/package/_package_pickler.py @@ -1,4 +1,5 @@ # mypy: allow-untyped-defs +# pyrefly: ignore # missing-module-attribute from pickle import ( # type: ignore[attr-defined] _compat_pickle, _extension_registry, diff --git a/torch/package/importer.py b/torch/package/importer.py index 8cfc1e336a45..3984ddfc40fb 100644 --- a/torch/package/importer.py +++ b/torch/package/importer.py @@ -2,10 +2,12 @@ import importlib import logging from abc import ABC, abstractmethod + +# pyrefly: ignore # missing-module-attribute from pickle import ( # type: ignore[attr-defined] _getattribute, _Pickler, - whichmodule as _pickle_whichmodule, + whichmodule as _pickle_whichmodule, # pyrefly: ignore # missing-module-attribute ) from types import ModuleType from typing import Any, Optional diff --git a/torch/package/package_exporter.py b/torch/package/package_exporter.py index 4ac2c33f1633..7b686f008201 100644 --- a/torch/package/package_exporter.py +++ b/torch/package/package_exporter.py @@ -219,7 +219,7 @@ class PackageExporter: torch._C._log_api_usage_once("torch.package.PackageExporter") self.debug = debug if isinstance(f, (str, os.PathLike)): - f = os.fspath(f) + f = os.fspath(f) # pyrefly: ignore # no-matching-overload self.buffer: Optional[IO[bytes]] = None else: # is a byte buffer self.buffer = f @@ -652,6 +652,7 @@ class PackageExporter: memo: defaultdict[int, str] = defaultdict(None) memo_count = 0 # pickletools.dis(data_value) + # pyrefly: ignore # bad-assignment for opcode, arg, _pos in pickletools.genops(data_value): if pickle_protocol == 4: if ( diff --git a/torch/package/package_importer.py b/torch/package/package_importer.py index 10bf8981e28a..8f2a009f9121 100644 --- a/torch/package/package_importer.py +++ b/torch/package/package_importer.py @@ -108,6 +108,7 @@ class PackageImporter(Importer): self.filename = "" self.zip_reader = file_or_buffer elif isinstance(file_or_buffer, (os.PathLike, str)): + # pyrefly: ignore # no-matching-overload self.filename = os.fspath(file_or_buffer) if not os.path.isdir(self.filename): self.zip_reader = torch._C.PyTorchFileReader(self.filename) diff --git a/torch/quantization/qconfig.py b/torch/quantization/qconfig.py index a02ff7d6f738..75398d3343f9 100644 --- a/torch/quantization/qconfig.py +++ b/torch/quantization/qconfig.py @@ -27,5 +27,5 @@ from torch.ao.quantization.qconfig import ( QConfig, qconfig_equals, QConfigAny, - QConfigDynamic, + QConfigDynamic, # pyrefly: ignore # deprecated ) diff --git a/torch/serialization.py b/torch/serialization.py index 1cda549821d0..dcdbf0c3cef9 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -774,7 +774,10 @@ def _open_file_like(name_or_buffer: FileLike, mode: str) -> _opener[IO[bytes]]: class _open_zipfile_reader(_opener[torch._C.PyTorchFileReader]): def __init__(self, name_or_buffer: Union[str, IO[bytes]]) -> None: - super().__init__(torch._C.PyTorchFileReader(name_or_buffer)) + super().__init__( + # pyrefly: ignore # no-matching-overload + torch._C.PyTorchFileReader(name_or_buffer) + ) class _open_zipfile_writer_file(_opener[torch._C.PyTorchFileWriter]): @@ -787,9 +790,10 @@ class _open_zipfile_writer_file(_opener[torch._C.PyTorchFileWriter]): # PyTorchFileWriter only supports ascii filename. # For filenames with non-ascii characters, we rely on Python # for writing out the file. + # pyrefly: ignore # bad-assignment self.file_stream = io.FileIO(self.name, mode="w") super().__init__( - torch._C.PyTorchFileWriter( + torch._C.PyTorchFileWriter( # pyrefly: ignore # no-matching-overload self.file_stream, get_crc32_options(), _get_storage_alignment() ) ) @@ -966,7 +970,7 @@ def save( _check_save_filelike(f) if isinstance(f, (str, os.PathLike)): - f = os.fspath(f) + f = os.fspath(f) # pyrefly: ignore # no-matching-overload if _use_new_zipfile_serialization: with _open_zipfile_writer(f) as opened_zipfile: @@ -1520,7 +1524,10 @@ def load( else: shared = False overall_storage = torch.UntypedStorage.from_file( - os.fspath(f), shared, size + # pyrefly: ignore # no-matching-overload + os.fspath(f), + shared, + size, ) if weights_only: try: diff --git a/torch/signal/windows/windows.py b/torch/signal/windows/windows.py index 83d62c503feb..ed240bda8160 100644 --- a/torch/signal/windows/windows.py +++ b/torch/signal/windows/windows.py @@ -326,7 +326,7 @@ def gaussian( requires_grad=requires_grad, ) - return torch.exp(-(k**2)) + return torch.exp(-(k**2)) # pyrefly: ignore # unsupported-operation @_add_docstr( @@ -397,11 +397,17 @@ def kaiser( ) # Avoid NaNs by casting `beta` to the appropriate dtype. + # pyrefly: ignore # bad-assignment beta = torch.tensor(beta, dtype=dtype, device=device) start = -beta constant = 2.0 * beta / (M if not sym else M - 1) - end = torch.minimum(beta, start + (M - 1) * constant) + end = torch.minimum( + # pyrefly: ignore # bad-argument-type + beta, + # pyrefly: ignore # bad-argument-type + start + (M - 1) * constant, + ) k = torch.linspace( start=start, @@ -413,7 +419,10 @@ def kaiser( requires_grad=requires_grad, ) - return torch.i0(torch.sqrt(beta * beta - torch.pow(k, 2))) / torch.i0(beta) + return torch.i0(torch.sqrt(beta * beta - torch.pow(k, 2))) / torch.i0( + # pyrefly: ignore # bad-argument-type + beta + ) @_add_docstr( diff --git a/torch/storage.py b/torch/storage.py index e651bc9d16eb..5fc60055cd71 100644 --- a/torch/storage.py +++ b/torch/storage.py @@ -618,7 +618,7 @@ def _get_storage_from_sequence(sequence, dtype, device): def _isint(x): if HAS_NUMPY: - return isinstance(x, (int, np.integer)) + return isinstance(x, (int, np.integer)) # pyrefly: ignore # missing-attribute else: return isinstance(x, int) diff --git a/torch/xpu/__init__.py b/torch/xpu/__init__.py index df2b8a6f5334..0c7f4cd3ec6b 100644 --- a/torch/xpu/__init__.py +++ b/torch/xpu/__init__.py @@ -247,7 +247,9 @@ def get_device_capability(device: Optional[_device_t] = None) -> dict[str, Any]: } -def get_device_properties(device: Optional[_device_t] = None) -> _XpuDeviceProperties: +def get_device_properties( + device: Optional[_device_t] = None, +) -> _XpuDeviceProperties: # pyrefly: ignore # not-a-type r"""Get the properties of a device. Args: @@ -315,7 +317,7 @@ class StreamContext: self.stream = stream self.idx = _get_device_index(None, True) if self.idx is None: - self.idx = -1 + self.idx = -1 # pyrefly: ignore # bad-assignment def __enter__(self): cur_stream = self.stream diff --git a/torch/xpu/streams.py b/torch/xpu/streams.py index dd381cf83419..378e71074c18 100644 --- a/torch/xpu/streams.py +++ b/torch/xpu/streams.py @@ -126,7 +126,7 @@ class Event(torch._C._XpuEventBase): """ if stream is None: stream = torch.xpu.current_stream() - super().record(stream) + super().record(stream) # pyrefly: ignore # bad-argument-type def wait(self, stream=None) -> None: r"""Make all future work submitted to the given stream wait for this event.