Add pyrefly suppressions (3/n) (#164588)

Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: uncomment lines in the pyrefly.toml file
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/bb31574ac8a59893c9cf52189e67bb2d

after:

 0 errors (1,970 ignored)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164588
Approved by: https://github.com/oulgen
This commit is contained in:
Maggie Moss
2025-10-03 22:02:59 +00:00
committed by PyTorch MergeBot
parent e438db2546
commit f414aa8e0d
49 changed files with 244 additions and 29 deletions

View File

@ -34,12 +34,6 @@ project-excludes = [
"torch/jit/**",
"torch/optim/**",
"torch/_higher_order_ops/**",
"torch/_functorch/**",
"torch/masked/**",
"torch/_subclasses/**",
"torch/autograd/**",
"torch/cuda/**",
"torch/export/**",
# formatting issues
"torch/linalg/__init__.py",
"torch/package/importer.py",

View File

@ -58,20 +58,20 @@ class TestBundledInputs(TestCase):
# Make sure the model only grew a little bit,
# despite having nominally large bundled inputs.
augmented_size = model_size(sm)
# pyrefly: ignore # missing-attribute
self.assertLess(augmented_size, original_size + (1 << 12))
loaded = save_and_load(sm)
inflated = loaded.get_all_bundled_inputs()
self.assertEqual(loaded.get_num_bundled_inputs(), len(samples))
self.assertEqual(len(inflated), len(samples))
# pyrefly: ignore # missing-attribute
self.assertTrue(loaded(*inflated[0]) is inflated[0][0])
for idx, inp in enumerate(inflated):
self.assertIsInstance(inp, tuple) # pyrefly: ignore # missing-attribute
self.assertIsInstance(inp, tuple)
self.assertEqual(len(inp), 1)
# pyrefly: ignore # missing-attribute
self.assertIsInstance(inp[0], torch.Tensor)
if idx != 5:
# Strides might be important for benchmarking.
@ -139,7 +139,7 @@ class TestBundledInputs(TestCase):
loaded = save_and_load(sm)
inflated = loaded.get_all_bundled_inputs()
self.assertEqual(inflated, samples)
# pyrefly: ignore # missing-attribute
self.assertTrue(loaded(*inflated[0]) == "first 1")
def test_multiple_methods_with_inputs(self):
@ -186,7 +186,7 @@ class TestBundledInputs(TestCase):
self.assertEqual(inflated, loaded.get_all_bundled_inputs_for_foo())
# Check running and size helpers
# pyrefly: ignore # missing-attribute
self.assertTrue(loaded(*inflated[0]) is inflated[0][0])
self.assertEqual(loaded.get_num_bundled_inputs(), len(samples))
@ -419,7 +419,7 @@ class TestBundledInputs(TestCase):
)
augmented_size = model_size(sm)
# assert the size has not increased more than 8KB
# pyrefly: ignore # missing-attribute
self.assertLess(augmented_size, original_size + (1 << 13))
loaded = save_and_load(sm)

View File

@ -48,7 +48,7 @@ class TestComplexTensor(TestCase):
def test_all(self, device, dtype):
# issue: https://github.com/pytorch/pytorch/issues/120875
x = torch.tensor([1 + 2j, 3 - 4j, 5j, 6], device=device, dtype=dtype)
self.assertTrue(torch.all(x)) # pyrefly: ignore # missing-attribute
self.assertTrue(torch.all(x))
@dtypes(*complex_types())
def test_any(self, device, dtype):
@ -56,7 +56,7 @@ class TestComplexTensor(TestCase):
x = torch.tensor(
[0, 0j, -0 + 0j, -0 - 0j, 0 + 0j, 0 - 0j], device=device, dtype=dtype
)
self.assertFalse(torch.any(x)) # pyrefly: ignore # missing-attribute
self.assertFalse(torch.any(x))
@onlyCPU
@dtypes(*complex_types())

View File

@ -142,7 +142,6 @@ class TestTypeHints(TestCase):
]
)
if result != 0:
# pyrefly: ignore # missing-attribute
self.fail(f"mypy failed:\n{stderr}\n{stdout}")

View File

@ -125,7 +125,7 @@ class TestDTypeInfo(TestCase):
# Regression test for https://github.com/pytorch/pytorch/issues/124868
# If reference count is leaked this would be a set of 10 elements
ref_cnt = {sys.getrefcount(torch.float32.to_complex()) for _ in range(10)}
self.assertLess(len(ref_cnt), 3) # pyrefly: ignore # missing-attribute
self.assertLess(len(ref_cnt), 3)
self.assertEqual(torch.float64.to_complex(), torch.complex128)
self.assertEqual(torch.float32.to_complex(), torch.complex64)
@ -135,7 +135,7 @@ class TestDTypeInfo(TestCase):
# Regression test for https://github.com/pytorch/pytorch/issues/124868
# If reference count is leaked this would be a set of 10 elements
ref_cnt = {sys.getrefcount(torch.cfloat.to_real()) for _ in range(10)}
self.assertLess(len(ref_cnt), 3) # pyrefly: ignore # missing-attribute
self.assertLess(len(ref_cnt), 3)
self.assertEqual(torch.complex128.to_real(), torch.double)
self.assertEqual(torch.complex64.to_real(), torch.float32)

View File

@ -2653,6 +2653,7 @@ def compile(
dynamic=dynamic,
disable=disable,
guard_filter_fn=guard_filter_fn,
# pyrefly: ignore # bad-argument-type
)(model)(*args, **kwargs)
return export_wrapped_fn

View File

