mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Add initial suppressions for pyrefly (#164177)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283 Test plan: `python3 scripts/lintrunner.py` `pyrefly check` --- Pyrefly check before: https://gist.github.com/maggiemoss/3a0aa0b6cdda0e449cd5743d5fce2c60 After: ``` INFO Checking project configured at `/Users/maggiemoss/python_projects/pytorch/pyrefly.toml` INFO 0 errors (1,063 ignored) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/164177 Approved by: https://github.com/Lucaskabela
This commit is contained in:
committed by
PyTorch MergeBot
parent
6b7970192f
commit
5f18f240de
@ -43,13 +43,17 @@ project-excludes = [
|
||||
"torch/profiler/**",
|
||||
"torch/_prims_common/**",
|
||||
"torch/backends/**",
|
||||
"torch/testing/**",
|
||||
# "torch/testing/**",
|
||||
"torch/_C/**",
|
||||
"torch/sparse/**",
|
||||
"torch/_library/**",
|
||||
"torch/_prims/**",
|
||||
"torch/_decomp/**",
|
||||
"torch/_meta_registrations.py",
|
||||
# formatting issues
|
||||
"torch/linalg/__init__.py",
|
||||
"torch/package/importer.py",
|
||||
"torch/package/_package_pickler.py",
|
||||
# ====
|
||||
"benchmarks/instruction_counts/main.py",
|
||||
"benchmarks/instruction_counts/definitions/setup.py",
|
||||
|
@ -58,17 +58,20 @@ class TestBundledInputs(TestCase):
|
||||
# Make sure the model only grew a little bit,
|
||||
# despite having nominally large bundled inputs.
|
||||
augmented_size = model_size(sm)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self.assertLess(augmented_size, original_size + (1 << 12))
|
||||
|
||||
loaded = save_and_load(sm)
|
||||
inflated = loaded.get_all_bundled_inputs()
|
||||
self.assertEqual(loaded.get_num_bundled_inputs(), len(samples))
|
||||
self.assertEqual(len(inflated), len(samples))
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self.assertTrue(loaded(*inflated[0]) is inflated[0][0])
|
||||
|
||||
for idx, inp in enumerate(inflated):
|
||||
self.assertIsInstance(inp, tuple)
|
||||
self.assertIsInstance(inp, tuple) # pyrefly: ignore # missing-attribute
|
||||
self.assertEqual(len(inp), 1)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self.assertIsInstance(inp[0], torch.Tensor)
|
||||
if idx != 5:
|
||||
# Strides might be important for benchmarking.
|
||||
@ -136,6 +139,7 @@ class TestBundledInputs(TestCase):
|
||||
loaded = save_and_load(sm)
|
||||
inflated = loaded.get_all_bundled_inputs()
|
||||
self.assertEqual(inflated, samples)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self.assertTrue(loaded(*inflated[0]) == "first 1")
|
||||
|
||||
def test_multiple_methods_with_inputs(self):
|
||||
@ -182,6 +186,7 @@ class TestBundledInputs(TestCase):
|
||||
self.assertEqual(inflated, loaded.get_all_bundled_inputs_for_foo())
|
||||
|
||||
# Check running and size helpers
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self.assertTrue(loaded(*inflated[0]) is inflated[0][0])
|
||||
self.assertEqual(loaded.get_num_bundled_inputs(), len(samples))
|
||||
|
||||
@ -414,6 +419,7 @@ class TestBundledInputs(TestCase):
|
||||
)
|
||||
augmented_size = model_size(sm)
|
||||
# assert the size has not increased more than 8KB
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self.assertLess(augmented_size, original_size + (1 << 13))
|
||||
|
||||
loaded = save_and_load(sm)
|
||||
|
@ -48,7 +48,7 @@ class TestComplexTensor(TestCase):
|
||||
def test_all(self, device, dtype):
|
||||
# issue: https://github.com/pytorch/pytorch/issues/120875
|
||||
x = torch.tensor([1 + 2j, 3 - 4j, 5j, 6], device=device, dtype=dtype)
|
||||
self.assertTrue(torch.all(x))
|
||||
self.assertTrue(torch.all(x)) # pyrefly: ignore # missing-attribute
|
||||
|
||||
@dtypes(*complex_types())
|
||||
def test_any(self, device, dtype):
|
||||
@ -56,7 +56,7 @@ class TestComplexTensor(TestCase):
|
||||
x = torch.tensor(
|
||||
[0, 0j, -0 + 0j, -0 - 0j, 0 + 0j, 0 - 0j], device=device, dtype=dtype
|
||||
)
|
||||
self.assertFalse(torch.any(x))
|
||||
self.assertFalse(torch.any(x)) # pyrefly: ignore # missing-attribute
|
||||
|
||||
@onlyCPU
|
||||
@dtypes(*complex_types())
|
||||
|
@ -142,6 +142,7 @@ class TestTypeHints(TestCase):
|
||||
]
|
||||
)
|
||||
if result != 0:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self.fail(f"mypy failed:\n{stderr}\n{stdout}")
|
||||
|
||||
|
||||
|
@ -125,7 +125,7 @@ class TestDTypeInfo(TestCase):
|
||||
# Regression test for https://github.com/pytorch/pytorch/issues/124868
|
||||
# If reference count is leaked this would be a set of 10 elements
|
||||
ref_cnt = {sys.getrefcount(torch.float32.to_complex()) for _ in range(10)}
|
||||
self.assertLess(len(ref_cnt), 3)
|
||||
self.assertLess(len(ref_cnt), 3) # pyrefly: ignore # missing-attribute
|
||||
|
||||
self.assertEqual(torch.float64.to_complex(), torch.complex128)
|
||||
self.assertEqual(torch.float32.to_complex(), torch.complex64)
|
||||
@ -135,7 +135,7 @@ class TestDTypeInfo(TestCase):
|
||||
# Regression test for https://github.com/pytorch/pytorch/issues/124868
|
||||
# If reference count is leaked this would be a set of 10 elements
|
||||
ref_cnt = {sys.getrefcount(torch.cfloat.to_real()) for _ in range(10)}
|
||||
self.assertLess(len(ref_cnt), 3)
|
||||
self.assertLess(len(ref_cnt), 3) # pyrefly: ignore # missing-attribute
|
||||
|
||||
self.assertEqual(torch.complex128.to_real(), torch.double)
|
||||
self.assertEqual(torch.complex64.to_real(), torch.float32)
|
||||
|
@ -1699,7 +1699,7 @@ def _check(cond, message=None): # noqa: F811
|
||||
an object that has a ``__str__()`` method to be used as the error
|
||||
message. Default: ``None``
|
||||
"""
|
||||
_check_with(RuntimeError, cond, message)
|
||||
_check_with(RuntimeError, cond, message) # pyrefly: ignore # bad-argument-type
|
||||
|
||||
|
||||
def _check_is_size(i, message=None, *, max=None):
|
||||
@ -1748,7 +1748,7 @@ def _check_index(cond, message=None): # noqa: F811
|
||||
an object that has a ``__str__()`` method to be used as the error
|
||||
message. Default: ``None``
|
||||
"""
|
||||
_check_with(IndexError, cond, message)
|
||||
_check_with(IndexError, cond, message) # pyrefly: ignore # bad-argument-type
|
||||
|
||||
|
||||
def _check_value(cond, message=None): # noqa: F811
|
||||
@ -1766,7 +1766,7 @@ def _check_value(cond, message=None): # noqa: F811
|
||||
an object that has a ``__str__()`` method to be used as the error
|
||||
message. Default: ``None``
|
||||
"""
|
||||
_check_with(ValueError, cond, message)
|
||||
_check_with(ValueError, cond, message) # pyrefly: ignore # bad-argument-type
|
||||
|
||||
|
||||
def _check_type(cond, message=None): # noqa: F811
|
||||
@ -1784,7 +1784,7 @@ def _check_type(cond, message=None): # noqa: F811
|
||||
an object that has a ``__str__()`` method to be used as the error
|
||||
message. Default: ``None``
|
||||
"""
|
||||
_check_with(TypeError, cond, message)
|
||||
_check_with(TypeError, cond, message) # pyrefly: ignore # bad-argument-type
|
||||
|
||||
|
||||
def _check_not_implemented(cond, message=None): # noqa: F811
|
||||
@ -1802,7 +1802,12 @@ def _check_not_implemented(cond, message=None): # noqa: F811
|
||||
an object that has a ``__str__()`` method to be used as the error
|
||||
message. Default: ``None``
|
||||
"""
|
||||
_check_with(NotImplementedError, cond, message)
|
||||
_check_with(
|
||||
NotImplementedError,
|
||||
cond,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
message,
|
||||
)
|
||||
|
||||
|
||||
def _check_tensor_all_with(error_type, cond, message=None): # noqa: F811
|
||||
@ -2612,7 +2617,7 @@ def compile(
|
||||
def fn(model: _Callable[_InputT, _RetT]) -> _Callable[_InputT, _RetT]:
|
||||
if model is None:
|
||||
raise RuntimeError("Model can't be None")
|
||||
return compile(
|
||||
return compile( # pyrefly: ignore # no-matching-overload
|
||||
model,
|
||||
fullgraph=fullgraph,
|
||||
dynamic=dynamic,
|
||||
|
@ -101,7 +101,7 @@ def custom_op(
|
||||
lib, ns, function_schema, name, ophandle, _private_access=True
|
||||
)
|
||||
|
||||
result.__name__ = func.__name__
|
||||
result.__name__ = func.__name__ # pyrefly: ignore # bad-assignment
|
||||
result.__module__ = func.__module__
|
||||
result.__doc__ = func.__doc__
|
||||
|
||||
|
@ -154,7 +154,7 @@ def make_crossref_functionalize(
|
||||
maybe_detach, (f_args, f_kwargs)
|
||||
)
|
||||
with fake_mode:
|
||||
f_r = op(*f_args, **f_kwargs)
|
||||
f_r = op(*f_args, **f_kwargs) # pyrefly: ignore # invalid-param-spec
|
||||
r = op._op_dk(final_key, *args, **kwargs)
|
||||
|
||||
def desc():
|
||||
|
@ -147,7 +147,7 @@ def _qualified_name(obj, mangle_name=True) -> str:
|
||||
|
||||
# If the module is actually a torchbind module, then we should short circuit
|
||||
if module_name == "torch._classes":
|
||||
return obj.qualified_name
|
||||
return obj.qualified_name # pyrefly: ignore # missing-attribute
|
||||
|
||||
# The Python docs are very clear that `__module__` can be None, but I can't
|
||||
# figure out when it actually would be.
|
||||
@ -759,7 +759,7 @@ def unused(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||
prop.fset, "_torchscript_modifier", FunctionModifiers.UNUSED
|
||||
)
|
||||
|
||||
return prop
|
||||
return prop # pyrefly: ignore # bad-return
|
||||
|
||||
fn._torchscript_modifier = FunctionModifiers.UNUSED # type: ignore[attr-defined]
|
||||
return fn
|
||||
@ -844,6 +844,7 @@ def ignore(drop=False, **kwargs):
|
||||
# @torch.jit.ignore
|
||||
# def fn(...):
|
||||
fn = drop
|
||||
# pyrefly: ignore # missing-attribute
|
||||
fn._torchscript_modifier = FunctionModifiers.IGNORE
|
||||
return fn
|
||||
|
||||
@ -1250,7 +1251,10 @@ def _get_named_tuple_properties(
|
||||
|
||||
obj_annotations = inspect.get_annotations(obj)
|
||||
if len(obj_annotations) == 0 and hasattr(obj, "__base__"):
|
||||
obj_annotations = inspect.get_annotations(obj.__base__)
|
||||
obj_annotations = inspect.get_annotations(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
obj.__base__
|
||||
)
|
||||
|
||||
annotations = []
|
||||
for field in obj._fields:
|
||||
@ -1439,7 +1443,9 @@ def container_checker(obj, target_type) -> bool:
|
||||
return False
|
||||
return True
|
||||
elif origin_type is Union or issubclass(
|
||||
origin_type, BuiltinUnionType
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
origin_type,
|
||||
BuiltinUnionType,
|
||||
): # also handles Optional
|
||||
if obj is None: # check before recursion because None is always fine
|
||||
return True
|
||||
|
@ -63,8 +63,10 @@ class AsyncClosureHandler(ClosureHandler):
|
||||
self._closure_exception.put(e)
|
||||
return
|
||||
|
||||
self._closure_event_loop = threading.Thread(target=event_loop)
|
||||
self._closure_event_loop.start()
|
||||
self._closure_event_loop = threading.Thread(
|
||||
target=event_loop
|
||||
) # pyrefly: ignore # bad-assignment
|
||||
self._closure_event_loop.start() # pyrefly: ignore # missing-attribute
|
||||
|
||||
def run(self, closure):
|
||||
with self._closure_lock:
|
||||
|
@ -301,7 +301,7 @@ class LOBPCGAutogradFunction(torch.autograd.Function):
|
||||
return D, U
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, D_grad, U_grad):
|
||||
def backward(ctx, D_grad, U_grad): # pyrefly: ignore # bad-override
|
||||
A_grad = B_grad = None
|
||||
grads = [None] * 14
|
||||
|
||||
@ -1048,7 +1048,11 @@ class LOBPCG:
|
||||
else:
|
||||
E[(torch.where(E < t))[0]] = t
|
||||
|
||||
return torch.matmul(U * d_col.mT, Z * E**-0.5)
|
||||
return torch.matmul(
|
||||
U * d_col.mT,
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
Z * E**-0.5,
|
||||
)
|
||||
|
||||
def _get_ortho(self, U, V):
|
||||
"""Return B-orthonormal U with columns are B-orthogonal to V.
|
||||
|
@ -803,7 +803,7 @@ class OpOverload(OperatorBase, Generic[_P, _T]):
|
||||
|
||||
# Logic replicated from aten/src/ATen/native/MathBitsFallback.h
|
||||
is_write = None
|
||||
for a in self._schema.arguments:
|
||||
for a in self._schema.arguments: # pyrefly: ignore # bad-assignment
|
||||
if a.alias_info is None:
|
||||
continue
|
||||
if is_write is None:
|
||||
@ -885,7 +885,7 @@ class OpOverload(OperatorBase, Generic[_P, _T]):
|
||||
elif torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), dk):
|
||||
return self._op_dk(dk, *args, **kwargs)
|
||||
else:
|
||||
return NotImplemented
|
||||
return NotImplemented # pyrefly: ignore # bad-return
|
||||
|
||||
# Remove a dispatch key from the dispatch cache. This will force it to get
|
||||
# recomputed the next time. Does nothing
|
||||
@ -990,9 +990,9 @@ class OpOverload(OperatorBase, Generic[_P, _T]):
|
||||
|
||||
r = self.py_kernels.get(final_key, final_key)
|
||||
if cache_result:
|
||||
self._dispatch_cache[key] = r
|
||||
self._dispatch_cache[key] = r # pyrefly: ignore # unsupported-operation
|
||||
add_cached_op(self)
|
||||
return r
|
||||
return r # pyrefly: ignore # bad-return
|
||||
|
||||
def name(self):
|
||||
return self._name
|
||||
@ -1122,7 +1122,7 @@ class TorchBindOpOverload(OpOverload[_P, _T]):
|
||||
)
|
||||
|
||||
assert isinstance(handler, Callable) # type: ignore[arg-type]
|
||||
return handler(*args, **kwargs)
|
||||
return handler(*args, **kwargs) # pyrefly: ignore # bad-return
|
||||
|
||||
|
||||
def _must_dispatch_in_python(args, kwargs):
|
||||
@ -1251,6 +1251,7 @@ class OpOverloadPacket(Generic[_P, _T]):
|
||||
# the schema and cause an error for torchbind op when inputs consist of FakeScriptObject so we
|
||||
# intercept it here and call TorchBindOpverload instead.
|
||||
if self._has_torchbind_op_overload and _must_dispatch_in_python(args, kwargs):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return _call_overload_packet_from_python(self, *args, **kwargs)
|
||||
return self._op(*args, **kwargs)
|
||||
|
||||
|
@ -314,6 +314,7 @@ def strobelight(
|
||||
) -> Callable[_P, Optional[_R]]:
|
||||
@functools.wraps(work_function)
|
||||
def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> Optional[_R]:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return profiler.profile(work_function, *args, **kwargs)
|
||||
|
||||
return wrapper_function
|
||||
|
@ -145,7 +145,7 @@ class StrobelightCompileTimeProfiler:
|
||||
async_stack_max_len=cls.max_stack_length,
|
||||
run_user_name="pt2-profiler/"
|
||||
+ os.environ.get("USER", os.environ.get("USERNAME", "")),
|
||||
sample_tags={cls.identifier},
|
||||
sample_tags={cls.identifier}, # pyrefly: ignore # bad-argument-type
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -756,7 +756,10 @@ class Tensor(torch._C.TensorBase):
|
||||
"post accumulate grad hooks cannot be registered on non-leaf tensors"
|
||||
)
|
||||
if self._post_accumulate_grad_hooks is None:
|
||||
self._post_accumulate_grad_hooks: dict[Any, Any] = OrderedDict()
|
||||
self._post_accumulate_grad_hooks: dict[Any, Any] = (
|
||||
# pyrefly: ignore # bad-assignment
|
||||
OrderedDict()
|
||||
)
|
||||
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
@ -1056,7 +1059,12 @@ class Tensor(torch._C.TensorBase):
|
||||
if isinstance(split_size, (int, torch.SymInt)):
|
||||
return torch._VF.split(self, split_size, dim) # type: ignore[attr-defined]
|
||||
else:
|
||||
return torch._VF.split_with_sizes(self, split_size, dim)
|
||||
return torch._VF.split_with_sizes(
|
||||
self,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
split_size,
|
||||
dim,
|
||||
)
|
||||
|
||||
def unique(self, sorted=True, return_inverse=False, return_counts=False, dim=None):
|
||||
r"""Returns the unique elements of the input tensor.
|
||||
@ -1101,6 +1109,7 @@ class Tensor(torch._C.TensorBase):
|
||||
|
||||
@_handle_torch_function_and_wrap_type_error_to_not_implemented
|
||||
def __rsub__(self, other: Union["Tensor", int, float, bool, complex]) -> "Tensor":
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
return _C._VariableFunctions.rsub(self, other)
|
||||
|
||||
@_handle_torch_function_and_wrap_type_error_to_not_implemented
|
||||
@ -1126,7 +1135,7 @@ class Tensor(torch._C.TensorBase):
|
||||
|
||||
@_handle_torch_function_and_wrap_type_error_to_not_implemented
|
||||
def __rmod__(self, other: Union["Tensor", int, float, bool, complex]) -> "Tensor":
|
||||
return torch.remainder(other, self)
|
||||
return torch.remainder(other, self) # pyrefly: ignore # no-matching-overload
|
||||
|
||||
def __format__(self, format_spec):
|
||||
if has_torch_function_unary(self):
|
||||
@ -1139,7 +1148,7 @@ class Tensor(torch._C.TensorBase):
|
||||
|
||||
@_handle_torch_function_and_wrap_type_error_to_not_implemented
|
||||
def __rpow__(self, other: Union["Tensor", int, float, bool, complex]) -> "Tensor":
|
||||
return torch.pow(other, self)
|
||||
return torch.pow(other, self) # pyrefly: ignore # no-matching-overload
|
||||
|
||||
@_handle_torch_function_and_wrap_type_error_to_not_implemented
|
||||
def __floordiv__(self, other: Union["Tensor", int, float, bool]) -> "Tensor": # type: ignore[override]
|
||||
@ -1155,12 +1164,14 @@ class Tensor(torch._C.TensorBase):
|
||||
def __rlshift__(
|
||||
self, other: Union["Tensor", int, float, bool, complex]
|
||||
) -> "Tensor":
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
return torch.bitwise_left_shift(other, self)
|
||||
|
||||
@_handle_torch_function_and_wrap_type_error_to_not_implemented
|
||||
def __rrshift__(
|
||||
self, other: Union["Tensor", int, float, bool, complex]
|
||||
) -> "Tensor":
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
return torch.bitwise_right_shift(other, self)
|
||||
|
||||
@_handle_torch_function_and_wrap_type_error_to_not_implemented
|
||||
@ -1335,7 +1346,7 @@ class Tensor(torch._C.TensorBase):
|
||||
|
||||
return self._typed_storage()._get_legacy_storage_class()
|
||||
|
||||
def refine_names(self, *names):
|
||||
def refine_names(self, *names): # pyrefly: ignore # bad-override
|
||||
r"""Refines the dimension names of :attr:`self` according to :attr:`names`.
|
||||
|
||||
Refining is a special case of renaming that "lifts" unnamed dimensions.
|
||||
@ -1379,7 +1390,7 @@ class Tensor(torch._C.TensorBase):
|
||||
names = resolve_ellipsis(names, self.names, "refine_names")
|
||||
return super().refine_names(names)
|
||||
|
||||
def align_to(self, *names):
|
||||
def align_to(self, *names): # pyrefly: ignore # bad-override
|
||||
r"""Permutes the dimensions of the :attr:`self` tensor to match the order
|
||||
specified in :attr:`names`, adding size-one dims for any new names.
|
||||
|
||||
|
@ -686,8 +686,8 @@ def _take_tensors(tensors, size_limit):
|
||||
if buf_and_size[1] + size > size_limit and buf_and_size[1] > 0:
|
||||
yield buf_and_size[0]
|
||||
buf_and_size = buf_dict[t] = [[], 0]
|
||||
buf_and_size[0].append(tensor)
|
||||
buf_and_size[1] += size
|
||||
buf_and_size[0].append(tensor) # pyrefly: ignore # missing-attribute
|
||||
buf_and_size[1] += size # pyrefly: ignore # unsupported-operation
|
||||
for buf, _ in buf_dict.values():
|
||||
if len(buf) > 0:
|
||||
yield buf
|
||||
@ -744,14 +744,17 @@ class ExceptionWrapper:
|
||||
if exc_info is None:
|
||||
exc_info = sys.exc_info()
|
||||
self.exc_type = exc_info[0]
|
||||
self.exc_msg = "".join(traceback.format_exception(*exc_info))
|
||||
self.exc_msg = "".join(
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
traceback.format_exception(*exc_info)
|
||||
)
|
||||
self.where = where
|
||||
|
||||
def reraise(self):
|
||||
r"""Reraises the wrapped exception in the current thread"""
|
||||
# Format a message such as: "Caught ValueError in DataLoader worker
|
||||
# process 2. Original Traceback:", followed by the traceback.
|
||||
msg = f"Caught {self.exc_type.__name__} {self.where}.\nOriginal {self.exc_msg}"
|
||||
msg = f"Caught {self.exc_type.__name__} {self.where}.\nOriginal {self.exc_msg}" # pyrefly: ignore # missing-attribute
|
||||
if self.exc_type == KeyError:
|
||||
# KeyError calls repr() on its argument (usually a dict key). This
|
||||
# makes stack traces unreadable. It will not be changed in Python
|
||||
@ -760,9 +763,13 @@ class ExceptionWrapper:
|
||||
elif getattr(self.exc_type, "message", None):
|
||||
# Some exceptions have first argument as non-str but explicitly
|
||||
# have message field
|
||||
raise self.exc_type(message=msg)
|
||||
# pyrefly: ignore # not-callable
|
||||
raise self.exc_type(
|
||||
# pyrefly: ignore # unexpected-keyword
|
||||
message=msg
|
||||
)
|
||||
try:
|
||||
exception = self.exc_type(msg)
|
||||
exception = self.exc_type(msg) # pyrefly: ignore # not-callable
|
||||
except Exception:
|
||||
# If the exception takes multiple arguments or otherwise can't
|
||||
# be constructed, don't try to instantiate since we don't know how to
|
||||
@ -1014,12 +1021,12 @@ class _LazySeedTracker:
|
||||
self.call_order = []
|
||||
|
||||
def queue_seed_all(self, cb, traceback):
|
||||
self.manual_seed_all_cb = (cb, traceback)
|
||||
self.manual_seed_all_cb = (cb, traceback) # pyrefly: ignore # bad-assignment
|
||||
# update seed_all to be latest
|
||||
self.call_order = [self.manual_seed_cb, self.manual_seed_all_cb]
|
||||
|
||||
def queue_seed(self, cb, traceback):
|
||||
self.manual_seed_cb = (cb, traceback)
|
||||
self.manual_seed_cb = (cb, traceback) # pyrefly: ignore # bad-assignment
|
||||
# update seed to be latest
|
||||
self.call_order = [self.manual_seed_all_cb, self.manual_seed_cb]
|
||||
|
||||
|
@ -84,7 +84,11 @@ def compile_time_strobelight_meta(
|
||||
) -> Callable[_P, _T]:
|
||||
@functools.wraps(function)
|
||||
def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> _T:
|
||||
if "skip" in kwargs and isinstance(skip := kwargs["skip"], int):
|
||||
if "skip" in kwargs and isinstance(
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
skip := kwargs["skip"],
|
||||
int,
|
||||
):
|
||||
kwargs["skip"] = skip + 1
|
||||
|
||||
# This is not needed but we have it here to avoid having profile_compile_time
|
||||
@ -327,7 +331,10 @@ def deprecated():
|
||||
|
||||
# public deprecated alias
|
||||
alias = typing_extensions.deprecated(
|
||||
warning_msg, category=UserWarning, stacklevel=1
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
warning_msg,
|
||||
category=UserWarning,
|
||||
stacklevel=1,
|
||||
)(func)
|
||||
|
||||
alias.__name__ = public_name
|
||||
|
@ -464,7 +464,11 @@ def _cast(value, device_type: str, dtype: _dtype):
|
||||
return value.to(dtype) if is_eligible else value
|
||||
elif isinstance(value, (str, bytes)):
|
||||
return value
|
||||
elif HAS_NUMPY and isinstance(value, np.ndarray):
|
||||
elif HAS_NUMPY and isinstance(
|
||||
value,
|
||||
# pyrefly: ignore # missing-attribute
|
||||
np.ndarray,
|
||||
):
|
||||
return value
|
||||
elif isinstance(value, collections.abc.Mapping):
|
||||
return {
|
||||
@ -521,18 +525,18 @@ def custom_fwd(
|
||||
args[0]._dtype = torch.get_autocast_dtype(device_type)
|
||||
if cast_inputs is None:
|
||||
args[0]._fwd_used_autocast = torch.is_autocast_enabled(device_type)
|
||||
return fwd(*args, **kwargs)
|
||||
return fwd(*args, **kwargs) # pyrefly: ignore # not-callable
|
||||
else:
|
||||
autocast_context = torch.is_autocast_enabled(device_type)
|
||||
args[0]._fwd_used_autocast = False
|
||||
if autocast_context:
|
||||
with autocast(device_type=device_type, enabled=False):
|
||||
return fwd(
|
||||
return fwd( # pyrefly: ignore # not-callable
|
||||
*_cast(args, device_type, cast_inputs),
|
||||
**_cast(kwargs, device_type, cast_inputs),
|
||||
)
|
||||
else:
|
||||
return fwd(*args, **kwargs)
|
||||
return fwd(*args, **kwargs) # pyrefly: ignore # not-callable
|
||||
|
||||
return decorate_fwd
|
||||
|
||||
@ -567,6 +571,6 @@ def custom_bwd(bwd=None, *, device_type: str):
|
||||
enabled=args[0]._fwd_used_autocast,
|
||||
dtype=args[0]._dtype,
|
||||
):
|
||||
return bwd(*args, **kwargs)
|
||||
return bwd(*args, **kwargs) # pyrefly: ignore # not-callable
|
||||
|
||||
return decorate_bwd
|
||||
|
@ -1,2 +1,3 @@
|
||||
# pyrefly: ignore # deprecated
|
||||
from .autocast_mode import autocast
|
||||
from .grad_scaler import GradScaler
|
||||
|
@ -1784,7 +1784,9 @@ def norm( # noqa: F811
|
||||
|
||||
if isinstance(p, str):
|
||||
if p == "fro" and (
|
||||
dim is None or isinstance(dim, (int, torch.SymInt)) or len(dim) <= 2
|
||||
dim is None
|
||||
or isinstance(dim, (int, torch.SymInt))
|
||||
or len(dim) <= 2 # pyrefly: ignore # bad-argument-type
|
||||
):
|
||||
if out is None:
|
||||
return torch.linalg.vector_norm(
|
||||
@ -1950,7 +1952,7 @@ def _unravel_index(indices: Tensor, shape: Union[int, Sequence[int]]) -> Tensor:
|
||||
)
|
||||
|
||||
if isinstance(shape, (int, torch.SymInt)):
|
||||
shape = torch.Size([shape])
|
||||
shape = torch.Size([shape]) # pyrefly: ignore # bad-argument-type
|
||||
else:
|
||||
for dim in shape:
|
||||
torch._check_type(
|
||||
|
@ -421,7 +421,7 @@ def set_dir(d: Union[str, os.PathLike]) -> None:
|
||||
d (str): path to a local folder to save downloaded models & weights.
|
||||
"""
|
||||
global _hub_dir
|
||||
_hub_dir = os.path.expanduser(d)
|
||||
_hub_dir = os.path.expanduser(d) # pyrefly: ignore # no-matching-overload
|
||||
|
||||
|
||||
def list(
|
||||
|
@ -242,6 +242,7 @@ class Library:
|
||||
|
||||
if dispatch_key == "":
|
||||
dispatch_key = self.dispatch_key
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
assert torch.DispatchKeySet(dispatch_key).has(torch._C.DispatchKey.Dense)
|
||||
|
||||
if isinstance(op_name, str):
|
||||
@ -643,6 +644,7 @@ def impl(
|
||||
>>> y2 = torch.sin(x) + 1
|
||||
>>> assert torch.allclose(y1, y2)
|
||||
"""
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
return _impl(qualname, types, func, lib=lib, disable_dynamo=False)
|
||||
|
||||
|
||||
@ -829,6 +831,7 @@ def register_kernel(
|
||||
if device_types is None:
|
||||
device_types = "CompositeExplicitAutograd"
|
||||
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
return _impl(op, device_types, func, lib=lib, disable_dynamo=True)
|
||||
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
from torch._C import ( # type: ignore[attr-defined]
|
||||
from torch._C import ( # type: ignore[attr-defined] # pyrefly: ignore # missing-module-attribute
|
||||
_add_docstr,
|
||||
_linalg,
|
||||
_LinAlgError as LinAlgError,
|
||||
_LinAlgError as LinAlgError, # pyrefly: ignore # missing-module-attribute
|
||||
)
|
||||
|
||||
|
||||
|
@ -303,7 +303,7 @@ class StreamContext:
|
||||
self.idx = _get_device_index(None, True)
|
||||
if not torch.jit.is_scripting():
|
||||
if self.idx is None:
|
||||
self.idx = -1
|
||||
self.idx = -1 # pyrefly: ignore # bad-assignment
|
||||
|
||||
self.src_prev_stream = (
|
||||
None if not torch.jit.is_scripting() else torch.mtia.default_stream(None)
|
||||
|
@ -119,6 +119,7 @@ class ProcessContext:
|
||||
"""Attempt to join all processes with a shared timeout."""
|
||||
end = time.monotonic() + timeout
|
||||
for process in self.processes:
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
time_to_wait = max(0, end - time.monotonic())
|
||||
process.join(time_to_wait)
|
||||
|
||||
@ -274,7 +275,7 @@ def start_processes(
|
||||
tf.close()
|
||||
os.unlink(tf.name)
|
||||
|
||||
process = mp.Process(
|
||||
process = mp.Process( # pyrefly: ignore # missing-attribute
|
||||
target=_wrap,
|
||||
args=(fn, i, args, tf.name),
|
||||
daemon=daemon,
|
||||
|
@ -406,7 +406,7 @@ class ViewBufferFromNested(torch.autograd.Function):
|
||||
# Not actually a view!
|
||||
class ViewNestedFromBuffer(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
def forward( # pyrefly: ignore # bad-override
|
||||
ctx,
|
||||
values: torch.Tensor,
|
||||
offsets: torch.Tensor,
|
||||
|
@ -46,11 +46,14 @@ def _outer_to_inner_dim(ndim, dim, ragged_dim, canonicalize=False):
|
||||
if canonicalize:
|
||||
dim = canonicalize_dims(ndim, dim)
|
||||
|
||||
assert dim >= 0 and dim < ndim
|
||||
assert dim >= 0 and dim < ndim # pyrefly: ignore # unsupported-operation
|
||||
|
||||
# Map dim=0 (AKA batch dim) -> packed dim i.e. outer ragged dim - 1.
|
||||
# For other dims, subtract 1 to convert to inner space.
|
||||
return ragged_dim - 1 if dim == 0 else dim - 1
|
||||
return (
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
ragged_dim - 1 if dim == 0 else dim - 1
|
||||
)
|
||||
|
||||
|
||||
def _wrap_jagged_dim(
|
||||
@ -2005,6 +2008,7 @@ def index_put_(func, *args, **kwargs):
|
||||
else:
|
||||
lengths = inp.lengths()
|
||||
torch._assert_async(
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
torch.all(indices[inp._ragged_idx] < lengths),
|
||||
"Some indices in the ragged dimension are out of bounds!",
|
||||
)
|
||||
|
@ -134,7 +134,8 @@ def _raise_if_logical_cpu_indices_invalid(*, logical_cpu_indices: set[int]) -> N
|
||||
|
||||
def _bind_current_thread_to_logical_cpus(*, logical_cpu_indices: set[int]) -> None:
|
||||
# 0 represents the current thread
|
||||
os.sched_setaffinity(0, logical_cpu_indices)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
os.sched_setaffinity(0, logical_cpu_indices) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def _get_logical_cpus_to_bind_to(
|
||||
@ -544,4 +545,5 @@ def _get_numa_node_indices_for_socket_index(*, socket_index: int) -> set[int]:
|
||||
|
||||
def _get_allowed_cpu_indices_for_current_thread() -> set[int]:
|
||||
# 0 denotes current thread
|
||||
return os.sched_getaffinity(0)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return os.sched_getaffinity(0) # type:ignore[attr-defined]
|
||||
|
@ -1,4 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# pyrefly: ignore # missing-module-attribute
|
||||
from pickle import ( # type: ignore[attr-defined]
|
||||
_compat_pickle,
|
||||
_extension_registry,
|
||||
|
@ -2,10 +2,12 @@
|
||||
import importlib
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
# pyrefly: ignore # missing-module-attribute
|
||||
from pickle import ( # type: ignore[attr-defined]
|
||||
_getattribute,
|
||||
_Pickler,
|
||||
whichmodule as _pickle_whichmodule,
|
||||
whichmodule as _pickle_whichmodule, # pyrefly: ignore # missing-module-attribute
|
||||
)
|
||||
from types import ModuleType
|
||||
from typing import Any, Optional
|
||||
|
@ -219,7 +219,7 @@ class PackageExporter:
|
||||
torch._C._log_api_usage_once("torch.package.PackageExporter")
|
||||
self.debug = debug
|
||||
if isinstance(f, (str, os.PathLike)):
|
||||
f = os.fspath(f)
|
||||
f = os.fspath(f) # pyrefly: ignore # no-matching-overload
|
||||
self.buffer: Optional[IO[bytes]] = None
|
||||
else: # is a byte buffer
|
||||
self.buffer = f
|
||||
@ -652,6 +652,7 @@ class PackageExporter:
|
||||
memo: defaultdict[int, str] = defaultdict(None)
|
||||
memo_count = 0
|
||||
# pickletools.dis(data_value)
|
||||
# pyrefly: ignore # bad-assignment
|
||||
for opcode, arg, _pos in pickletools.genops(data_value):
|
||||
if pickle_protocol == 4:
|
||||
if (
|
||||
|
@ -108,6 +108,7 @@ class PackageImporter(Importer):
|
||||
self.filename = "<pytorch_file_reader>"
|
||||
self.zip_reader = file_or_buffer
|
||||
elif isinstance(file_or_buffer, (os.PathLike, str)):
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
self.filename = os.fspath(file_or_buffer)
|
||||
if not os.path.isdir(self.filename):
|
||||
self.zip_reader = torch._C.PyTorchFileReader(self.filename)
|
||||
|
@ -27,5 +27,5 @@ from torch.ao.quantization.qconfig import (
|
||||
QConfig,
|
||||
qconfig_equals,
|
||||
QConfigAny,
|
||||
QConfigDynamic,
|
||||
QConfigDynamic, # pyrefly: ignore # deprecated
|
||||
)
|
||||
|
@ -774,7 +774,10 @@ def _open_file_like(name_or_buffer: FileLike, mode: str) -> _opener[IO[bytes]]:
|
||||
|
||||
class _open_zipfile_reader(_opener[torch._C.PyTorchFileReader]):
|
||||
def __init__(self, name_or_buffer: Union[str, IO[bytes]]) -> None:
|
||||
super().__init__(torch._C.PyTorchFileReader(name_or_buffer))
|
||||
super().__init__(
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
torch._C.PyTorchFileReader(name_or_buffer)
|
||||
)
|
||||
|
||||
|
||||
class _open_zipfile_writer_file(_opener[torch._C.PyTorchFileWriter]):
|
||||
@ -787,9 +790,10 @@ class _open_zipfile_writer_file(_opener[torch._C.PyTorchFileWriter]):
|
||||
# PyTorchFileWriter only supports ascii filename.
|
||||
# For filenames with non-ascii characters, we rely on Python
|
||||
# for writing out the file.
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self.file_stream = io.FileIO(self.name, mode="w")
|
||||
super().__init__(
|
||||
torch._C.PyTorchFileWriter(
|
||||
torch._C.PyTorchFileWriter( # pyrefly: ignore # no-matching-overload
|
||||
self.file_stream, get_crc32_options(), _get_storage_alignment()
|
||||
)
|
||||
)
|
||||
@ -966,7 +970,7 @@ def save(
|
||||
_check_save_filelike(f)
|
||||
|
||||
if isinstance(f, (str, os.PathLike)):
|
||||
f = os.fspath(f)
|
||||
f = os.fspath(f) # pyrefly: ignore # no-matching-overload
|
||||
|
||||
if _use_new_zipfile_serialization:
|
||||
with _open_zipfile_writer(f) as opened_zipfile:
|
||||
@ -1520,7 +1524,10 @@ def load(
|
||||
else:
|
||||
shared = False
|
||||
overall_storage = torch.UntypedStorage.from_file(
|
||||
os.fspath(f), shared, size
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
os.fspath(f),
|
||||
shared,
|
||||
size,
|
||||
)
|
||||
if weights_only:
|
||||
try:
|
||||
|
@ -326,7 +326,7 @@ def gaussian(
|
||||
requires_grad=requires_grad,
|
||||
)
|
||||
|
||||
return torch.exp(-(k**2))
|
||||
return torch.exp(-(k**2)) # pyrefly: ignore # unsupported-operation
|
||||
|
||||
|
||||
@_add_docstr(
|
||||
@ -397,11 +397,17 @@ def kaiser(
|
||||
)
|
||||
|
||||
# Avoid NaNs by casting `beta` to the appropriate dtype.
|
||||
# pyrefly: ignore # bad-assignment
|
||||
beta = torch.tensor(beta, dtype=dtype, device=device)
|
||||
|
||||
start = -beta
|
||||
constant = 2.0 * beta / (M if not sym else M - 1)
|
||||
end = torch.minimum(beta, start + (M - 1) * constant)
|
||||
end = torch.minimum(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
beta,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
start + (M - 1) * constant,
|
||||
)
|
||||
|
||||
k = torch.linspace(
|
||||
start=start,
|
||||
@ -413,7 +419,10 @@ def kaiser(
|
||||
requires_grad=requires_grad,
|
||||
)
|
||||
|
||||
return torch.i0(torch.sqrt(beta * beta - torch.pow(k, 2))) / torch.i0(beta)
|
||||
return torch.i0(torch.sqrt(beta * beta - torch.pow(k, 2))) / torch.i0(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
beta
|
||||
)
|
||||
|
||||
|
||||
@_add_docstr(
|
||||
|
@ -618,7 +618,7 @@ def _get_storage_from_sequence(sequence, dtype, device):
|
||||
|
||||
def _isint(x):
|
||||
if HAS_NUMPY:
|
||||
return isinstance(x, (int, np.integer))
|
||||
return isinstance(x, (int, np.integer)) # pyrefly: ignore # missing-attribute
|
||||
else:
|
||||
return isinstance(x, int)
|
||||
|
||||
|
@ -247,7 +247,9 @@ def get_device_capability(device: Optional[_device_t] = None) -> dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
def get_device_properties(device: Optional[_device_t] = None) -> _XpuDeviceProperties:
|
||||
def get_device_properties(
|
||||
device: Optional[_device_t] = None,
|
||||
) -> _XpuDeviceProperties: # pyrefly: ignore # not-a-type
|
||||
r"""Get the properties of a device.
|
||||
|
||||
Args:
|
||||
@ -315,7 +317,7 @@ class StreamContext:
|
||||
self.stream = stream
|
||||
self.idx = _get_device_index(None, True)
|
||||
if self.idx is None:
|
||||
self.idx = -1
|
||||
self.idx = -1 # pyrefly: ignore # bad-assignment
|
||||
|
||||
def __enter__(self):
|
||||
cur_stream = self.stream
|
||||
|
@ -126,7 +126,7 @@ class Event(torch._C._XpuEventBase):
|
||||
"""
|
||||
if stream is None:
|
||||
stream = torch.xpu.current_stream()
|
||||
super().record(stream)
|
||||
super().record(stream) # pyrefly: ignore # bad-argument-type
|
||||
|
||||
def wait(self, stream=None) -> None:
|
||||
r"""Make all future work submitted to the given stream wait for this event.
|
||||
|
Reference in New Issue
Block a user