mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add pyrefly suppressions (3/n) (#164588)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283 Test plan: dmypy restart && python3 scripts/lintrunner.py -a pyrefly check step 1: uncomment lines in the pyrefly.toml file step 2: run pyrefly check step 3: add suppressions, clean up unused suppressions before: https://gist.github.com/maggiemoss/bb31574ac8a59893c9cf52189e67bb2d after: 0 errors (1,970 ignored) Pull Request resolved: https://github.com/pytorch/pytorch/pull/164588 Approved by: https://github.com/oulgen
This commit is contained in:
committed by
PyTorch MergeBot
parent
e438db2546
commit
f414aa8e0d
@ -34,12 +34,6 @@ project-excludes = [
|
||||
"torch/jit/**",
|
||||
"torch/optim/**",
|
||||
"torch/_higher_order_ops/**",
|
||||
"torch/_functorch/**",
|
||||
"torch/masked/**",
|
||||
"torch/_subclasses/**",
|
||||
"torch/autograd/**",
|
||||
"torch/cuda/**",
|
||||
"torch/export/**",
|
||||
# formatting issues
|
||||
"torch/linalg/__init__.py",
|
||||
"torch/package/importer.py",
|
||||
|
@ -58,20 +58,20 @@ class TestBundledInputs(TestCase):
|
||||
# Make sure the model only grew a little bit,
|
||||
# despite having nominally large bundled inputs.
|
||||
augmented_size = model_size(sm)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
|
||||
self.assertLess(augmented_size, original_size + (1 << 12))
|
||||
|
||||
loaded = save_and_load(sm)
|
||||
inflated = loaded.get_all_bundled_inputs()
|
||||
self.assertEqual(loaded.get_num_bundled_inputs(), len(samples))
|
||||
self.assertEqual(len(inflated), len(samples))
|
||||
# pyrefly: ignore # missing-attribute
|
||||
|
||||
self.assertTrue(loaded(*inflated[0]) is inflated[0][0])
|
||||
|
||||
for idx, inp in enumerate(inflated):
|
||||
self.assertIsInstance(inp, tuple) # pyrefly: ignore # missing-attribute
|
||||
self.assertIsInstance(inp, tuple)
|
||||
self.assertEqual(len(inp), 1)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
|
||||
self.assertIsInstance(inp[0], torch.Tensor)
|
||||
if idx != 5:
|
||||
# Strides might be important for benchmarking.
|
||||
@ -139,7 +139,7 @@ class TestBundledInputs(TestCase):
|
||||
loaded = save_and_load(sm)
|
||||
inflated = loaded.get_all_bundled_inputs()
|
||||
self.assertEqual(inflated, samples)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
|
||||
self.assertTrue(loaded(*inflated[0]) == "first 1")
|
||||
|
||||
def test_multiple_methods_with_inputs(self):
|
||||
@ -186,7 +186,7 @@ class TestBundledInputs(TestCase):
|
||||
self.assertEqual(inflated, loaded.get_all_bundled_inputs_for_foo())
|
||||
|
||||
# Check running and size helpers
|
||||
# pyrefly: ignore # missing-attribute
|
||||
|
||||
self.assertTrue(loaded(*inflated[0]) is inflated[0][0])
|
||||
self.assertEqual(loaded.get_num_bundled_inputs(), len(samples))
|
||||
|
||||
@ -419,7 +419,7 @@ class TestBundledInputs(TestCase):
|
||||
)
|
||||
augmented_size = model_size(sm)
|
||||
# assert the size has not increased more than 8KB
|
||||
# pyrefly: ignore # missing-attribute
|
||||
|
||||
self.assertLess(augmented_size, original_size + (1 << 13))
|
||||
|
||||
loaded = save_and_load(sm)
|
||||
|
@ -48,7 +48,7 @@ class TestComplexTensor(TestCase):
|
||||
def test_all(self, device, dtype):
|
||||
# issue: https://github.com/pytorch/pytorch/issues/120875
|
||||
x = torch.tensor([1 + 2j, 3 - 4j, 5j, 6], device=device, dtype=dtype)
|
||||
self.assertTrue(torch.all(x)) # pyrefly: ignore # missing-attribute
|
||||
self.assertTrue(torch.all(x))
|
||||
|
||||
@dtypes(*complex_types())
|
||||
def test_any(self, device, dtype):
|
||||
@ -56,7 +56,7 @@ class TestComplexTensor(TestCase):
|
||||
x = torch.tensor(
|
||||
[0, 0j, -0 + 0j, -0 - 0j, 0 + 0j, 0 - 0j], device=device, dtype=dtype
|
||||
)
|
||||
self.assertFalse(torch.any(x)) # pyrefly: ignore # missing-attribute
|
||||
self.assertFalse(torch.any(x))
|
||||
|
||||
@onlyCPU
|
||||
@dtypes(*complex_types())
|
||||
|
@ -142,7 +142,6 @@ class TestTypeHints(TestCase):
|
||||
]
|
||||
)
|
||||
if result != 0:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self.fail(f"mypy failed:\n{stderr}\n{stdout}")
|
||||
|
||||
|
||||
|
@ -125,7 +125,7 @@ class TestDTypeInfo(TestCase):
|
||||
# Regression test for https://github.com/pytorch/pytorch/issues/124868
|
||||
# If reference count is leaked this would be a set of 10 elements
|
||||
ref_cnt = {sys.getrefcount(torch.float32.to_complex()) for _ in range(10)}
|
||||
self.assertLess(len(ref_cnt), 3) # pyrefly: ignore # missing-attribute
|
||||
self.assertLess(len(ref_cnt), 3)
|
||||
|
||||
self.assertEqual(torch.float64.to_complex(), torch.complex128)
|
||||
self.assertEqual(torch.float32.to_complex(), torch.complex64)
|
||||
@ -135,7 +135,7 @@ class TestDTypeInfo(TestCase):
|
||||
# Regression test for https://github.com/pytorch/pytorch/issues/124868
|
||||
# If reference count is leaked this would be a set of 10 elements
|
||||
ref_cnt = {sys.getrefcount(torch.cfloat.to_real()) for _ in range(10)}
|
||||
self.assertLess(len(ref_cnt), 3) # pyrefly: ignore # missing-attribute
|
||||
self.assertLess(len(ref_cnt), 3)
|
||||
|
||||
self.assertEqual(torch.complex128.to_real(), torch.double)
|
||||
self.assertEqual(torch.complex64.to_real(), torch.float32)
|
||||
|
@ -2653,6 +2653,7 @@ def compile(
|
||||
dynamic=dynamic,
|
||||
disable=disable,
|
||||
guard_filter_fn=guard_filter_fn,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
)(model)(*args, **kwargs)
|
||||
|
||||
return export_wrapped_fn
|
||||
|
@ -384,6 +384,7 @@ class AOTAutogradCacheDetails(FxGraphHashDetails):
|
||||
class AOTAutogradCachePickler(FxGraphCachePickler):
|
||||
def __init__(self, gm: torch.fx.GraphModule):
|
||||
super().__init__(gm)
|
||||
# pyrefly: ignore # bad-override
|
||||
self.dispatch_table: dict
|
||||
self.dispatch_table.update(
|
||||
{
|
||||
|
@ -86,8 +86,10 @@ def coerce_tangent_and_suggest_memory_format(x: Tensor):
|
||||
|
||||
memory_format = MemoryFormatMeta.from_tensor(out)
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
if memory_format.memory_format is not None:
|
||||
was = out
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
out = out.contiguous(memory_format=memory_format.memory_format)
|
||||
updated = was is not out
|
||||
|
||||
@ -117,6 +119,7 @@ def coerce_tangent_and_suggest_memory_format(x: Tensor):
|
||||
out = out.__coerce_tangent_metadata__() # type: ignore[attr-defined]
|
||||
|
||||
if is_subclass:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
attrs = out.__tensor_flatten__()[0]
|
||||
|
||||
for attr in attrs:
|
||||
@ -126,6 +129,7 @@ def coerce_tangent_and_suggest_memory_format(x: Tensor):
|
||||
new_elem_memory_format,
|
||||
elem_updated,
|
||||
) = coerce_tangent_and_suggest_memory_format(elem)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
out_memory_format.append(new_elem_memory_format)
|
||||
if elem_updated:
|
||||
setattr(out, attr, new_elem)
|
||||
@ -492,6 +496,7 @@ def run_functionalized_fw_and_collect_metadata(
|
||||
curr_storage in inp_storage_refs
|
||||
and not functional_tensor_storage_changed
|
||||
):
|
||||
# pyrefly: ignore # index-error
|
||||
base_idx = inp_storage_refs[curr_storage]
|
||||
is_input_tensor = id(o) in inp_tensor_ids
|
||||
num_aliased_outs = out_tensor_alias_counts[curr_storage]
|
||||
@ -699,6 +704,7 @@ from a multi-output view call"
|
||||
# Anything that aliases (inputs returned in the fw due to metadata mutations, or outputs that alias inputs/intermediates)
|
||||
# are *regenerated* later, and not used directly in the autograd graph
|
||||
def _plain_fake_tensor_like_subclass(x):
|
||||
# pyrefly: ignore # bad-context-manager
|
||||
with detect_fake_mode():
|
||||
return torch.empty(
|
||||
x.shape, dtype=x.dtype, device=x.device, layout=x.layout
|
||||
|
@ -78,6 +78,7 @@ def get_all_input_and_grad_nodes(
|
||||
continue
|
||||
if isinstance(desc, SubclassGetAttrAOTInput):
|
||||
_raise_autograd_subclass_not_implemented(n, desc)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
input_index[desc] = (n, None)
|
||||
elif n.op == "output":
|
||||
assert "desc" in n.meta, (n, n.meta)
|
||||
@ -129,6 +130,7 @@ def get_all_output_and_tangent_nodes(
|
||||
continue
|
||||
if isinstance(sub_d, SubclassGetAttrAOTOutput):
|
||||
_raise_autograd_subclass_not_implemented(sub_n, sub_d)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
output_index[sub_d] = (sub_n, None)
|
||||
for n in g.nodes:
|
||||
if n.op == "placeholder":
|
||||
|
@ -1305,10 +1305,12 @@ def aot_dispatch_subclass(
|
||||
# See Note: [Partitioner handling for Subclasses, Part 2] for more info.
|
||||
meta_updated = run_functionalized_fw_and_collect_metadata(
|
||||
without_output_descs(metadata_fn),
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
flat_args_descs=primals_unwrapped_descs,
|
||||
static_input_indices=remapped_static_indices,
|
||||
keep_input_mutations=meta.keep_input_mutations,
|
||||
is_train=meta.is_train,
|
||||
# pyrefly: ignore # not-iterable
|
||||
)(*primals_unwrapped)
|
||||
|
||||
subclass_meta.fw_metadata = meta_updated
|
||||
|
@ -425,6 +425,7 @@ def collect_fw_donated_buffer_idxs(
|
||||
"""
|
||||
|
||||
storage_refs = set()
|
||||
# pyrefly: ignore # bad-assignment
|
||||
for t in itertools.chain(fw_ins, user_fw_outs, bw_outs):
|
||||
# Only access storage if a tensor has storage (not sparse)
|
||||
if t is not None and isinstance(t, FakeTensor) and not is_sparse_any(t):
|
||||
@ -494,6 +495,7 @@ def collect_bw_donated_buffer_idxs(
|
||||
fw_ins,
|
||||
user_fw_outs,
|
||||
bw_outs,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
saved_tensors,
|
||||
)
|
||||
|
||||
@ -1762,6 +1764,7 @@ def aot_stage2_autograd(
|
||||
# (2408448, 1, 21504, 192). The solution mentioned will
|
||||
# decide a stride of (802816, 1, 7168, 64) for this
|
||||
# tensor which is wrong.
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
placeholder_list[i] = ph_arg.as_strided(ph_arg.size(), real_stride)
|
||||
|
||||
compiled_bw_func = None
|
||||
|
@ -225,6 +225,7 @@ def make_output_handler(info, runtime_metadata, trace_joint):
|
||||
# not sure why AOTDispatcher needs to manually set this
|
||||
def maybe_mark_dynamic_helper(t: torch.Tensor, dims: set[int]):
|
||||
if hasattr(t, "_dynamo_weak_dynamic_indices"):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
t._dynamo_weak_dynamic_indices |= dims
|
||||
else:
|
||||
t._dynamo_weak_dynamic_indices = dims.copy() # type: ignore[attr-defined]
|
||||
@ -1142,6 +1143,7 @@ class AOTSyntheticBaseWrapper(CompilerWrapper):
|
||||
|
||||
def _unpack_synthetic_bases(primals: tuple[Any, ...]) -> list[Any]:
|
||||
f_args_inner = []
|
||||
# pyrefly: ignore # not-iterable
|
||||
for inner_idx_or_tuple in synthetic_base_info:
|
||||
if isinstance(inner_idx_or_tuple, int):
|
||||
f_args_inner.append(primals[inner_idx_or_tuple])
|
||||
@ -2112,6 +2114,7 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
|
||||
return (ctx._autograd_function_id, *ctx.symints)
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(ctx, *deduped_flat_tensor_args):
|
||||
args = deduped_flat_tensor_args
|
||||
if backward_state_indices:
|
||||
@ -2148,6 +2151,7 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
|
||||
# in the fw output order.
|
||||
fw_outs = call_func_at_runtime_with_args(
|
||||
CompiledFunction.compiled_fw,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
args,
|
||||
disable_amp=disable_amp,
|
||||
)
|
||||
@ -2343,6 +2347,7 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
|
||||
_aot_id = aot_config.aot_id
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(double_ctx, *unused_args):
|
||||
return impl_fn(double_ctx)
|
||||
|
||||
|
@ -1231,7 +1231,9 @@ class SerializableAOTDispatchCompiler(AOTDispatchCompiler):
|
||||
output_code_ty: type[TOutputCode],
|
||||
compiler_fn: Callable[[torch.fx.GraphModule, Sequence[InputType]], TOutputCode],
|
||||
):
|
||||
# pyrefly: ignore # invalid-type-var
|
||||
self.output_code_ty = output_code_ty
|
||||
# pyrefly: ignore # invalid-type-var
|
||||
self.compiler_fn = compiler_fn
|
||||
|
||||
def __call__(
|
||||
|
@ -90,6 +90,7 @@ def unwrap_tensor_subclass_parameters(module: torch.nn.Module) -> torch.nn.Modul
|
||||
"""
|
||||
for name, tensor in itertools.chain(
|
||||
list(module.named_parameters(recurse=False)),
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
list(module.named_buffers(recurse=False)),
|
||||
):
|
||||
if is_traceable_wrapper_subclass(tensor):
|
||||
|
@ -232,11 +232,13 @@ def unwrap_tensor_subclasses(
|
||||
|
||||
attrs, _ = t.__tensor_flatten__()
|
||||
|
||||
# pyrefly: ignore # bad-assignment
|
||||
for attr in attrs:
|
||||
inner_tensor = getattr(t, attr)
|
||||
n_desc: Any = (
|
||||
SubclassGetAttrAOTInput(desc, attr)
|
||||
if isinstance(desc, AOTInput)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
else SubclassGetAttrAOTOutput(desc, attr)
|
||||
)
|
||||
flatten_subclass(inner_tensor, n_desc, out=out)
|
||||
@ -257,6 +259,7 @@ def unwrap_tensor_subclasses(
|
||||
descs_inner: list[AOTDescriptor] = []
|
||||
|
||||
for x, desc in zip(wrapped_args, wrapped_args_descs):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
flatten_subclass(typing.cast(Tensor, x), desc, out=(xs_inner, descs_inner))
|
||||
|
||||
return xs_inner, descs_inner
|
||||
@ -281,6 +284,7 @@ def runtime_unwrap_tensor_subclasses(
|
||||
|
||||
for attr in attrs:
|
||||
inner_tensor = getattr(x, attr)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
inner_meta = meta.attrs.get(attr)
|
||||
flatten_subclass(inner_tensor, inner_meta, out=out)
|
||||
|
||||
@ -310,6 +314,7 @@ def runtime_unwrap_tensor_subclasses(
|
||||
|
||||
for idx, x in enumerate(wrapped_args):
|
||||
if not is_traceable_wrapper_subclass(x):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
xs_inner.append(x)
|
||||
continue
|
||||
|
||||
|
@ -328,6 +328,7 @@ def unlift_tokens(fw_module, fw_metadata, aot_config, bw_module=None):
|
||||
and out.args[1] == 0
|
||||
and out.args[0] in with_effect_nodes
|
||||
):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
output_token_nodes.append(out)
|
||||
else:
|
||||
other_output_nodes.append(out)
|
||||
@ -529,8 +530,10 @@ def without_output_descs(f: Callable[_P, tuple[_T, _S]]) -> Callable[_P, _T]:
|
||||
@wraps(f)
|
||||
@simple_wraps(f)
|
||||
def inner(*args, **kwargs):
|
||||
# pyrefly: ignore # invalid-param-spec
|
||||
return f(*args, **kwargs)[0]
|
||||
|
||||
# pyrefly: ignore # bad-return
|
||||
return inner
|
||||
|
||||
|
||||
|
@ -753,6 +753,7 @@ class AutogradFunctionApply(HigherOrderOperator):
|
||||
|
||||
class ApplyTemplate(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(ctx, *args):
|
||||
nonlocal saved_values
|
||||
output, saved_values = fwd(None, *fwd_args)
|
||||
|
@ -42,7 +42,9 @@ def create_names_map(
|
||||
This function creates a mapping from the names in named_params to the
|
||||
names in tied_named_params: {'A': ['A'], 'B': ['B', 'B_tied']}.
|
||||
"""
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
named_params = dict(named_params)
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
tied_named_params = dict(tied_named_params)
|
||||
|
||||
tensors_dict_keys = set(named_params.keys())
|
||||
@ -51,9 +53,11 @@ def create_names_map(
|
||||
|
||||
tensor_to_mapping: dict[Tensor, tuple[str, list[str]]] = {}
|
||||
for key, tensor in named_params.items():
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
tensor_to_mapping[tensor] = (key, [])
|
||||
for key, tensor in tied_named_params.items():
|
||||
assert tensor in tensor_to_mapping
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
tensor_to_mapping[tensor][1].append(key)
|
||||
return dict(tensor_to_mapping.values())
|
||||
|
||||
|
@ -1174,6 +1174,7 @@ def reordering_to_mimic_autograd_engine(gm: fx.GraphModule) -> fx.GraphModule:
|
||||
# critical path first.
|
||||
cur_nodes += node.all_input_nodes
|
||||
|
||||
# pyrefly: ignore # bad-assignment
|
||||
insertable_nodes = sorted(insertable_nodes, key=lambda n: order[n])
|
||||
for node in insertable_nodes:
|
||||
env[node] = new_graph.node_copy(node, lambda x: env[x])
|
||||
@ -2849,6 +2850,7 @@ def min_cut_rematerialization_partition(
|
||||
fw_module, bw_module = _extract_fwd_bwd_modules(
|
||||
joint_module,
|
||||
saved_values,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
saved_sym_nodes=saved_sym_nodes,
|
||||
num_fwd_outputs=num_fwd_outputs,
|
||||
static_lifetime_input_nodes=node_info.static_lifetime_input_nodes,
|
||||
|
@ -131,6 +131,7 @@ class VmapInterpreter(FuncTorchInterpreter):
|
||||
self._cdata = cdata
|
||||
|
||||
@cached_property
|
||||
# pyrefly: ignore # bad-override
|
||||
def _cptr(self):
|
||||
return CVmapInterpreterPtr(self._cdata)
|
||||
|
||||
@ -170,6 +171,7 @@ class GradInterpreter(FuncTorchInterpreter):
|
||||
self._cdata = cdata
|
||||
|
||||
@cached_property
|
||||
# pyrefly: ignore # bad-override
|
||||
def _cptr(self):
|
||||
return CGradInterpreterPtr(self._cdata)
|
||||
|
||||
@ -207,6 +209,7 @@ class JvpInterpreter(FuncTorchInterpreter):
|
||||
self._cdata = cdata
|
||||
|
||||
@cached_property
|
||||
# pyrefly: ignore # bad-override
|
||||
def _cptr(self):
|
||||
return CJvpInterpreterPtr(self._cdata)
|
||||
|
||||
@ -243,6 +246,7 @@ class FunctionalizeInterpreter(FuncTorchInterpreter):
|
||||
self._cdata = cdata
|
||||
|
||||
@cached_property
|
||||
# pyrefly: ignore # bad-override
|
||||
def _cptr(self):
|
||||
return CFunctionalizeInterpreterPtr(self._cdata)
|
||||
|
||||
|
@ -279,7 +279,6 @@ def out_wrapper(
|
||||
TensorLikeType
|
||||
if is_tensor
|
||||
else NamedTuple(
|
||||
# pyrefly: ignore # bad-argument-count
|
||||
f"return_types_{fn.__name__}",
|
||||
# pyrefly: ignore # bad-argument-count
|
||||
[(o, TensorLikeType) for o in out_names],
|
||||
|
@ -33,7 +33,12 @@ class _DeconstructedSymNode:
|
||||
@staticmethod
|
||||
def from_node(node: SymNode) -> _DeconstructedSymNode:
|
||||
return _DeconstructedSymNode(
|
||||
node._expr, node.pytype, node._hint, node.constant, node.fx_node
|
||||
node._expr,
|
||||
node.pytype,
|
||||
node._hint,
|
||||
node.constant,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
node.fx_node,
|
||||
)
|
||||
|
||||
def extract(self, shape_env: ShapeEnv) -> SymNode:
|
||||
|
@ -404,7 +404,9 @@ class FakeTensorConverter:
|
||||
with no_dispatch():
|
||||
return FakeTensor(
|
||||
fake_mode,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
make_meta_t(),
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
device,
|
||||
# TODO: callback might be used in recursive contexts, in
|
||||
# which case using t is wrong! BUG!
|
||||
@ -679,6 +681,7 @@ class FakeTensor(Tensor):
|
||||
_mode_key = torch._C._TorchDispatchModeKey.FAKE
|
||||
|
||||
@property
|
||||
# pyrefly: ignore # bad-override
|
||||
def device(self) -> torch.device:
|
||||
if self.fake_mode.in_kernel_invocation:
|
||||
return torch.device("meta")
|
||||
@ -706,6 +709,7 @@ class FakeTensor(Tensor):
|
||||
|
||||
# We don't support named tensors; graph break
|
||||
@property
|
||||
# pyrefly: ignore # bad-override
|
||||
def names(self) -> list[str]:
|
||||
raise UnsupportedFakeTensorException(
|
||||
"torch.compile doesn't support named tensors"
|
||||
@ -764,6 +768,7 @@ class FakeTensor(Tensor):
|
||||
)
|
||||
else:
|
||||
device = torch.device(f"{device.type}:0")
|
||||
# pyrefly: ignore # read-only
|
||||
self.fake_device = device
|
||||
self.fake_mode = fake_mode
|
||||
self.constant = constant
|
||||
@ -1493,6 +1498,7 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
# Do this dispatch outside the above except handler so if it
|
||||
# generates its own exception there won't be a __context__ caused by
|
||||
# the caching mechanism.
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return self._dispatch_impl(func, types, args, kwargs)
|
||||
|
||||
assert state is not None
|
||||
@ -1510,22 +1516,27 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
# This represents a negative cache entry - we already saw that the
|
||||
# output is uncachable. Compute it from first principals.
|
||||
FakeTensorMode.cache_bypasses[entry.reason] += 1
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return self._dispatch_impl(func, types, args, kwargs)
|
||||
|
||||
# We have a cache entry.
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
output = self._output_from_cache_entry(state, entry, key, func, args)
|
||||
FakeTensorMode.cache_hits += 1
|
||||
if self.cache_crosscheck_enabled:
|
||||
# For debugging / testing: Validate that the output synthesized
|
||||
# from the cache matches the output created by normal dispatch.
|
||||
with disable_fake_tensor_cache(self):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
self._crosscheck_cache_output(output, func, types, args, kwargs)
|
||||
return output
|
||||
|
||||
# We don't have a cache entry.
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
output = self._dispatch_impl(func, types, args, kwargs)
|
||||
|
||||
try:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
self._validate_cache_key(func, args, kwargs)
|
||||
except _BypassDispatchCache as e:
|
||||
# We ran "extra" checks on the cache key and determined that it's no
|
||||
@ -1545,6 +1556,7 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
return output
|
||||
|
||||
try:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
entry = self._make_cache_entry(state, key, func, args, kwargs, output)
|
||||
except _BypassDispatchCache as e:
|
||||
# We had trouble making the cache entry. Record the reason and mark
|
||||
@ -1587,13 +1599,16 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
if state.known_symbols:
|
||||
# If there are symbols then include the epoch - this is really more
|
||||
# of a Shape env var which lives on the FakeTensorMode.
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
key_values.append(self.epoch)
|
||||
# Collect the id_hashed objects to attach a weakref finalize later
|
||||
id_hashed_objects: list[object] = []
|
||||
# Translate any FakeTensor args to metadata.
|
||||
if args:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
self._prep_args_for_hash(key_values, args, state, id_hashed_objects)
|
||||
if kwargs:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
self._prep_args_for_hash(key_values, kwargs, state, id_hashed_objects)
|
||||
key = _DispatchCacheKey(tuple(key_values))
|
||||
|
||||
@ -1909,27 +1924,53 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
if isinstance(output, tuple):
|
||||
for out_element in output:
|
||||
self._validate_output_for_cache_entry(
|
||||
state, key, func, args, kwargs, out_element
|
||||
state,
|
||||
key,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
func,
|
||||
args,
|
||||
kwargs,
|
||||
out_element,
|
||||
)
|
||||
else:
|
||||
self._validate_output_for_cache_entry(
|
||||
state, key, func, args, kwargs, output
|
||||
state,
|
||||
key,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
func,
|
||||
args,
|
||||
kwargs,
|
||||
output,
|
||||
)
|
||||
|
||||
if isinstance(output, tuple):
|
||||
output_infos = [
|
||||
self._get_output_info_for_cache_entry(
|
||||
state, key, func, args, kwargs, out_elem
|
||||
state,
|
||||
key,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
func,
|
||||
args,
|
||||
kwargs,
|
||||
out_elem,
|
||||
)
|
||||
for out_elem in output
|
||||
]
|
||||
return _DispatchCacheValidEntry(
|
||||
output_infos=tuple(output_infos), is_output_tuple=True
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
output_infos=tuple(output_infos),
|
||||
is_output_tuple=True,
|
||||
)
|
||||
|
||||
else:
|
||||
output_info = self._get_output_info_for_cache_entry(
|
||||
state, key, func, args, kwargs, output
|
||||
state,
|
||||
key,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
func,
|
||||
args,
|
||||
kwargs,
|
||||
output,
|
||||
)
|
||||
return _DispatchCacheValidEntry(
|
||||
output_infos=(output_info,), is_output_tuple=False
|
||||
@ -2472,6 +2513,7 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
)
|
||||
|
||||
with self, maybe_ignore_fresh_unbacked_symbols():
|
||||
# pyrefly: ignore # index-error
|
||||
return registered_hop_fake_fns[func](*args, **kwargs)
|
||||
|
||||
self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs)
|
||||
@ -2625,6 +2667,7 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
# TODO: Is this really needed?
|
||||
compute_unbacked_bindings(self.shape_env, fake_out, peek=True)
|
||||
|
||||
# pyrefly: ignore # bad-return
|
||||
return fake_out
|
||||
|
||||
# Try for fastpath
|
||||
@ -2906,6 +2949,7 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
self, e, device or common_device
|
||||
)
|
||||
else:
|
||||
# pyrefly: ignore # bad-return
|
||||
return e
|
||||
|
||||
return tree_map(wrap, r)
|
||||
|
@ -81,6 +81,7 @@ def safe_is_leaf(t: Union[MetaTensorDesc, torch.Tensor]) -> bool:
|
||||
|
||||
def safe_grad(t: _TensorLikeT) -> Optional[_TensorLikeT]:
|
||||
with torch._logging.hide_warnings(torch._logging._internal.safe_grad_filter):
|
||||
# pyrefly: ignore # bad-return
|
||||
return t.grad
|
||||
|
||||
|
||||
@ -415,6 +416,7 @@ class MetaTensorDescriber:
|
||||
device=t.device,
|
||||
size=t.size(),
|
||||
stride=stride,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
storage_offset=storage_offset,
|
||||
dynamo_dynamic_indices=list(getattr(t, "_dynamo_dynamic_indices", set())),
|
||||
dynamo_hint_overrides=getattr(t, "_dynamo_hint_overrides", {}),
|
||||
@ -539,7 +541,11 @@ class _FakeTensorViewFunc(ViewFunc["FakeTensor"]):
|
||||
tensor_visitor_fn: Optional[Callable[[torch.Tensor], FakeTensor]] = None,
|
||||
) -> FakeTensor:
|
||||
return torch._subclasses.fake_tensor.FakeTensor._view_func_unsafe(
|
||||
t, new_base, symint_visitor_fn, tensor_visitor_fn
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
t,
|
||||
new_base,
|
||||
symint_visitor_fn,
|
||||
tensor_visitor_fn,
|
||||
)
|
||||
|
||||
|
||||
@ -1013,6 +1019,7 @@ class MetaConverter(Generic[_TensorT]):
|
||||
# Morally, the code here is same as transform_subclass, but we've
|
||||
# written it from scratch to read EmptyCreateSubclass
|
||||
outer_size = outer_size if outer_size is not None else t.size
|
||||
# pyrefly: ignore # bad-assignment
|
||||
outer_stride = outer_stride if outer_stride is not None else t.stride
|
||||
|
||||
assert symbolic_context is None or isinstance(
|
||||
@ -1269,6 +1276,7 @@ class MetaConverter(Generic[_TensorT]):
|
||||
) -> torch.Tensor:
|
||||
# It's possible to close over an undefined tensor (e.g. NJT's lengths).
|
||||
if visited_t is None:
|
||||
# pyrefly: ignore # bad-return
|
||||
return None
|
||||
|
||||
# NB: visited_t being a Tensor here is very naughty! Should
|
||||
@ -1399,6 +1407,7 @@ class MetaConverter(Generic[_TensorT]):
|
||||
if t.requires_grad:
|
||||
r.requires_grad = True
|
||||
if t.requires_grad and not is_leaf:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
r = self._backward_error(r)
|
||||
elif t.is_nested and not t.is_traceable_wrapper_subclass:
|
||||
# TODO: Handle this better in Dynamo?
|
||||
@ -1437,6 +1446,7 @@ class MetaConverter(Generic[_TensorT]):
|
||||
if t.requires_grad:
|
||||
r.requires_grad = True
|
||||
if t.requires_grad and not is_leaf:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
r = self._backward_error(r)
|
||||
elif t.is_functorch_wrapped:
|
||||
if t.is_view:
|
||||
@ -1533,6 +1543,7 @@ class MetaConverter(Generic[_TensorT]):
|
||||
)
|
||||
assert t.data is not None
|
||||
_safe_copy(r.real_tensor, t.data) # type: ignore[attr-defined]
|
||||
# pyrefly: ignore # bad-return
|
||||
return r
|
||||
|
||||
r = _to_fake_tensor(t)
|
||||
@ -1682,6 +1693,7 @@ class MetaConverter(Generic[_TensorT]):
|
||||
not (t.is_batchedtensor or t.is_gradtrackingtensor)
|
||||
and t.is_functorch_wrapped
|
||||
) or t.is_legacy_batchedtensor:
|
||||
# pyrefly: ignore # bad-return
|
||||
return NotImplemented
|
||||
|
||||
(
|
||||
@ -1728,6 +1740,7 @@ class MetaConverter(Generic[_TensorT]):
|
||||
# the metadata of the inner tensor.
|
||||
# So instead, we now have a dedicated fn to set autograd history,
|
||||
# without inadvertently changing other metadata.
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
r = self._backward_error(r)
|
||||
|
||||
s = t.storage
|
||||
@ -1839,6 +1852,7 @@ class MetaConverter(Generic[_TensorT]):
|
||||
nt_tensor_id=t.nested_int
|
||||
)
|
||||
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
self.set_tensor_memo(t, r)
|
||||
|
||||
return self._checked_get_tensor_memo(t)
|
||||
@ -1882,11 +1896,13 @@ class MetaConverter(Generic[_TensorT]):
|
||||
(t._is_view() and t._base is not None and t._base.is_sparse)
|
||||
):
|
||||
self.miss += 1
|
||||
# pyrefly: ignore # bad-return
|
||||
return NotImplemented
|
||||
else:
|
||||
self.hit += 1
|
||||
elif torch.overrides.is_tensor_like(t):
|
||||
self.miss += 1
|
||||
# pyrefly: ignore # bad-return
|
||||
return NotImplemented
|
||||
else:
|
||||
# non-Tensor types don't count as hit or miss
|
||||
|
@ -92,6 +92,7 @@ def _make_grads(
|
||||
is_grads_batched: bool,
|
||||
) -> tuple[_OptionalTensor, ...]:
|
||||
new_grads: list[_OptionalTensor] = []
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
for out, grad in zip(outputs, grads):
|
||||
out = cast(Union[torch.Tensor, graph.GradientEdge], out)
|
||||
out_size = None
|
||||
@ -341,6 +342,7 @@ def backward(
|
||||
Union[tuple[torch.Tensor], tuple[graph.GradientEdge]], (tensors,)
|
||||
)
|
||||
else:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
tensors = tuple(tensors)
|
||||
|
||||
grad_tensors_ = _tensor_or_tensors_to_tuple(grad_tensors, len(tensors))
|
||||
@ -440,10 +442,12 @@ def grad(
|
||||
Union[Sequence[torch.Tensor], Sequence[graph.GradientEdge]], (outputs,)
|
||||
)
|
||||
else:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
outputs = tuple(outputs)
|
||||
if is_tensor_like(inputs) or isinstance(inputs, graph.GradientEdge):
|
||||
inputs = cast(_TensorOrTensorsOrGradEdge, (inputs,))
|
||||
else:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
inputs = tuple(inputs)
|
||||
t_outputs = tuple(i for i in outputs if is_tensor_like(i))
|
||||
t_inputs = tuple(i for i in inputs if is_tensor_like(i))
|
||||
|
@ -15,12 +15,14 @@ class Type(Function):
|
||||
"please use `torch.tensor.to(dtype=dtype)` instead.",
|
||||
category=FutureWarning,
|
||||
)
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(ctx, i, dest_type):
|
||||
ctx.input_type = type(i)
|
||||
ctx.input_device = -1 if not i.is_cuda else i.get_device()
|
||||
return i.type(dest_type)
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def backward(ctx, grad_output):
|
||||
if ctx.input_device == -1:
|
||||
return grad_output.type(ctx.input_type), None
|
||||
@ -32,6 +34,7 @@ class Type(Function):
|
||||
# TODO: deprecate this
|
||||
class Resize(Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(ctx, tensor, sizes):
|
||||
ctx.sizes = sizes
|
||||
ctx.numel = reduce(operator.mul, sizes, 1)
|
||||
@ -60,6 +63,7 @@ class Resize(Function):
|
||||
return tensor.contiguous().view(*sizes)
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def backward(ctx, grad_output):
|
||||
assert grad_output.numel() == ctx.numel
|
||||
return grad_output.contiguous().view(ctx.input_sizes), None
|
||||
|
@ -9,6 +9,8 @@ from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
import torch.testing
|
||||
|
||||
# pyrefly: ignore # deprecated
|
||||
from torch._vmap_internals import _vmap, vmap
|
||||
from torch.overrides import is_tensor_like
|
||||
from torch.types import _TensorOrTensors
|
||||
|
@ -229,6 +229,7 @@ def get_gradient_edge(tensor: torch.Tensor) -> GradientEdge:
|
||||
|
||||
# Note that output_nr default to 0 which is the right value
|
||||
# for the AccumulateGrad node.
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return GradientEdge(grad_fn, tensor.output_nr, ownership_token=token)
|
||||
|
||||
|
||||
@ -531,6 +532,7 @@ def register_multi_grad_hook(
|
||||
"expected this hook to be called inside a backward call"
|
||||
)
|
||||
count[id] = count.get(id, 0)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
buffer[id] = buffer.get(id, [None] * len_tensors)
|
||||
|
||||
with lock:
|
||||
|
@ -731,6 +731,7 @@ class profile:
|
||||
return all_function_events
|
||||
|
||||
|
||||
# pyrefly: ignore # invalid-inheritance
|
||||
class record_function(_ContextDecorator):
|
||||
"""Context manager/function decorator that adds a label to a code block/function when running autograd profiler.
|
||||
Label will only appear if CPU activity tracing is enabled.
|
||||
@ -778,7 +779,9 @@ class record_function(_ContextDecorator):
|
||||
# TODO: TorchScript ignores standard type annotation here
|
||||
# self.record: Optional["torch.classes.profiler._RecordFunction"] = None
|
||||
self.record = torch.jit.annotate(
|
||||
Optional["torch.classes.profiler._RecordFunction"], None
|
||||
# pyrefly: ignore # not-a-type
|
||||
Optional["torch.classes.profiler._RecordFunction"],
|
||||
None,
|
||||
)
|
||||
|
||||
def __enter__(self):
|
||||
|
@ -101,12 +101,14 @@ class profile:
|
||||
|
||||
records = _disable_profiler_legacy()
|
||||
parsed_results = _parse_legacy_records(records)
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self.function_events = EventList(
|
||||
parsed_results,
|
||||
use_device="cuda" if self.use_cuda else None,
|
||||
profile_memory=self.profile_memory,
|
||||
with_flops=self.with_flops,
|
||||
)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self.function_events._build_tree()
|
||||
return False
|
||||
|
||||
|
@ -48,10 +48,14 @@ class EventList(list):
|
||||
def _remove_dup_nodes(self):
|
||||
while True:
|
||||
to_delete = set()
|
||||
# pyrefly: ignore # bad-assignment
|
||||
for idx in range(len(self)):
|
||||
if (
|
||||
# pyrefly: ignore # index-error
|
||||
self[idx].cpu_parent is not None
|
||||
# pyrefly: ignore # index-error
|
||||
and self[idx].cpu_parent.name == self[idx].name
|
||||
# pyrefly: ignore # index-error
|
||||
and len(self[idx].cpu_parent.cpu_children) == 1
|
||||
):
|
||||
self[idx].cpu_parent.cpu_children = self[idx].cpu_children
|
||||
@ -61,8 +65,11 @@ class EventList(list):
|
||||
to_delete.add(idx)
|
||||
if len(to_delete) == 0:
|
||||
break
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
new_evts = [ev for ind, ev in enumerate(self) if ind not in to_delete]
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self.clear()
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self.extend(new_evts)
|
||||
|
||||
def _populate_cpu_children(self):
|
||||
@ -496,7 +503,9 @@ class FunctionEvent(FormattedTimesMixin):
|
||||
self.id: int = id
|
||||
self.node_id: int = node_id
|
||||
self.name: str = name
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self.overload_name: str = overload_name
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self.trace_name: str = trace_name
|
||||
self.time_range: Interval = Interval(start_us, end_us)
|
||||
self.thread: int = thread
|
||||
@ -505,9 +514,13 @@ class FunctionEvent(FormattedTimesMixin):
|
||||
self.count: int = 1
|
||||
self.cpu_children: list[FunctionEvent] = []
|
||||
self.cpu_parent: Optional[FunctionEvent] = None
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self.input_shapes: tuple[int, ...] = input_shapes
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self.concrete_inputs: list[Any] = concrete_inputs
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self.kwinputs: dict[str, Any] = kwinputs
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self.stack: list = stack
|
||||
self.scope: int = scope
|
||||
self.use_device: Optional[str] = use_device
|
||||
@ -732,6 +745,7 @@ class FunctionEventAvg(FormattedTimesMixin):
|
||||
self.self_device_memory_usage += other.self_device_memory_usage
|
||||
self.count += other.count
|
||||
if self.flops is None:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self.flops = other.flops
|
||||
elif other.flops is not None:
|
||||
self.flops += other.flops
|
||||
@ -967,6 +981,7 @@ def _build_table(
|
||||
"PFLOPs",
|
||||
]
|
||||
assert flops > 0
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
log_flops = max(0, min(math.log10(flops) / 3, float(len(flop_headers) - 1)))
|
||||
assert log_flops >= 0 and log_flops < len(flop_headers)
|
||||
return (pow(10, (math.floor(log_flops) * -3.0)), flop_headers[int(log_flops)])
|
||||
|
@ -496,12 +496,14 @@ class cudaStatus:
|
||||
|
||||
class CudaError(RuntimeError):
|
||||
def __init__(self, code: int) -> None:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
msg = _cudart.cudaGetErrorString(_cudart.cudaError(code))
|
||||
super().__init__(f"{msg} ({code})")
|
||||
|
||||
|
||||
def check_error(res: int) -> None:
|
||||
r"""Raise an error if the result of a CUDA runtime API call is not success."""
|
||||
# pyrefly: ignore # missing-attribute
|
||||
if res != _cudart.cudaError.success:
|
||||
raise CudaError(res)
|
||||
|
||||
@ -601,6 +603,7 @@ def get_device_capability(device: "Device" = None) -> tuple[int, int]:
|
||||
return prop.major, prop.minor
|
||||
|
||||
|
||||
# pyrefly: ignore # not-a-type
|
||||
def get_device_properties(device: "Device" = None) -> _CudaDeviceProperties:
|
||||
r"""Get the properties of a device.
|
||||
|
||||
@ -651,6 +654,7 @@ class StreamContext:
|
||||
self.idx = _get_device_index(None, True)
|
||||
if not torch.jit.is_scripting():
|
||||
if self.idx is None:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self.idx = -1
|
||||
|
||||
self.src_prev_stream = (
|
||||
@ -953,7 +957,9 @@ def _device_count_amdsmi() -> int:
|
||||
if raw_cnt <= 0:
|
||||
return raw_cnt
|
||||
# Trim the list up to a maximum available device
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
for idx, val in enumerate(visible_devices):
|
||||
# pyrefly: ignore # redundant-cast
|
||||
if cast(int, val) >= raw_cnt:
|
||||
return idx
|
||||
except OSError:
|
||||
@ -987,7 +993,9 @@ def _device_count_nvml() -> int:
|
||||
if raw_cnt <= 0:
|
||||
return raw_cnt
|
||||
# Trim the list up to a maximum available device
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
for idx, val in enumerate(visible_devices):
|
||||
# pyrefly: ignore # redundant-cast
|
||||
if cast(int, val) >= raw_cnt:
|
||||
return idx
|
||||
except OSError:
|
||||
@ -1203,7 +1211,9 @@ def _get_pynvml_handler(device: "Device" = None):
|
||||
if not _HAS_PYNVML:
|
||||
raise ModuleNotFoundError(
|
||||
"pynvml does not seem to be installed or it can't be imported."
|
||||
# pyrefly: ignore # invalid-inheritance
|
||||
) from _PYNVML_ERR
|
||||
# pyrefly: ignore # import-error
|
||||
from pynvml import NVMLError_DriverNotLoaded
|
||||
|
||||
try:
|
||||
@ -1220,6 +1230,7 @@ def _get_amdsmi_handler(device: "Device" = None):
|
||||
if not _HAS_PYNVML:
|
||||
raise ModuleNotFoundError(
|
||||
"amdsmi does not seem to be installed or it can't be imported."
|
||||
# pyrefly: ignore # invalid-inheritance
|
||||
) from _PYNVML_ERR
|
||||
try:
|
||||
amdsmi.amdsmi_init()
|
||||
@ -1483,6 +1494,7 @@ def _get_rng_state_offset(device: Union[int, str, torch.device] = "cuda") -> int
|
||||
return default_generator.get_offset()
|
||||
|
||||
|
||||
# pyrefly: ignore # deprecated
|
||||
from .memory import * # noqa: F403
|
||||
from .random import * # noqa: F403
|
||||
|
||||
@ -1699,6 +1711,7 @@ def _register_triton_kernels():
|
||||
def kernel_impl(*args, **kwargs):
|
||||
from torch.sparse._triton_ops import bsr_dense_mm
|
||||
|
||||
# pyrefly: ignore # not-callable
|
||||
return bsr_dense_mm(*args, skip_checks=True, **kwargs)
|
||||
|
||||
@_WrappedTritonKernel
|
||||
|
@ -279,6 +279,7 @@ class _CudaModule:
|
||||
return self._kernels[name]
|
||||
|
||||
# Import the CUDA library inside the method
|
||||
# pyrefly: ignore # missing-module-attribute
|
||||
from torch.cuda._utils import _get_gpu_runtime_library
|
||||
|
||||
libcuda = _get_gpu_runtime_library()
|
||||
|
@ -1,3 +1,4 @@
|
||||
# pyrefly: ignore # deprecated
|
||||
from .autocast_mode import autocast, custom_bwd, custom_fwd
|
||||
from .common import amp_definitely_not_available
|
||||
from .grad_scaler import GradScaler
|
||||
|
@ -259,6 +259,7 @@ class graph:
|
||||
self.cuda_graph.capture_begin(
|
||||
# type: ignore[misc]
|
||||
*self.pool,
|
||||
# pyrefly: ignore # bad-keyword-argument
|
||||
capture_error_mode=self.capture_error_mode,
|
||||
)
|
||||
|
||||
@ -524,6 +525,7 @@ def make_graphed_callables(
|
||||
) -> Callable[..., object]:
|
||||
class Graphed(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(ctx: object, *inputs: Tensor) -> tuple[Tensor, ...]:
|
||||
# At this stage, only the user args may (potentially) be new tensors.
|
||||
for i in range(len_user_args):
|
||||
@ -535,6 +537,7 @@ def make_graphed_callables(
|
||||
|
||||
@staticmethod
|
||||
@torch.autograd.function.once_differentiable
|
||||
# pyrefly: ignore # bad-override
|
||||
def backward(ctx: object, *grads: Tensor) -> tuple[Tensor, ...]:
|
||||
assert len(grads) == len(static_grad_outputs)
|
||||
for g, grad in zip(static_grad_outputs, grads):
|
||||
@ -548,7 +551,9 @@ def make_graphed_callables(
|
||||
# Input args that didn't require grad expect a None gradient.
|
||||
assert isinstance(static_grad_inputs, tuple)
|
||||
return tuple(
|
||||
b.detach() if b is not None else b for b in static_grad_inputs
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
b.detach() if b is not None else b
|
||||
for b in static_grad_inputs
|
||||
)
|
||||
|
||||
def functionalized(*user_args: object) -> object:
|
||||
|
@ -770,6 +770,7 @@ def list_gpu_processes(device: "Device" = None) -> str:
|
||||
import pynvml # type: ignore[import]
|
||||
except ModuleNotFoundError:
|
||||
return "pynvml module not found, please install pynvml"
|
||||
# pyrefly: ignore # import-error
|
||||
from pynvml import NVMLError_DriverNotLoaded
|
||||
|
||||
try:
|
||||
@ -852,6 +853,7 @@ def _record_memory_history_legacy(
|
||||
_C._cuda_record_memory_history_legacy( # type: ignore[call-arg]
|
||||
enabled,
|
||||
record_context,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
trace_alloc_max_entries,
|
||||
trace_alloc_record_context,
|
||||
record_context_cpp,
|
||||
|
@ -53,6 +53,7 @@ def range_start(msg) -> int:
|
||||
Args:
|
||||
msg (str): ASCII message to associate with the range.
|
||||
"""
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return _nvtx.rangeStartA(msg)
|
||||
|
||||
|
||||
@ -63,6 +64,7 @@ def range_end(range_id) -> None:
|
||||
Args:
|
||||
range_id (int): an unique handle for the start range.
|
||||
"""
|
||||
# pyrefly: ignore # missing-attribute
|
||||
_nvtx.rangeEnd(range_id)
|
||||
|
||||
|
||||
@ -83,6 +85,7 @@ def _device_range_start(msg: str, stream: int = 0) -> object:
|
||||
msg (str): ASCII message to associate with the range.
|
||||
stream (int): CUDA stream id.
|
||||
"""
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return _nvtx.deviceRangeStart(msg, stream)
|
||||
|
||||
|
||||
@ -95,6 +98,7 @@ def _device_range_end(range_handle: object, stream: int = 0) -> None:
|
||||
range_handle: an unique handle for the start range.
|
||||
stream (int): CUDA stream id.
|
||||
"""
|
||||
# pyrefly: ignore # missing-attribute
|
||||
_nvtx.deviceRangeEnd(range_handle, stream)
|
||||
|
||||
|
||||
|
@ -436,6 +436,7 @@ def load(
|
||||
print(ep(torch.randn(5)))
|
||||
"""
|
||||
if isinstance(f, (str, os.PathLike)):
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
f = os.fspath(f)
|
||||
|
||||
extra_files = extra_files or {}
|
||||
|
@ -295,6 +295,7 @@ class CaptureStructuredTrace(torch._logging._internal.LazyTraceHandler):
|
||||
|
||||
self.logger.addHandler(self)
|
||||
self.prev_get_dtrace = torch._logging._internal.GET_DTRACE_STRUCTURED
|
||||
# pyrefly: ignore # bad-assignment
|
||||
torch._logging._internal.GET_DTRACE_STRUCTURED = True
|
||||
return self
|
||||
|
||||
@ -302,6 +303,7 @@ class CaptureStructuredTrace(torch._logging._internal.LazyTraceHandler):
|
||||
self.log_record = LogRecord()
|
||||
self.expression_created_logs = {}
|
||||
self.logger.removeHandler(self)
|
||||
# pyrefly: ignore # bad-assignment
|
||||
torch._logging._internal.GET_DTRACE_STRUCTURED = self.prev_get_dtrace
|
||||
self.prev_get_dtrace = False
|
||||
|
||||
|
@ -107,8 +107,11 @@ def _try_remove_connecting_pytrees(curr_module_node: torch.fx.Node) -> None:
|
||||
return
|
||||
|
||||
if not (
|
||||
# pyrefly: ignore # missing-attribute
|
||||
arg.op == "call_function"
|
||||
# pyrefly: ignore # missing-attribute
|
||||
and arg.target == operator.getitem
|
||||
# pyrefly: ignore # missing-attribute
|
||||
and arg.args[1] == i
|
||||
):
|
||||
log.debug(
|
||||
|
@ -185,6 +185,7 @@ def _ignore_backend_decomps():
|
||||
def _disable_custom_triton_op_functional_decomposition():
|
||||
old = torch._functorch.config.decompose_custom_triton_ops
|
||||
try:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
torch._functorch.config.decompose_custom_triton_ops = False
|
||||
yield torch._functorch.config.decompose_custom_triton_ops
|
||||
finally:
|
||||
@ -365,6 +366,7 @@ def _normalize_nn_module_stack(gm_torch_level, root_cls):
|
||||
return self
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
parts.append(str(idx))
|
||||
return self
|
||||
|
||||
@ -660,6 +662,7 @@ def _rename_constants_nodes(
|
||||
if spec.kind == InputKind.CONSTANT_TENSOR and not spec.arg.name.startswith(
|
||||
const_prefix
|
||||
):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
if spec.arg.name.startswith(buffer_prefix): # map from buffer to constants
|
||||
c_name = rename_constant(
|
||||
const_prefix + spec.arg.name[len(buffer_prefix) :]
|
||||
|
@ -293,6 +293,7 @@ class _ExportPackage:
|
||||
if isinstance(fn, torch.nn.Module):
|
||||
dynamic_shapes = v(fn, *args, **kwargs) # type: ignore[arg-type]
|
||||
else:
|
||||
# pyrefly: ignore # invalid-param-spec
|
||||
dynamic_shapes = v(*args, **kwargs)
|
||||
except AssertionError:
|
||||
continue
|
||||
@ -340,6 +341,7 @@ class _ExportPackage:
|
||||
assert not hasattr(fn, "_define_overload")
|
||||
_exporter_context._define_overload = _define_overload # type: ignore[attr-defined]
|
||||
|
||||
# pyrefly: ignore # bad-return
|
||||
return _exporter_context
|
||||
|
||||
@property
|
||||
@ -376,6 +378,7 @@ class _ExportPackage:
|
||||
kwargs=ep.example_inputs[1],
|
||||
options=options,
|
||||
)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
aoti_files_map[name] = aoti_files
|
||||
|
||||
from torch._inductor.package import package
|
||||
|
@ -1500,6 +1500,7 @@ class ExportedProgram:
|
||||
transformed_gm = res.graph_module if res is not None else self.graph_module
|
||||
assert transformed_gm is not None
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
if transformed_gm is self.graph_module and not res.modified:
|
||||
return self
|
||||
|
||||
@ -1578,6 +1579,7 @@ class ExportedProgram:
|
||||
verifiers=self.verifiers,
|
||||
)
|
||||
transformed_ep.graph_module.meta.update(self.graph_module.meta)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
transformed_ep.graph_module.meta.update(res.graph_module.meta)
|
||||
return transformed_ep
|
||||
|
||||
|
@ -81,6 +81,7 @@ def move_to_device_pass(
|
||||
and node.target == torch.ops.aten.to.device
|
||||
):
|
||||
args = list(node.args)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
args[1] = _get_new_device(args[1], location)
|
||||
node.args = tuple(args)
|
||||
|
||||
|
@ -172,8 +172,10 @@ class PT2ArchiveWriter:
|
||||
os.path.isfile, glob.glob(f"{folder_dir}/**", recursive=True)
|
||||
)
|
||||
for file_path in file_paths:
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
filename = os.path.relpath(file_path, folder_dir)
|
||||
archive_path = os.path.join(archive_dir, filename)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
self.write_file(archive_path, file_path)
|
||||
|
||||
def close(self) -> None:
|
||||
@ -593,6 +595,7 @@ def package_pt2(
|
||||
|
||||
if not (
|
||||
(isinstance(f, (io.IOBase, IO)) and f.writable() and f.seekable())
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
or (isinstance(f, (str, os.PathLike)) and os.fspath(f).endswith(".pt2"))
|
||||
or (isinstance(f, tempfile._TemporaryFileWrapper) and f.name.endswith(".pt2"))
|
||||
):
|
||||
@ -604,8 +607,10 @@ def package_pt2(
|
||||
)
|
||||
|
||||
if isinstance(f, (str, os.PathLike)):
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
f = os.fspath(f)
|
||||
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
with PT2ArchiveWriter(f) as archive_writer:
|
||||
_package_exported_programs(
|
||||
archive_writer, exported_programs, pickle_protocol=pickle_protocol
|
||||
@ -620,6 +625,7 @@ def package_pt2(
|
||||
|
||||
if isinstance(f, (io.IOBase, IO)):
|
||||
f.seek(0)
|
||||
# pyrefly: ignore # bad-return
|
||||
return f
|
||||
|
||||
|
||||
@ -992,6 +998,7 @@ def load_pt2(
|
||||
|
||||
if not (
|
||||
(isinstance(f, (io.IOBase, IO)) and f.readable() and f.seekable())
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
or (isinstance(f, (str, os.PathLike)) and os.fspath(f).endswith(".pt2"))
|
||||
):
|
||||
# TODO: turn this into an error in 2.9
|
||||
@ -1002,10 +1009,12 @@ def load_pt2(
|
||||
)
|
||||
|
||||
if isinstance(f, (str, os.PathLike)):
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
f = os.fspath(f)
|
||||
|
||||
weights = {}
|
||||
weight_maps = {}
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
with PT2ArchiveReader(f) as archive_reader:
|
||||
version = archive_reader.read_string(ARCHIVE_VERSION_PATH)
|
||||
if version != ARCHIVE_VERSION_VALUE:
|
||||
@ -1070,7 +1079,12 @@ def load_pt2(
|
||||
else:
|
||||
aoti_runners = {
|
||||
model_name: _load_aoti(
|
||||
f, model_name, run_single_threaded, num_runners, device_index
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
f,
|
||||
model_name,
|
||||
run_single_threaded,
|
||||
num_runners,
|
||||
device_index,
|
||||
)
|
||||
for model_name in aoti_model_names
|
||||
}
|
||||
|
@ -937,6 +937,7 @@ def _check_graph_equivalence(x: torch.nn.Module, y: torch.nn.Module):
|
||||
for key, value in pytree.tree_map(arg_dump, node.kwargs).items()
|
||||
]
|
||||
target = node.target if node.op in ("call_function", "get_attr") else ""
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
ret.append(f"{i}: {node.op}[{target}]({', '.join(args_dump)})")
|
||||
nodes_idx[id(node)] = i
|
||||
return "\n".join(ret)
|
||||
@ -1473,6 +1474,7 @@ class _ModuleFrame:
|
||||
self.seen_attrs[self.child_fqn].add(node.target)
|
||||
|
||||
self.copy_node(node)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
node_idx += 1
|
||||
|
||||
|
||||
|
@ -483,6 +483,7 @@ def _canonical_dim(dim: DimOrDims, ndim: int) -> tuple[int, ...]:
|
||||
raise IndexError(
|
||||
f"Dimension out of range (expected to be in range of [{-ndim}, {ndim - 1}], but got {d})"
|
||||
)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
dims.append(d % ndim)
|
||||
return tuple(sorted(dims))
|
||||
|
||||
@ -641,6 +642,7 @@ def _sparse_coo_scatter_reduction_helper(
|
||||
|
||||
# promote dtype if specified
|
||||
if values.dtype != output_dtype:
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
values = values.to(output_dtype)
|
||||
|
||||
if keepdim:
|
||||
@ -765,6 +767,7 @@ def _sparse_csr_segment_reduction_helper(
|
||||
|
||||
# promote dtype if specified
|
||||
if values.dtype != output_dtype:
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
values = values.to(output_dtype)
|
||||
|
||||
if len(dims) == 0:
|
||||
@ -1015,6 +1018,7 @@ def _combine_input_and_mask(
|
||||
|
||||
class Combine(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(ctx, input, mask):
|
||||
"""Return input with masked-out elements eliminated for the given operations."""
|
||||
ctx.save_for_backward(mask)
|
||||
@ -1025,6 +1029,7 @@ def _combine_input_and_mask(
|
||||
return helper(input, mask)
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def backward(ctx, grad_output):
|
||||
(mask,) = ctx.saved_tensors
|
||||
grad_data = (
|
||||
@ -1399,15 +1404,18 @@ elements, have ``nan`` values.
|
||||
if input.layout == torch.strided:
|
||||
if mask is None:
|
||||
# TODO: compute count analytically
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
count = sum(
|
||||
torch.ones(input.shape, dtype=torch.int64, device=input.device),
|
||||
dim,
|
||||
keepdim=keepdim,
|
||||
)
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
total = sum(input, dim, keepdim=keepdim, dtype=dtype)
|
||||
else:
|
||||
inmask = _input_mask(input, mask=mask)
|
||||
count = inmask.sum(dim=dim, keepdim=bool(keepdim))
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
total = sum(input, dim, keepdim=keepdim, dtype=dtype, mask=inmask)
|
||||
return total / count
|
||||
elif input.layout == torch.sparse_csr:
|
||||
@ -1618,15 +1626,18 @@ def _std_var(
|
||||
if input.layout == torch.strided:
|
||||
if mask is None:
|
||||
# TODO: compute count analytically
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
count = sum(
|
||||
torch.ones(input.shape, dtype=torch.int64, device=input.device),
|
||||
dim,
|
||||
keepdim=True,
|
||||
)
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
sample_total = sum(input, dim, keepdim=True, dtype=dtype)
|
||||
else:
|
||||
inmask = _input_mask(input, mask=mask)
|
||||
count = inmask.sum(dim=dim, keepdim=True)
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
sample_total = sum(input, dim, keepdim=True, dtype=dtype, mask=inmask)
|
||||
# TODO: replace torch.subtract/divide/square/maximum with
|
||||
# masked subtract/divide/square/maximum when these will be
|
||||
@ -1634,6 +1645,7 @@ def _std_var(
|
||||
sample_mean = torch.divide(sample_total, count)
|
||||
x = torch.subtract(input, sample_mean)
|
||||
if mask is None:
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
total = sum(x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype)
|
||||
else:
|
||||
total = sum(
|
||||
|
@ -47,6 +47,7 @@ def _check_args_kwargs_length(
|
||||
|
||||
class _MaskedContiguous(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(ctx, input):
|
||||
if not is_masked_tensor(input):
|
||||
raise ValueError("MaskedContiguous forward: input must be a MaskedTensor.")
|
||||
@ -60,12 +61,14 @@ class _MaskedContiguous(torch.autograd.Function):
|
||||
return MaskedTensor(data.contiguous(), mask.contiguous())
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output
|
||||
|
||||
|
||||
class _MaskedToDense(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(ctx, input):
|
||||
if not is_masked_tensor(input):
|
||||
raise ValueError("MaskedToDense forward: input must be a MaskedTensor.")
|
||||
@ -80,6 +83,7 @@ class _MaskedToDense(torch.autograd.Function):
|
||||
return MaskedTensor(data.to_dense(), mask.to_dense())
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def backward(ctx, grad_output):
|
||||
layout = ctx.layout
|
||||
|
||||
@ -94,6 +98,7 @@ class _MaskedToDense(torch.autograd.Function):
|
||||
|
||||
class _MaskedToSparse(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(ctx, input):
|
||||
if not is_masked_tensor(input):
|
||||
raise ValueError("MaskedToSparse forward: input must be a MaskedTensor.")
|
||||
@ -110,12 +115,14 @@ class _MaskedToSparse(torch.autograd.Function):
|
||||
return MaskedTensor(sparse_data, sparse_mask)
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output.to_dense()
|
||||
|
||||
|
||||
class _MaskedToSparseCsr(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(ctx, input):
|
||||
if not is_masked_tensor(input):
|
||||
raise ValueError("MaskedToSparseCsr forward: input must be a MaskedTensor.")
|
||||
@ -136,18 +143,21 @@ class _MaskedToSparseCsr(torch.autograd.Function):
|
||||
return MaskedTensor(sparse_data, sparse_mask)
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output.to_dense()
|
||||
|
||||
|
||||
class _MaskedWhere(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(ctx, cond, self, other):
|
||||
ctx.mark_non_differentiable(cond)
|
||||
ctx.save_for_backward(cond)
|
||||
return torch.ops.aten.where(cond, self, other)
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def backward(ctx, grad_output):
|
||||
(cond,) = ctx.saved_tensors
|
||||
|
||||
|
@ -174,6 +174,7 @@ class MaskedTensor(torch.Tensor):
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return torch.Tensor._make_wrapper_subclass(cls, data.size(), **kwargs)
|
||||
|
||||
def _preprocess_data(self, data, mask):
|
||||
@ -243,10 +244,12 @@ class MaskedTensor(torch.Tensor):
|
||||
|
||||
class Constructor(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(ctx, data, mask):
|
||||
return MaskedTensor(data, mask)
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output, None
|
||||
|
||||
@ -333,10 +336,12 @@ class MaskedTensor(torch.Tensor):
|
||||
def get_data(self):
|
||||
class GetData(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(ctx, self):
|
||||
return self._masked_data.detach()
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def backward(ctx, grad_output):
|
||||
if is_masked_tensor(grad_output):
|
||||
return grad_output
|
||||
|
Reference in New Issue
Block a user