@ -384,6 +384,7 @@ class AOTAutogradCacheDetails(FxGraphHashDetails):
class AOTAutogradCachePickler(FxGraphCachePickler):
def __init__(self, gm: torch.fx.GraphModule):
super().__init__(gm)
# pyrefly: ignore # bad-override
self.dispatch_table: dict
self.dispatch_table.update(
{

View File

@ -86,8 +86,10 @@ def coerce_tangent_and_suggest_memory_format(x: Tensor):
memory_format = MemoryFormatMeta.from_tensor(out)
# pyrefly: ignore # missing-attribute
if memory_format.memory_format is not None:
was = out
# pyrefly: ignore # bad-argument-type
out = out.contiguous(memory_format=memory_format.memory_format)
updated = was is not out
@ -117,6 +119,7 @@ def coerce_tangent_and_suggest_memory_format(x: Tensor):
out = out.__coerce_tangent_metadata__() # type: ignore[attr-defined]
if is_subclass:
# pyrefly: ignore # missing-attribute
attrs = out.__tensor_flatten__()[0]
for attr in attrs:
@ -126,6 +129,7 @@ def coerce_tangent_and_suggest_memory_format(x: Tensor):
new_elem_memory_format,
elem_updated,
) = coerce_tangent_and_suggest_memory_format(elem)
# pyrefly: ignore # missing-attribute
out_memory_format.append(new_elem_memory_format)
if elem_updated:
setattr(out, attr, new_elem)
@ -492,6 +496,7 @@ def run_functionalized_fw_and_collect_metadata(
curr_storage in inp_storage_refs
and not functional_tensor_storage_changed
):
# pyrefly: ignore # index-error
base_idx = inp_storage_refs[curr_storage]
is_input_tensor = id(o) in inp_tensor_ids
num_aliased_outs = out_tensor_alias_counts[curr_storage]
@ -699,6 +704,7 @@ from a multi-output view call"
# Anything that aliases (inputs returned in the fw due to metadata mutations, or outputs that alias inputs/intermediates)
# are *regenerated* later, and not used directly in the autograd graph
def _plain_fake_tensor_like_subclass(x):
# pyrefly: ignore # bad-context-manager
with detect_fake_mode():
return torch.empty(
x.shape, dtype=x.dtype, device=x.device, layout=x.layout

View File

@ -78,6 +78,7 @@ def get_all_input_and_grad_nodes(
continue
if isinstance(desc, SubclassGetAttrAOTInput):
_raise_autograd_subclass_not_implemented(n, desc)
# pyrefly: ignore # unsupported-operation
input_index[desc] = (n, None)
elif n.op == "output":
assert "desc" in n.meta, (n, n.meta)
@ -129,6 +130,7 @@ def get_all_output_and_tangent_nodes(
continue
if isinstance(sub_d, SubclassGetAttrAOTOutput):
_raise_autograd_subclass_not_implemented(sub_n, sub_d)
# pyrefly: ignore # unsupported-operation
output_index[sub_d] = (sub_n, None)
for n in g.nodes:
if n.op == "placeholder":

View File

@ -1305,10 +1305,12 @@ def aot_dispatch_subclass(
# See Note: [Partitioner handling for Subclasses, Part 2] for more info.
meta_updated = run_functionalized_fw_and_collect_metadata(
without_output_descs(metadata_fn),
# pyrefly: ignore # bad-argument-type
flat_args_descs=primals_unwrapped_descs,
static_input_indices=remapped_static_indices,
keep_input_mutations=meta.keep_input_mutations,
is_train=meta.is_train,
# pyrefly: ignore # not-iterable
)(*primals_unwrapped)
subclass_meta.fw_metadata = meta_updated

View File

@ -425,6 +425,7 @@ def collect_fw_donated_buffer_idxs(
"""
storage_refs = set()
# pyrefly: ignore # bad-assignment
for t in itertools.chain(fw_ins, user_fw_outs, bw_outs):
# Only access storage if a tensor has storage (not sparse)
if t is not None and isinstance(t, FakeTensor) and not is_sparse_any(t):
@ -494,6 +495,7 @@ def collect_bw_donated_buffer_idxs(
fw_ins,
user_fw_outs,
bw_outs,
# pyrefly: ignore # bad-argument-type
saved_tensors,
)
@ -1762,6 +1764,7 @@ def aot_stage2_autograd(
# (2408448, 1, 21504, 192). The solution mentioned will
# decide a stride of (802816, 1, 7168, 64) for this
# tensor which is wrong.
# pyrefly: ignore # bad-argument-type
placeholder_list[i] = ph_arg.as_strided(ph_arg.size(), real_stride)
compiled_bw_func = None

View File

@ -225,6 +225,7 @@ def make_output_handler(info, runtime_metadata, trace_joint):
# not sure why AOTDispatcher needs to manually set this
def maybe_mark_dynamic_helper(t: torch.Tensor, dims: set[int]):
if hasattr(t, "_dynamo_weak_dynamic_indices"):
# pyrefly: ignore # missing-attribute
t._dynamo_weak_dynamic_indices |= dims
else:
t._dynamo_weak_dynamic_indices = dims.copy() # type: ignore[attr-defined]
@ -1142,6 +1143,7 @@ class AOTSyntheticBaseWrapper(CompilerWrapper):
def _unpack_synthetic_bases(primals: tuple[Any, ...]) -> list[Any]:
f_args_inner = []
# pyrefly: ignore # not-iterable
for inner_idx_or_tuple in synthetic_base_info:
if isinstance(inner_idx_or_tuple, int):
f_args_inner.append(primals[inner_idx_or_tuple])
@ -2112,6 +2114,7 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
return (ctx._autograd_function_id, *ctx.symints)
@staticmethod
# pyrefly: ignore # bad-override
def forward(ctx, *deduped_flat_tensor_args):
args = deduped_flat_tensor_args
if backward_state_indices:
@ -2148,6 +2151,7 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
# in the fw output order.
fw_outs = call_func_at_runtime_with_args(
CompiledFunction.compiled_fw,
# pyrefly: ignore # bad-argument-type
args,
disable_amp=disable_amp,
)
@ -2343,6 +2347,7 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
_aot_id = aot_config.aot_id
@staticmethod
# pyrefly: ignore # bad-override
def forward(double_ctx, *unused_args):
return impl_fn(double_ctx)

View File

@ -1231,7 +1231,9 @@ class SerializableAOTDispatchCompiler(AOTDispatchCompiler):
output_code_ty: type[TOutputCode],
compiler_fn: Callable[[torch.fx.GraphModule, Sequence[InputType]], TOutputCode],
):
# pyrefly: ignore # invalid-type-var
self.output_code_ty = output_code_ty
# pyrefly: ignore # invalid-type-var
self.compiler_fn = compiler_fn
def __call__(

View File

@ -90,6 +90,7 @@ def unwrap_tensor_subclass_parameters(module: torch.nn.Module) -> torch.nn.Modul
"""
for name, tensor in itertools.chain(
list(module.named_parameters(recurse=False)),
# pyrefly: ignore # no-matching-overload
list(module.named_buffers(recurse=False)),
):
if is_traceable_wrapper_subclass(tensor):

View File

@ -232,11 +232,13 @@ def unwrap_tensor_subclasses(
attrs, _ = t.__tensor_flatten__()
# pyrefly: ignore # bad-assignment
for attr in attrs:
inner_tensor = getattr(t, attr)
n_desc: Any = (
SubclassGetAttrAOTInput(desc, attr)
if isinstance(desc, AOTInput)
# pyrefly: ignore # bad-argument-type
else SubclassGetAttrAOTOutput(desc, attr)
)
flatten_subclass(inner_tensor, n_desc, out=out)
@ -257,6 +259,7 @@ def unwrap_tensor_subclasses(
descs_inner: list[AOTDescriptor] = []
for x, desc in zip(wrapped_args, wrapped_args_descs):
# pyrefly: ignore # bad-argument-type
flatten_subclass(typing.cast(Tensor, x), desc, out=(xs_inner, descs_inner))
return xs_inner, descs_inner
@ -281,6 +284,7 @@ def runtime_unwrap_tensor_subclasses(
for attr in attrs:
inner_tensor = getattr(x, attr)
# pyrefly: ignore # missing-attribute
inner_meta = meta.attrs.get(attr)
flatten_subclass(inner_tensor, inner_meta, out=out)
@ -310,6 +314,7 @@ def runtime_unwrap_tensor_subclasses(
for idx, x in enumerate(wrapped_args):
if not is_traceable_wrapper_subclass(x):
# pyrefly: ignore # bad-argument-type
xs_inner.append(x)
continue

View File

@ -328,6 +328,7 @@ def unlift_tokens(fw_module, fw_metadata, aot_config, bw_module=None):
and out.args[1] == 0
and out.args[0] in with_effect_nodes
):
# pyrefly: ignore # missing-attribute
output_token_nodes.append(out)
else:
other_output_nodes.append(out)
@ -529,8 +530,10 @@ def without_output_descs(f: Callable[_P, tuple[_T, _S]]) -> Callable[_P, _T]:
@wraps(f)
@simple_wraps(f)
def inner(*args, **kwargs):
# pyrefly: ignore # invalid-param-spec
return f(*args, **kwargs)[0]
# pyrefly: ignore # bad-return
return inner

View File

@ -753,6 +753,7 @@ class AutogradFunctionApply(HigherOrderOperator):
class ApplyTemplate(torch.autograd.Function):
@staticmethod
# pyrefly: ignore # bad-override
def forward(ctx, *args):
nonlocal saved_values
output, saved_values = fwd(None, *fwd_args)

View File

@ -42,7 +42,9 @@ def create_names_map(
This function creates a mapping from the names in named_params to the
names in tied_named_params: {'A': ['A'], 'B': ['B', 'B_tied']}.
"""
# pyrefly: ignore # no-matching-overload
named_params = dict(named_params)
# pyrefly: ignore # no-matching-overload
tied_named_params = dict(tied_named_params)
tensors_dict_keys = set(named_params.keys())
@ -51,9 +53,11 @@ def create_names_map(
tensor_to_mapping: dict[Tensor, tuple[str, list[str]]] = {}
for key, tensor in named_params.items():
# pyrefly: ignore # unsupported-operation
tensor_to_mapping[tensor] = (key, [])
for key, tensor in tied_named_params.items():
assert tensor in tensor_to_mapping
# pyrefly: ignore # bad-argument-type
tensor_to_mapping[tensor][1].append(key)
return dict(tensor_to_mapping.values())

View File

@ -1174,6 +1174,7 @@ def reordering_to_mimic_autograd_engine(gm: fx.GraphModule) -> fx.GraphModule:
# critical path first.
cur_nodes += node.all_input_nodes
# pyrefly: ignore # bad-assignment
insertable_nodes = sorted(insertable_nodes, key=lambda n: order[n])
for node in insertable_nodes:
env[node] = new_graph.node_copy(node, lambda x: env[x])
@ -2849,6 +2850,7 @@ def min_cut_rematerialization_partition(
fw_module, bw_module = _extract_fwd_bwd_modules(
joint_module,
saved_values,
# pyrefly: ignore # bad-argument-type
saved_sym_nodes=saved_sym_nodes,
num_fwd_outputs=num_fwd_outputs,
static_lifetime_input_nodes=node_info.static_lifetime_input_nodes,

View File

@ -131,6 +131,7 @@ class VmapInterpreter(FuncTorchInterpreter):
self._cdata = cdata
@cached_property
# pyrefly: ignore # bad-override
def _cptr(self):
return CVmapInterpreterPtr(self._cdata)
@ -170,6 +171,7 @@ class GradInterpreter(FuncTorchInterpreter):
self._cdata = cdata
@cached_property
# pyrefly: ignore # bad-override
def _cptr(self):
return CGradInterpreterPtr(self._cdata)
@ -207,6 +209,7 @@ class JvpInterpreter(FuncTorchInterpreter):
self._cdata = cdata
@cached_property
# pyrefly: ignore # bad-override
def _cptr(self):
return CJvpInterpreterPtr(self._cdata)
@ -243,6 +246,7 @@ class FunctionalizeInterpreter(FuncTorchInterpreter):
self._cdata = cdata
@cached_property
# pyrefly: ignore # bad-override
def _cptr(self):
return CFunctionalizeInterpreterPtr(self._cdata)

View File

@ -279,7 +279,6 @@ def out_wrapper(
TensorLikeType
if is_tensor
else NamedTuple(
# pyrefly: ignore # bad-argument-count
f"return_types_{fn.__name__}",
# pyrefly: ignore # bad-argument-count
[(o, TensorLikeType) for o in out_names],

View File

@ -33,7 +33,12 @@ class _DeconstructedSymNode:
@staticmethod
def from_node(node: SymNode) -> _DeconstructedSymNode:
return _DeconstructedSymNode(
node._expr, node.pytype, node._hint, node.constant, node.fx_node
node._expr,
node.pytype,
node._hint,
node.constant,
# pyrefly: ignore # bad-argument-type
node.fx_node,
)
def extract(self, shape_env: ShapeEnv) -> SymNode:

View File

@ -404,7 +404,9 @@ class FakeTensorConverter:
with no_dispatch():
return FakeTensor(
fake_mode,
# pyrefly: ignore # bad-argument-type
make_meta_t(),
# pyrefly: ignore # bad-argument-type
device,
# TODO: callback might be used in recursive contexts, in
# which case using t is wrong! BUG!
@ -679,6 +681,7 @@ class FakeTensor(Tensor):
_mode_key = torch._C._TorchDispatchModeKey.FAKE
@property
# pyrefly: ignore # bad-override
def device(self) -> torch.device:
if self.fake_mode.in_kernel_invocation:
return torch.device("meta")
@ -706,6 +709,7 @@ class FakeTensor(Tensor):
# We don't support named tensors; graph break
@property
# pyrefly: ignore # bad-override
def names(self) -> list[str]:
raise UnsupportedFakeTensorException(
"torch.compile doesn't support named tensors"
@ -764,6 +768,7 @@ class FakeTensor(Tensor):
)
else:
device = torch.device(f"{device.type}:0")
# pyrefly: ignore # read-only
self.fake_device = device
self.fake_mode = fake_mode
self.constant = constant
@ -1493,6 +1498,7 @@ class FakeTensorMode(TorchDispatchMode):
# Do this dispatch outside the above except handler so if it
# generates its own exception there won't be a __context__ caused by
# the caching mechanism.
# pyrefly: ignore # bad-argument-type
return self._dispatch_impl(func, types, args, kwargs)
assert state is not None
@ -1510,22 +1516,27 @@ class FakeTensorMode(TorchDispatchMode):
# This represents a negative cache entry - we already saw that the
# output is uncachable. Compute it from first principals.
FakeTensorMode.cache_bypasses[entry.reason] += 1
# pyrefly: ignore # bad-argument-type
return self._dispatch_impl(func, types, args, kwargs)
# We have a cache entry.
# pyrefly: ignore # bad-argument-type
output = self._output_from_cache_entry(state, entry, key, func, args)
FakeTensorMode.cache_hits += 1
if self.cache_crosscheck_enabled:
# For debugging / testing: Validate that the output synthesized
# from the cache matches the output created by normal dispatch.
with disable_fake_tensor_cache(self):
# pyrefly: ignore # bad-argument-type
self._crosscheck_cache_output(output, func, types, args, kwargs)
return output
# We don't have a cache entry.
# pyrefly: ignore # bad-argument-type
output = self._dispatch_impl(func, types, args, kwargs)
try:
# pyrefly: ignore # bad-argument-type
self._validate_cache_key(func, args, kwargs)
except _BypassDispatchCache as e:
# We ran "extra" checks on the cache key and determined that it's no
@ -1545,6 +1556,7 @@ class FakeTensorMode(TorchDispatchMode):
return output
try:
# pyrefly: ignore # bad-argument-type
entry = self._make_cache_entry(state, key, func, args, kwargs, output)
except _BypassDispatchCache as e:
# We had trouble making the cache entry. Record the reason and mark
@ -1587,13 +1599,16 @@ class FakeTensorMode(TorchDispatchMode):
if state.known_symbols:
# If there are symbols then include the epoch - this is really more
# of a Shape env var which lives on the FakeTensorMode.
# pyrefly: ignore # bad-argument-type
key_values.append(self.epoch)
# Collect the id_hashed objects to attach a weakref finalize later
id_hashed_objects: list[object] = []
# Translate any FakeTensor args to metadata.
if args:
# pyrefly: ignore # bad-argument-type
self._prep_args_for_hash(key_values, args, state, id_hashed_objects)
if kwargs:
# pyrefly: ignore # bad-argument-type
self._prep_args_for_hash(key_values, kwargs, state, id_hashed_objects)
key = _DispatchCacheKey(tuple(key_values))
@ -1909,27 +1924,53 @@ class FakeTensorMode(TorchDispatchMode):
if isinstance(output, tuple):
for out_element in output:
self._validate_output_for_cache_entry(
state, key, func, args, kwargs, out_element
state,
key,
# pyrefly: ignore # bad-argument-type
func,
args,
kwargs,
out_element,
)
else:
self._validate_output_for_cache_entry(
state, key, func, args, kwargs, output
state,
key,
# pyrefly: ignore # bad-argument-type
func,
args,
kwargs,
output,
)
if isinstance(output, tuple):
output_infos = [
self._get_output_info_for_cache_entry(
state, key, func, args, kwargs, out_elem
state,
key,
# pyrefly: ignore # bad-argument-type
func,
args,
kwargs,
out_elem,
)
for out_elem in output
]
return _DispatchCacheValidEntry(
output_infos=tuple(output_infos), is_output_tuple=True
# pyrefly: ignore # bad-argument-type
output_infos=tuple(output_infos),
is_output_tuple=True,
)
else:
output_info = self._get_output_info_for_cache_entry(
state, key, func, args, kwargs, output
state,
key,
# pyrefly: ignore # bad-argument-type
func,
args,
kwargs,
output,
)
return _DispatchCacheValidEntry(
output_infos=(output_info,), is_output_tuple=False
@ -2472,6 +2513,7 @@ class FakeTensorMode(TorchDispatchMode):
)
with self, maybe_ignore_fresh_unbacked_symbols():
# pyrefly: ignore # index-error
return registered_hop_fake_fns[func](*args, **kwargs)
self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs)
@ -2625,6 +2667,7 @@ class FakeTensorMode(TorchDispatchMode):
# TODO: Is this really needed?
compute_unbacked_bindings(self.shape_env, fake_out, peek=True)
# pyrefly: ignore # bad-return
return fake_out
# Try for fastpath
@ -2906,6 +2949,7 @@ class FakeTensorMode(TorchDispatchMode):
self, e, device or common_device
)
else:
# pyrefly: ignore # bad-return
return e
return tree_map(wrap, r)

View File

@ -81,6 +81,7 @@ def safe_is_leaf(t: Union[MetaTensorDesc, torch.Tensor]) -> bool:
def safe_grad(t: _TensorLikeT) -> Optional[_TensorLikeT]:
with torch._logging.hide_warnings(torch._logging._internal.safe_grad_filter):
# pyrefly: ignore # bad-return
return t.grad
@ -415,6 +416,7 @@ class MetaTensorDescriber:
device=t.device,
size=t.size(),
stride=stride,
# pyrefly: ignore # bad-argument-type
storage_offset=storage_offset,
dynamo_dynamic_indices=list(getattr(t, "_dynamo_dynamic_indices", set())),
dynamo_hint_overrides=getattr(t, "_dynamo_hint_overrides", {}),
@ -539,7 +541,11 @@ class _FakeTensorViewFunc(ViewFunc["FakeTensor"]):
tensor_visitor_fn: Optional[Callable[[torch.Tensor], FakeTensor]] = None,
) -> FakeTensor:
return torch._subclasses.fake_tensor.FakeTensor._view_func_unsafe(
t, new_base, symint_visitor_fn, tensor_visitor_fn
# pyrefly: ignore # bad-argument-type
t,
new_base,
symint_visitor_fn,
tensor_visitor_fn,
)
@ -1013,6 +1019,7 @@ class MetaConverter(Generic[_TensorT]):
# Morally, the code here is same as transform_subclass, but we've
# written it from scratch to read EmptyCreateSubclass
outer_size = outer_size if outer_size is not None else t.size
# pyrefly: ignore # bad-assignment
outer_stride = outer_stride if outer_stride is not None else t.stride
assert symbolic_context is None or isinstance(
@ -1269,6 +1276,7 @@ class MetaConverter(Generic[_TensorT]):
) -> torch.Tensor:
# It's possible to close over an undefined tensor (e.g. NJT's lengths).
if visited_t is None:
# pyrefly: ignore # bad-return
return None
# NB: visited_t being a Tensor here is very naughty! Should
@ -1399,6 +1407,7 @@ class MetaConverter(Generic[_TensorT]):
if t.requires_grad:
r.requires_grad = True
if t.requires_grad and not is_leaf:
# pyrefly: ignore # bad-argument-type
r = self._backward_error(r)
elif t.is_nested and not t.is_traceable_wrapper_subclass:
# TODO: Handle this better in Dynamo?
@ -1437,6 +1446,7 @@ class MetaConverter(Generic[_TensorT]):
if t.requires_grad:
r.requires_grad = True
if t.requires_grad and not is_leaf:
# pyrefly: ignore # bad-argument-type
r = self._backward_error(r)
elif t.is_functorch_wrapped:
if t.is_view:
@ -1533,6 +1543,7 @@ class MetaConverter(Generic[_TensorT]):
)
assert t.data is not None
_safe_copy(r.real_tensor, t.data) # type: ignore[attr-defined]
# pyrefly: ignore # bad-return
return r
r = _to_fake_tensor(t)
@ -1682,6 +1693,7 @@ class MetaConverter(Generic[_TensorT]):
not (t.is_batchedtensor or t.is_gradtrackingtensor)
and t.is_functorch_wrapped
) or t.is_legacy_batchedtensor:
# pyrefly: ignore # bad-return
return NotImplemented
(
@ -1728,6 +1740,7 @@ class MetaConverter(Generic[_TensorT]):
# the metadata of the inner tensor.
# So instead, we now have a dedicated fn to set autograd history,
# without inadvertently changing other metadata.
# pyrefly: ignore # bad-argument-type
r = self._backward_error(r)
s = t.storage
@ -1839,6 +1852,7 @@ class MetaConverter(Generic[_TensorT]):
nt_tensor_id=t.nested_int
)
# pyrefly: ignore # bad-argument-type
self.set_tensor_memo(t, r)
return self._checked_get_tensor_memo(t)
@ -1882,11 +1896,13 @@ class MetaConverter(Generic[_TensorT]):
(t._is_view() and t._base is not None and t._base.is_sparse)
):
self.miss += 1
# pyrefly: ignore # bad-return
return NotImplemented
else:
self.hit += 1
elif torch.overrides.is_tensor_like(t):
self.miss += 1
# pyrefly: ignore # bad-return
return NotImplemented
else:
# non-Tensor types don't count as hit or miss

View File

@ -92,6 +92,7 @@ def _make_grads(
is_grads_batched: bool,
) -> tuple[_OptionalTensor, ...]:
new_grads: list[_OptionalTensor] = []
# pyrefly: ignore # no-matching-overload
for out, grad in zip(outputs, grads):
out = cast(Union[torch.Tensor, graph.GradientEdge], out)
out_size = None
@ -341,6 +342,7 @@ def backward(
Union[tuple[torch.Tensor], tuple[graph.GradientEdge]], (tensors,)
)
else:
# pyrefly: ignore # bad-argument-type
tensors = tuple(tensors)
grad_tensors_ = _tensor_or_tensors_to_tuple(grad_tensors, len(tensors))
@ -440,10 +442,12 @@ def grad(
Union[Sequence[torch.Tensor], Sequence[graph.GradientEdge]], (outputs,)
)
else:
# pyrefly: ignore # bad-argument-type
outputs = tuple(outputs)
if is_tensor_like(inputs) or isinstance(inputs, graph.GradientEdge):
inputs = cast(_TensorOrTensorsOrGradEdge, (inputs,))
else:
# pyrefly: ignore # bad-argument-type
inputs = tuple(inputs)
t_outputs = tuple(i for i in outputs if is_tensor_like(i))
t_inputs = tuple(i for i in inputs if is_tensor_like(i))

View File

@ -15,12 +15,14 @@ class Type(Function):
"please use `torch.tensor.to(dtype=dtype)` instead.",
category=FutureWarning,
)
# pyrefly: ignore # bad-override
def forward(ctx, i, dest_type):
ctx.input_type = type(i)
ctx.input_device = -1 if not i.is_cuda else i.get_device()
return i.type(dest_type)
@staticmethod
# pyrefly: ignore # bad-override
def backward(ctx, grad_output):
if ctx.input_device == -1:
return grad_output.type(ctx.input_type), None
@ -32,6 +34,7 @@ class Type(Function):
# TODO: deprecate this
class Resize(Function):
@staticmethod
# pyrefly: ignore # bad-override
def forward(ctx, tensor, sizes):
ctx.sizes = sizes
ctx.numel = reduce(operator.mul, sizes, 1)
@ -60,6 +63,7 @@ class Resize(Function):
return tensor.contiguous().view(*sizes)
@staticmethod
# pyrefly: ignore # bad-override
def backward(ctx, grad_output):
assert grad_output.numel() == ctx.numel
return grad_output.contiguous().view(ctx.input_sizes), None

View File

@ -9,6 +9,8 @@ from typing_extensions import deprecated
import torch
import torch.testing
# pyrefly: ignore # deprecated
from torch._vmap_internals import _vmap, vmap
from torch.overrides import is_tensor_like
from torch.types import _TensorOrTensors

View File

@ -229,6 +229,7 @@ def get_gradient_edge(tensor: torch.Tensor) -> GradientEdge:
# Note that output_nr default to 0 which is the right value
# for the AccumulateGrad node.
# pyrefly: ignore # bad-argument-type
return GradientEdge(grad_fn, tensor.output_nr, ownership_token=token)
@ -531,6 +532,7 @@ def register_multi_grad_hook(
"expected this hook to be called inside a backward call"
)
count[id] = count.get(id, 0)
# pyrefly: ignore # unsupported-operation
buffer[id] = buffer.get(id, [None] * len_tensors)
with lock:

View File

@ -731,6 +731,7 @@ class profile:
return all_function_events
# pyrefly: ignore # invalid-inheritance
class record_function(_ContextDecorator):
"""Context manager/function decorator that adds a label to a code block/function when running autograd profiler.
Label will only appear if CPU activity tracing is enabled.
@ -778,7 +779,9 @@ class record_function(_ContextDecorator):
# TODO: TorchScript ignores standard type annotation here
# self.record: Optional["torch.classes.profiler._RecordFunction"] = None
self.record = torch.jit.annotate(
Optional["torch.classes.profiler._RecordFunction"], None
# pyrefly: ignore # not-a-type
Optional["torch.classes.profiler._RecordFunction"],
None,
)
def __enter__(self):

View File

@ -101,12 +101,14 @@ class profile:
records = _disable_profiler_legacy()
parsed_results = _parse_legacy_records(records)
# pyrefly: ignore # bad-assignment
self.function_events = EventList(
parsed_results,
use_device="cuda" if self.use_cuda else None,
profile_memory=self.profile_memory,
with_flops=self.with_flops,
)
# pyrefly: ignore # missing-attribute
self.function_events._build_tree()
return False

View File

@ -48,10 +48,14 @@ class EventList(list):
def _remove_dup_nodes(self):
while True:
to_delete = set()
# pyrefly: ignore # bad-assignment
for idx in range(len(self)):
if (
# pyrefly: ignore # index-error
self[idx].cpu_parent is not None
# pyrefly: ignore # index-error
and self[idx].cpu_parent.name == self[idx].name
# pyrefly: ignore # index-error
and len(self[idx].cpu_parent.cpu_children) == 1
):
self[idx].cpu_parent.cpu_children = self[idx].cpu_children
@ -61,8 +65,11 @@ class EventList(list):
to_delete.add(idx)
if len(to_delete) == 0:
break
# pyrefly: ignore # bad-argument-type
new_evts = [ev for ind, ev in enumerate(self) if ind not in to_delete]
# pyrefly: ignore # missing-attribute
self.clear()
# pyrefly: ignore # missing-attribute
self.extend(new_evts)
def _populate_cpu_children(self):
@ -496,7 +503,9 @@ class FunctionEvent(FormattedTimesMixin):
self.id: int = id
self.node_id: int = node_id
self.name: str = name
# pyrefly: ignore # bad-assignment
self.overload_name: str = overload_name
# pyrefly: ignore # bad-assignment
self.trace_name: str = trace_name
self.time_range: Interval = Interval(start_us, end_us)
self.thread: int = thread
@ -505,9 +514,13 @@ class FunctionEvent(FormattedTimesMixin):
self.count: int = 1
self.cpu_children: list[FunctionEvent] = []
self.cpu_parent: Optional[FunctionEvent] = None
# pyrefly: ignore # bad-assignment
self.input_shapes: tuple[int, ...] = input_shapes
# pyrefly: ignore # bad-assignment
self.concrete_inputs: list[Any] = concrete_inputs
# pyrefly: ignore # bad-assignment
self.kwinputs: dict[str, Any] = kwinputs
# pyrefly: ignore # bad-assignment
self.stack: list = stack
self.scope: int = scope
self.use_device: Optional[str] = use_device
@ -732,6 +745,7 @@ class FunctionEventAvg(FormattedTimesMixin):
self.self_device_memory_usage += other.self_device_memory_usage
self.count += other.count
if self.flops is None:
# pyrefly: ignore # bad-assignment
self.flops = other.flops
elif other.flops is not None:
self.flops += other.flops
@ -967,6 +981,7 @@ def _build_table(
"PFLOPs",
]
assert flops > 0
# pyrefly: ignore # no-matching-overload
log_flops = max(0, min(math.log10(flops) / 3, float(len(flop_headers) - 1)))
assert log_flops >= 0 and log_flops < len(flop_headers)
return (pow(10, (math.floor(log_flops) * -3.0)), flop_headers[int(log_flops)])

View File

@ -496,12 +496,14 @@ class cudaStatus:
class CudaError(RuntimeError):
def __init__(self, code: int) -> None:
# pyrefly: ignore # missing-attribute
msg = _cudart.cudaGetErrorString(_cudart.cudaError(code))
super().__init__(f"{msg} ({code})")
def check_error(res: int) -> None:
r"""Raise an error if the result of a CUDA runtime API call is not success."""
# pyrefly: ignore # missing-attribute
if res != _cudart.cudaError.success:
raise CudaError(res)
@ -601,6 +603,7 @@ def get_device_capability(device: "Device" = None) -> tuple[int, int]:
return prop.major, prop.minor
# pyrefly: ignore # not-a-type
def get_device_properties(device: "Device" = None) -> _CudaDeviceProperties:
r"""Get the properties of a device.
@ -651,6 +654,7 @@ class StreamContext:
self.idx = _get_device_index(None, True)
if not torch.jit.is_scripting():
if self.idx is None:
# pyrefly: ignore # bad-assignment
self.idx = -1
self.src_prev_stream = (
@ -953,7 +957,9 @@ def _device_count_amdsmi() -> int:
if raw_cnt <= 0:
return raw_cnt
# Trim the list up to a maximum available device
# pyrefly: ignore # bad-argument-type
for idx, val in enumerate(visible_devices):
# pyrefly: ignore # redundant-cast
if cast(int, val) >= raw_cnt:
return idx
except OSError:
@ -987,7 +993,9 @@ def _device_count_nvml() -> int:
if raw_cnt <= 0:
return raw_cnt
# Trim the list up to a maximum available device
# pyrefly: ignore # bad-argument-type
for idx, val in enumerate(visible_devices):
# pyrefly: ignore # redundant-cast
if cast(int, val) >= raw_cnt:
return idx
except OSError:
@ -1203,7 +1211,9 @@ def _get_pynvml_handler(device: "Device" = None):
if not _HAS_PYNVML:
raise ModuleNotFoundError(
"pynvml does not seem to be installed or it can't be imported."
# pyrefly: ignore # invalid-inheritance
) from _PYNVML_ERR
# pyrefly: ignore # import-error
from pynvml import NVMLError_DriverNotLoaded
try:
@ -1220,6 +1230,7 @@ def _get_amdsmi_handler(device: "Device" = None):
if not _HAS_PYNVML:
raise ModuleNotFoundError(
"amdsmi does not seem to be installed or it can't be imported."
# pyrefly: ignore # invalid-inheritance
) from _PYNVML_ERR
try:
amdsmi.amdsmi_init()
@ -1483,6 +1494,7 @@ def _get_rng_state_offset(device: Union[int, str, torch.device] = "cuda") -> int
return default_generator.get_offset()
# pyrefly: ignore # deprecated
from .memory import * # noqa: F403
from .random import * # noqa: F403
@ -1699,6 +1711,7 @@ def _register_triton_kernels():
def kernel_impl(*args, **kwargs):
from torch.sparse._triton_ops import bsr_dense_mm
# pyrefly: ignore # not-callable
return bsr_dense_mm(*args, skip_checks=True, **kwargs)
@_WrappedTritonKernel

View File

@ -279,6 +279,7 @@ class _CudaModule:
return self._kernels[name]
# Import the CUDA library inside the method
# pyrefly: ignore # missing-module-attribute
from torch.cuda._utils import _get_gpu_runtime_library
libcuda = _get_gpu_runtime_library()

View File

@ -1,3 +1,4 @@
# pyrefly: ignore # deprecated
from .autocast_mode import autocast, custom_bwd, custom_fwd
from .common import amp_definitely_not_available
from .grad_scaler import GradScaler

View File

@ -259,6 +259,7 @@ class graph:
self.cuda_graph.capture_begin(
# type: ignore[misc]
*self.pool,
# pyrefly: ignore # bad-keyword-argument
capture_error_mode=self.capture_error_mode,
)
@ -524,6 +525,7 @@ def make_graphed_callables(
) -> Callable[..., object]:
class Graphed(torch.autograd.Function):
@staticmethod
# pyrefly: ignore # bad-override
def forward(ctx: object, *inputs: Tensor) -> tuple[Tensor, ...]:
# At this stage, only the user args may (potentially) be new tensors.
for i in range(len_user_args):
@ -535,6 +537,7 @@ def make_graphed_callables(
@staticmethod
@torch.autograd.function.once_differentiable
# pyrefly: ignore # bad-override
def backward(ctx: object, *grads: Tensor) -> tuple[Tensor, ...]:
assert len(grads) == len(static_grad_outputs)
for g, grad in zip(static_grad_outputs, grads):
@ -548,7 +551,9 @@ def make_graphed_callables(
# Input args that didn't require grad expect a None gradient.
assert isinstance(static_grad_inputs, tuple)
return tuple(
b.detach() if b is not None else b for b in static_grad_inputs
# pyrefly: ignore # bad-argument-type
b.detach() if b is not None else b
for b in static_grad_inputs
)
def functionalized(*user_args: object) -> object:

View File

@ -770,6 +770,7 @@ def list_gpu_processes(device: "Device" = None) -> str:
import pynvml # type: ignore[import]
except ModuleNotFoundError:
return "pynvml module not found, please install pynvml"
# pyrefly: ignore # import-error
from pynvml import NVMLError_DriverNotLoaded
try:
@ -852,6 +853,7 @@ def _record_memory_history_legacy(
_C._cuda_record_memory_history_legacy( # type: ignore[call-arg]
enabled,
record_context,
# pyrefly: ignore # bad-argument-type
trace_alloc_max_entries,
trace_alloc_record_context,
record_context_cpp,

View File

@ -53,6 +53,7 @@ def range_start(msg) -> int:
Args:
msg (str): ASCII message to associate with the range.
"""
# pyrefly: ignore # missing-attribute
return _nvtx.rangeStartA(msg)
@ -63,6 +64,7 @@ def range_end(range_id) -> None:
Args:
range_id (int): an unique handle for the start range.
"""
# pyrefly: ignore # missing-attribute
_nvtx.rangeEnd(range_id)
@ -83,6 +85,7 @@ def _device_range_start(msg: str, stream: int = 0) -> object:
msg (str): ASCII message to associate with the range.
stream (int): CUDA stream id.
"""
# pyrefly: ignore # missing-attribute
return _nvtx.deviceRangeStart(msg, stream)
@ -95,6 +98,7 @@ def _device_range_end(range_handle: object, stream: int = 0) -> None:
range_handle: an unique handle for the start range.
stream (int): CUDA stream id.
"""
# pyrefly: ignore # missing-attribute
_nvtx.deviceRangeEnd(range_handle, stream)

View File

@ -436,6 +436,7 @@ def load(
print(ep(torch.randn(5)))
"""
if isinstance(f, (str, os.PathLike)):
# pyrefly: ignore # no-matching-overload
f = os.fspath(f)
extra_files = extra_files or {}

View File

@ -295,6 +295,7 @@ class CaptureStructuredTrace(torch._logging._internal.LazyTraceHandler):
self.logger.addHandler(self)
self.prev_get_dtrace = torch._logging._internal.GET_DTRACE_STRUCTURED
# pyrefly: ignore # bad-assignment
torch._logging._internal.GET_DTRACE_STRUCTURED = True
return self
@ -302,6 +303,7 @@ class CaptureStructuredTrace(torch._logging._internal.LazyTraceHandler):
self.log_record = LogRecord()
self.expression_created_logs = {}
self.logger.removeHandler(self)
# pyrefly: ignore # bad-assignment
torch._logging._internal.GET_DTRACE_STRUCTURED = self.prev_get_dtrace
self.prev_get_dtrace = False

View File

@ -107,8 +107,11 @@ def _try_remove_connecting_pytrees(curr_module_node: torch.fx.Node) -> None:
return
if not (
# pyrefly: ignore # missing-attribute
arg.op == "call_function"
# pyrefly: ignore # missing-attribute
and arg.target == operator.getitem
# pyrefly: ignore # missing-attribute
and arg.args[1] == i
):
log.debug(

View File

@ -185,6 +185,7 @@ def _ignore_backend_decomps():
def _disable_custom_triton_op_functional_decomposition():
old = torch._functorch.config.decompose_custom_triton_ops
try:
# pyrefly: ignore # bad-assignment
torch._functorch.config.decompose_custom_triton_ops = False
yield torch._functorch.config.decompose_custom_triton_ops
finally:
@ -365,6 +366,7 @@ def _normalize_nn_module_stack(gm_torch_level, root_cls):
return self
def __getitem__(self, idx):
# pyrefly: ignore # bad-argument-type
parts.append(str(idx))
return self
@ -660,6 +662,7 @@ def _rename_constants_nodes(
if spec.kind == InputKind.CONSTANT_TENSOR and not spec.arg.name.startswith(
const_prefix
):
# pyrefly: ignore # bad-argument-type
if spec.arg.name.startswith(buffer_prefix): # map from buffer to constants
c_name = rename_constant(
const_prefix + spec.arg.name[len(buffer_prefix) :]

View File

@ -293,6 +293,7 @@ class _ExportPackage:
if isinstance(fn, torch.nn.Module):
dynamic_shapes = v(fn, *args, **kwargs) # type: ignore[arg-type]
else:
# pyrefly: ignore # invalid-param-spec
dynamic_shapes = v(*args, **kwargs)
except AssertionError:
continue
@ -340,6 +341,7 @@ class _ExportPackage:
assert not hasattr(fn, "_define_overload")
_exporter_context._define_overload = _define_overload # type: ignore[attr-defined]
# pyrefly: ignore # bad-return
return _exporter_context
@property
@ -376,6 +378,7 @@ class _ExportPackage:
kwargs=ep.example_inputs[1],
options=options,
)
# pyrefly: ignore # unsupported-operation
aoti_files_map[name] = aoti_files
from torch._inductor.package import package

View File

@ -1500,6 +1500,7 @@ class ExportedProgram:
transformed_gm = res.graph_module if res is not None else self.graph_module
assert transformed_gm is not None
# pyrefly: ignore # missing-attribute
if transformed_gm is self.graph_module and not res.modified:
return self
@ -1578,6 +1579,7 @@ class ExportedProgram:
verifiers=self.verifiers,
)
transformed_ep.graph_module.meta.update(self.graph_module.meta)
# pyrefly: ignore # missing-attribute
transformed_ep.graph_module.meta.update(res.graph_module.meta)
return transformed_ep

View File

@ -81,6 +81,7 @@ def move_to_device_pass(
and node.target == torch.ops.aten.to.device
):
args = list(node.args)
# pyrefly: ignore # unsupported-operation
args[1] = _get_new_device(args[1], location)
node.args = tuple(args)

View File

@ -172,8 +172,10 @@ class PT2ArchiveWriter:
os.path.isfile, glob.glob(f"{folder_dir}/**", recursive=True)
)
for file_path in file_paths:
# pyrefly: ignore # no-matching-overload
filename = os.path.relpath(file_path, folder_dir)
archive_path = os.path.join(archive_dir, filename)
# pyrefly: ignore # bad-argument-type
self.write_file(archive_path, file_path)
def close(self) -> None:
@ -593,6 +595,7 @@ def package_pt2(
if not (
(isinstance(f, (io.IOBase, IO)) and f.writable() and f.seekable())
# pyrefly: ignore # no-matching-overload
or (isinstance(f, (str, os.PathLike)) and os.fspath(f).endswith(".pt2"))
or (isinstance(f, tempfile._TemporaryFileWrapper) and f.name.endswith(".pt2"))
):
@ -604,8 +607,10 @@ def package_pt2(
)
if isinstance(f, (str, os.PathLike)):
# pyrefly: ignore # no-matching-overload
f = os.fspath(f)
# pyrefly: ignore # bad-argument-type
with PT2ArchiveWriter(f) as archive_writer:
_package_exported_programs(
archive_writer, exported_programs, pickle_protocol=pickle_protocol
@ -620,6 +625,7 @@ def package_pt2(
if isinstance(f, (io.IOBase, IO)):
f.seek(0)
# pyrefly: ignore # bad-return
return f
@ -992,6 +998,7 @@ def load_pt2(
if not (
(isinstance(f, (io.IOBase, IO)) and f.readable() and f.seekable())
# pyrefly: ignore # no-matching-overload
or (isinstance(f, (str, os.PathLike)) and os.fspath(f).endswith(".pt2"))
):
# TODO: turn this into an error in 2.9
@ -1002,10 +1009,12 @@ def load_pt2(
)
if isinstance(f, (str, os.PathLike)):
# pyrefly: ignore # no-matching-overload
f = os.fspath(f)
weights = {}
weight_maps = {}
# pyrefly: ignore # bad-argument-type
with PT2ArchiveReader(f) as archive_reader:
version = archive_reader.read_string(ARCHIVE_VERSION_PATH)
if version != ARCHIVE_VERSION_VALUE:
@ -1070,7 +1079,12 @@ def load_pt2(
else:
aoti_runners = {
model_name: _load_aoti(
f, model_name, run_single_threaded, num_runners, device_index
# pyrefly: ignore # bad-argument-type
f,
model_name,
run_single_threaded,
num_runners,
device_index,
)
for model_name in aoti_model_names
}

View File

@ -937,6 +937,7 @@ def _check_graph_equivalence(x: torch.nn.Module, y: torch.nn.Module):
for key, value in pytree.tree_map(arg_dump, node.kwargs).items()
]
target = node.target if node.op in ("call_function", "get_attr") else ""
# pyrefly: ignore # bad-argument-type
ret.append(f"{i}: {node.op}[{target}]({', '.join(args_dump)})")
nodes_idx[id(node)] = i
return "\n".join(ret)
@ -1473,6 +1474,7 @@ class _ModuleFrame:
self.seen_attrs[self.child_fqn].add(node.target)
self.copy_node(node)
# pyrefly: ignore # unsupported-operation
node_idx += 1

View File

@ -483,6 +483,7 @@ def _canonical_dim(dim: DimOrDims, ndim: int) -> tuple[int, ...]:
raise IndexError(
f"Dimension out of range (expected to be in range of [{-ndim}, {ndim - 1}], but got {d})"
)
# pyrefly: ignore # bad-argument-type
dims.append(d % ndim)
return tuple(sorted(dims))
@ -641,6 +642,7 @@ def _sparse_coo_scatter_reduction_helper(
# promote dtype if specified
if values.dtype != output_dtype:
# pyrefly: ignore # no-matching-overload
values = values.to(output_dtype)
if keepdim:
@ -765,6 +767,7 @@ def _sparse_csr_segment_reduction_helper(
# promote dtype if specified
if values.dtype != output_dtype:
# pyrefly: ignore # no-matching-overload
values = values.to(output_dtype)
if len(dims) == 0:
@ -1015,6 +1018,7 @@ def _combine_input_and_mask(
class Combine(torch.autograd.Function):
@staticmethod
# pyrefly: ignore # bad-override
def forward(ctx, input, mask):
"""Return input with masked-out elements eliminated for the given operations."""
ctx.save_for_backward(mask)
@ -1025,6 +1029,7 @@ def _combine_input_and_mask(
return helper(input, mask)
@staticmethod
# pyrefly: ignore # bad-override
def backward(ctx, grad_output):
(mask,) = ctx.saved_tensors
grad_data = (
@ -1399,15 +1404,18 @@ elements, have ``nan`` values.
if input.layout == torch.strided:
if mask is None:
# TODO: compute count analytically
# pyrefly: ignore # no-matching-overload
count = sum(
torch.ones(input.shape, dtype=torch.int64, device=input.device),
dim,
keepdim=keepdim,
)
# pyrefly: ignore # no-matching-overload
total = sum(input, dim, keepdim=keepdim, dtype=dtype)
else:
inmask = _input_mask(input, mask=mask)
count = inmask.sum(dim=dim, keepdim=bool(keepdim))
# pyrefly: ignore # no-matching-overload
total = sum(input, dim, keepdim=keepdim, dtype=dtype, mask=inmask)
return total / count
elif input.layout == torch.sparse_csr:
@ -1618,15 +1626,18 @@ def _std_var(
if input.layout == torch.strided:
if mask is None:
# TODO: compute count analytically
# pyrefly: ignore # no-matching-overload
count = sum(
torch.ones(input.shape, dtype=torch.int64, device=input.device),
dim,
keepdim=True,
)
# pyrefly: ignore # no-matching-overload
sample_total = sum(input, dim, keepdim=True, dtype=dtype)
else:
inmask = _input_mask(input, mask=mask)
count = inmask.sum(dim=dim, keepdim=True)
# pyrefly: ignore # no-matching-overload
sample_total = sum(input, dim, keepdim=True, dtype=dtype, mask=inmask)
# TODO: replace torch.subtract/divide/square/maximum with
# masked subtract/divide/square/maximum when these will be
@ -1634,6 +1645,7 @@ def _std_var(
sample_mean = torch.divide(sample_total, count)
x = torch.subtract(input, sample_mean)
if mask is None:
# pyrefly: ignore # no-matching-overload
total = sum(x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype)
else:
total = sum(

View File

@ -47,6 +47,7 @@ def _check_args_kwargs_length(
class _MaskedContiguous(torch.autograd.Function):
@staticmethod
# pyrefly: ignore # bad-override
def forward(ctx, input):
if not is_masked_tensor(input):
raise ValueError("MaskedContiguous forward: input must be a MaskedTensor.")
@ -60,12 +61,14 @@ class _MaskedContiguous(torch.autograd.Function):
return MaskedTensor(data.contiguous(), mask.contiguous())
@staticmethod
# pyrefly: ignore # bad-override
def backward(ctx, grad_output):
return grad_output
class _MaskedToDense(torch.autograd.Function):
@staticmethod
# pyrefly: ignore # bad-override
def forward(ctx, input):
if not is_masked_tensor(input):
raise ValueError("MaskedToDense forward: input must be a MaskedTensor.")
@ -80,6 +83,7 @@ class _MaskedToDense(torch.autograd.Function):
return MaskedTensor(data.to_dense(), mask.to_dense())
@staticmethod
# pyrefly: ignore # bad-override
def backward(ctx, grad_output):
layout = ctx.layout
@ -94,6 +98,7 @@ class _MaskedToDense(torch.autograd.Function):
class _MaskedToSparse(torch.autograd.Function):
@staticmethod
# pyrefly: ignore # bad-override
def forward(ctx, input):
if not is_masked_tensor(input):
raise ValueError("MaskedToSparse forward: input must be a MaskedTensor.")
@ -110,12 +115,14 @@ class _MaskedToSparse(torch.autograd.Function):
return MaskedTensor(sparse_data, sparse_mask)
@staticmethod
# pyrefly: ignore # bad-override
def backward(ctx, grad_output):
return grad_output.to_dense()
class _MaskedToSparseCsr(torch.autograd.Function):
@staticmethod
# pyrefly: ignore # bad-override
def forward(ctx, input):
if not is_masked_tensor(input):
raise ValueError("MaskedToSparseCsr forward: input must be a MaskedTensor.")
@ -136,18 +143,21 @@ class _MaskedToSparseCsr(torch.autograd.Function):
return MaskedTensor(sparse_data, sparse_mask)
@staticmethod
# pyrefly: ignore # bad-override
def backward(ctx, grad_output):
return grad_output.to_dense()
class _MaskedWhere(torch.autograd.Function):
@staticmethod
# pyrefly: ignore # bad-override
def forward(ctx, cond, self, other):
ctx.mark_non_differentiable(cond)
ctx.save_for_backward(cond)
return torch.ops.aten.where(cond, self, other)
@staticmethod
# pyrefly: ignore # bad-override
def backward(ctx, grad_output):
(cond,) = ctx.saved_tensors

View File

@ -174,6 +174,7 @@ class MaskedTensor(torch.Tensor):
UserWarning,
stacklevel=2,
)
# pyrefly: ignore # bad-argument-type
return torch.Tensor._make_wrapper_subclass(cls, data.size(), **kwargs)
def _preprocess_data(self, data, mask):
@ -243,10 +244,12 @@ class MaskedTensor(torch.Tensor):
class Constructor(torch.autograd.Function):
@staticmethod
# pyrefly: ignore # bad-override
def forward(ctx, data, mask):
return MaskedTensor(data, mask)
@staticmethod
# pyrefly: ignore # bad-override
def backward(ctx, grad_output):
return grad_output, None
@ -333,10 +336,12 @@ class MaskedTensor(torch.Tensor):
def get_data(self):
class GetData(torch.autograd.Function):
@staticmethod
# pyrefly: ignore # bad-override
def forward(ctx, self):
return self._masked_data.detach()
@staticmethod
# pyrefly: ignore # bad-override
def backward(ctx, grad_output):
if is_masked_tensor(grad_output):
return grad_output