Add initial suppressions for pyrefly (#164177)

Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

Test plan:
`python3 scripts/lintrunner.py`
`pyrefly check`

---

Pyrefly check before: https://gist.github.com/maggiemoss/3a0aa0b6cdda0e449cd5743d5fce2c60
After:

```
 INFO Checking project configured at `/Users/maggiemoss/python_projects/pytorch/pyrefly.toml`
 INFO 0 errors (1,063 ignored)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164177
Approved by: https://github.com/Lucaskabela
This commit is contained in:
Maggie Moss
2025-10-02 20:57:37 +00:00
committed by PyTorch MergeBot
parent 6b7970192f
commit 5f18f240de
38 changed files with 170 additions and 75 deletions

View File

@ -43,13 +43,17 @@ project-excludes = [
"torch/profiler/**",
"torch/_prims_common/**",
"torch/backends/**",
"torch/testing/**",
# "torch/testing/**",
"torch/_C/**",
"torch/sparse/**",
"torch/_library/**",
"torch/_prims/**",
"torch/_decomp/**",
"torch/_meta_registrations.py",
# formatting issues
"torch/linalg/__init__.py",
"torch/package/importer.py",
"torch/package/_package_pickler.py",
# ====
"benchmarks/instruction_counts/main.py",
"benchmarks/instruction_counts/definitions/setup.py",

View File

@ -58,17 +58,20 @@ class TestBundledInputs(TestCase):
# Make sure the model only grew a little bit,
# despite having nominally large bundled inputs.
augmented_size = model_size(sm)
# pyrefly: ignore # missing-attribute
self.assertLess(augmented_size, original_size + (1 << 12))
loaded = save_and_load(sm)
inflated = loaded.get_all_bundled_inputs()
self.assertEqual(loaded.get_num_bundled_inputs(), len(samples))
self.assertEqual(len(inflated), len(samples))
# pyrefly: ignore # missing-attribute
self.assertTrue(loaded(*inflated[0]) is inflated[0][0])
for idx, inp in enumerate(inflated):
self.assertIsInstance(inp, tuple)
self.assertIsInstance(inp, tuple) # pyrefly: ignore # missing-attribute
self.assertEqual(len(inp), 1)
# pyrefly: ignore # missing-attribute
self.assertIsInstance(inp[0], torch.Tensor)
if idx != 5:
# Strides might be important for benchmarking.
@ -136,6 +139,7 @@ class TestBundledInputs(TestCase):
loaded = save_and_load(sm)
inflated = loaded.get_all_bundled_inputs()
self.assertEqual(inflated, samples)
# pyrefly: ignore # missing-attribute
self.assertTrue(loaded(*inflated[0]) == "first 1")
def test_multiple_methods_with_inputs(self):
@ -182,6 +186,7 @@ class TestBundledInputs(TestCase):
self.assertEqual(inflated, loaded.get_all_bundled_inputs_for_foo())
# Check running and size helpers
# pyrefly: ignore # missing-attribute
self.assertTrue(loaded(*inflated[0]) is inflated[0][0])
self.assertEqual(loaded.get_num_bundled_inputs(), len(samples))
@ -414,6 +419,7 @@ class TestBundledInputs(TestCase):
)
augmented_size = model_size(sm)
# assert the size has not increased more than 8KB
# pyrefly: ignore # missing-attribute
self.assertLess(augmented_size, original_size + (1 << 13))
loaded = save_and_load(sm)

View File

@ -48,7 +48,7 @@ class TestComplexTensor(TestCase):
def test_all(self, device, dtype):
# issue: https://github.com/pytorch/pytorch/issues/120875
x = torch.tensor([1 + 2j, 3 - 4j, 5j, 6], device=device, dtype=dtype)
self.assertTrue(torch.all(x))
self.assertTrue(torch.all(x)) # pyrefly: ignore # missing-attribute
@dtypes(*complex_types())
def test_any(self, device, dtype):
@ -56,7 +56,7 @@ class TestComplexTensor(TestCase):
x = torch.tensor(
[0, 0j, -0 + 0j, -0 - 0j, 0 + 0j, 0 - 0j], device=device, dtype=dtype
)
self.assertFalse(torch.any(x))
self.assertFalse(torch.any(x)) # pyrefly: ignore # missing-attribute
@onlyCPU
@dtypes(*complex_types())

View File

@ -142,6 +142,7 @@ class TestTypeHints(TestCase):
]
)
if result != 0:
# pyrefly: ignore # missing-attribute
self.fail(f"mypy failed:\n{stderr}\n{stdout}")

View File

@ -125,7 +125,7 @@ class TestDTypeInfo(TestCase):
# Regression test for https://github.com/pytorch/pytorch/issues/124868
# If reference count is leaked this would be a set of 10 elements
ref_cnt = {sys.getrefcount(torch.float32.to_complex()) for _ in range(10)}
self.assertLess(len(ref_cnt), 3)
self.assertLess(len(ref_cnt), 3) # pyrefly: ignore # missing-attribute
self.assertEqual(torch.float64.to_complex(), torch.complex128)
self.assertEqual(torch.float32.to_complex(), torch.complex64)
@ -135,7 +135,7 @@ class TestDTypeInfo(TestCase):
# Regression test for https://github.com/pytorch/pytorch/issues/124868
# If reference count is leaked this would be a set of 10 elements
ref_cnt = {sys.getrefcount(torch.cfloat.to_real()) for _ in range(10)}
self.assertLess(len(ref_cnt), 3)
self.assertLess(len(ref_cnt), 3) # pyrefly: ignore # missing-attribute
self.assertEqual(torch.complex128.to_real(), torch.double)
self.assertEqual(torch.complex64.to_real(), torch.float32)

View File

@ -1699,7 +1699,7 @@ def _check(cond, message=None): # noqa: F811
an object that has a ``__str__()`` method to be used as the error
message. Default: ``None``
"""
_check_with(RuntimeError, cond, message)
_check_with(RuntimeError, cond, message) # pyrefly: ignore # bad-argument-type
def _check_is_size(i, message=None, *, max=None):
@ -1748,7 +1748,7 @@ def _check_index(cond, message=None): # noqa: F811
an object that has a ``__str__()`` method to be used as the error
message. Default: ``None``
"""
_check_with(IndexError, cond, message)
_check_with(IndexError, cond, message) # pyrefly: ignore # bad-argument-type
def _check_value(cond, message=None): # noqa: F811
@ -1766,7 +1766,7 @@ def _check_value(cond, message=None): # noqa: F811
an object that has a ``__str__()`` method to be used as the error
message. Default: ``None``
"""
_check_with(ValueError, cond, message)
_check_with(ValueError, cond, message) # pyrefly: ignore # bad-argument-type
def _check_type(cond, message=None): # noqa: F811
@ -1784,7 +1784,7 @@ def _check_type(cond, message=None): # noqa: F811
an object that has a ``__str__()`` method to be used as the error
message. Default: ``None``
"""
_check_with(TypeError, cond, message)
_check_with(TypeError, cond, message) # pyrefly: ignore # bad-argument-type
def _check_not_implemented(cond, message=None): # noqa: F811
@ -1802,7 +1802,12 @@ def _check_not_implemented(cond, message=None): # noqa: F811
an object that has a ``__str__()`` method to be used as the error
message. Default: ``None``
"""
_check_with(NotImplementedError, cond, message)
_check_with(
NotImplementedError,
cond,
# pyrefly: ignore # bad-argument-type
message,
)
def _check_tensor_all_with(error_type, cond, message=None): # noqa: F811
@ -2612,7 +2617,7 @@ def compile(
def fn(model: _Callable[_InputT, _RetT]) -> _Callable[_InputT, _RetT]:
if model is None:
raise RuntimeError("Model can't be None")
return compile(
return compile( # pyrefly: ignore # no-matching-overload
model,
fullgraph=fullgraph,
dynamic=dynamic,

View File

@ -101,7 +101,7 @@ def custom_op(
lib, ns, function_schema, name, ophandle, _private_access=True
)
result.__name__ = func.__name__
result.__name__ = func.__name__ # pyrefly: ignore # bad-assignment
result.__module__ = func.__module__
result.__doc__ = func.__doc__

View File

@ -154,7 +154,7 @@ def make_crossref_functionalize(
maybe_detach, (f_args, f_kwargs)
)
with fake_mode:
f_r = op(*f_args, **f_kwargs)
f_r = op(*f_args, **f_kwargs) # pyrefly: ignore # invalid-param-spec
r = op._op_dk(final_key, *args, **kwargs)
def desc():

View File

@ -147,7 +147,7 @@ def _qualified_name(obj, mangle_name=True) -> str:
# If the module is actually a torchbind module, then we should short circuit
if module_name == "torch._classes":
return obj.qualified_name
return obj.qualified_name # pyrefly: ignore # missing-attribute
# The Python docs are very clear that `__module__` can be None, but I can't
# figure out when it actually would be.
@ -759,7 +759,7 @@ def unused(fn: Callable[_P, _R]) -> Callable[_P, _R]:
prop.fset, "_torchscript_modifier", FunctionModifiers.UNUSED
)
return prop
return prop # pyrefly: ignore # bad-return
fn._torchscript_modifier = FunctionModifiers.UNUSED # type: ignore[attr-defined]
return fn
@ -844,6 +844,7 @@ def ignore(drop=False, **kwargs):
# @torch.jit.ignore
# def fn(...):
fn = drop
# pyrefly: ignore # missing-attribute
fn._torchscript_modifier = FunctionModifiers.IGNORE
return fn
@ -1250,7 +1251,10 @@ def _get_named_tuple_properties(
obj_annotations = inspect.get_annotations(obj)
if len(obj_annotations) == 0 and hasattr(obj, "__base__"):
obj_annotations = inspect.get_annotations(obj.__base__)
obj_annotations = inspect.get_annotations(
# pyrefly: ignore # bad-argument-type
obj.__base__
)
annotations = []
for field in obj._fields:
@ -1439,7 +1443,9 @@ def container_checker(obj, target_type) -> bool:
return False
return True
elif origin_type is Union or issubclass(
origin_type, BuiltinUnionType
# pyrefly: ignore # bad-argument-type
origin_type,
BuiltinUnionType,
): # also handles Optional
if obj is None: # check before recursion because None is always fine
return True

View File

@ -63,8 +63,10 @@ class AsyncClosureHandler(ClosureHandler):
self._closure_exception.put(e)
return
self._closure_event_loop = threading.Thread(target=event_loop)
self._closure_event_loop.start()
self._closure_event_loop = threading.Thread(
target=event_loop
) # pyrefly: ignore # bad-assignment
self._closure_event_loop.start() # pyrefly: ignore # missing-attribute
def run(self, closure):
with self._closure_lock:

View File

@ -301,7 +301,7 @@ class LOBPCGAutogradFunction(torch.autograd.Function):
return D, U
@staticmethod
def backward(ctx, D_grad, U_grad):
def backward(ctx, D_grad, U_grad): # pyrefly: ignore # bad-override
A_grad = B_grad = None
grads = [None] * 14
@ -1048,7 +1048,11 @@ class LOBPCG:
else:
E[(torch.where(E < t))[0]] = t
return torch.matmul(U * d_col.mT, Z * E**-0.5)
return torch.matmul(
U * d_col.mT,
# pyrefly: ignore # unsupported-operation
Z * E**-0.5,
)
def _get_ortho(self, U, V):
"""Return B-orthonormal U with columns are B-orthogonal to V.

View File

@ -803,7 +803,7 @@ class OpOverload(OperatorBase, Generic[_P, _T]):
# Logic replicated from aten/src/ATen/native/MathBitsFallback.h
is_write = None
for a in self._schema.arguments:
for a in self._schema.arguments: # pyrefly: ignore # bad-assignment
if a.alias_info is None:
continue
if is_write is None:
@ -885,7 +885,7 @@ class OpOverload(OperatorBase, Generic[_P, _T]):
elif torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), dk):
return self._op_dk(dk, *args, **kwargs)
else:
return NotImplemented
return NotImplemented # pyrefly: ignore # bad-return
# Remove a dispatch key from the dispatch cache. This will force it to get
# recomputed the next time. Does nothing
@ -990,9 +990,9 @@ class OpOverload(OperatorBase, Generic[_P, _T]):
r = self.py_kernels.get(final_key, final_key)
if cache_result:
self._dispatch_cache[key] = r
self._dispatch_cache[key] = r # pyrefly: ignore # unsupported-operation
add_cached_op(self)
return r
return r # pyrefly: ignore # bad-return
def name(self):
return self._name
@ -1122,7 +1122,7 @@ class TorchBindOpOverload(OpOverload[_P, _T]):
)
assert isinstance(handler, Callable) # type: ignore[arg-type]
return handler(*args, **kwargs)
return handler(*args, **kwargs) # pyrefly: ignore # bad-return
def _must_dispatch_in_python(args, kwargs):
@ -1251,6 +1251,7 @@ class OpOverloadPacket(Generic[_P, _T]):
# the schema and cause an error for torchbind op when inputs consist of FakeScriptObject so we
# intercept it here and call TorchBindOpverload instead.
if self._has_torchbind_op_overload and _must_dispatch_in_python(args, kwargs):
# pyrefly: ignore # bad-argument-type
return _call_overload_packet_from_python(self, *args, **kwargs)
return self._op(*args, **kwargs)

View File

@ -314,6 +314,7 @@ def strobelight(
) -> Callable[_P, Optional[_R]]:
@functools.wraps(work_function)
def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> Optional[_R]:
# pyrefly: ignore # bad-argument-type
return profiler.profile(work_function, *args, **kwargs)
return wrapper_function

View File

@ -145,7 +145,7 @@ class StrobelightCompileTimeProfiler:
async_stack_max_len=cls.max_stack_length,
run_user_name="pt2-profiler/"
+ os.environ.get("USER", os.environ.get("USERNAME", "")),
sample_tags={cls.identifier},
sample_tags={cls.identifier}, # pyrefly: ignore # bad-argument-type
)
@classmethod

View File

@ -756,7 +756,10 @@ class Tensor(torch._C.TensorBase):
"post accumulate grad hooks cannot be registered on non-leaf tensors"
)
if self._post_accumulate_grad_hooks is None:
self._post_accumulate_grad_hooks: dict[Any, Any] = OrderedDict()
self._post_accumulate_grad_hooks: dict[Any, Any] = (
# pyrefly: ignore # bad-assignment
OrderedDict()
)
from torch.utils.hooks import RemovableHandle
@ -1056,7 +1059,12 @@ class Tensor(torch._C.TensorBase):
if isinstance(split_size, (int, torch.SymInt)):
return torch._VF.split(self, split_size, dim) # type: ignore[attr-defined]
else:
return torch._VF.split_with_sizes(self, split_size, dim)
return torch._VF.split_with_sizes(
self,
# pyrefly: ignore # bad-argument-type
split_size,
dim,
)
def unique(self, sorted=True, return_inverse=False, return_counts=False, dim=None):
r"""Returns the unique elements of the input tensor.
@ -1101,6 +1109,7 @@ class Tensor(torch._C.TensorBase):
@_handle_torch_function_and_wrap_type_error_to_not_implemented
def __rsub__(self, other: Union["Tensor", int, float, bool, complex]) -> "Tensor":
# pyrefly: ignore # no-matching-overload
return _C._VariableFunctions.rsub(self, other)
@_handle_torch_function_and_wrap_type_error_to_not_implemented
@ -1126,7 +1135,7 @@ class Tensor(torch._C.TensorBase):
@_handle_torch_function_and_wrap_type_error_to_not_implemented
def __rmod__(self, other: Union["Tensor", int, float, bool, complex]) -> "Tensor":
return torch.remainder(other, self)
return torch.remainder(other, self) # pyrefly: ignore # no-matching-overload
def __format__(self, format_spec):
if has_torch_function_unary(self):
@ -1139,7 +1148,7 @@ class Tensor(torch._C.TensorBase):
@_handle_torch_function_and_wrap_type_error_to_not_implemented
def __rpow__(self, other: Union["Tensor", int, float, bool, complex]) -> "Tensor":
return torch.pow(other, self)
return torch.pow(other, self) # pyrefly: ignore # no-matching-overload
@_handle_torch_function_and_wrap_type_error_to_not_implemented
def __floordiv__(self, other: Union["Tensor", int, float, bool]) -> "Tensor": # type: ignore[override]
@ -1155,12 +1164,14 @@ class Tensor(torch._C.TensorBase):
def __rlshift__(
self, other: Union["Tensor", int, float, bool, complex]
) -> "Tensor":
# pyrefly: ignore # no-matching-overload
return torch.bitwise_left_shift(other, self)
@_handle_torch_function_and_wrap_type_error_to_not_implemented
def __rrshift__(
self, other: Union["Tensor", int, float, bool, complex]
) -> "Tensor":
# pyrefly: ignore # no-matching-overload
return torch.bitwise_right_shift(other, self)
@_handle_torch_function_and_wrap_type_error_to_not_implemented
@ -1335,7 +1346,7 @@ class Tensor(torch._C.TensorBase):
return self._typed_storage()._get_legacy_storage_class()
def refine_names(self, *names):
def refine_names(self, *names): # pyrefly: ignore # bad-override
r"""Refines the dimension names of :attr:`self` according to :attr:`names`.
Refining is a special case of renaming that "lifts" unnamed dimensions.
@ -1379,7 +1390,7 @@ class Tensor(torch._C.TensorBase):
names = resolve_ellipsis(names, self.names, "refine_names")
return super().refine_names(names)
def align_to(self, *names):
def align_to(self, *names): # pyrefly: ignore # bad-override
r"""Permutes the dimensions of the :attr:`self` tensor to match the order
specified in :attr:`names`, adding size-one dims for any new names.

View File

@ -686,8 +686,8 @@ def _take_tensors(tensors, size_limit):
if buf_and_size[1] + size > size_limit and buf_and_size[1] > 0:
yield buf_and_size[0]
buf_and_size = buf_dict[t] = [[], 0]
buf_and_size[0].append(tensor)
buf_and_size[1] += size
buf_and_size[0].append(tensor) # pyrefly: ignore # missing-attribute
buf_and_size[1] += size # pyrefly: ignore # unsupported-operation
for buf, _ in buf_dict.values():
if len(buf) > 0:
yield buf
@ -744,14 +744,17 @@ class ExceptionWrapper:
if exc_info is None:
exc_info = sys.exc_info()
self.exc_type = exc_info[0]
self.exc_msg = "".join(traceback.format_exception(*exc_info))
self.exc_msg = "".join(
# pyrefly: ignore # no-matching-overload
traceback.format_exception(*exc_info)
)
self.where = where
def reraise(self):
r"""Reraises the wrapped exception in the current thread"""
# Format a message such as: "Caught ValueError in DataLoader worker
# process 2. Original Traceback:", followed by the traceback.
msg = f"Caught {self.exc_type.__name__} {self.where}.\nOriginal {self.exc_msg}"
msg = f"Caught {self.exc_type.__name__} {self.where}.\nOriginal {self.exc_msg}" # pyrefly: ignore # missing-attribute
if self.exc_type == KeyError:
# KeyError calls repr() on its argument (usually a dict key). This
# makes stack traces unreadable. It will not be changed in Python
@ -760,9 +763,13 @@ class ExceptionWrapper:
elif getattr(self.exc_type, "message", None):
# Some exceptions have first argument as non-str but explicitly
# have message field
raise self.exc_type(message=msg)
# pyrefly: ignore # not-callable
raise self.exc_type(
# pyrefly: ignore # unexpected-keyword
message=msg
)
try:
exception = self.exc_type(msg)
exception = self.exc_type(msg) # pyrefly: ignore # not-callable
except Exception:
# If the exception takes multiple arguments or otherwise can't
# be constructed, don't try to instantiate since we don't know how to
@ -1014,12 +1021,12 @@ class _LazySeedTracker:
self.call_order = []
def queue_seed_all(self, cb, traceback):
self.manual_seed_all_cb = (cb, traceback)
self.manual_seed_all_cb = (cb, traceback) # pyrefly: ignore # bad-assignment
# update seed_all to be latest
self.call_order = [self.manual_seed_cb, self.manual_seed_all_cb]
def queue_seed(self, cb, traceback):
self.manual_seed_cb = (cb, traceback)
self.manual_seed_cb = (cb, traceback) # pyrefly: ignore # bad-assignment
# update seed to be latest
self.call_order = [self.manual_seed_all_cb, self.manual_seed_cb]

View File

@ -84,7 +84,11 @@ def compile_time_strobelight_meta(
) -> Callable[_P, _T]:
@functools.wraps(function)
def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> _T:
if "skip" in kwargs and isinstance(skip := kwargs["skip"], int):
if "skip" in kwargs and isinstance(
# pyrefly: ignore # unsupported-operation
skip := kwargs["skip"],
int,
):
kwargs["skip"] = skip + 1
# This is not needed but we have it here to avoid having profile_compile_time
@ -327,7 +331,10 @@ def deprecated():
# public deprecated alias
alias = typing_extensions.deprecated(
warning_msg, category=UserWarning, stacklevel=1
# pyrefly: ignore # bad-argument-type
warning_msg,
category=UserWarning,
stacklevel=1,
)(func)
alias.__name__ = public_name

View File

@ -464,7 +464,11 @@ def _cast(value, device_type: str, dtype: _dtype):
return value.to(dtype) if is_eligible else value
elif isinstance(value, (str, bytes)):
return value
elif HAS_NUMPY and isinstance(value, np.ndarray):
elif HAS_NUMPY and isinstance(
value,
# pyrefly: ignore # missing-attribute
np.ndarray,
):
return value
elif isinstance(value, collections.abc.Mapping):
return {
@ -521,18 +525,18 @@ def custom_fwd(
args[0]._dtype = torch.get_autocast_dtype(device_type)
if cast_inputs is None:
args[0]._fwd_used_autocast = torch.is_autocast_enabled(device_type)
return fwd(*args, **kwargs)
return fwd(*args, **kwargs) # pyrefly: ignore # not-callable
else:
autocast_context = torch.is_autocast_enabled(device_type)
args[0]._fwd_used_autocast = False
if autocast_context:
with autocast(device_type=device_type, enabled=False):
return fwd(
return fwd( # pyrefly: ignore # not-callable
*_cast(args, device_type, cast_inputs),
**_cast(kwargs, device_type, cast_inputs),
)
else:
return fwd(*args, **kwargs)
return fwd(*args, **kwargs) # pyrefly: ignore # not-callable
return decorate_fwd
@ -567,6 +571,6 @@ def custom_bwd(bwd=None, *, device_type: str):
enabled=args[0]._fwd_used_autocast,
dtype=args[0]._dtype,
):
return bwd(*args, **kwargs)
return bwd(*args, **kwargs) # pyrefly: ignore # not-callable
return decorate_bwd

View File

@ -1,2 +1,3 @@
# pyrefly: ignore # deprecated
from .autocast_mode import autocast
from .grad_scaler import GradScaler

View File

@ -1784,7 +1784,9 @@ def norm( # noqa: F811
if isinstance(p, str):
if p == "fro" and (
dim is None or isinstance(dim, (int, torch.SymInt)) or len(dim) <= 2
dim is None
or isinstance(dim, (int, torch.SymInt))
or len(dim) <= 2 # pyrefly: ignore # bad-argument-type
):
if out is None:
return torch.linalg.vector_norm(
@ -1950,7 +1952,7 @@ def _unravel_index(indices: Tensor, shape: Union[int, Sequence[int]]) -> Tensor:
)
if isinstance(shape, (int, torch.SymInt)):
shape = torch.Size([shape])
shape = torch.Size([shape]) # pyrefly: ignore # bad-argument-type
else:
for dim in shape:
torch._check_type(

View File

@ -421,7 +421,7 @@ def set_dir(d: Union[str, os.PathLike]) -> None:
d (str): path to a local folder to save downloaded models & weights.
"""
global _hub_dir
_hub_dir = os.path.expanduser(d)
_hub_dir = os.path.expanduser(d) # pyrefly: ignore # no-matching-overload
def list(

View File

@ -242,6 +242,7 @@ class Library:
if dispatch_key == "":
dispatch_key = self.dispatch_key
# pyrefly: ignore # bad-argument-type
assert torch.DispatchKeySet(dispatch_key).has(torch._C.DispatchKey.Dense)
if isinstance(op_name, str):
@ -643,6 +644,7 @@ def impl(
>>> y2 = torch.sin(x) + 1
>>> assert torch.allclose(y1, y2)
"""
# pyrefly: ignore # no-matching-overload
return _impl(qualname, types, func, lib=lib, disable_dynamo=False)
@ -829,6 +831,7 @@ def register_kernel(
if device_types is None:
device_types = "CompositeExplicitAutograd"
# pyrefly: ignore # no-matching-overload
return _impl(op, device_types, func, lib=lib, disable_dynamo=True)

View File

@ -1,7 +1,7 @@
from torch._C import ( # type: ignore[attr-defined]
from torch._C import ( # type: ignore[attr-defined] # pyrefly: ignore # missing-module-attribute
_add_docstr,
_linalg,
_LinAlgError as LinAlgError,
_LinAlgError as LinAlgError, # pyrefly: ignore # missing-module-attribute
)

View File

@ -303,7 +303,7 @@ class StreamContext:
self.idx = _get_device_index(None, True)
if not torch.jit.is_scripting():
if self.idx is None:
self.idx = -1
self.idx = -1 # pyrefly: ignore # bad-assignment
self.src_prev_stream = (
None if not torch.jit.is_scripting() else torch.mtia.default_stream(None)

View File

@ -119,6 +119,7 @@ class ProcessContext:
"""Attempt to join all processes with a shared timeout."""
end = time.monotonic() + timeout
for process in self.processes:
# pyrefly: ignore # no-matching-overload
time_to_wait = max(0, end - time.monotonic())
process.join(time_to_wait)
@ -274,7 +275,7 @@ def start_processes(
tf.close()
os.unlink(tf.name)
process = mp.Process(
process = mp.Process( # pyrefly: ignore # missing-attribute
target=_wrap,
args=(fn, i, args, tf.name),
daemon=daemon,

View File

@ -406,7 +406,7 @@ class ViewBufferFromNested(torch.autograd.Function):
# Not actually a view!
class ViewNestedFromBuffer(torch.autograd.Function):
@staticmethod
def forward(
def forward( # pyrefly: ignore # bad-override
ctx,
values: torch.Tensor,
offsets: torch.Tensor,

View File

@ -46,11 +46,14 @@ def _outer_to_inner_dim(ndim, dim, ragged_dim, canonicalize=False):
if canonicalize:
dim = canonicalize_dims(ndim, dim)
assert dim >= 0 and dim < ndim
assert dim >= 0 and dim < ndim # pyrefly: ignore # unsupported-operation
# Map dim=0 (AKA batch dim) -> packed dim i.e. outer ragged dim - 1.
# For other dims, subtract 1 to convert to inner space.
return ragged_dim - 1 if dim == 0 else dim - 1
return (
# pyrefly: ignore # unsupported-operation
ragged_dim - 1 if dim == 0 else dim - 1
)
def _wrap_jagged_dim(
@ -2005,6 +2008,7 @@ def index_put_(func, *args, **kwargs):
else:
lengths = inp.lengths()
torch._assert_async(
# pyrefly: ignore # no-matching-overload
torch.all(indices[inp._ragged_idx] < lengths),
"Some indices in the ragged dimension are out of bounds!",
)

View File

@ -134,7 +134,8 @@ def _raise_if_logical_cpu_indices_invalid(*, logical_cpu_indices: set[int]) -> N
def _bind_current_thread_to_logical_cpus(*, logical_cpu_indices: set[int]) -> None:
# 0 represents the current thread
os.sched_setaffinity(0, logical_cpu_indices)
# pyrefly: ignore # missing-attribute
os.sched_setaffinity(0, logical_cpu_indices) # type: ignore[attr-defined]
def _get_logical_cpus_to_bind_to(
@ -544,4 +545,5 @@ def _get_numa_node_indices_for_socket_index(*, socket_index: int) -> set[int]:
def _get_allowed_cpu_indices_for_current_thread() -> set[int]:
# 0 denotes current thread
return os.sched_getaffinity(0)
# pyrefly: ignore # missing-attribute
return os.sched_getaffinity(0) # type:ignore[attr-defined]

View File

@ -1,4 +1,5 @@
# mypy: allow-untyped-defs
# pyrefly: ignore # missing-module-attribute
from pickle import ( # type: ignore[attr-defined]
_compat_pickle,
_extension_registry,

View File

@ -2,10 +2,12 @@
import importlib
import logging
from abc import ABC, abstractmethod
# pyrefly: ignore # missing-module-attribute
from pickle import ( # type: ignore[attr-defined]
_getattribute,
_Pickler,
whichmodule as _pickle_whichmodule,
whichmodule as _pickle_whichmodule, # pyrefly: ignore # missing-module-attribute
)
from types import ModuleType
from typing import Any, Optional

View File

@ -219,7 +219,7 @@ class PackageExporter:
torch._C._log_api_usage_once("torch.package.PackageExporter")
self.debug = debug
if isinstance(f, (str, os.PathLike)):
f = os.fspath(f)
f = os.fspath(f) # pyrefly: ignore # no-matching-overload
self.buffer: Optional[IO[bytes]] = None
else: # is a byte buffer
self.buffer = f
@ -652,6 +652,7 @@ class PackageExporter:
memo: defaultdict[int, str] = defaultdict(None)
memo_count = 0
# pickletools.dis(data_value)
# pyrefly: ignore # bad-assignment
for opcode, arg, _pos in pickletools.genops(data_value):
if pickle_protocol == 4:
if (

View File

@ -108,6 +108,7 @@ class PackageImporter(Importer):
self.filename = "<pytorch_file_reader>"
self.zip_reader = file_or_buffer
elif isinstance(file_or_buffer, (os.PathLike, str)):
# pyrefly: ignore # no-matching-overload
self.filename = os.fspath(file_or_buffer)
if not os.path.isdir(self.filename):
self.zip_reader = torch._C.PyTorchFileReader(self.filename)

View File

@ -27,5 +27,5 @@ from torch.ao.quantization.qconfig import (
QConfig,
qconfig_equals,
QConfigAny,
QConfigDynamic,
QConfigDynamic, # pyrefly: ignore # deprecated
)

View File

@ -774,7 +774,10 @@ def _open_file_like(name_or_buffer: FileLike, mode: str) -> _opener[IO[bytes]]:
class _open_zipfile_reader(_opener[torch._C.PyTorchFileReader]):
def __init__(self, name_or_buffer: Union[str, IO[bytes]]) -> None:
super().__init__(torch._C.PyTorchFileReader(name_or_buffer))
super().__init__(
# pyrefly: ignore # no-matching-overload
torch._C.PyTorchFileReader(name_or_buffer)
)
class _open_zipfile_writer_file(_opener[torch._C.PyTorchFileWriter]):
@ -787,9 +790,10 @@ class _open_zipfile_writer_file(_opener[torch._C.PyTorchFileWriter]):
# PyTorchFileWriter only supports ascii filename.
# For filenames with non-ascii characters, we rely on Python
# for writing out the file.
# pyrefly: ignore # bad-assignment
self.file_stream = io.FileIO(self.name, mode="w")
super().__init__(
torch._C.PyTorchFileWriter(
torch._C.PyTorchFileWriter( # pyrefly: ignore # no-matching-overload
self.file_stream, get_crc32_options(), _get_storage_alignment()
)
)
@ -966,7 +970,7 @@ def save(
_check_save_filelike(f)
if isinstance(f, (str, os.PathLike)):
f = os.fspath(f)
f = os.fspath(f) # pyrefly: ignore # no-matching-overload
if _use_new_zipfile_serialization:
with _open_zipfile_writer(f) as opened_zipfile:
@ -1520,7 +1524,10 @@ def load(
else:
shared = False
overall_storage = torch.UntypedStorage.from_file(
os.fspath(f), shared, size
# pyrefly: ignore # no-matching-overload
os.fspath(f),
shared,
size,
)
if weights_only:
try:

View File

@ -326,7 +326,7 @@ def gaussian(
requires_grad=requires_grad,
)
return torch.exp(-(k**2))
return torch.exp(-(k**2)) # pyrefly: ignore # unsupported-operation
@_add_docstr(
@ -397,11 +397,17 @@ def kaiser(
)
# Avoid NaNs by casting `beta` to the appropriate dtype.
# pyrefly: ignore # bad-assignment
beta = torch.tensor(beta, dtype=dtype, device=device)
start = -beta
constant = 2.0 * beta / (M if not sym else M - 1)
end = torch.minimum(beta, start + (M - 1) * constant)
end = torch.minimum(
# pyrefly: ignore # bad-argument-type
beta,
# pyrefly: ignore # bad-argument-type
start + (M - 1) * constant,
)
k = torch.linspace(
start=start,
@ -413,7 +419,10 @@ def kaiser(
requires_grad=requires_grad,
)
return torch.i0(torch.sqrt(beta * beta - torch.pow(k, 2))) / torch.i0(beta)
return torch.i0(torch.sqrt(beta * beta - torch.pow(k, 2))) / torch.i0(
# pyrefly: ignore # bad-argument-type
beta
)
@_add_docstr(

View File

@ -618,7 +618,7 @@ def _get_storage_from_sequence(sequence, dtype, device):
def _isint(x):
if HAS_NUMPY:
return isinstance(x, (int, np.integer))
return isinstance(x, (int, np.integer)) # pyrefly: ignore # missing-attribute
else:
return isinstance(x, int)

View File

@ -247,7 +247,9 @@ def get_device_capability(device: Optional[_device_t] = None) -> dict[str, Any]:
}
def get_device_properties(device: Optional[_device_t] = None) -> _XpuDeviceProperties:
def get_device_properties(
device: Optional[_device_t] = None,
) -> _XpuDeviceProperties: # pyrefly: ignore # not-a-type
r"""Get the properties of a device.
Args:
@ -315,7 +317,7 @@ class StreamContext:
self.stream = stream
self.idx = _get_device_index(None, True)
if self.idx is None:
self.idx = -1
self.idx = -1 # pyrefly: ignore # bad-assignment
def __enter__(self):
cur_stream = self.stream

View File

@ -126,7 +126,7 @@ class Event(torch._C._XpuEventBase):
"""
if stream is None:
stream = torch.xpu.current_stream()
super().record(stream)
super().record(stream) # pyrefly: ignore # bad-argument-type
def wait(self, stream=None) -> None:
r"""Make all future work submitted to the given stream wait for this event.