mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add pyrefly suppressions (3/n) (#164588)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283 Test plan: dmypy restart && python3 scripts/lintrunner.py -a pyrefly check step 1: uncomment lines in the pyrefly.toml file step 2: run pyrefly check step 3: add suppressions, clean up unused suppressions before: https://gist.github.com/maggiemoss/bb31574ac8a59893c9cf52189e67bb2d after: 0 errors (1,970 ignored) Pull Request resolved: https://github.com/pytorch/pytorch/pull/164588 Approved by: https://github.com/oulgen
This commit is contained in:
committed by
PyTorch MergeBot
parent
e438db2546
commit
f414aa8e0d
@ -404,7 +404,9 @@ class FakeTensorConverter:
|
||||
with no_dispatch():
|
||||
return FakeTensor(
|
||||
fake_mode,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
make_meta_t(),
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
device,
|
||||
# TODO: callback might be used in recursive contexts, in
|
||||
# which case using t is wrong! BUG!
|
||||
@ -679,6 +681,7 @@ class FakeTensor(Tensor):
|
||||
_mode_key = torch._C._TorchDispatchModeKey.FAKE
|
||||
|
||||
@property
|
||||
# pyrefly: ignore # bad-override
|
||||
def device(self) -> torch.device:
|
||||
if self.fake_mode.in_kernel_invocation:
|
||||
return torch.device("meta")
|
||||
@ -706,6 +709,7 @@ class FakeTensor(Tensor):
|
||||
|
||||
# We don't support named tensors; graph break
|
||||
@property
|
||||
# pyrefly: ignore # bad-override
|
||||
def names(self) -> list[str]:
|
||||
raise UnsupportedFakeTensorException(
|
||||
"torch.compile doesn't support named tensors"
|
||||
@ -764,6 +768,7 @@ class FakeTensor(Tensor):
|
||||
)
|
||||
else:
|
||||
device = torch.device(f"{device.type}:0")
|
||||
# pyrefly: ignore # read-only
|
||||
self.fake_device = device
|
||||
self.fake_mode = fake_mode
|
||||
self.constant = constant
|
||||
@ -1493,6 +1498,7 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
# Do this dispatch outside the above except handler so if it
|
||||
# generates its own exception there won't be a __context__ caused by
|
||||
# the caching mechanism.
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return self._dispatch_impl(func, types, args, kwargs)
|
||||
|
||||
assert state is not None
|
||||
@ -1510,22 +1516,27 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
# This represents a negative cache entry - we already saw that the
|
||||
# output is uncachable. Compute it from first principals.
|
||||
FakeTensorMode.cache_bypasses[entry.reason] += 1
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return self._dispatch_impl(func, types, args, kwargs)
|
||||
|
||||
# We have a cache entry.
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
output = self._output_from_cache_entry(state, entry, key, func, args)
|
||||
FakeTensorMode.cache_hits += 1
|
||||
if self.cache_crosscheck_enabled:
|
||||
# For debugging / testing: Validate that the output synthesized
|
||||
# from the cache matches the output created by normal dispatch.
|
||||
with disable_fake_tensor_cache(self):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
self._crosscheck_cache_output(output, func, types, args, kwargs)
|
||||
return output
|
||||
|
||||
# We don't have a cache entry.
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
output = self._dispatch_impl(func, types, args, kwargs)
|
||||
|
||||
try:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
self._validate_cache_key(func, args, kwargs)
|
||||
except _BypassDispatchCache as e:
|
||||
# We ran "extra" checks on the cache key and determined that it's no
|
||||
@ -1545,6 +1556,7 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
return output
|
||||
|
||||
try:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
entry = self._make_cache_entry(state, key, func, args, kwargs, output)
|
||||
except _BypassDispatchCache as e:
|
||||
# We had trouble making the cache entry. Record the reason and mark
|
||||
@ -1587,13 +1599,16 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
if state.known_symbols:
|
||||
# If there are symbols then include the epoch - this is really more
|
||||
# of a Shape env var which lives on the FakeTensorMode.
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
key_values.append(self.epoch)
|
||||
# Collect the id_hashed objects to attach a weakref finalize later
|
||||
id_hashed_objects: list[object] = []
|
||||
# Translate any FakeTensor args to metadata.
|
||||
if args:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
self._prep_args_for_hash(key_values, args, state, id_hashed_objects)
|
||||
if kwargs:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
self._prep_args_for_hash(key_values, kwargs, state, id_hashed_objects)
|
||||
key = _DispatchCacheKey(tuple(key_values))
|
||||
|
||||
@ -1909,27 +1924,53 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
if isinstance(output, tuple):
|
||||
for out_element in output:
|
||||
self._validate_output_for_cache_entry(
|
||||
state, key, func, args, kwargs, out_element
|
||||
state,
|
||||
key,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
func,
|
||||
args,
|
||||
kwargs,
|
||||
out_element,
|
||||
)
|
||||
else:
|
||||
self._validate_output_for_cache_entry(
|
||||
state, key, func, args, kwargs, output
|
||||
state,
|
||||
key,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
func,
|
||||
args,
|
||||
kwargs,
|
||||
output,
|
||||
)
|
||||
|
||||
if isinstance(output, tuple):
|
||||
output_infos = [
|
||||
self._get_output_info_for_cache_entry(
|
||||
state, key, func, args, kwargs, out_elem
|
||||
state,
|
||||
key,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
func,
|
||||
args,
|
||||
kwargs,
|
||||
out_elem,
|
||||
)
|
||||
for out_elem in output
|
||||
]
|
||||
return _DispatchCacheValidEntry(
|
||||
output_infos=tuple(output_infos), is_output_tuple=True
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
output_infos=tuple(output_infos),
|
||||
is_output_tuple=True,
|
||||
)
|
||||
|
||||
else:
|
||||
output_info = self._get_output_info_for_cache_entry(
|
||||
state, key, func, args, kwargs, output
|
||||
state,
|
||||
key,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
func,
|
||||
args,
|
||||
kwargs,
|
||||
output,
|
||||
)
|
||||
return _DispatchCacheValidEntry(
|
||||
output_infos=(output_info,), is_output_tuple=False
|
||||
@ -2472,6 +2513,7 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
)
|
||||
|
||||
with self, maybe_ignore_fresh_unbacked_symbols():
|
||||
# pyrefly: ignore # index-error
|
||||
return registered_hop_fake_fns[func](*args, **kwargs)
|
||||
|
||||
self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs)
|
||||
@ -2625,6 +2667,7 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
# TODO: Is this really needed?
|
||||
compute_unbacked_bindings(self.shape_env, fake_out, peek=True)
|
||||
|
||||
# pyrefly: ignore # bad-return
|
||||
return fake_out
|
||||
|
||||
# Try for fastpath
|
||||
@ -2906,6 +2949,7 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
self, e, device or common_device
|
||||
)
|
||||
else:
|
||||
# pyrefly: ignore # bad-return
|
||||
return e
|
||||
|
||||
return tree_map(wrap, r)
|
||||
|
||||
Reference in New Issue
Block a user