Add pyrefly suppressions (3/n) (#164588)

Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: uncomment lines in the pyrefly.toml file
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/bb31574ac8a59893c9cf52189e67bb2d

after:

 0 errors (1,970 ignored)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164588
Approved by: https://github.com/oulgen
This commit is contained in:
Maggie Moss
2025-10-03 22:02:59 +00:00
committed by PyTorch MergeBot
parent e438db2546
commit f414aa8e0d
49 changed files with 244 additions and 29 deletions

View File

@ -404,7 +404,9 @@ class FakeTensorConverter:
with no_dispatch():
return FakeTensor(
fake_mode,
# pyrefly: ignore # bad-argument-type
make_meta_t(),
# pyrefly: ignore # bad-argument-type
device,
# TODO: callback might be used in recursive contexts, in
# which case using t is wrong! BUG!
@ -679,6 +681,7 @@ class FakeTensor(Tensor):
_mode_key = torch._C._TorchDispatchModeKey.FAKE
@property
# pyrefly: ignore # bad-override
def device(self) -> torch.device:
if self.fake_mode.in_kernel_invocation:
return torch.device("meta")
@ -706,6 +709,7 @@ class FakeTensor(Tensor):
# We don't support named tensors; graph break
@property
# pyrefly: ignore # bad-override
def names(self) -> list[str]:
raise UnsupportedFakeTensorException(
"torch.compile doesn't support named tensors"
@ -764,6 +768,7 @@ class FakeTensor(Tensor):
)
else:
device = torch.device(f"{device.type}:0")
# pyrefly: ignore # read-only
self.fake_device = device
self.fake_mode = fake_mode
self.constant = constant
@ -1493,6 +1498,7 @@ class FakeTensorMode(TorchDispatchMode):
# Do this dispatch outside the above except handler so if it
# generates its own exception there won't be a __context__ caused by
# the caching mechanism.
# pyrefly: ignore # bad-argument-type
return self._dispatch_impl(func, types, args, kwargs)
assert state is not None
@ -1510,22 +1516,27 @@ class FakeTensorMode(TorchDispatchMode):
# This represents a negative cache entry - we already saw that the
# output is uncachable. Compute it from first principals.
FakeTensorMode.cache_bypasses[entry.reason] += 1
# pyrefly: ignore # bad-argument-type
return self._dispatch_impl(func, types, args, kwargs)
# We have a cache entry.
# pyrefly: ignore # bad-argument-type
output = self._output_from_cache_entry(state, entry, key, func, args)
FakeTensorMode.cache_hits += 1
if self.cache_crosscheck_enabled:
# For debugging / testing: Validate that the output synthesized
# from the cache matches the output created by normal dispatch.
with disable_fake_tensor_cache(self):
# pyrefly: ignore # bad-argument-type
self._crosscheck_cache_output(output, func, types, args, kwargs)
return output
# We don't have a cache entry.
# pyrefly: ignore # bad-argument-type
output = self._dispatch_impl(func, types, args, kwargs)
try:
# pyrefly: ignore # bad-argument-type
self._validate_cache_key(func, args, kwargs)
except _BypassDispatchCache as e:
# We ran "extra" checks on the cache key and determined that it's no
@ -1545,6 +1556,7 @@ class FakeTensorMode(TorchDispatchMode):
return output
try:
# pyrefly: ignore # bad-argument-type
entry = self._make_cache_entry(state, key, func, args, kwargs, output)
except _BypassDispatchCache as e:
# We had trouble making the cache entry. Record the reason and mark
@ -1587,13 +1599,16 @@ class FakeTensorMode(TorchDispatchMode):
if state.known_symbols:
# If there are symbols then include the epoch - this is really more
# of a Shape env var which lives on the FakeTensorMode.
# pyrefly: ignore # bad-argument-type
key_values.append(self.epoch)
# Collect the id_hashed objects to attach a weakref finalize later
id_hashed_objects: list[object] = []
# Translate any FakeTensor args to metadata.
if args:
# pyrefly: ignore # bad-argument-type
self._prep_args_for_hash(key_values, args, state, id_hashed_objects)
if kwargs:
# pyrefly: ignore # bad-argument-type
self._prep_args_for_hash(key_values, kwargs, state, id_hashed_objects)
key = _DispatchCacheKey(tuple(key_values))
@ -1909,27 +1924,53 @@ class FakeTensorMode(TorchDispatchMode):
if isinstance(output, tuple):
for out_element in output:
self._validate_output_for_cache_entry(
state, key, func, args, kwargs, out_element
state,
key,
# pyrefly: ignore # bad-argument-type
func,
args,
kwargs,
out_element,
)
else:
self._validate_output_for_cache_entry(
state, key, func, args, kwargs, output
state,
key,
# pyrefly: ignore # bad-argument-type
func,
args,
kwargs,
output,
)
if isinstance(output, tuple):
output_infos = [
self._get_output_info_for_cache_entry(
state, key, func, args, kwargs, out_elem
state,
key,
# pyrefly: ignore # bad-argument-type
func,
args,
kwargs,
out_elem,
)
for out_elem in output
]
return _DispatchCacheValidEntry(
output_infos=tuple(output_infos), is_output_tuple=True
# pyrefly: ignore # bad-argument-type
output_infos=tuple(output_infos),
is_output_tuple=True,
)
else:
output_info = self._get_output_info_for_cache_entry(
state, key, func, args, kwargs, output
state,
key,
# pyrefly: ignore # bad-argument-type
func,
args,
kwargs,
output,
)
return _DispatchCacheValidEntry(
output_infos=(output_info,), is_output_tuple=False
@ -2472,6 +2513,7 @@ class FakeTensorMode(TorchDispatchMode):
)
with self, maybe_ignore_fresh_unbacked_symbols():
# pyrefly: ignore # index-error
return registered_hop_fake_fns[func](*args, **kwargs)
self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs)
@ -2625,6 +2667,7 @@ class FakeTensorMode(TorchDispatchMode):
# TODO: Is this really needed?
compute_unbacked_bindings(self.shape_env, fake_out, peek=True)
# pyrefly: ignore # bad-return
return fake_out
# Try for fastpath
@ -2906,6 +2949,7 @@ class FakeTensorMode(TorchDispatchMode):
self, e, device or common_device
)
else:
# pyrefly: ignore # bad-return
return e
return tree_map(wrap, r)