Pyrefly suppressions 7/n (#164913)

Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

Almost there!

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: delete lines in the pyrefly.toml file from the project-excludes field
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199

after:
 INFO 0 errors (6,884 ignored)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164913
Approved by: https://github.com/oulgen
This commit is contained in:
Maggie Moss
2025-10-08 07:27:14 +00:00
committed by PyTorch MergeBot
parent 12d2ef557f
commit c855f8632e
89 changed files with 626 additions and 67 deletions

View File

@ -22,14 +22,16 @@ project-excludes = [
# ==== to test Pyrefly on a specific directory, simply comment it out ====
"torch/_inductor/**",
"torch/distributed/**",
"torch/nn/**",
"torch/_dynamo/**",
# formatting issues
"torch/linalg/__init__.py",
"torch/package/importer.py",
"torch/package/_package_pickler.py",
"torch/jit/annotations.py",
"torch/utils/data/datapipes/_typing.py",
"torch/nn/functional.py",
"torch/_export/utils.py",
"torch/fx/experimental/unification/multipledispatch/__init__.py",
"torch/nn/modules/__init__.py",
# ====
"benchmarks/instruction_counts/main.py",
"benchmarks/instruction_counts/definitions/setup.py",

View File

@ -59,7 +59,6 @@ class TestBundledInputs(TestCase):
# despite having nominally large bundled inputs.
augmented_size = model_size(sm)
# pyrefly: ignore # missing-attribute
self.assertLess(augmented_size, original_size + (1 << 12))
loaded = save_and_load(sm)
@ -67,15 +66,12 @@ class TestBundledInputs(TestCase):
self.assertEqual(loaded.get_num_bundled_inputs(), len(samples))
self.assertEqual(len(inflated), len(samples))
# pyrefly: ignore # missing-attribute
self.assertTrue(loaded(*inflated[0]) is inflated[0][0])
for idx, inp in enumerate(inflated):
# pyrefly: ignore # missing-attribute
self.assertIsInstance(inp, tuple)
self.assertEqual(len(inp), 1)
# pyrefly: ignore # missing-attribute
self.assertIsInstance(inp[0], torch.Tensor)
if idx != 5:
# Strides might be important for benchmarking.
@ -144,7 +140,6 @@ class TestBundledInputs(TestCase):
inflated = loaded.get_all_bundled_inputs()
self.assertEqual(inflated, samples)
# pyrefly: ignore # missing-attribute
self.assertTrue(loaded(*inflated[0]) == "first 1")
def test_multiple_methods_with_inputs(self):
@ -192,7 +187,6 @@ class TestBundledInputs(TestCase):
# Check running and size helpers
# pyrefly: ignore # missing-attribute
self.assertTrue(loaded(*inflated[0]) is inflated[0][0])
self.assertEqual(loaded.get_num_bundled_inputs(), len(samples))
@ -426,7 +420,6 @@ class TestBundledInputs(TestCase):
augmented_size = model_size(sm)
# assert the size has not increased more than 8KB
# pyrefly: ignore # missing-attribute
self.assertLess(augmented_size, original_size + (1 << 13))
loaded = save_and_load(sm)

View File

@ -48,7 +48,7 @@ class TestComplexTensor(TestCase):
def test_all(self, device, dtype):
# issue: https://github.com/pytorch/pytorch/issues/120875
x = torch.tensor([1 + 2j, 3 - 4j, 5j, 6], device=device, dtype=dtype)
# pyrefly: ignore # missing-attribute
self.assertTrue(torch.all(x))
@dtypes(*complex_types())
@ -57,7 +57,7 @@ class TestComplexTensor(TestCase):
x = torch.tensor(
[0, 0j, -0 + 0j, -0 - 0j, 0 + 0j, 0 - 0j], device=device, dtype=dtype
)
# pyrefly: ignore # missing-attribute
self.assertFalse(torch.any(x))
@onlyCPU

View File

@ -142,7 +142,6 @@ class TestTypeHints(TestCase):
]
)
if result != 0:
# pyrefly: ignore # missing-attribute
self.fail(f"mypy failed:\n{stderr}\n{stdout}")

View File

@ -125,7 +125,7 @@ class TestDTypeInfo(TestCase):
# Regression test for https://github.com/pytorch/pytorch/issues/124868
# If reference count is leaked this would be a set of 10 elements
ref_cnt = {sys.getrefcount(torch.float32.to_complex()) for _ in range(10)}
# pyrefly: ignore # missing-attribute
self.assertLess(len(ref_cnt), 3)
self.assertEqual(torch.float64.to_complex(), torch.complex128)
@ -136,7 +136,7 @@ class TestDTypeInfo(TestCase):
# Regression test for https://github.com/pytorch/pytorch/issues/124868
# If reference count is leaked this would be a set of 10 elements
ref_cnt = {sys.getrefcount(torch.cfloat.to_real()) for _ in range(10)}
# pyrefly: ignore # missing-attribute
self.assertLess(len(ref_cnt), 3)
self.assertEqual(torch.complex128.to_real(), torch.double)

View File

@ -53,6 +53,8 @@ from .eval_frame import (
OptimizedModule,
reset_code,
)
# pyrefly: ignore # deprecated
from .external_utils import is_compiling
from .mutation_guard import GenerationTracker
from .pgo import reset_code_state

View File

@ -95,6 +95,7 @@ class ModIndex(torch.autograd.Function):
generate_vmap_rule = True
@staticmethod
# pyrefly: ignore # bad-override
def forward(x: Tensor, indices: list[Tensor]) -> Tensor:
return torch.ops.aten.index(x, indices)
@ -242,6 +243,7 @@ def _trace_wrapped_functionalized(ctx: Any, *args: Any, **kwargs: Any) -> Any:
def autograd_function_backward_rewritten(original_backward: Any) -> Any:
def new_backward(ctx: Any, *grads: Any) -> Any:
# pyrefly: ignore # bad-assignment
grads = [g.contiguous() for g in grads]
return original_backward(ctx, *grads)

View File

@ -89,6 +89,7 @@ class AOTCompiledFunction:
**import_sources,
self._artifacts.backend_id: self._artifacts.compiled_fn,
}
# pyrefly: ignore # read-only
self.fn = types.FunctionType(
self._artifacts.bytecode, f_globals, closure=self._artifacts.closure
)

View File

@ -206,6 +206,7 @@ def cudagraphs(dynamo_model: torch.fx.GraphModule, dynamo_inputs: Sequence[Any])
assert manager is not None
def fn(inputs: list[Any]) -> Any:
# pyrefly: ignore # missing-attribute
manager.set_to_running_backward()
return aot_model(inputs)

View File

@ -77,16 +77,19 @@ def tvm(
opt_level = options.get("opt_level", 3)
if scheduler == "auto_scheduler":
# pyrefly: ignore # import-error
from tvm import auto_scheduler
log_file = tempfile.NamedTemporaryFile()
# pyrefly: ignore # bad-argument-type
if not os.path.exists(log_file):
tasks, task_weights = auto_scheduler.extract_tasks(
mod["main"], params, target
)
if len(tasks) != 0:
tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
# pyrefly: ignore # bad-argument-type
if not os.path.exists(log_file):
assert trials > 0
tune_option = auto_scheduler.TuningOptions(
@ -97,7 +100,9 @@ def tvm(
try:
tuner.tune(tune_option)
except Exception:
# pyrefly: ignore # bad-argument-type
if os.path.exists(log_file):
# pyrefly: ignore # bad-argument-type
os.unlink(log_file)
raise
@ -107,6 +112,7 @@ def tvm(
):
lib = relay.build(mod, target=target, params=params)
elif scheduler == "meta_schedule":
# pyrefly: ignore # import-error
from tvm import meta_schedule as ms
with tempfile.TemporaryDirectory() as work_dir:

View File

@ -37,6 +37,7 @@ if sys.version_info >= (3, 11):
TERMINAL_OPCODES.add(dis.opmap["JUMP_FORWARD"])
else:
TERMINAL_OPCODES.add(dis.opmap["JUMP_ABSOLUTE"])
# pyrefly: ignore # unsupported-operation
if (3, 12) <= sys.version_info < (3, 14):
TERMINAL_OPCODES.add(dis.opmap["RETURN_CONST"])
if sys.version_info >= (3, 13):

View File

@ -903,6 +903,7 @@ def devirtualize_jumps(instructions: list[Instruction]) -> None:
inst.arg = abs(
int(target.offset - inst.offset - instruction_size(inst))
)
# pyrefly: ignore # unsupported-operation
inst.arg //= 2
inst.argval = target.offset
inst.argrepr = f"to {target.offset}"
@ -1354,6 +1355,7 @@ def update_offsets(instructions: Sequence[Instruction]) -> None:
offset = 0
for inst in instructions:
inst.offset = offset
# pyrefly: ignore # unsupported-operation
offset += instruction_size(inst)

View File

@ -464,6 +464,7 @@ def cprofile_wrapper(func: Callable[_P, _T]) -> Callable[_P, _T]:
try:
prof.enable()
start_ts = time.time()
# pyrefly: ignore # bad-argument-type
retval = prof.runcall(func, *args, **kwargs)
profile_latency = time.time() - start_ts
prof.disable()
@ -957,6 +958,7 @@ def get_traced_fn(mod: Any) -> tuple[FunctionType, Optional[object]]:
if isinstance(mod, torch.nn.Module):
mod = mod.forward
if hasattr(mod, "__self__"):
# pyrefly: ignore # missing-attribute
return mod.__func__, mod.__self__
elif inspect.isfunction(mod):
return mod, None
@ -1096,6 +1098,7 @@ def _fullgraph_capture_frame(
while cur_exn.__cause__ is not None:
cur_exn.__cause__.with_traceback(None)
cur_exn = cur_exn.__cause__
# pyrefly: ignore # invalid-inheritance
raise e.with_traceback(None) from e.__cause__ # User compiler error
return CaptureOutput(
@ -1119,6 +1122,7 @@ def compile_frame( # type: ignore[return]
frame_state: Optional[dict[str, Union[int, FrameStateSizeEntry]]] = None,
distributed_state: Optional[DistributedState] = None,
package: Optional[CompilePackage] = None,
# pyrefly: ignore # bad-return
) -> DynamoOutput:
"""
A helper function taking a frame and backend, then return the generated bytecode

View File

@ -20,6 +20,7 @@ allowed to compute gradients on).
class TracableCreateParameter(torch.autograd.Function):
@staticmethod
# pyrefly: ignore # bad-override
def forward(ctx: Any, tensor: Any, placeholder: Any) -> torch.nn.Parameter:
assert not tensor.requires_grad
return placeholder.set_(tensor)

View File

@ -879,6 +879,7 @@ def aot_graph_input_parser(
data_type, shape_str = match.groups()
shape = tuple(shape_str.split(","))
dtype = dtype_map[data_type]
# pyrefly: ignore # bad-argument-type
kwargs[param] = gen_tensor(shape, dtype)
match = re.search(sym_shape_regex, annotation)
@ -892,6 +893,7 @@ def aot_graph_input_parser(
attr_name, data_type, shape_str, _ = match.groups()
shape = tuple(shape_str.split(","))
dtype = dtype_map[data_type]
# pyrefly: ignore # bad-argument-type
setattr(container, attr_name, gen_tensor(shape, dtype))
return kwargs

View File

@ -95,6 +95,7 @@ def disable(fn=None, recursive=True, *, reason=None, wrapping=True): # type: ig
nonrecursive_disable_wrapper._torchdynamo_disable = True # type: ignore[attr-defined]
nonrecursive_disable_wrapper._torchdynamo_disable_msg = reason # type: ignore[attr-defined]
nonrecursive_disable_wrapper._torchdynamo_orig_callable = fn # type: ignore[attr-defined]
# pyrefly: ignore # bad-return
return nonrecursive_disable_wrapper
if fn is None:
@ -306,6 +307,7 @@ def forbid_in_graph(fn: Any) -> Any:
if isinstance(fn, (list, tuple)):
return [forbid_in_graph(x) for x in fn]
assert callable(fn), "forbid_in_graph applies only to callables"
# pyrefly: ignore # missing-attribute
fn._dynamo_forbidden = True
return fn
@ -653,21 +655,28 @@ def mark_dynamic(
if isinstance(index, int):
if not hasattr(t, "_dynamo_dynamic_indices"):
# pyrefly: ignore # missing-attribute
t._dynamo_dynamic_indices = set()
# pyrefly: ignore # missing-attribute
t._dynamo_dynamic_range = set()
# pyrefly: ignore # missing-attribute
t._dynamo_hint_overrides = {}
if not hasattr(t, "_specialize_on"):
# pyrefly: ignore # missing-attribute
t._specialize_on = {}
if hint_override:
# pyrefly: ignore # missing-attribute
t._dynamo_hint_overrides[index] = hint_override
# TODO(voz): Should we bounds check?
# pyrefly: ignore # missing-attribute
t._dynamo_dynamic_indices.add(index)
t._dynamo_dynamic_range.add(_DimRange(index, min, max)) # type: ignore[arg-type]
# FX tracers don't respect @forbid_in_graph and choke on the following error since it passes in proxies:
# TypeError: 'Attribute' object does not support item assignment
# pyrefly: ignore # missing-attribute
if isinstance(t._specialize_on, dict):
t._specialize_on[index] = specialize_on if specialize_on is not None else []
@ -692,8 +701,10 @@ def maybe_mark_dynamic(t: Any, index: Union[int, list[Any], tuple[Any]]) -> None
if isinstance(index, int):
if not hasattr(t, "_dynamo_weak_dynamic_indices"):
# pyrefly: ignore # missing-attribute
t._dynamo_weak_dynamic_indices = set()
# TODO(voz): Should we bounds check?
# pyrefly: ignore # missing-attribute
t._dynamo_weak_dynamic_indices.add(index)
return
@ -745,8 +756,11 @@ def mark_static(
# TODO: Make this configurable via a supported public API
_apply_func_to_inner_tensors_of_same_dim(mark_static, t, index)
# pyrefly: ignore # bad-argument-type
if not isinstance(t, torch.Tensor) and issubclass(t, torch.nn.Module):
# pyrefly: ignore # missing-attribute
t._dynamo_marked_static = True
# pyrefly: ignore # bad-return
return t
if not isinstance(t, torch.Tensor):

View File

@ -205,6 +205,7 @@ class CudaInterface(DeviceInterface):
Event = torch.cuda.Event # type: ignore[assignment]
Stream = torch.cuda.Stream # type: ignore[assignment]
# pyrefly: ignore # bad-override
class Worker:
@staticmethod
def set_device(device: int) -> None:
@ -240,6 +241,7 @@ class CudaInterface(DeviceInterface):
set_device = staticmethod(torch.cuda.set_device)
device_count = staticmethod(torch.cuda.device_count)
stream = staticmethod(torch.cuda.stream) # type: ignore[assignment]
# pyrefly: ignore # bad-override
current_stream = staticmethod(torch.cuda.current_stream)
set_stream = staticmethod(torch.cuda.set_stream) # type: ignore[assignment]
_set_stream_by_id = staticmethod(torch.cuda._set_stream_by_id) # type: ignore[assignment]
@ -300,6 +302,7 @@ class MtiaInterface(DeviceInterface):
Event = torch.mtia.Event # type: ignore[assignment]
Stream = torch.mtia.Stream # type: ignore[assignment]
# pyrefly: ignore # bad-override
class Worker:
@staticmethod
def set_device(device: int) -> None:
@ -335,6 +338,7 @@ class MtiaInterface(DeviceInterface):
set_device = staticmethod(torch.mtia.set_device) # type: ignore[assignment]
device_count = staticmethod(torch.mtia.device_count)
stream = staticmethod(torch.mtia.stream) # type: ignore[assignment]
# pyrefly: ignore # bad-override
current_stream = staticmethod(torch.mtia.current_stream)
set_stream = staticmethod(torch.mtia.set_stream) # type: ignore[assignment]
_set_stream_by_id = staticmethod(torch.mtia._set_stream_by_id) # type: ignore[assignment]
@ -381,6 +385,7 @@ class XpuInterface(DeviceInterface):
Event = torch.xpu.Event # type: ignore[assignment]
Stream = torch.xpu.Stream # type: ignore[assignment]
# pyrefly: ignore # bad-override
class Worker:
@staticmethod
def set_device(device: int) -> None:
@ -416,6 +421,7 @@ class XpuInterface(DeviceInterface):
set_device = staticmethod(torch.xpu.set_device)
device_count = staticmethod(torch.xpu.device_count) # type: ignore[has-type]
stream = staticmethod(torch.xpu.stream) # type: ignore[assignment]
# pyrefly: ignore # bad-override
current_stream = staticmethod(torch.xpu.current_stream)
set_stream = staticmethod(torch.xpu.set_stream) # type: ignore[assignment]
_set_stream_by_id = staticmethod(torch.xpu._set_stream_by_id) # type: ignore[assignment]
@ -458,6 +464,7 @@ class CpuDeviceProperties:
class CpuInterface(DeviceInterface):
# pyrefly: ignore # bad-override
class Event(torch.Event):
def __init__(self, enable_timing: bool = True) -> None:
self.time = 0.0
@ -468,6 +475,7 @@ class CpuInterface(DeviceInterface):
def record(self, stream: Any = None) -> None:
self.time = time.perf_counter()
# pyrefly: ignore # bad-override
class Worker:
@staticmethod
def get_device_properties(
@ -543,6 +551,7 @@ class MpsInterface(DeviceInterface):
def synchronize(device: torch.types.Device = None) -> None:
torch.mps.synchronize()
# pyrefly: ignore # bad-override
class Worker:
@staticmethod
def get_device_properties(device: torch.types.Device = None) -> Any:

View File

@ -484,6 +484,7 @@ class OptimizedModule(torch.nn.Module):
self._initialize()
@property
# pyrefly: ignore # bad-override
def training(self) -> bool:
return self._orig_mod.training
@ -892,6 +893,7 @@ class _TorchDynamoContext:
while cur_exn.__cause__ is not None:
cur_exn.__cause__.with_traceback(None)
cur_exn = cur_exn.__cause__
# pyrefly: ignore # invalid-inheritance
raise e.with_traceback(None) from e.__cause__ # User compiler error
except ShortenTraceback as e:
# Failures in the backend likely don't have useful
@ -1020,7 +1022,10 @@ class OptimizeContext(_TorchDynamoContext):
assert rebuild_ctx is not None
compiler_fn = rebuild_ctx()
ctx = torch._dynamo.compiled_autograd._enable(
compiler_fn, dynamic=_dynamic, ignore_active_disable_ctx=False
compiler_fn,
# pyrefly: ignore # bad-argument-type
dynamic=_dynamic,
ignore_active_disable_ctx=False,
)
ctx.__enter__()
return functools.partial(ctx.__exit__, None, None, None)
@ -1083,6 +1088,7 @@ class DisableContext(_TorchDynamoContext):
cls_obj.__call__ = self(cls_obj.__call__)
if issubclass(cls_obj, torch.nn.Module):
# NN module variable tracker directly inlines the _call_impl. Disable it.
# pyrefly: ignore # missing-attribute
cls_obj._call_impl = self(cls_obj._call_impl)
return cls_obj
@ -1988,6 +1994,7 @@ def export(
path: KeyPath, t: Union[torch.Tensor, _IntWrapper, Any]
) -> Any:
if isinstance(t, torch.Tensor):
# pyrefly: ignore # missing-attribute
return ambient_fake_mode.from_tensor(t, static_shapes=True)
elif isinstance(t, _IntWrapper):
if (
@ -2068,8 +2075,11 @@ def export(
)
and not trace_rules.check(call_to_inspect)
):
# pyrefly: ignore # unbound-name
dim_constraints.solve()
# pyrefly: ignore # unbound-name
forced_specializations = dim_constraints.forced_specializations()
# pyrefly: ignore # unbound-name
msg = dim_constraints.prettify_results(
original_signature,
dynamic_shapes,
@ -2090,9 +2100,11 @@ def export(
)
# Error if we have any constraints on static values
# pyrefly: ignore # unbound-name
for k in shape_env.var_to_range.keys():
if isinstance(k, sympy.Integer):
constraint_violation_error = ConstraintViolationError(
# pyrefly: ignore # unbound-name
f"{''.join(traceback.format_list(shape_env.var_to_stack[k]))}\n"
"It appears that you're trying to set a constraint on a "
f"value which we evaluated to have a static value of {k}. "

View File

@ -369,6 +369,7 @@ def get_dynamo_observed_exception(exc_type: type[Exception]) -> type[ObservedExc
observed_exception_map[exc_type] = type( # type: ignore[assignment]
f"Observed{name}Error", (ObservedException,), {}
)
# pyrefly: ignore # index-error
return observed_exception_map[exc_type]

View File

@ -96,7 +96,9 @@ def wrap_numpy(f: Callable[_P, _R]) -> Callable[_P, _R]:
args, kwargs = pytree.tree_map_only(
torch.Tensor, lambda x: x.numpy(), (args, kwargs)
)
# pyrefly: ignore # invalid-param-spec
out = f(*args, **kwargs)
# pyrefly: ignore # missing-attribute
return pytree.tree_map_only(np.ndarray, lambda x: torch.as_tensor(x), out)
return wrap

View File

@ -250,6 +250,7 @@ class DynamoGraphTransformer(torch.fx.Transformer):
else:
placeholder.node.meta["val"] = self.flat_inputs[i]
# pyrefly: ignore # unsupported-operation
self.new_input_nodes[i] = placeholder
def _create_placeholder_mapping(self) -> None:
@ -324,12 +325,18 @@ class DynamoGraphTransformer(torch.fx.Transformer):
# Copy module metadata like the original implementation
if hasattr(self.module, "meta"):
# pyrefly: ignore # unsupported-operation
if "dynamo_flat_name_to_original_fqn" in self.module.meta:
# pyrefly: ignore # index-error
result_gm.meta["dynamo_flat_name_to_original_fqn"] = self.module.meta[
# pyrefly: ignore # index-error
"dynamo_flat_name_to_original_fqn"
]
# pyrefly: ignore # unsupported-operation
if "dynamo_compile_id" in self.module.meta:
# pyrefly: ignore # index-error
result_gm.meta["dynamo_compile_id"] = self.module.meta[
# pyrefly: ignore # index-error
"dynamo_compile_id"
]
@ -361,8 +368,11 @@ def _suggest_or_raise_constraint_violation(
torch._ops.OpOverloadPacket | torch._ops.OpOverload,
)
):
# pyrefly: ignore # unbound-name
dim_constraints.solve()
# pyrefly: ignore # unbound-name
forced_specializations = dim_constraints.forced_specializations()
# pyrefly: ignore # unbound-name
msg = dim_constraints.prettify_results(
inspect.signature(orig_callable), # type: ignore[attr-defined]
dynamic_shapes,
@ -383,9 +393,11 @@ def _suggest_or_raise_constraint_violation(
)
# Error if we have any constraints on static values
# pyrefly: ignore # unbound-name
for k in shape_env.var_to_range.keys():
if isinstance(k, sympy.Integer):
constraint_violation_error = ConstraintViolationError(
# pyrefly: ignore # unbound-name
f"{''.join(traceback.format_list(shape_env.var_to_stack[k]))}\n"
"It appears that you're trying to set a constraint on a "
f"value which we evaluated to have a static value of {k}. "

View File

@ -320,6 +320,7 @@ class GraphRegionTracker:
if len(group) > 1:
region_group = []
min_rank = math.inf
# pyrefly: ignore # bad-assignment
for node in group:
# some nodes aren't in the topo ranking?
if node in topological_ranking:

View File

@ -640,6 +640,7 @@ class GuardManagerWrapper:
if isinstance(guard, RelationalGuard):
if guard not in self.printed_relational_guards:
self.printed_relational_guards.add(guard)
# pyrefly: ignore # bad-argument-type
body.writelines(self.get_guard_lines(guard))
else:
body.writelines(
@ -700,6 +701,7 @@ class GuardManagerWrapper:
for guard in mgr.get_leaf_guards():
if isinstance(guard, RelationalGuard):
if guard not in relational_guards_seen:
# pyrefly: ignore # bad-argument-type
self.code_parts.extend(get_code_parts(guard))
relational_guards_seen.add(guard)
else:
@ -716,6 +718,7 @@ def from_numpy(a: Any) -> torch.Tensor:
# Re-enable torch function since we disable it on leaf guards
# we need it to properly construct the tensor if a default device is set
with torch.overrides._enable_torch_function():
# pyrefly: ignore # missing-attribute
return torch.as_tensor(a) if isinstance(a, (np.generic, np.ndarray)) else a
@ -729,6 +732,7 @@ def uninteresting_files() -> set[str]:
from torch._dynamo.polyfills.loader import POLYFILLED_MODULES
# pyrefly: ignore # bad-argument-type
mods.extend(POLYFILLED_MODULES)
return {inspect.getfile(m) for m in mods}
@ -2205,6 +2209,7 @@ class GuardBuilder(GuardBuilderBase):
return
# Python math library doesn't support complex nan, so we need to use numpy
# pyrefly: ignore # missing-attribute
if istype(val, complex) and np.isnan(val):
code = [f"(type({ref}) is complex and __numpy_isnan({ref}))"]
self._set_guard_export_info(guard, code)
@ -2495,6 +2500,7 @@ class GuardBuilder(GuardBuilderBase):
# sources for the corresponding tensor dimension.
return [
TensorPropertySource(source, TensorProperty.SIZE, dim)
# pyrefly: ignore # missing-attribute
for source in output_graph.tracked_fakes_id_to_source[t_id]
]
@ -2531,6 +2537,7 @@ class GuardBuilder(GuardBuilderBase):
equalities_inputs = None
def _get_code_parts(langs: tuple[str, ...]) -> list[_ShapeGuardsHelper]:
# pyrefly: ignore # missing-attribute
return output_graph.shape_env.produce_guards_verbose(
[a.fake for a in fs], # type: ignore[misc]
[a.source for a in fs],
@ -2538,6 +2545,7 @@ class GuardBuilder(GuardBuilderBase):
equalities_inputs=equalities_inputs,
source_ref=self.source_ref,
# Export keeps static.
# pyrefly: ignore # missing-attribute
ignore_static=(not output_graph.export),
langs=langs,
)
@ -2599,7 +2607,9 @@ class GuardBuilder(GuardBuilderBase):
if not python_fallback:
assert cpp_code_parts # type: ignore[possibly-undefined]
code_parts, source_to_symbol = (
# pyrefly: ignore # unbound-name
cpp_code_parts.exprs,
# pyrefly: ignore # unbound-name, missing-attribute
cpp_code_parts.source_to_symbol,
)
@ -2630,7 +2640,9 @@ class GuardBuilder(GuardBuilderBase):
assert cpp_code_parts # type: ignore[possibly-undefined]
code_parts, source_to_symbol = (
# pyrefly: ignore # unbound-name
cpp_code_parts.exprs,
# pyrefly: ignore # unbound-name, missing-attribute
cpp_code_parts.source_to_symbol,
)
@ -3240,6 +3252,7 @@ class GuardsStatePickler(pickle.Pickler):
assert _.__closure__ is not None
return _.__closure__[0]
# pyrefly: ignore # bad-override
def reducer_override(
self, obj: Any
) -> Union[tuple[Callable[..., Any], tuple[Any, ...]], Any]:
@ -4072,10 +4085,13 @@ class CheckFunctionManager:
and (cache_entry := self.guard_manager.cache_entry) is not None
and (extra_state := self.guard_manager.extra_state) is not None
):
# pyrefly: ignore # unbound-name
assert isinstance(cache_entry, CacheEntry)
# pyrefly: ignore # unbound-name
assert isinstance(extra_state, ExtraState)
reason = f"Cache line invalidated because {obj_str} got deallocated"
deleted_guard_manager = DeletedGuardManagerWrapper(reason)
# pyrefly: ignore # unbound-name
extra_state.invalidate(cache_entry, deleted_guard_manager)
self.guard_manager = deleted_guard_manager

View File

@ -707,6 +707,7 @@ class OutputGraph(OutputGraphCommon):
self.backward_state_proxy: Optional[torch.fx.Proxy] = None
self.backward_state_var: Optional[str] = None
# pyrefly: ignore # bad-override
self.name_of_builtins_dict_key_in_fglobals: str = (
self.install_builtins_dict_in_fglobals()
)
@ -1146,6 +1147,7 @@ class OutputGraph(OutputGraphCommon):
vt = self.root_tx.output.side_effects.track_object_existing(target, vt)
assert "tensor_dict" not in vt.as_proxy().node.meta
# pyrefly: ignore # bad-argument-type
vt.as_proxy().node.meta["tensor_dict"] = _extract_tensor_dict(target)
return vt
@ -1157,6 +1159,7 @@ class OutputGraph(OutputGraphCommon):
install_guard(source.make_guard(GuardBuilder.NN_MODULE))
def wrap_name(module_key: str) -> VariableTracker:
# pyrefly: ignore # bad-argument-type
return NNModuleVariable(type(target), module_key, target, **options)
else:
@ -1970,7 +1973,9 @@ class OutputGraph(OutputGraphCommon):
tx = self.root_tx
assert tx is not None
if (ds := tx.distributed_state) is not None and ds.all_states is None:
# pyrefly: ignore # unbound-name
compile_pg = ds.compile_pg
# pyrefly: ignore # unbound-name
log.info("compiler_collective %s", ds.local_state)
torch._logging.trace_structured(
"artifact",
@ -1978,6 +1983,7 @@ class OutputGraph(OutputGraphCommon):
"name": "compiler_collective",
"encoding": "string",
},
# pyrefly: ignore # unbound-name
payload_fn=lambda: ds.local_state.render(),
)
device_types = compile_pg._device_types
@ -1991,7 +1997,9 @@ class OutputGraph(OutputGraphCommon):
dynamo_timed("compiler_collective", log_pt2_compile_event=True),
):
all_states: list[Any] = [None] * compile_pg.size()
# pyrefly: ignore # unbound-name
dist.all_gather_object(all_states, ds.local_state, group=compile_pg)
# pyrefly: ignore # unbound-name
ds.all_states = all_states
# Clear speculation log, because are tracing may diverge due to
# this information from the compiler collective
@ -2321,6 +2329,7 @@ class OutputGraph(OutputGraphCommon):
},
)
# pyrefly: ignore # unbound-name
return compiled_fn
def dedup_pass(self) -> dict[str, torch.fx.GraphModule]:
@ -2375,6 +2384,7 @@ class OutputGraph(OutputGraphCommon):
isinstance(b, torch.SymBool)
and (r := b.node.maybe_as_bool()) is not None
):
# pyrefly: ignore # unbound-name
return r
# TODO: We can also technically remove all cases when the input
# doesn't have unbacked inputs, since it's all in the ShapeEnv
@ -2740,6 +2750,7 @@ def check_pt2_compliant_op(
hints=[],
)
# pyrefly: ignore # unbound-name
op = getattr(target, overload)
if torch.Tag.pt2_compliant_tag in op.tags:
encountered_compliant_op(op)
@ -2747,6 +2758,7 @@ def check_pt2_compliant_op(
encountered_non_compliant_op(
op,
f"Encountered the torch.ops.OpOverloadPacket {target} "
# pyrefly: ignore # unbound-name
f"which resolves to the overload ({overload}) that is "
f"not PT2 compliant.",
)
@ -2767,6 +2779,7 @@ class LazyProxy:
**kwargs: P.kwargs,
) -> None:
self.tracer = tracer
# pyrefly: ignore # invalid-type-var
self.fn = fn
self.args = args
self.kwargs = kwargs

View File

@ -319,6 +319,7 @@ def _get_code_source(code: types.CodeType) -> tuple[str, str]:
code_source = _find_code_source(toplevel)
if code_source is None:
_raise_resolution_error(code, toplevel)
# pyrefly: ignore # missing-attribute
return toplevel.__qualname__, code_source.strip(".")
@ -593,9 +594,11 @@ class CompilePackage:
f"Source code changes detected for {code.module} (line {code.firstlineno} - line {code.lastlineno})"
)
# pyrefly: ignore # bad-assignment
self._source_info = dynamo.source_info
main, *codes = dynamo.codes
# pyrefly: ignore # bad-assignment
self._codes = {self._innermost_fn.__code__: main}
for code in codes:
self._codes[SerializedCode.to_code_object(code.python_code)] = code
@ -603,6 +606,7 @@ class CompilePackage:
self._add_function(
self._innermost_fn.__code__, self._innermost_fn.__module__
)
# pyrefly: ignore # bad-assignment
self._initialized = True
def _add_function(
@ -746,6 +750,7 @@ class CompilePackage:
for name in names:
module.__dict__.pop(name)
# pyrefly: ignore # bad-assignment
self._installed_globals = {}
_reset_precompile_entries(self._innermost_fn.__code__)

View File

@ -167,6 +167,7 @@ class CodeId:
@dataclasses.dataclass
class CodeState:
automatic_dynamic: defaultdict[str, FrameStateSizeEntry] = dataclasses.field(
# pyrefly: ignore # unbound-name
default_factory=lambda: defaultdict(FrameStateSizeEntry)
)
@ -851,6 +852,7 @@ def get_code_state() -> defaultdict[CodeId, CodeState]:
not _CODE_STATE
and (sticky_read := torch.compiler.config.pgo_extra_read_key) is not None
):
# pyrefly: ignore # unbound-name
extra_read_key = get_extra_cache_key(sticky_read)
if extra_read_key is not None:
get_extra_remote_code_state(extra_read_key)

View File

@ -196,6 +196,7 @@ def tee(iterable: Iterable[_T], n: int = 2, /) -> tuple[Iterator[_T], ...]:
@overload
# pyrefly: ignore # inconsistent-overload
def zip_longest(
iter1: Iterable[_T1],
/,
@ -205,6 +206,7 @@ def zip_longest(
@overload
# pyrefly: ignore # inconsistent-overload
def zip_longest(
iter1: Iterable[_T1],
iter2: Iterable[_T2],
@ -213,6 +215,7 @@ def zip_longest(
@overload
# pyrefly: ignore # inconsistent-overload
def zip_longest(
iter1: Iterable[_T1],
iter2: Iterable[_T2],
@ -223,6 +226,7 @@ def zip_longest(
@overload
# pyrefly: ignore # inconsistent-overload
def zip_longest(
iter1: Iterable[_T],
iter2: Iterable[_T],
@ -233,6 +237,7 @@ def zip_longest(
@overload
# pyrefly: ignore # inconsistent-overload
def zip_longest(
iter1: Iterable[_T],
iter2: Iterable[_T],

View File

@ -30,10 +30,12 @@ _Us = TypeVarTuple("_Us")
@overload
# pyrefly: ignore # inconsistent-overload
def attrgetter(attr: str, /) -> Callable[[Any], _U]: ...
@overload
# pyrefly: ignore # inconsistent-overload
def attrgetter(
attr1: str, attr2: str, /, *attrs: str
) -> Callable[[Any], tuple[_U1, _U2, Unpack[_Us]]]: ...
@ -68,10 +70,12 @@ def attrgetter(*attrs: str) -> Callable[[Any], Any | tuple[Any, ...]]:
@overload
# pyrefly: ignore # inconsistent-overload
def itemgetter(item: _T, /) -> Callable[[Any], _U]: ...
@overload
# pyrefly: ignore # inconsistent-overload
def itemgetter(
item1: _T1, item2: _T2, /, *items: Unpack[_Ts]
) -> Callable[[Any], tuple[_U1, _U2, Unpack[_Us]]]: ...

View File

@ -17,6 +17,7 @@ __all__ = ["fspath"]
@substitute_in_graph(os.fspath, can_constant_fold_through=True)
def fspath(path: AnyStr | os.PathLike[AnyStr]) -> AnyStr:
if isinstance(path, (str, bytes)):
# pyrefly: ignore # bad-return
return path
path_type = type(path)

View File

@ -171,6 +171,7 @@ if python_pytree._cxx_pytree_dynamo_traceable:
or optree.is_namedtuple_class(treespec.type)
or optree.is_structseq_class(treespec.type)
):
# pyrefly: ignore # bad-return
return treespec._unflatten_func(
treespec._metadata,
children_representations,

View File

@ -49,8 +49,11 @@ class ProfileMetrics:
if isinstance(other, int):
other = ProfileMetrics(other, other, other)
return ProfileMetrics(
# pyrefly: ignore # no-matching-overload
self.microseconds / max(1, other.microseconds),
# pyrefly: ignore # bad-argument-type
self.operators / max(1, other.operators),
# pyrefly: ignore # bad-argument-type
self.fusions / max(1, other.fusions),
)

View File

@ -370,13 +370,16 @@ isolate_fails_code_str = None
try:
if isinstance(kernel, Autotuner):
# pyrefly: ignore # missing-attribute
if isinstance(kernel.fn, Heuristics):
model_str += "ERROR: Repro will not work as intended, "
model_str += "triton.runtime.autotuner.Heuristics is not currently supported\n"
break
config_strs = []
# pyrefly: ignore # missing-attribute
for kernel_config in kernel.configs:
# pyrefly: ignore # bad-argument-type
config_strs.append(f"""triton.Config(
{str(kernel_config.kwargs)},
num_warps={kernel_config.num_warps},
@ -394,8 +397,10 @@ isolate_fails_code_str = None
""").strip()
model_str += "\n@triton.jit\n"
# pyrefly: ignore # missing-attribute
src_code = kernel.src if isinstance(kernel, JITFunction) else kernel.fn.src
fn_name = (
# pyrefly: ignore # missing-attribute
kernel._fn_name
if isinstance(kernel, JITFunction)
else kernel.fn._fn_name
@ -409,7 +414,9 @@ isolate_fails_code_str = None
model_str += "ERROR: Repro will not work as intended, "
model_str += f"User defined triton kernel exception: {e}\n"
# pyrefly: ignore # unbound-name
if len(kernel_side_table.constant_args) > 0:
# pyrefly: ignore # unbound-name
model_str += f"{kernel_side_table_prefix}.constant_args={kernel_side_table.constant_args}\n"
model_str += NNModuleToString.convert(gm)
@ -420,8 +427,10 @@ isolate_fails_code_str = None
# Extract from graph placeholders and their corresponding arguments
placeholder_targets = fx_placeholder_targets(gm)
for placeholder, arg in zip(placeholder_targets, args):
# pyrefly: ignore # unbound-name
if isinstance(arg, (int, torch.SymInt)):
writer.symint(placeholder, arg)
# pyrefly: ignore # unbound-name
elif isinstance(arg, torch.Tensor):
# TODO: improve these names with FQN
writer.tensor(placeholder, arg)
@ -431,16 +440,20 @@ isolate_fails_code_str = None
writer.unsupported(placeholder, arg)
# Extract symbolic variables from the same arguments
# pyrefly: ignore # unbound-name
if isinstance(arg, torch.SymInt):
sym_name = str(arg.node)
if arg.node.hint is not None:
used_syms[sym_name] = arg.node.hint
# pyrefly: ignore # unbound-name
elif isinstance(arg, torch.Tensor):
# Extract symbolic variables from tensor shapes and strides
for dim in arg.shape:
# pyrefly: ignore # unbound-name
if isinstance(dim, torch.SymInt) and dim.node.hint is not None:
used_syms[str(dim.node)] = dim.node.hint
for stride in arg.stride():
# pyrefly: ignore # unbound-name
if isinstance(stride, torch.SymInt) and stride.node.hint is not None:
used_syms[str(stride.node)] = stride.node.hint
@ -758,6 +771,7 @@ def repro_common(
# TODO: speed this up
mod = make_fx(mod, tracing_mode=options.tracing_mode)(*args)
# pyrefly: ignore # bad-assignment
torch._inductor.config.generate_intermediate_hooks = True
return mod, args

View File

@ -301,6 +301,7 @@ def repro_load_args(load_args: Any, save_dir: Optional[str]) -> tuple[Any]:
def repro_common(
options: Any, exported_program: ExportedProgram
) -> tuple[torch.fx.GraphModule, Any, Any]:
# pyrefly: ignore # bad-assignment
torch._inductor.config.generate_intermediate_hooks = True
mod = exported_program.module(check_guards=False)
args, kwargs = exported_program.example_inputs
@ -422,6 +423,7 @@ def repro_minify(
) -> bool:
# Need to export first so the in_spec and out_spec are populated
tuple_inputs = tuple(flat_example_inputs)
# pyrefly: ignore # bad-assignment
gm = export_for_aoti_minifier(
gm, tuple_inputs, strict=strict, skip_export_error=skip_export_error
)

View File

@ -102,6 +102,7 @@ def _bytecode_from_template_with_split(
def _try_except_tf_mode_template(dummy: Any, stack_var_name: Any) -> None:
# NOTE: Make sure this name matches what is generated by symbolic_convert:import_source
# on torch._dynamo.utils.
# pyrefly: ignore # unknown-name
global __import_torch_dot__dynamo_dot_utils
try:
dummy
@ -555,6 +556,7 @@ class ContinueExecutionCache:
# remap original instructions' exception table entries
if old_hook_target_remap:
# pyrefly: ignore # unbound-name
assert is_py311_plus
for inst in instructions:
if (

View File

@ -696,6 +696,7 @@ class SideEffects:
cg.add_cache(var)
var.source = LocalSource(cg.tempvars[var]) # type: ignore[attr-defined]
elif var.source is None:
# pyrefly: ignore # bad-assignment
var.source = LocalCellSource(var.local_name)
elif isinstance(var, variables.TensorVariable):
# NOTE: for historical reasons we never assigned local sources
@ -732,6 +733,7 @@ class SideEffects:
if isinstance(var, variables.UserDefinedObjectVariable):
def load_new_method() -> None:
# pyrefly: ignore # missing-attribute
assert var.base_cls_vt is not None
cg(var.base_cls_vt) # type: ignore[attr-defined]
cg.extend_output([cg.create_load_attr("__new__")])
@ -978,7 +980,9 @@ class SideEffects:
elif self.is_attribute_mutation(var):
if isinstance(
var, variables.UserDefinedDictVariable
var,
variables.UserDefinedDictVariable,
# pyrefly: ignore # bad-argument-type
) and self.is_modified(var._dict_vt):
# Do dict related update manually here. The store_attr
# mutations will be applied later.
@ -1011,6 +1015,7 @@ class SideEffects:
]
)
# pyrefly: ignore # bad-argument-type
cg(var._dict_vt, allow_cache=False) # Don't codegen via source
cg.extend_output(
[
@ -1031,7 +1036,9 @@ class SideEffects:
]
)
elif isinstance(
var, variables.UserDefinedListVariable
var,
variables.UserDefinedListVariable,
# pyrefly: ignore # bad-argument-type
) and self.is_modified(var._list_vt):
# Update the list to the updated items. Be careful in
# calling the list methods and not the overridden methods.
@ -1048,6 +1055,7 @@ class SideEffects:
]
)
# pyrefly: ignore # bad-argument-type
cg(var._list_vt, allow_cache=False) # Don't codegen via source
cg.extend_output(
[

View File

@ -563,6 +563,7 @@ def log_graph_break(
)
else:
user_stack = get_stack_above_dynamo() + user_stack # type: ignore[assignment]
# pyrefly: ignore # bad-argument-type
user_stack = collapse_resume_frames(user_stack)
user_stack_formatted = "".join(traceback.format_list(user_stack))
user_stack_trace = (
@ -1040,6 +1041,7 @@ class BytecodeDispatchTableMeta(type):
op: getattr(cls, opname, functools.partial(_missing, opname))
for opname, op in dis.opmap.items()
}
# pyrefly: ignore # missing-attribute
cls.dispatch_table = [dispatch_table.get(i) for i in range(2**8)]
@ -1788,13 +1790,17 @@ class InstructionTranslatorBase(
source = self.import_source(module_name)
if self.exec_recorder:
# pyrefly: ignore # unbound-name
self.exec_recorder.add_local_mod(recorded_name, value)
# pyrefly: ignore # unbound-name
if istype(value, (types.ModuleType, DummyModule)):
# pyrefly: ignore # unbound-name
self.push(PythonModuleVariable(value, source=source))
else:
unimplemented_v2(
gb_type="Bad import result",
# pyrefly: ignore # unbound-name
context=typestr(value),
explanation="Import result is not a Python module.",
hints=[],
@ -1873,6 +1879,7 @@ class InstructionTranslatorBase(
exit, exc = self.popn(2)
assert exc is None
self.push(exc)
# pyrefly: ignore # bad-argument-type
self.push(exit.call_function(self, [ConstantVariable.create(None)] * 3, {}))
def WITH_CLEANUP_FINISH(self, inst: Instruction) -> None:
@ -2294,7 +2301,9 @@ class InstructionTranslatorBase(
):
return True
elif isinstance(exc_instance, variables.BuiltinVariable) and issubclass(
exc_instance.fn, expected_type.fn
exc_instance.fn,
# pyrefly: ignore # missing-attribute
expected_type.fn,
):
return True
@ -2354,26 +2363,37 @@ class InstructionTranslatorBase(
assert isinstance(null, NullVariable)
if not isinstance(
argsvars, BaseListVariable
# pyrefly: ignore # unbound-name
argsvars,
BaseListVariable,
# pyrefly: ignore # unbound-name
) and argsvars.has_force_unpack_var_sequence(self):
# pyrefly: ignore # unbound-name
argsvars = TupleVariable(argsvars.force_unpack_var_sequence(self))
# Unpack for cases like fn(**obj) where obj is a map
# pyrefly: ignore # unbound-name
if isinstance(kwargsvars, UserDefinedObjectVariable):
kwargsvars = BuiltinVariable.call_custom_dict(self, dict, kwargsvars) # type: ignore[arg-type]
# pyrefly: ignore # unbound-name
if not isinstance(argsvars, BaseListVariable) or not isinstance(
kwargsvars, ConstDictVariable
# pyrefly: ignore # unbound-name
kwargsvars,
ConstDictVariable,
):
unimplemented_v2(
gb_type="Variadic function call with bad args/kwargs type",
# pyrefly: ignore # unbound-name
context=f"args type: {typestr(argsvars)}, kwargs type: {typestr(kwargsvars)}",
explanation="Expected args to be a list and kwargs to be a dict",
hints=[*graph_break_hints.USER_ERROR],
)
# Map to a dictionary of str -> VariableTracker
# pyrefly: ignore # unbound-name, missing-attribute
kwargsvars = kwargsvars.keys_as_python_constant()
# pyrefly: ignore # unbound-name, missing-attribute
self.call_function(fn, argsvars.items, kwargsvars)
@break_graph_if_unsupported(push=1)
@ -2437,6 +2457,7 @@ class InstructionTranslatorBase(
def LOAD_ATTR(self, inst: Instruction) -> None:
if sys.version_info >= (3, 12):
# pyrefly: ignore # unsupported-operation
if inst.arg % 2:
self.LOAD_METHOD(inst)
return
@ -3029,14 +3050,17 @@ class InstructionTranslatorBase(
"(i.e. `a, b, c = d`).",
hints=[*graph_break_hints.USER_ERROR],
)
# pyrefly: ignore # unbound-name
if len(val) != inst.argval:
unimplemented_v2(
gb_type="Length mismatch when unpacking object for UNPACK_SEQUENCE",
# pyrefly: ignore # unbound-name
context=f"expected length: {inst.argval}, actual: {len(val)}",
explanation=f"{seq} unpacked to a list for the UNPACK_SEQUENCE bytecode "
"(i.e. `a, b, c = d`) with unexpected length.",
hints=[*graph_break_hints.DYNAMO_BUG],
)
# pyrefly: ignore # unbound-name
for i in reversed(val):
self.push(i)
@ -3409,9 +3433,13 @@ class InstructionTranslatorBase(
args = [contents[1]]
if kw_names:
# pyrefly: ignore # bad-argument-type
args = args + contents[2 : -len(kw_names)]
# pyrefly: ignore # bad-argument-type
kwargs_list = contents[-len(kw_names) :]
# pyrefly: ignore # no-matching-overload
kwargs = dict(zip(kw_names, kwargs_list))
# pyrefly: ignore # bad-argument-type
assert len(kwargs) == len(kw_names)
else:
args = args + contents[2:]
@ -4118,6 +4146,7 @@ class InstructionTranslator(InstructionTranslatorBase):
and isinstance(tos, LocalGeneratorObjectVariable)
):
self.stack[-1] = ListIteratorVariable(
# pyrefly: ignore # unbound-name
tos.force_unpack_var_sequence(self),
mutation_type=ValueMutationNew(),
)
@ -4188,6 +4217,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
"""Trace and inline a called method"""
symbolic_result: Optional[VariableTracker]
# pyrefly: ignore # bad-override
parent: InstructionTranslatorBase
@classmethod
@ -4231,6 +4261,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
# trace through.
if (
hasattr(getattr(func, "fn", None), "_origin")
# pyrefly: ignore # missing-attribute
and func.fn._origin is produce_trampoline_autograd_apply
):
# Known sound
@ -4305,12 +4336,14 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
tracing_ctx.previously_inlined_functions[code] = result
try:
# pyrefly: ignore # missing-attribute
sub_locals = func.bind_args(parent, args, kwargs)
except TypeError as e:
# Wrap the general TypeError during bind_args() to the internal ArgsMismatchError with detailed info
raise ArgsMismatchError( # noqa: B904
"{reason}.\n func = {func}, args = {args}, kwargs = {kwargs}".format(
reason=str(e),
# pyrefly: ignore # missing-attribute
func=f"'{func.get_name()}' {func.get_filename()}:{func.get_code().co_firstlineno}",
args=[arg.python_type() for arg in args],
kwargs=kwargs,
@ -4394,6 +4427,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
sub_locals,
parent.symbolic_globals,
parent.symbolic_torch_function_state,
# pyrefly: ignore # bad-argument-type
func,
)
return tracer

View File

@ -153,7 +153,9 @@ class CPythonTestCase(TestCase):
assertTupleEqual = unittest.TestCase.assertTupleEqual
assertSetEqual = unittest.TestCase.assertSetEqual
assertDictEqual = polyfills.assert_dict_equal
# pyrefly: ignore # bad-override
assertRaises = unittest.TestCase.assertRaises
# pyrefly: ignore # bad-override
assertRaisesRegex = unittest.TestCase.assertRaisesRegex
assertWarns = unittest.TestCase.assertWarns
assertWarnsRegex = unittest.TestCase.assertWarnsRegex
@ -169,8 +171,10 @@ class CPythonTestCase(TestCase):
) -> Callable[..., Any]:
# We want to compile only the test function, excluding any setup code
# from unittest
method = getattr(self, self._testMethodName)
method = torch._dynamo.optimize(backend, error_on_graph_break=nopython)(method)
setattr(self, self._testMethodName, method)
return fn

View File

@ -207,6 +207,7 @@ torch._inductor.config.{"cpp" if device == "cpu" else "triton"}.inject_relu_bug_
launch_file = _as_posix_path(os.path.join(repro_dir, "minifier_launcher.py"))
with open(launch_file) as f:
launch_code = f.read()
self.assertTrue(os.path.exists(launch_file))
args = ["python3", launch_file, "minify", *minifier_args]
@ -218,6 +219,7 @@ torch._inductor.config.{"cpp" if device == "cpu" else "triton"}.inject_relu_bug_
print("minifier stdout:", launch_proc.stdout.decode("utf-8"))
stderr = launch_proc.stderr.decode("utf-8")
print("minifier stderr:", stderr)
self.assertNotIn("Input graph did not fail the tester", stderr)
return launch_proc, launch_code
@ -230,6 +232,7 @@ torch._inductor.config.{"cpp" if device == "cpu" else "triton"}.inject_relu_bug_
repro_file = _as_posix_path(os.path.join(repro_dir, "repro.py"))
with open(repro_file) as f:
repro_code = f.read()
self.assertTrue(os.path.exists(repro_file))
repro_proc = self._maybe_subprocess_run(
@ -296,11 +299,14 @@ torch._dynamo.config.debug_dir_root = "{_as_posix_path(self.DEBUG_DIR)}"
if expected_error is None:
# Just check that there was no error
self.assertEqual(test_proc.returncode, 0)
self.assertIsNone(repro_dir)
return None
# NB: Intentionally do not test return code; we only care about
# actually generating the repro, we don't have to crash
self.assertIn(expected_error, test_proc.stderr.decode("utf-8"))
self.assertIsNotNone(repro_dir)
print("running minifier", file=sys.stderr)
_minifier_proc, minifier_code = self._run_minifier_launcher(
@ -311,6 +317,7 @@ torch._dynamo.config.debug_dir_root = "{_as_posix_path(self.DEBUG_DIR)}"
)
print("running repro", file=sys.stderr)
repro_proc, repro_code = self._run_repro(repro_dir, isolate=isolate)
self.assertIn(expected_error, repro_proc.stderr.decode("utf-8"))
self.assertNotEqual(repro_proc.returncode, 0)
return MinifierTestResult(minifier_code=minifier_code, repro_code=repro_code)

View File

@ -496,6 +496,7 @@ def make_test_cls_with_patches(
def skipIfNotPy311(fn: Callable[_P, _T]) -> Callable[_P, _T]:
if sys.version_info >= (3, 11):
return fn
# pyrefly: ignore # bad-return, bad-argument-type
return unittest.skip(fn)

View File

@ -3005,6 +3005,7 @@ def get_torch_obj_rule_map() -> dict[Any, type["VariableTracker"]]:
obj = torch_dir + k[len("torch/") :]
if obj is not None:
if is_annotate_wrapped_function(obj):
# pyrefly: ignore # missing-attribute
obj = obj.__wrapped__
if is_lru_cache_wrapped_function(obj):
obj = obj.__wrapped__

View File

@ -295,11 +295,13 @@ def increment_op_count(cnt: int) -> None:
def calculate_time_spent() -> dict[str, float]:
total_by_key = {}
for phase, timing in cumulative_time_spent_ns.items():
# pyrefly: ignore # unsupported-operation
total_by_key[phase] = timing / 1e9
total_by_key["total_wall_time"] = total_by_key.get(
"entire_frame_compile", 0
) + total_by_key.get("entire_backward_compile", 0)
# pyrefly: ignore # bad-return
return total_by_key
@ -798,6 +800,7 @@ def compile_times(repr: Literal["str"], aggregate: bool = False) -> str: ...
@overload
# pyrefly: ignore # inconsistent-overload
def compile_times(
repr: Literal["csv"], aggregate: bool = False
) -> tuple[list[str], list[object]]: ...
@ -1463,6 +1466,7 @@ class CompilationMetrics:
compile_id = all_metrics.get("compile_id")
all_metrics["compile_id"] = str(compile_id) if compile_id else None
# pyrefly: ignore # bad-argument-type
return cls(**all_metrics)
@ -2253,6 +2257,7 @@ def is_jit_model(
Union[
torch.jit._trace.TopLevelTracedModule,
torch.jit._script.RecursiveScriptModule,
# pyrefly: ignore # invalid-param-spec
torch.jit.ScriptFunction[Any, Any],
torch.jit.ScriptModule,
]
@ -2361,6 +2366,7 @@ def checkpoint_params(gm: torch.fx.GraphModule) -> Callable[[], None]:
cuda_rng_state = torch.clone(torch.cuda.get_rng_state())
saved_state = [
(param, param._version, torch.clone(param))
# pyrefly: ignore # bad-argument-type
for param in itertools.chain(gm.parameters(), gm.buffers())
]
@ -2626,13 +2632,16 @@ def get_items_from_dict(obj: dict[K, V]) -> Iterable[tuple[K, Union[V, Any]]]:
if istype(obj, (dict, OrderedDict)):
return obj.items()
elif isinstance(obj, OrderedDict):
# pyrefly: ignore # bad-argument-type
return [(k, OrderedDict.__getitem__(obj, k)) for k in OrderedDict.keys(obj)]
else:
# pyrefly: ignore # bad-argument-type
return [(k, dict.__getitem__(obj, k)) for k in dict.keys(obj)]
def nn_module_new(cls: Any) -> Any:
obj = object_new(cls)
# pyrefly: ignore # bad-argument-type
torch.nn.Module.__init__(obj)
return obj
@ -2679,6 +2688,7 @@ def dict_keys_getitem(d: dict[Any, Any], n: int) -> Any:
dict_class = dict
if isinstance(d, OrderedDict):
dict_class = OrderedDict
# pyrefly: ignore # bad-argument-type
return next(itertools.islice(dict_class.keys(d), n, n + 1))
@ -3222,8 +3232,10 @@ def format_func_info(code: CodeType) -> str:
@contextlib.contextmanager
def disable_cache_limit() -> Generator[None, None, None]:
prior = config.recompile_limit
# pyrefly: ignore # bad-assignment
config.recompile_limit = sys.maxsize
prior_acc_limit = config.accumulated_recompile_limit
# pyrefly: ignore # bad-assignment
config.accumulated_recompile_limit = sys.maxsize
try:
@ -3958,6 +3970,7 @@ class numpy_operator_wrapper(Generic[_P, R]):
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> Any:
assert not kwargs
# pyrefly: ignore # bad-assignment
args = (
tnp.ndarray(arg) if isinstance(arg, torch.Tensor) else arg for arg in args
)
@ -4157,6 +4170,7 @@ def _extract_anchors_from_expr(segment: str) -> Optional[_Anchors]:
# (x) + (y)
# ~~^~~~~~~
while (ch := lines[cur_lineno][cur_col]).isspace() or ch in ")\\#":
# pyrefly: ignore # unbound-name
if ch in "\\#":
cur_lineno, cur_col = nextline(cur_lineno, cur_col)
else:
@ -4507,6 +4521,7 @@ class GmWrapper(torch.nn.Module):
self.unflatten_fn = unflatten_fn
def forward(self, *args: Any) -> Any:
# pyrefly: ignore # annotation-mismatch
args: list[Any] = list(args)
return self.gm(*self.unflatten_fn(args))

View File

@ -1028,6 +1028,7 @@ class BuiltinVariable(VariableTracker):
def call_self_handler(tx: "InstructionTranslator", args, kwargs):
try:
# pyrefly: ignore # not-callable
result = self_handler(tx, *args, **kwargs)
if result is not None:
return result
@ -1035,6 +1036,7 @@ class BuiltinVariable(VariableTracker):
# Check if binding is bad. inspect signature bind is expensive.
# So check only when handler call fails.
try:
# pyrefly: ignore # bad-argument-type
inspect.signature(self_handler).bind(tx, *args, **kwargs)
except TypeError as e:
has_constant_handler = obj.has_constant_handler(args, kwargs)
@ -1087,6 +1089,7 @@ class BuiltinVariable(VariableTracker):
hints=[*graph_break_hints.DYNAMO_BUG],
from_exc=exc,
)
# pyrefly: ignore # unbound-name
return VariableTracker.build(tx, res)
else:
@ -1115,6 +1118,7 @@ class BuiltinVariable(VariableTracker):
tx,
args=list(map(ConstantVariable.create, exc.args)),
)
# pyrefly: ignore # unbound-name
return VariableTracker.build(tx, res)
handlers.append(constant_fold_handler)
@ -1437,6 +1441,7 @@ class BuiltinVariable(VariableTracker):
resolved_fn = getattr(self.fn, name)
if resolved_fn in dict_methods:
if isinstance(args[0], variables.UserDefinedDictVariable):
# pyrefly: ignore # missing-attribute
return args[0]._dict_vt.call_method(tx, name, args[1:], kwargs)
elif isinstance(args[0], variables.ConstDictVariable):
return args[0].call_method(tx, name, args[1:], kwargs)
@ -1445,6 +1450,7 @@ class BuiltinVariable(VariableTracker):
resolved_fn = getattr(self.fn, name)
if resolved_fn in set_methods:
if isinstance(args[0], variables.UserDefinedSetVariable):
# pyrefly: ignore # missing-attribute
return args[0]._set_vt.call_method(tx, name, args[1:], kwargs)
elif isinstance(args[0], variables.SetVariable):
return args[0].call_method(tx, name, args[1:], kwargs)
@ -1533,10 +1539,12 @@ class BuiltinVariable(VariableTracker):
if type(arg.value).__str__ is object.__str__:
# Rely on the object str method
try:
# pyrefly: ignore # unbound-name
return variables.ConstantVariable.create(value=str_method())
except AttributeError:
# Graph break
return
# pyrefly: ignore # unbound-name
elif is_wrapper_or_member_descriptor(str_method):
unimplemented_v2(
gb_type="Attempted to a str() method implemented in C/C++",
@ -1653,8 +1661,10 @@ class BuiltinVariable(VariableTracker):
else:
raw_b = b.raw_value
if self.fn is max:
# pyrefly: ignore # missing-attribute
raw_res = max(a.raw_value, raw_b)
else:
# pyrefly: ignore # missing-attribute
raw_res = min(a.raw_value, raw_b)
need_unwrap = any(
@ -2106,6 +2116,7 @@ class BuiltinVariable(VariableTracker):
)
if isinstance(arg, variables.UserDefinedExceptionClassVariable):
# pyrefly: ignore # unbound-name
return ConstantVariable.create(isinstance(arg_type, isinstance_type))
isinstance_type_tuple: tuple[type, ...]
@ -2138,8 +2149,10 @@ class BuiltinVariable(VariableTracker):
# through it. This is a limitation of the current implementation.
# Usually `__subclasscheck__` and `__instancecheck__` can be constant fold through, it
# might not be a big issue and we trade off it for performance.
# pyrefly: ignore # unbound-name
val = issubclass(arg_type, isinstance_type_tuple)
except TypeError:
# pyrefly: ignore # unbound-name
val = arg_type in isinstance_type_tuple
return variables.ConstantVariable.create(val)
@ -2161,6 +2174,7 @@ class BuiltinVariable(VariableTracker):
# WARNING: This might run arbitrary user code `__subclasscheck__`.
# See the comment in call_isinstance above.
# pyrefly: ignore # unbound-name
return variables.ConstantVariable(issubclass(left_ty_py, right_ty_py))
def call_super(self, tx: "InstructionTranslator", a, b):
@ -2206,7 +2220,9 @@ class BuiltinVariable(VariableTracker):
value = getattr(self.fn, name)
except AttributeError:
raise_observed_exception(AttributeError, tx)
# pyrefly: ignore # unbound-name
if not callable(value):
# pyrefly: ignore # unbound-name
return VariableTracker.build(tx, value, source)
return variables.GetAttrVariable(self, name, source=source)

View File

@ -651,6 +651,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
def handle_use_deterministic_algorithms(
self, tx: "InstructionTranslator", mode, warn_only=False
):
# pyrefly: ignore # missing-attribute
if warn_only and warn_only.as_python_constant():
unimplemented_v2(
gb_type="Attempted to use torch.use_deterministic_algorithms(warn_only=True)",
@ -1035,6 +1036,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
else:
raise torch._dynamo.exc.Unsupported("branch not supported")
return variables.ConstantVariable.create(
# pyrefly: ignore # bad-argument-type
torch.fx.experimental.symbolic_shapes.guard_scalar(val)
)
@ -1081,6 +1083,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
return
return variables.ConstantVariable.create(
# pyrefly: ignore # bad-argument-type
torch.fx.experimental.symbolic_shapes.has_static_value(val)
)
@ -1212,13 +1215,17 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
)
# need to guard only on no-arg get_device_module
# pyrefly: ignore # unbound-name
if device is None:
source = CallFunctionNoArgsSource(self.source)
install_guard(source.make_guard(GuardBuilder.ID_MATCH))
# assumes `module` is in the form `torch.xyz`
new_source = AttrSource(
TorchSource(), module.__name__.rsplit(".", maxsplit=1)[-1]
TorchSource(),
# pyrefly: ignore # unbound-name
module.__name__.rsplit(".", maxsplit=1)[-1],
)
# pyrefly: ignore # unbound-name
return VariableTracker.build(tx, module, new_source)
@register(torch.set_default_device)
@ -1373,9 +1380,12 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
f"{fn.__name__}_spec", f_spec
)
input_spec_proxy = tx.output.register_static_attr_and_return_proxy(
fn.__name__ + "_input_spec", input_spec
fn.__name__ + "_input_spec",
# pyrefly: ignore # unbound-name
input_spec,
)
f_spec_proxy.node.type = type(f_spec)
# pyrefly: ignore # unbound-name
input_spec_proxy.node.type = type(input_spec)
all_args = (f_spec_proxy, input_spec_proxy, *proxified_flat_args)
@ -1716,6 +1726,7 @@ For now, dynamo will explicitly graph break when it encounters user code with th
)
# this results in cleaner graphs, but only works for inputs
# pyrefly: ignore # missing-attribute
if data.source:
return cls._nn_param_via_prefix_insert(tx, data, requires_grad)
@ -1734,7 +1745,9 @@ For now, dynamo will explicitly graph break when it encounters user code with th
# TODO[@lucaskabela]: Remove the behavior below since it is deprecated
if isinstance(
data, TensorWithTFOverrideVariable
data,
TensorWithTFOverrideVariable,
# pyrefly: ignore # missing-attribute
) or is_traceable_wrapper_subclass_type(data.class_type):
unimplemented_v2(
gb_type="Attempted to use torch.nn.Parameter constructor with tensor subclass",
@ -1757,8 +1770,11 @@ For now, dynamo will explicitly graph break when it encounters user code with th
)
try:
# pyrefly: ignore # missing-attribute
shape = tuple(data.var_getattr(tx, "shape").as_python_constant())
# pyrefly: ignore # missing-attribute
dtype = data.var_getattr(tx, "dtype").as_python_constant()
# pyrefly: ignore # missing-attribute
device = data.var_getattr(tx, "device").as_python_constant()
except NotImplementedError as e:
unimplemented_v2(
@ -1773,9 +1789,13 @@ For now, dynamo will explicitly graph break when it encounters user code with th
)
placeholder = tx.output.synthetic_graph_input(
new_parameter_placeholder, [shape, dtype, device, requires_grad]
new_parameter_placeholder,
# pyrefly: ignore # unbound-name
[shape, dtype, device, requires_grad],
)
# pyrefly: ignore # missing-attribute
if data.requires_grad:
# pyrefly: ignore # missing-attribute
data = data.call_method(tx, "detach", [], {})
from .builder import wrap_fx_proxy
@ -1785,6 +1805,7 @@ For now, dynamo will explicitly graph break when it encounters user code with th
tx.output.create_proxy(
"call_function",
tracable_create_parameter,
# pyrefly: ignore # missing-attribute
(data.as_proxy(), placeholder.as_proxy()),
{},
),

View File

@ -646,7 +646,6 @@ def update_schema():
assert thrift_content[1].startswith("// checksum<<")
thrift_checksum_real = _hash_content("\n".join(thrift_content[2:]))
# pyrefly: ignore # import-error
from yaml import load, Loader
dst = load(content, Loader=Loader)

View File

@ -34,7 +34,7 @@ from torch.fx._pytree import (
_deregister_pytree_flatten_spec,
register_pytree_flatten_spec,
)
from torch.utils._pytree import ( # pyrefly: ignore # deprecated
from torch.utils._pytree import (
_deregister_pytree_node,
_register_pytree_node,
Context,

View File

@ -198,7 +198,7 @@ def generate_yaml_from_profiles(op_profiles: dict[str, set[OpProfile]]) -> str:
to a file. The yaml string can be loaded back into an operator profile
structure using `read_profiles_from_yaml`.
"""
# pyrefly: ignore # import-error
import yaml
from torch._export.serde.serialize import (
@ -263,7 +263,7 @@ def read_profiles_from_yaml(yaml_str: str) -> dict[str, set[OpProfile]]:
"""
Reads the yaml saved by `save_op_profiles` and returns the operator profiles.
"""
# pyrefly: ignore # import-error
import yaml
from torch._export.serde.serialize import (

View File

@ -1,5 +1,5 @@
from .core import dispatch
from .dispatcher import ( # pyrefly: ignore # deprecated
from .dispatcher import (
Dispatcher,
halt_ordering,
MDNotImplementedError,

View File

@ -153,6 +153,7 @@ class CausalBias(torch.Tensor):
diagonal=diagonal_offset,
)
# pyrefly: ignore # bad-return
def _materialize(self, device: Optional[torch.device] = None) -> torch.Tensor:
"""
Materializes the causal bias into a tensor form.

View File

@ -84,6 +84,7 @@ _score_mod_signature = Callable[[Tensor, Tensor, Tensor, Tensor, Tensor], Tensor
_mask_mod_signature = Callable[[Tensor, Tensor, Tensor, Tensor], Tensor]
# pyrefly: ignore # invalid-inheritance
class FlexKernelOptions(TypedDict, total=False):
"""Options for controlling the behavior of FlexAttention kernels.
@ -127,76 +128,93 @@ class FlexKernelOptions(TypedDict, total=False):
"""
# Performance tuning options
# pyrefly: ignore # invalid-annotation
num_warps: NotRequired[int]
"""Number of warps to use in the CUDA kernel. Higher values may improve performance
but increase register pressure. Default is determined by autotuning."""
# pyrefly: ignore # invalid-annotation
num_stages: NotRequired[int]
"""Number of pipeline stages in the CUDA kernel. Higher values may improve performance
but increase shared memory usage. Default is determined by autotuning."""
# pyrefly: ignore # invalid-annotation
BLOCK_M: NotRequired[int]
"""Thread block size for the sequence length dimension of Q in forward pass.
Must be a power of 2. Common values: 16, 32, 64, 128. Default is determined by autotuning."""
# pyrefly: ignore # invalid-annotation
BLOCK_N: NotRequired[int]
"""Thread block size for the sequence length dimension of K/V in forward pass.
Must be a power of 2. Common values: 16, 32, 64, 128. Default is determined by autotuning."""
# Backward-specific block sizes (when prefixed with 'bwd_')
# pyrefly: ignore # invalid-annotation
BLOCK_M1: NotRequired[int]
"""Thread block size for Q dimension in backward pass. Use as 'bwd_BLOCK_M1'.
Default is determined by autotuning."""
# pyrefly: ignore # invalid-annotation
BLOCK_N1: NotRequired[int]
"""Thread block size for K/V dimension in backward pass. Use as 'bwd_BLOCK_N1'.
Default is determined by autotuning."""
# pyrefly: ignore # invalid-annotation
BLOCK_M2: NotRequired[int]
"""Thread block size for second Q dimension in backward pass. Use as 'bwd_BLOCK_M2'.
Default is determined by autotuning."""
# pyrefly: ignore # invalid-annotation
BLOCK_N2: NotRequired[int]
"""Thread block size for second K/V dimension in backward pass. Use as 'bwd_BLOCK_N2'.
Default is determined by autotuning."""
# pyrefly: ignore # invalid-annotation
PRESCALE_QK: NotRequired[bool]
"""Whether to pre-scale QK by 1/sqrt(d) and change of base. This is slightly faster but
may have more numerical error. Default: False."""
# pyrefly: ignore # invalid-annotation
ROWS_GUARANTEED_SAFE: NotRequired[bool]
"""If True, guarantees that at least one value in each row is not masked out.
Allows skipping safety checks for better performance. Only set this if you are certain
your mask guarantees this property. For example, causal attention is guaranteed safe
because each query has at least 1 key-value to attend to. Default: False."""
# pyrefly: ignore # invalid-annotation
BLOCKS_ARE_CONTIGUOUS: NotRequired[bool]
"""If True, guarantees that all blocks in the mask are contiguous.
Allows optimizing block traversal. For example, causal masks would satisfy this,
but prefix_lm + sliding window would not. Default: False."""
# pyrefly: ignore # invalid-annotation
WRITE_DQ: NotRequired[bool]
"""Controls whether gradient scatters are done in the DQ iteration loop of the backward pass.
Setting this to False will force this to happen in the DK loop which depending on your
specific score_mod and mask_mod might be faster. Default: True."""
# pyrefly: ignore # invalid-annotation
FORCE_USE_FLEX_ATTENTION: NotRequired[bool]
"""If True, forces the use of the flex attention kernel instead of potentially using
the more optimized flex-decoding kernel for short sequences. This can be a helpful
option for debugging. Default: False."""
# pyrefly: ignore # invalid-annotation
USE_TMA: NotRequired[bool]
"""Whether to use Tensor Memory Accelerator (TMA) on supported hardware.
This is experimental and may not work on all hardware, currently specific
to NVIDIA GPUs Hopper+. Default: False."""
# ROCm-specific options
# pyrefly: ignore # invalid-annotation
kpack: NotRequired[int]
"""ROCm-specific kernel packing parameter."""
# pyrefly: ignore # invalid-annotation
matrix_instr_nonkdim: NotRequired[int]
"""ROCm-specific matrix instruction non-K dimension."""
# pyrefly: ignore # invalid-annotation
waves_per_eu: NotRequired[int]
"""ROCm-specific waves per execution unit."""
@ -581,6 +599,7 @@ class BlockMask:
block_size = (self.BLOCK_SIZE,) # type: ignore[assignment]
seq_lengths = (self.seq_lengths,) # type: ignore[assignment]
# pyrefly: ignore # not-iterable
return (
*seq_lengths,
self.kv_num_blocks,
@ -753,6 +772,7 @@ class BlockMask:
partial_dense = _ordered_to_dense(self.kv_num_blocks, self.kv_indices)
if self.full_kv_num_blocks is not None:
assert self.full_kv_indices is not None
# pyrefly: ignore # bad-return
return partial_dense | _ordered_to_dense(
self.full_kv_num_blocks, self.full_kv_indices
)

View File

@ -78,6 +78,7 @@ class ModuleWrapper(nn.Module):
# nn.Module defines training as a boolean
@property # type: ignore[override]
# pyrefly: ignore # bad-override
def training(self):
return self.cpp_module.training

View File

@ -1266,6 +1266,7 @@ def adaptive_max_pool2d_with_indices(
output_size,
return_indices=return_indices,
)
# pyrefly: ignore # bad-argument-type
output_size = _list_with_default(output_size, input.size())
return torch._C._nn.adaptive_max_pool2d(input, output_size)
@ -1323,6 +1324,7 @@ def adaptive_max_pool3d_with_indices(
output_size,
return_indices=return_indices,
)
# pyrefly: ignore # bad-argument-type
output_size = _list_with_default(output_size, input.size())
return torch._C._nn.adaptive_max_pool3d(input, output_size)
@ -1381,6 +1383,7 @@ def adaptive_avg_pool2d(input: Tensor, output_size: BroadcastingList2[int]) -> T
"""
if has_torch_function_unary(input):
return handle_torch_function(adaptive_avg_pool2d, (input,), input, output_size)
# pyrefly: ignore # bad-argument-type
_output_size = _list_with_default(output_size, input.size())
return torch._C._nn.adaptive_avg_pool2d(input, _output_size)
@ -1396,6 +1399,7 @@ def adaptive_avg_pool3d(input: Tensor, output_size: BroadcastingList3[int]) -> T
"""
if has_torch_function_unary(input):
return handle_torch_function(adaptive_avg_pool3d, (input,), input, output_size)
# pyrefly: ignore # bad-argument-type
_output_size = _list_with_default(output_size, input.size())
return torch._C._nn.adaptive_avg_pool3d(input, _output_size)
@ -2431,6 +2435,7 @@ def _no_grad_embedding_renorm_(
input: Tensor,
max_norm: float,
norm_type: float,
# pyrefly: ignore # bad-return
) -> tuple[Tensor, Tensor]:
torch.embedding_renorm_(weight.detach(), input, max_norm, norm_type)
@ -2684,6 +2689,7 @@ def embedding_bag(
if not torch.jit.is_scripting() and input.dim() == 2 and input.is_nested:
include_last_offset = True
# pyrefly: ignore # missing-attribute
offsets = input.offsets()
input = input.values().reshape(-1)
if per_sample_weights is not None:
@ -2818,6 +2824,7 @@ def batch_norm(
eps=eps,
)
if training:
# pyrefly: ignore # bad-argument-type
_verify_batch_size(input.size())
return torch.batch_norm(
@ -2873,6 +2880,7 @@ def instance_norm(
eps=eps,
)
if use_input_stats:
# pyrefly: ignore # bad-argument-type
_verify_spatial_size(input.size())
return torch.instance_norm(
input,
@ -2998,11 +3006,13 @@ def local_response_norm(
div = input.mul(input)
if dim == 3:
div = div.unsqueeze(1)
# pyrefly: ignore # bad-argument-type
div = pad(div, (0, 0, size // 2, (size - 1) // 2))
div = avg_pool2d(div, (size, 1), stride=1).squeeze(1)
else:
sizes = input.size()
div = div.view(sizes[0], 1, sizes[1], sizes[2], -1)
# pyrefly: ignore # bad-argument-type
div = pad(div, (0, 0, 0, 0, size // 2, (size - 1) // 2))
div = avg_pool3d(div, (size, 1, 1), stride=1).squeeze(1)
div = div.view(sizes)
@ -3151,7 +3161,12 @@ def nll_loss(
if size_average is not None or reduce is not None:
reduction = _Reduction.legacy_get_string(size_average, reduce)
return torch._C._nn.nll_loss_nd(
input, target, weight, _Reduction.get_enum(reduction), ignore_index
input,
target,
weight,
# pyrefly: ignore # bad-argument-type
_Reduction.get_enum(reduction),
ignore_index,
)
@ -3296,6 +3311,7 @@ def gaussian_nll_loss(
var.clamp_(min=eps)
# Calculate the loss
# pyrefly: ignore # unsupported-operation
loss = 0.5 * (torch.log(var) + (input - target) ** 2 / var)
if full:
loss += 0.5 * math.log(2 * math.pi)
@ -3471,6 +3487,7 @@ def cross_entropy(
input,
target,
weight,
# pyrefly: ignore # bad-argument-type
_Reduction.get_enum(reduction),
ignore_index,
label_smoothing,
@ -3535,6 +3552,7 @@ def binary_cross_entropy(
new_size = _infer_size(target.size(), weight.size())
weight = weight.expand(new_size)
# pyrefly: ignore # bad-argument-type
return torch._C._nn.binary_cross_entropy(input, target, weight, reduction_enum)
@ -3663,11 +3681,18 @@ def smooth_l1_loss(
if beta == 0.0:
return torch._C._nn.l1_loss(
expanded_input, expanded_target, _Reduction.get_enum(reduction)
expanded_input,
expanded_target,
# pyrefly: ignore # bad-argument-type
_Reduction.get_enum(reduction),
)
else:
return torch._C._nn.smooth_l1_loss(
expanded_input, expanded_target, _Reduction.get_enum(reduction), beta
expanded_input,
expanded_target,
# pyrefly: ignore # bad-argument-type
_Reduction.get_enum(reduction),
beta,
)
@ -3725,7 +3750,11 @@ def huber_loss(
if weight is None:
# Use the optimized C++ backend for standard Huber loss
return torch._C._nn.huber_loss(
expanded_input, expanded_target, _Reduction.get_enum(reduction), delta
expanded_input,
expanded_target,
# pyrefly: ignore # bad-argument-type
_Reduction.get_enum(reduction),
delta,
)
else:
if weight.size() != input.size():
@ -3733,7 +3762,11 @@ def huber_loss(
# Calculate the unweighted loss first
unweighted_loss = torch._C._nn.huber_loss(
expanded_input, expanded_target, _Reduction.get_enum("none"), delta
expanded_input,
expanded_target,
# pyrefly: ignore # bad-argument-type
_Reduction.get_enum("none"),
delta,
)
# Apply weight to the unweighted loss
@ -3820,7 +3853,10 @@ def l1_loss(
)
else:
return torch._C._nn.l1_loss(
expanded_input, expanded_target, _Reduction.get_enum(reduction)
expanded_input,
expanded_target,
# pyrefly: ignore # bad-argument-type
_Reduction.get_enum(reduction),
)
@ -3895,7 +3931,10 @@ def mse_loss(
)
else:
return torch._C._nn.mse_loss(
expanded_input, expanded_target, _Reduction.get_enum(reduction)
expanded_input,
expanded_target,
# pyrefly: ignore # bad-argument-type
_Reduction.get_enum(reduction),
)
@ -4032,6 +4071,7 @@ def multilabel_margin_loss(
reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
else:
reduction_enum = _Reduction.get_enum(reduction)
# pyrefly: ignore # bad-argument-type
return torch._C._nn.multilabel_margin_loss(input, target, reduction_enum)
@ -4073,6 +4113,7 @@ def soft_margin_loss(
reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
else:
reduction_enum = _Reduction.get_enum(reduction)
# pyrefly: ignore # bad-argument-type
return torch._C._nn.soft_margin_loss(input, target, reduction_enum)
@ -4237,7 +4278,13 @@ def multi_margin_loss(
raise ValueError("weight must be one-dimensional")
return torch._C._nn.multi_margin_loss(
input, target, p, margin, weight, reduction_enum
input,
target,
p,
margin,
weight,
# pyrefly: ignore # bad-argument-type
reduction_enum,
)
@ -4383,6 +4430,7 @@ def upsample( # noqa: F811
scale_factor: Optional[float] = None,
mode: str = "nearest",
align_corners: Optional[bool] = None,
# pyrefly: ignore # bad-return
) -> Tensor: # noqa: B950
pass
@ -4394,6 +4442,7 @@ def upsample( # noqa: F811
scale_factor: Optional[float] = None,
mode: str = "nearest",
align_corners: Optional[bool] = None,
# pyrefly: ignore # bad-return
) -> Tensor: # noqa: B950
pass
@ -4496,6 +4545,7 @@ def interpolate( # noqa: F811
align_corners: Optional[bool] = None,
recompute_scale_factor: Optional[bool] = None,
antialias: bool = False,
# pyrefly: ignore # bad-return
) -> Tensor: # noqa: B950
pass
@ -4509,6 +4559,7 @@ def interpolate( # noqa: F811
align_corners: Optional[bool] = None,
recompute_scale_factor: Optional[bool] = None,
antialias: bool = False,
# pyrefly: ignore # bad-return
) -> Tensor: # noqa: B950
pass
@ -4522,6 +4573,7 @@ def interpolate( # noqa: F811
align_corners: Optional[bool] = None,
recompute_scale_factor: Optional[bool] = None,
antialias: bool = False,
# pyrefly: ignore # bad-return
) -> Tensor: # noqa: B950
pass
@ -4535,6 +4587,7 @@ def interpolate( # noqa: F811
align_corners: Optional[bool] = None,
recompute_scale_factor: Optional[bool] = None,
antialias: bool = False,
# pyrefly: ignore # bad-return
) -> Tensor:
pass
@ -4709,6 +4762,7 @@ def interpolate( # noqa: F811
(
torch.floor(
(
# pyrefly: ignore # missing-attribute
input.size(i + 2).float()
* torch.tensor(scale_factors[i], dtype=torch.float32)
).float()
@ -4733,21 +4787,28 @@ def interpolate( # noqa: F811
)
if input.dim() == 3 and mode == "nearest":
# pyrefly: ignore # bad-argument-type
return torch._C._nn.upsample_nearest1d(input, output_size, scale_factors)
if input.dim() == 4 and mode == "nearest":
# pyrefly: ignore # bad-argument-type
return torch._C._nn.upsample_nearest2d(input, output_size, scale_factors)
if input.dim() == 5 and mode == "nearest":
# pyrefly: ignore # bad-argument-type
return torch._C._nn.upsample_nearest3d(input, output_size, scale_factors)
if input.dim() == 3 and mode == "nearest-exact":
# pyrefly: ignore # bad-argument-type
return torch._C._nn._upsample_nearest_exact1d(input, output_size, scale_factors)
if input.dim() == 4 and mode == "nearest-exact":
# pyrefly: ignore # bad-argument-type
return torch._C._nn._upsample_nearest_exact2d(input, output_size, scale_factors)
if input.dim() == 5 and mode == "nearest-exact":
# pyrefly: ignore # bad-argument-type
return torch._C._nn._upsample_nearest_exact3d(input, output_size, scale_factors)
if input.dim() == 3 and mode == "area":
assert output_size is not None
# pyrefly: ignore # bad-argument-type
return adaptive_avg_pool1d(input, output_size)
if input.dim() == 4 and mode == "area":
assert output_size is not None
@ -4759,13 +4820,21 @@ def interpolate( # noqa: F811
if input.dim() == 3 and mode == "linear":
assert align_corners is not None
return torch._C._nn.upsample_linear1d(
input, output_size, align_corners, scale_factors
input,
# pyrefly: ignore # bad-argument-type
output_size,
align_corners,
scale_factors,
)
if input.dim() == 4 and mode == "bilinear":
assert align_corners is not None
if antialias:
return torch._C._nn._upsample_bilinear2d_aa(
input, output_size, align_corners, scale_factors
input,
# pyrefly: ignore # bad-argument-type
output_size,
align_corners,
scale_factors,
)
# Two levels are necessary to prevent TorchScript from touching
# are_deterministic_algorithms_enabled.
@ -4778,7 +4847,11 @@ def interpolate( # noqa: F811
"torch._decomp.decompositions"
)._upsample_linear_vec(input, output_size, align_corners, scale_factors)
return torch._C._nn.upsample_bilinear2d(
input, output_size, align_corners, scale_factors
input,
# pyrefly: ignore # bad-argument-type
output_size,
align_corners,
scale_factors,
)
if input.dim() == 5 and mode == "trilinear":
assert align_corners is not None
@ -4793,16 +4866,28 @@ def interpolate( # noqa: F811
"torch._decomp.decompositions"
)._upsample_linear_vec(input, output_size, align_corners, scale_factors)
return torch._C._nn.upsample_trilinear3d(
input, output_size, align_corners, scale_factors
input,
# pyrefly: ignore # bad-argument-type
output_size,
align_corners,
scale_factors,
)
if input.dim() == 4 and mode == "bicubic":
assert align_corners is not None
if antialias:
return torch._C._nn._upsample_bicubic2d_aa(
input, output_size, align_corners, scale_factors
input,
# pyrefly: ignore # bad-argument-type
output_size,
align_corners,
scale_factors,
)
return torch._C._nn.upsample_bicubic2d(
input, output_size, align_corners, scale_factors
input,
# pyrefly: ignore # bad-argument-type
output_size,
align_corners,
scale_factors,
)
if input.dim() == 3 and mode == "bilinear":
@ -4834,6 +4919,7 @@ def upsample_nearest( # noqa: F811
input: Tensor,
size: Optional[int] = None,
scale_factor: Optional[float] = None,
# pyrefly: ignore # bad-return
) -> Tensor:
pass
@ -4843,6 +4929,7 @@ def upsample_nearest( # noqa: F811
input: Tensor,
size: Optional[list[int]] = None,
scale_factor: Optional[float] = None,
# pyrefly: ignore # bad-return
) -> Tensor:
pass
@ -4884,6 +4971,7 @@ def upsample_bilinear( # noqa: F811
input: Tensor,
size: Optional[int] = None,
scale_factor: Optional[float] = None,
# pyrefly: ignore # bad-return
) -> Tensor:
pass
@ -4893,6 +4981,7 @@ def upsample_bilinear( # noqa: F811
input: Tensor,
size: Optional[list[int]] = None,
scale_factor: Optional[float] = None,
# pyrefly: ignore # bad-return
) -> Tensor:
pass
@ -4902,6 +4991,7 @@ def upsample_bilinear( # noqa: F811
input: Tensor,
size: Optional[int] = None,
scale_factor: Optional[list[float]] = None,
# pyrefly: ignore # bad-return
) -> Tensor:
pass
@ -4911,6 +5001,7 @@ def upsample_bilinear( # noqa: F811
input: Tensor,
size: Optional[list[int]] = None,
scale_factor: Optional[list[float]] = None,
# pyrefly: ignore # bad-return
) -> Tensor:
pass
@ -5717,6 +5808,7 @@ def _in_projection_packed(
.squeeze(-2)
.contiguous()
)
# pyrefly: ignore # bad-return
return proj[0], proj[1], proj[2]
else:
# encoder-decoder attention
@ -5735,6 +5827,7 @@ def _in_projection_packed(
.squeeze(-2)
.contiguous()
)
# pyrefly: ignore # bad-return
return (q_proj, kv_proj[0], kv_proj[1])
else:
w_q, w_k, w_v = w.chunk(3)
@ -5742,6 +5835,7 @@ def _in_projection_packed(
b_q = b_k = b_v = None
else:
b_q, b_k, b_v = b.chunk(3)
# pyrefly: ignore # bad-return
return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
@ -6372,8 +6466,10 @@ def multi_head_attention_forward(
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
if attn_mask is not None:
# pyrefly: ignore # bad-argument-type
attn_mask = pad(attn_mask, (0, 1))
if key_padding_mask is not None:
# pyrefly: ignore # bad-argument-type
key_padding_mask = pad(key_padding_mask, (0, 1))
else:
assert bias_k is None
@ -6382,8 +6478,10 @@ def multi_head_attention_forward(
#
# reshape q, k, v for multihead attention and make them batch first
#
# pyrefly: ignore # no-matching-overload
q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
if static_k is None:
# pyrefly: ignore # no-matching-overload
k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
@ -6395,6 +6493,7 @@ def multi_head_attention_forward(
)
k = static_k
if static_v is None:
# pyrefly: ignore # no-matching-overload
v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
@ -6410,14 +6509,20 @@ def multi_head_attention_forward(
if add_zero_attn:
zero_attn_shape = (bsz * num_heads, 1, head_dim)
k = torch.cat(
[k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1
# pyrefly: ignore # no-matching-overload
[k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)],
dim=1,
)
v = torch.cat(
[v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1
# pyrefly: ignore # no-matching-overload
[v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)],
dim=1,
)
if attn_mask is not None:
# pyrefly: ignore # bad-argument-type
attn_mask = pad(attn_mask, (0, 1))
if key_padding_mask is not None:
# pyrefly: ignore # bad-argument-type
key_padding_mask = pad(key_padding_mask, (0, 1))
# update source sequence length after adjustments
@ -6467,6 +6572,7 @@ def multi_head_attention_forward(
attn_output = torch.bmm(attn_output_weights, v)
attn_output = (
# pyrefly: ignore # no-matching-overload
attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
)
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
@ -6493,13 +6599,16 @@ def multi_head_attention_forward(
attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)
q = q.view(bsz, num_heads, tgt_len, head_dim)
# pyrefly: ignore # no-matching-overload
k = k.view(bsz, num_heads, src_len, head_dim)
# pyrefly: ignore # no-matching-overload
v = v.view(bsz, num_heads, src_len, head_dim)
attn_output = scaled_dot_product_attention(
q, k, v, attn_mask, dropout_p, is_causal
)
attn_output = (
# pyrefly: ignore # no-matching-overload
attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
)

View File

@ -500,6 +500,7 @@ def xavier_normal_(
def _calculate_correct_fan(tensor: Tensor, mode: _FanMode) -> int:
# pyrefly: ignore # bad-assignment
mode = mode.lower()
valid_modes = ["fan_in", "fan_out"]
if mode not in valid_modes:

View File

@ -6,6 +6,7 @@ from torch.autograd.function import Function
class SyncBatchNorm(Function):
@staticmethod
# pyrefly: ignore # bad-override
def forward(
self,
input,
@ -210,6 +211,7 @@ class SyncBatchNorm(Function):
class CrossMapLRN2d(Function):
@staticmethod
# pyrefly: ignore # bad-override
def forward(ctx, input, size, alpha=1e-4, beta=0.75, k=1):
ctx.size = size
ctx.alpha = alpha
@ -265,6 +267,7 @@ class CrossMapLRN2d(Function):
return output
@staticmethod
# pyrefly: ignore # bad-override
def backward(ctx, grad_output):
input, output = ctx.saved_tensors
grad_input = grad_output.new()
@ -306,6 +309,7 @@ class CrossMapLRN2d(Function):
class BackwardHookFunction(torch.autograd.Function):
@staticmethod
# pyrefly: ignore # bad-override
def forward(ctx, *args):
ctx.mark_non_differentiable(*[arg for arg in args if not arg.requires_grad])
return args

View File

@ -72,6 +72,7 @@ class _NormBase(Module):
torch.tensor(
0,
dtype=torch.long,
# pyrefly: ignore # bad-argument-type
**{k: v for k, v in factory_kwargs.items() if k != "dtype"},
),
)
@ -221,6 +222,7 @@ class _LazyNormBase(LazyModuleMixin, _NormBase):
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
# pyrefly: ignore # bad-argument-type
super().__init__(
# affine and track_running_stats are hardcoded to False to
# avoid creating tensors that will soon be overwritten.
@ -234,22 +236,29 @@ class _LazyNormBase(LazyModuleMixin, _NormBase):
self.affine = affine
self.track_running_stats = track_running_stats
if self.affine:
# pyrefly: ignore # bad-argument-type
self.weight = UninitializedParameter(**factory_kwargs)
# pyrefly: ignore # bad-argument-type
self.bias = UninitializedParameter(**factory_kwargs)
if self.track_running_stats:
# pyrefly: ignore # bad-argument-type
self.running_mean = UninitializedBuffer(**factory_kwargs)
# pyrefly: ignore # bad-argument-type
self.running_var = UninitializedBuffer(**factory_kwargs)
self.num_batches_tracked = torch.tensor(
0,
dtype=torch.long,
# pyrefly: ignore # bad-argument-type
**{k: v for k, v in factory_kwargs.items() if k != "dtype"},
)
def reset_parameters(self) -> None:
# pyrefly: ignore # bad-argument-type
if not self.has_uninitialized_params() and self.num_features != 0:
super().reset_parameters()
def initialize_parameters(self, input) -> None: # type: ignore[override]
# pyrefly: ignore # bad-argument-type
if self.has_uninitialized_params():
self.num_features = input.shape[1]
if self.affine:

View File

@ -109,6 +109,7 @@ class Sequential(Module):
def __init__(self, *args: Module) -> None: ...
@overload
# pyrefly: ignore # inconsistent-overload
def __init__(self, arg: OrderedDict[str, Module]) -> None: ...
def __init__(self, *args):
@ -472,6 +473,7 @@ class ModuleList(Module):
return self
def pop(self, key: Union[int, slice]) -> Module:
# pyrefly: ignore # index-error
v = self[key]
del self[key]
return v
@ -623,9 +625,11 @@ class ModuleDict(Module):
"ModuleDict update sequence element "
"#" + str(j) + " should be Iterable; is" + type(m).__name__
)
# pyrefly: ignore # bad-argument-type
if not len(m) == 2:
raise ValueError(
"ModuleDict update sequence element "
# pyrefly: ignore # bad-argument-type
"#" + str(j) + " has length " + str(len(m)) + "; 2 is required"
)
# modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)]
@ -684,6 +688,7 @@ class ParameterList(Module):
def __getitem__(self, idx: int) -> Any: ...
@overload
# pyrefly: ignore # inconsistent-overload
def __getitem__(self: T, idx: slice) -> T: ...
def __getitem__(self, idx):
@ -769,9 +774,11 @@ class ParameterList(Module):
size_str,
device_str,
)
# pyrefly: ignore # bad-argument-type
child_lines.append(" (" + str(k) + "): " + parastr)
else:
child_lines.append(
# pyrefly: ignore # bad-argument-type
" (" + str(k) + "): Object of type: " + type(p).__name__
)
@ -979,9 +986,11 @@ class ParameterDict(Module):
"ParameterDict update sequence element "
"#" + str(j) + " should be Iterable; is" + type(p).__name__
)
# pyrefly: ignore # bad-argument-type
if not len(p) == 2:
raise ValueError(
"ParameterDict update sequence element "
# pyrefly: ignore # bad-argument-type
"#" + str(j) + " has length " + str(len(p)) + "; 2 is required"
)
# parameters as length-2 list too cumbersome to type, see ModuleDict.update comment
@ -1002,9 +1011,11 @@ class ParameterDict(Module):
size_str,
device_str,
)
# pyrefly: ignore # bad-argument-type
child_lines.append(" (" + str(k) + "): " + parastr)
else:
child_lines.append(
# pyrefly: ignore # bad-argument-type
" (" + str(k) + "): Object of type: " + type(p).__name__
)
tmpstr = "\n".join(child_lines)

View File

@ -363,6 +363,7 @@ class Conv1d(_ConvNd):
self.dilation,
self.groups,
)
# pyrefly: ignore # no-matching-overload
return F.conv1d(
input, weight, bias, self.stride, self.padding, self.dilation, self.groups
)
@ -540,6 +541,7 @@ class Conv2d(_ConvNd):
self.dilation,
self.groups,
)
# pyrefly: ignore # no-matching-overload
return F.conv2d(
input, weight, bias, self.stride, self.padding, self.dilation, self.groups
)
@ -709,6 +711,7 @@ class Conv3d(_ConvNd):
self.dilation,
self.groups,
)
# pyrefly: ignore # no-matching-overload
return F.conv3d(
input, weight, bias, self.stride, self.padding, self.dilation, self.groups
)
@ -1494,6 +1497,7 @@ class LazyConv1d(_LazyConvXdMixin, Conv1d): # type: ignore[misc]
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
# pyrefly: ignore # bad-argument-type
super().__init__(
0,
0,
@ -1508,9 +1512,11 @@ class LazyConv1d(_LazyConvXdMixin, Conv1d): # type: ignore[misc]
padding_mode,
**factory_kwargs,
)
# pyrefly: ignore # bad-override, bad-argument-type
self.weight = UninitializedParameter(**factory_kwargs)
self.out_channels = out_channels
if bias:
# pyrefly: ignore # bad-override, bad-argument-type
self.bias = UninitializedParameter(**factory_kwargs)
def _get_num_spatial_dims(self) -> int:
@ -1563,6 +1569,7 @@ class LazyConv2d(_LazyConvXdMixin, Conv2d): # type: ignore[misc]
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
# pyrefly: ignore # bad-argument-type
super().__init__(
0,
0,
@ -1577,9 +1584,11 @@ class LazyConv2d(_LazyConvXdMixin, Conv2d): # type: ignore[misc]
padding_mode,
**factory_kwargs,
)
# pyrefly: ignore # bad-override, bad-argument-type
self.weight = UninitializedParameter(**factory_kwargs)
self.out_channels = out_channels
if bias:
# pyrefly: ignore # bad-override, bad-argument-type
self.bias = UninitializedParameter(**factory_kwargs)
def _get_num_spatial_dims(self) -> int:
@ -1633,6 +1642,7 @@ class LazyConv3d(_LazyConvXdMixin, Conv3d): # type: ignore[misc]
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
# pyrefly: ignore # bad-argument-type
super().__init__(
0,
0,
@ -1647,9 +1657,11 @@ class LazyConv3d(_LazyConvXdMixin, Conv3d): # type: ignore[misc]
padding_mode,
**factory_kwargs,
)
# pyrefly: ignore # bad-override, bad-argument-type
self.weight = UninitializedParameter(**factory_kwargs)
self.out_channels = out_channels
if bias:
# pyrefly: ignore # bad-override, bad-argument-type
self.bias = UninitializedParameter(**factory_kwargs)
def _get_num_spatial_dims(self) -> int:
@ -1701,6 +1713,7 @@ class LazyConvTranspose1d(_LazyConvXdMixin, ConvTranspose1d): # type: ignore[mi
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
# pyrefly: ignore # bad-argument-type
super().__init__(
0,
0,
@ -1716,9 +1729,11 @@ class LazyConvTranspose1d(_LazyConvXdMixin, ConvTranspose1d): # type: ignore[mi
padding_mode,
**factory_kwargs,
)
# pyrefly: ignore # bad-override, bad-argument-type
self.weight = UninitializedParameter(**factory_kwargs)
self.out_channels = out_channels
if bias:
# pyrefly: ignore # bad-override, bad-argument-type
self.bias = UninitializedParameter(**factory_kwargs)
def _get_num_spatial_dims(self) -> int:
@ -1770,6 +1785,7 @@ class LazyConvTranspose2d(_LazyConvXdMixin, ConvTranspose2d): # type: ignore[mi
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
# pyrefly: ignore # bad-argument-type
super().__init__(
0,
0,
@ -1785,9 +1801,11 @@ class LazyConvTranspose2d(_LazyConvXdMixin, ConvTranspose2d): # type: ignore[mi
padding_mode,
**factory_kwargs,
)
# pyrefly: ignore # bad-override, bad-argument-type
self.weight = UninitializedParameter(**factory_kwargs)
self.out_channels = out_channels
if bias:
# pyrefly: ignore # bad-override, bad-argument-type
self.bias = UninitializedParameter(**factory_kwargs)
def _get_num_spatial_dims(self) -> int:
@ -1839,6 +1857,7 @@ class LazyConvTranspose3d(_LazyConvXdMixin, ConvTranspose3d): # type: ignore[mi
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
# pyrefly: ignore # bad-argument-type
super().__init__(
0,
0,
@ -1854,9 +1873,11 @@ class LazyConvTranspose3d(_LazyConvXdMixin, ConvTranspose3d): # type: ignore[mi
padding_mode,
**factory_kwargs,
)
# pyrefly: ignore # bad-override, bad-argument-type
self.weight = UninitializedParameter(**factory_kwargs)
self.out_channels = out_channels
if bias:
# pyrefly: ignore # bad-override, bad-argument-type
self.bias = UninitializedParameter(**factory_kwargs)
def _get_num_spatial_dims(self) -> int:

View File

@ -172,7 +172,9 @@ class LazyModuleMixin:
def __init__(self: _LazyProtocol, *args, **kwargs):
# Mypy doesn't like this super call in a mixin
super().__init__(*args, **kwargs) # type: ignore[misc]
# pyrefly: ignore # read-only
self._load_hook = self._register_load_state_dict_pre_hook(self._lazy_load_hook)
# pyrefly: ignore # read-only
self._initialize_hook = self.register_forward_pre_hook(
self._infer_parameters, with_kwargs=True
)

View File

@ -286,6 +286,7 @@ class LazyLinear(LazyModuleMixin, Linear):
"""
cls_to_become = Linear # type: ignore[assignment]
# pyrefly: ignore # bad-override
weight: UninitializedParameter
bias: UninitializedParameter # type: ignore[assignment]
@ -295,16 +296,20 @@ class LazyLinear(LazyModuleMixin, Linear):
factory_kwargs = {"device": device, "dtype": dtype}
# bias is hardcoded to False to avoid creating tensor
# that will soon be overwritten.
# pyrefly: ignore # bad-argument-type
super().__init__(0, 0, False)
# pyrefly: ignore # bad-argument-type
self.weight = UninitializedParameter(**factory_kwargs)
self.out_features = out_features
if bias:
# pyrefly: ignore # bad-argument-type
self.bias = UninitializedParameter(**factory_kwargs)
def reset_parameters(self) -> None:
"""
Resets parameters based on their initialization used in ``__init__``.
"""
# pyrefly: ignore # bad-argument-type
if not self.has_uninitialized_params() and self.in_features != 0:
super().reset_parameters()
@ -312,6 +317,7 @@ class LazyLinear(LazyModuleMixin, Linear):
"""
Infers ``in_features`` based on ``input`` and initializes parameters.
"""
# pyrefly: ignore # bad-argument-type
if self.has_uninitialized_params():
with torch.no_grad():
self.in_features = input.shape[-1]

View File

@ -38,11 +38,13 @@ T = TypeVar("T", bound="Module")
class _IncompatibleKeys(
# pyrefly: ignore # invalid-inheritance
namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"]),
):
__slots__ = ()
def __repr__(self) -> str:
# pyrefly: ignore # missing-attribute
if not self.missing_keys and not self.unexpected_keys:
return "<All keys matched successfully>"
return super().__repr__()
@ -91,6 +93,7 @@ class _WrappedHook:
def __getstate__(self) -> dict:
result = {"hook": self.hook, "with_module": self.with_module}
if self.with_module:
# pyrefly: ignore # unsupported-operation
result["module"] = self.module()
return result
@ -976,7 +979,9 @@ class Module:
# Decrement use count of the gradient by setting to None
param.grad = None
param_applied = torch.nn.Parameter(
param_applied, requires_grad=param.requires_grad
# pyrefly: ignore # bad-argument-type
param_applied,
requires_grad=param.requires_grad,
)
torch.utils.swap_tensors(param, param_applied)
except Exception as e:
@ -987,11 +992,13 @@ class Module:
) from e
out_param = param
elif p_should_use_set_data:
# pyrefly: ignore # bad-assignment
param.data = param_applied
out_param = param
else:
assert isinstance(param, Parameter)
assert param.is_leaf
# pyrefly: ignore # bad-argument-type
out_param = Parameter(param_applied, param.requires_grad)
self._parameters[key] = out_param
@ -2253,6 +2260,7 @@ class Module:
if destination is None:
destination = OrderedDict()
# pyrefly: ignore # missing-attribute
destination._metadata = OrderedDict()
local_metadata = dict(version=self._version)
@ -2402,7 +2410,9 @@ class Module:
if k not in self._non_persistent_buffers_set
}
local_name_params = itertools.chain(
self._parameters.items(), persistent_buffers.items()
self._parameters.items(),
# pyrefly: ignore # bad-argument-type
persistent_buffers.items(),
)
local_state = {k: v for k, v in local_name_params if v is not None}
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)

View File

@ -84,6 +84,7 @@ class CircularPad1d(_CircularPadNd):
[5., 6., 7., 4., 5., 6., 7., 4.]]])
"""
# pyrefly: ignore # bad-override
padding: tuple[int, int]
def __init__(self, padding: _size_2_t) -> None:
@ -144,6 +145,7 @@ class CircularPad2d(_CircularPadNd):
[8., 6., 7., 8., 6.]]]])
"""
# pyrefly: ignore # bad-override
padding: tuple[int, int, int, int]
def __init__(self, padding: _size_4_t) -> None:
@ -194,6 +196,7 @@ class CircularPad3d(_CircularPadNd):
>>> output = m(input)
"""
# pyrefly: ignore # bad-override
padding: tuple[int, int, int, int, int, int]
def __init__(self, padding: _size_6_t) -> None:
@ -265,6 +268,7 @@ class ConstantPad1d(_ConstantPadNd):
[ 3.5000, 3.5000, 3.5000, -3.6372, 0.1182, -1.8652, 3.5000]]])
"""
# pyrefly: ignore # bad-override
padding: tuple[int, int]
def __init__(self, padding: _size_2_t, value: float) -> None:
@ -316,6 +320,7 @@ class ConstantPad2d(_ConstantPadNd):
"""
__constants__ = ["padding", "value"]
# pyrefly: ignore # bad-override
padding: tuple[int, int, int, int]
def __init__(self, padding: _size_4_t, value: float) -> None:
@ -356,6 +361,7 @@ class ConstantPad3d(_ConstantPadNd):
>>> output = m(input)
"""
# pyrefly: ignore # bad-override
padding: tuple[int, int, int, int, int, int]
def __init__(self, padding: _size_6_t, value: float) -> None:
@ -409,6 +415,7 @@ class ReflectionPad1d(_ReflectionPadNd):
[7., 6., 5., 4., 5., 6., 7., 6.]]])
"""
# pyrefly: ignore # bad-override
padding: tuple[int, int]
def __init__(self, padding: _size_2_t) -> None:
@ -462,6 +469,7 @@ class ReflectionPad2d(_ReflectionPadNd):
[7., 6., 7., 8., 7.]]]])
"""
# pyrefly: ignore # bad-override
padding: tuple[int, int, int, int]
def __init__(self, padding: _size_4_t) -> None:
@ -517,6 +525,7 @@ class ReflectionPad3d(_ReflectionPadNd):
[1., 0., 1., 0.]]]]])
"""
# pyrefly: ignore # bad-override
padding: tuple[int, int, int, int, int, int]
def __init__(self, padding: _size_6_t) -> None:
@ -570,6 +579,7 @@ class ReplicationPad1d(_ReplicationPadNd):
[4., 4., 4., 4., 5., 6., 7., 7.]]])
"""
# pyrefly: ignore # bad-override
padding: tuple[int, int]
def __init__(self, padding: _size_2_t) -> None:
@ -623,6 +633,7 @@ class ReplicationPad2d(_ReplicationPadNd):
[6., 6., 7., 8., 8.]]]])
"""
# pyrefly: ignore # bad-override
padding: tuple[int, int, int, int]
def __init__(self, padding: _size_4_t) -> None:
@ -665,6 +676,7 @@ class ReplicationPad3d(_ReplicationPadNd):
>>> output = m(input)
"""
# pyrefly: ignore # bad-override
padding: tuple[int, int, int, int, int, int]
def __init__(self, padding: _size_6_t) -> None:

View File

@ -111,6 +111,7 @@ class RNNBase(Module):
if (
not isinstance(dropout, numbers.Number)
# pyrefly: ignore # unsupported-operation
or not 0 <= dropout <= 1
or isinstance(dropout, bool)
):
@ -119,6 +120,7 @@ class RNNBase(Module):
"representing the probability of an element being "
"zeroed"
)
# pyrefly: ignore # unsupported-operation
if dropout > 0 and num_layers == 1:
warnings.warn(
"dropout option adds dropout after all but last "
@ -639,15 +641,22 @@ class RNN(RNNBase):
@overload
@torch._jit_internal._overload_method # noqa: F811
# pyrefly: ignore # bad-override
def forward(
self, input: Tensor, hx: Optional[Tensor] = None
self,
input: Tensor,
hx: Optional[Tensor] = None,
# pyrefly: ignore # bad-return
) -> tuple[Tensor, Tensor]:
pass
@overload
@torch._jit_internal._overload_method # noqa: F811
def forward(
self, input: PackedSequence, hx: Optional[Tensor] = None
self,
input: PackedSequence,
hx: Optional[Tensor] = None,
# pyrefly: ignore # bad-return
) -> tuple[PackedSequence, Tensor]:
pass
@ -772,7 +781,11 @@ class RNN(RNNBase):
if isinstance(orig_input, PackedSequence):
output_packed = PackedSequence(
output, batch_sizes, sorted_indices, unsorted_indices
output,
# pyrefly: ignore # bad-argument-type
batch_sizes,
sorted_indices,
unsorted_indices,
)
return output_packed, self.permute_hidden(hidden, unsorted_indices)
@ -996,6 +1009,7 @@ class LSTM(RNNBase):
# In the future, we should prevent mypy from applying contravariance rules here.
# See torch/nn/modules/module.py::_forward_unimplemented
# pyrefly: ignore # bad-override
def check_forward_args(
self,
input: Tensor,
@ -1029,8 +1043,12 @@ class LSTM(RNNBase):
# Same as above, see torch/nn/modules/module.py::_forward_unimplemented
@overload # type: ignore[override]
@torch._jit_internal._overload_method # noqa: F811
# pyrefly: ignore # bad-override
def forward(
self, input: Tensor, hx: Optional[tuple[Tensor, Tensor]] = None
self,
input: Tensor,
hx: Optional[tuple[Tensor, Tensor]] = None,
# pyrefly: ignore # bad-return
) -> tuple[Tensor, tuple[Tensor, Tensor]]: # noqa: F811
pass
@ -1038,7 +1056,10 @@ class LSTM(RNNBase):
@overload
@torch._jit_internal._overload_method # noqa: F811
def forward(
self, input: PackedSequence, hx: Optional[tuple[Tensor, Tensor]] = None
self,
input: PackedSequence,
hx: Optional[tuple[Tensor, Tensor]] = None,
# pyrefly: ignore # bad-return
) -> tuple[PackedSequence, tuple[Tensor, Tensor]]: # noqa: F811
pass
@ -1152,7 +1173,11 @@ class LSTM(RNNBase):
# xxx: isinstance check needs to be in conditional for TorchScript to compile
if isinstance(orig_input, PackedSequence):
output_packed = PackedSequence(
output, batch_sizes, sorted_indices, unsorted_indices
output,
# pyrefly: ignore # bad-argument-type
batch_sizes,
sorted_indices,
unsorted_indices,
)
return output_packed, self.permute_hidden(hidden, unsorted_indices)
else:
@ -1318,15 +1343,22 @@ class GRU(RNNBase):
@overload # type: ignore[override]
@torch._jit_internal._overload_method # noqa: F811
# pyrefly: ignore # bad-override
def forward(
self, input: Tensor, hx: Optional[Tensor] = None
self,
input: Tensor,
hx: Optional[Tensor] = None,
# pyrefly: ignore # bad-return
) -> tuple[Tensor, Tensor]: # noqa: F811
pass
@overload
@torch._jit_internal._overload_method # noqa: F811
def forward(
self, input: PackedSequence, hx: Optional[Tensor] = None
self,
input: PackedSequence,
hx: Optional[Tensor] = None,
# pyrefly: ignore # bad-return
) -> tuple[PackedSequence, Tensor]: # noqa: F811
pass
@ -1420,7 +1452,11 @@ class GRU(RNNBase):
# xxx: isinstance check needs to be in conditional for TorchScript to compile
if isinstance(orig_input, PackedSequence):
output_packed = PackedSequence(
output, batch_sizes, sorted_indices, unsorted_indices
output,
# pyrefly: ignore # bad-argument-type
batch_sizes,
sorted_indices,
unsorted_indices,
)
return output_packed, self.permute_hidden(hidden, unsorted_indices)
else:

View File

@ -135,7 +135,11 @@ class Transformer(Module):
**factory_kwargs,
)
encoder_norm = LayerNorm(
d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs
d_model,
eps=layer_norm_eps,
bias=bias,
# pyrefly: ignore # bad-argument-type
**factory_kwargs,
)
self.encoder = TransformerEncoder(
encoder_layer, num_encoder_layers, encoder_norm
@ -157,7 +161,11 @@ class Transformer(Module):
**factory_kwargs,
)
decoder_norm = LayerNorm(
d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs
d_model,
eps=layer_norm_eps,
bias=bias,
# pyrefly: ignore # bad-argument-type
**factory_kwargs,
)
self.decoder = TransformerDecoder(
decoder_layer, num_decoder_layers, decoder_norm
@ -760,7 +768,9 @@ class TransformerEncoderLayer(Module):
self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs)
self.norm_first = norm_first
# pyrefly: ignore # bad-argument-type
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
# pyrefly: ignore # bad-argument-type
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
self.dropout1 = Dropout(dropout)
self.dropout2 = Dropout(dropout)
@ -1052,8 +1062,11 @@ class TransformerDecoderLayer(Module):
self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs)
self.norm_first = norm_first
# pyrefly: ignore # bad-argument-type
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
# pyrefly: ignore # bad-argument-type
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
# pyrefly: ignore # bad-argument-type
self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
self.dropout1 = Dropout(dropout)
self.dropout2 = Dropout(dropout)

View File

@ -36,6 +36,7 @@ def _list_with_default(out_size: list[int], defaults: list[int]) -> list[int]:
import torch
if isinstance(out_size, (int, torch.SymInt)):
# pyrefly: ignore # bad-return
return out_size
if len(defaults) <= len(out_size):
raise ValueError(f"Input dimension should be at least {len(out_size) + 1}")

View File

@ -43,6 +43,7 @@ def broadcast(tensor, devices=None, *, out=None):
devices = [_get_device_index(d) for d in devices]
return torch._C._broadcast(tensor, devices)
else:
# pyrefly: ignore # bad-argument-type
return torch._C._broadcast_out(tensor, out)
@ -200,6 +201,7 @@ def scatter(tensor, devices=None, chunk_sizes=None, dim=0, streams=None, *, out=
"""
tensor = _handle_complex(tensor)
if out is None:
# pyrefly: ignore # not-iterable
devices = [_get_device_index(d) for d in devices]
return tuple(torch._C._scatter(tensor, devices, chunk_sizes, dim, streams))
else:

View File

@ -160,6 +160,7 @@ class DataParallel(Module, Generic[T]):
self.module = module
self.device_ids = [_get_device_index(x, True) for x in device_ids]
self.output_device = _get_device_index(output_device, True)
# pyrefly: ignore # read-only
self.src_device_obj = torch.device(device_type, self.device_ids[0])
if device_type == "cuda":
@ -173,6 +174,7 @@ class DataParallel(Module, Generic[T]):
if not self.device_ids:
return self.module(*inputs, **kwargs)
# pyrefly: ignore # bad-argument-type
for t in chain(self.module.parameters(), self.module.buffers()):
if t.device != self.src_device_obj:
raise RuntimeError(
@ -259,8 +261,10 @@ def data_parallel(
device_ids = [_get_device_index(x, True) for x in device_ids]
output_device = _get_device_index(output_device, True)
# pyrefly: ignore # no-matching-overload
src_device_obj = torch.device(device_type, device_ids[0])
# pyrefly: ignore # bad-argument-type
for t in chain(module.parameters(), module.buffers()):
if t.device != src_device_obj:
raise RuntimeError(

View File

@ -241,6 +241,7 @@ class _BufferCommHook:
# is completed.
class _DDPSink(Function):
@staticmethod
# pyrefly: ignore # bad-override
def forward(ctx, ddp_weakref, *inputs):
# set_materialize_grads(False) will ensure that None gradients stay as
# None and are not filled with zeros.
@ -691,6 +692,7 @@ class DistributedDataParallel(Module, Joinable):
elif process_group is None and device_mesh is None:
self.process_group = _get_default_group()
elif device_mesh is None:
# pyrefly: ignore # bad-assignment
self.process_group = process_group
else:
if device_mesh.ndim != 1:
@ -779,11 +781,13 @@ class DistributedDataParallel(Module, Joinable):
self.device_ids = None
self.output_device = None
else:
# pyrefly: ignore # bad-assignment
self.device_ids = [_get_device_index(x, True) for x in device_ids]
if output_device is None:
output_device = device_ids[0]
# pyrefly: ignore # bad-assignment
self.output_device = _get_device_index(output_device, True)
self.static_graph = False
@ -933,6 +937,7 @@ class DistributedDataParallel(Module, Joinable):
# enabled.
self._accum_grad_hooks: list[RemovableHandle] = []
if self._use_python_reducer:
# pyrefly: ignore # bad-assignment
torch._inductor.config._fuse_ddp_communication = True
torch._inductor.config._fuse_ddp_bucket_size = bucket_cap_mb
# Directly adding this to the trace rule will disturb the users

View File

@ -56,12 +56,16 @@ def scatter(inputs, target_gpus, dim=0):
if isinstance(obj, torch.Tensor):
return Scatter.apply(target_gpus, None, dim, obj)
if _is_namedtuple(obj):
# pyrefly: ignore # no-matching-overload
return [type(obj)(*args) for args in zip(*map(scatter_map, obj))]
if isinstance(obj, tuple) and len(obj) > 0:
# pyrefly: ignore # no-matching-overload
return list(zip(*map(scatter_map, obj)))
if isinstance(obj, list) and len(obj) > 0:
# pyrefly: ignore # no-matching-overload
return [list(i) for i in zip(*map(scatter_map, obj))]
if isinstance(obj, dict) and len(obj) > 0:
# pyrefly: ignore # no-matching-overload
return [type(obj)(i) for i in zip(*map(scatter_map, obj.items()))]
return [obj for _ in target_gpus]
@ -123,9 +127,12 @@ def gather(outputs: Any, target_device: Union[int, torch.device], dim: int = 0)
if isinstance(out, dict):
if not all(len(out) == len(d) for d in outputs):
raise ValueError("All dicts must have the same number of keys")
# pyrefly: ignore # not-callable
return type(out)((k, gather_map([d[k] for d in outputs])) for k in out)
if _is_namedtuple(out):
# pyrefly: ignore # no-matching-overload
return type(out)._make(map(gather_map, zip(*outputs)))
# pyrefly: ignore # no-matching-overload
return type(out)(map(gather_map, zip(*outputs)))
# Recursive function calls like this create reference cycles.

View File

@ -81,6 +81,7 @@ class Parameter(torch.Tensor, metaclass=_ParameterMeta):
memo[id(self)] = result
return result
# pyrefly: ignore # bad-override
def __repr__(self):
return "Parameter containing:\n" + super().__repr__()
@ -143,6 +144,7 @@ class UninitializedTensorMixin:
if dtype is None:
dtype = self.data.dtype
self.data = torch.empty(shape, device=device, dtype=dtype)
# pyrefly: ignore # bad-override, missing-attribute
self.__class__ = self.cls_to_become
@property
@ -166,6 +168,7 @@ class UninitializedTensorMixin:
def __reduce_ex__(self, proto):
# See Note [Don't serialize hooks]
# pyrefly: ignore # missing-attribute
return (self.__class__, (self.requires_grad,))
@classmethod
@ -175,6 +178,7 @@ class UninitializedTensorMixin:
if func in cls._allowed_methods or func.__class__.__name__ == "method-wrapper":
if kwargs is None:
kwargs = {}
# pyrefly: ignore # missing-attribute
return super().__torch_function__(func, types, args, kwargs)
raise ValueError(
f"Attempted to use an uninitialized parameter in {func}. "
@ -216,6 +220,7 @@ class UninitializedParameter(UninitializedTensorMixin, Parameter):
def __new__(cls, requires_grad=True, device=None, dtype=None) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
data = torch.empty(0, **factory_kwargs)
# pyrefly: ignore # bad-return
return torch.Tensor._make_subclass(cls, data, requires_grad)
def __deepcopy__(self, memo):
@ -261,7 +266,9 @@ class Buffer(torch.Tensor, metaclass=_BufferMeta):
data = torch.empty(0)
t = data.detach().requires_grad_(data.requires_grad)
# pyrefly: ignore # missing-attribute
t.persistent = persistent
# pyrefly: ignore # missing-attribute
t._is_buffer = True
return t
@ -292,6 +299,9 @@ class UninitializedBuffer(UninitializedTensorMixin, torch.Tensor):
factory_kwargs = {"device": device, "dtype": dtype}
data = torch.empty(0, **factory_kwargs)
ret = torch.Tensor._make_subclass(cls, data, requires_grad)
# pyrefly: ignore # missing-attribute
ret.persistent = persistent
# pyrefly: ignore # missing-attribute
ret._is_buffer = True
# pyrefly: ignore # bad-return
return ret

View File

@ -1,5 +1,5 @@
from . import parametrizations, parametrize, rnn, stateless
from .clip_grad import (
from .clip_grad import ( # pyrefly: ignore # deprecated
_clip_grads_with_norm_ as clip_grads_with_norm_,
_get_total_norm as get_total_norm,
clip_grad_norm,

View File

@ -24,6 +24,7 @@ from .expanded_weights_utils import forward_helper
@implements_per_sample_grads(F.conv3d)
class ConvPerSampleGrad(torch.autograd.Function):
@staticmethod
# pyrefly: ignore # bad-override
def forward(
ctx: Any,
kwarg_names: list[str],
@ -56,6 +57,7 @@ class ConvPerSampleGrad(torch.autograd.Function):
f"unbatched input of dim {input.dim()}, expected input of dim {batched_dim_size}"
)
# pyrefly: ignore # invalid-type-var
ctx.conv_fn = conv_fn
ctx.batch_size = orig_input.shape[0]

View File

@ -237,6 +237,7 @@ def conv_unfold_weight_grad_sample(
# n=batch_sz; o=num_out_channels; p=(num_in_channels/groups)*kernel_sz
weight_grad_sample = torch.einsum("noq,npq->nop", grad_output, input)
# rearrange the above tensor and extract diagonals.
# pyrefly: ignore # no-matching-overload
weight_grad_sample = weight_grad_sample.view(
n,
groups,

View File

@ -14,6 +14,7 @@ from .expanded_weights_utils import (
@implements_per_sample_grads(F.embedding)
class EmbeddingPerSampleGrad(torch.autograd.Function):
@staticmethod
# pyrefly: ignore # bad-override
def forward(
ctx: Any, kwarg_names: list[str], _: Any, *expanded_args_and_kwargs: Any
) -> torch.Tensor:
@ -34,6 +35,7 @@ class EmbeddingPerSampleGrad(torch.autograd.Function):
return output
@staticmethod
# pyrefly: ignore # bad-override
def backward(
ctx: Any, grad_output: torch.Tensor
) -> tuple[Optional[torch.Tensor], ...]:

View File

@ -131,7 +131,9 @@ class ExpandedWeight(torch.Tensor):
# in aten, choosing the input or data variants is done by parsing logic. This mimics some of that
decomp_opts = expanded_weights_rnn_decomps[func]
use_input_variant = isinstance(
args[2], list
# pyrefly: ignore # index-error
args[2],
list,
) # data variant uses a list here
decomp = decomp_opts[0] if use_input_variant else decomp_opts[1]

View File

@ -8,6 +8,7 @@ from .expanded_weights_impl import ExpandedWeight
def is_batch_first(expanded_args_and_kwargs):
batch_first = None
# pyrefly: ignore # bad-assignment
for arg in expanded_args_and_kwargs:
if not isinstance(arg, ExpandedWeight):
continue

View File

@ -18,6 +18,7 @@ from .expanded_weights_utils import (
@implements_per_sample_grads(F.group_norm)
class GroupNormPerSampleGrad(torch.autograd.Function):
@staticmethod
# pyrefly: ignore # bad-override
def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs):
expanded_args, expanded_kwargs = standard_kwargs(
kwarg_names, expanded_args_and_kwargs
@ -46,6 +47,7 @@ class GroupNormPerSampleGrad(torch.autograd.Function):
return output
@staticmethod
# pyrefly: ignore # bad-override
def backward(ctx, grad_output):
input, num_groups = ctx.input, ctx.num_groups
weight, bias, eps = ctx.weight, ctx.bias, ctx.eps
@ -94,7 +96,9 @@ class GroupNormPerSampleGrad(torch.autograd.Function):
set_grad_sample_if_exists(
weight,
lambda _: torch.einsum(
"ni...->ni", F.group_norm(input, num_groups, eps=eps) * grad_output
"ni...->ni",
# pyrefly: ignore # unsupported-operation
F.group_norm(input, num_groups, eps=eps) * grad_output,
),
)
if hasattr(ctx, "bias"):

View File

@ -17,6 +17,7 @@ from .expanded_weights_utils import (
@implements_per_sample_grads(F.instance_norm)
class InstanceNormPerSampleGrad(torch.autograd.Function):
@staticmethod
# pyrefly: ignore # bad-override
def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs):
instance_norm = partial(torch.instance_norm, cudnn_enabled=True)
expanded_args, expanded_kwargs = standard_kwargs(
@ -36,6 +37,7 @@ class InstanceNormPerSampleGrad(torch.autograd.Function):
return output
@staticmethod
# pyrefly: ignore # bad-override
def backward(ctx, grad_output):
input, running_mean, running_var = ctx.input, ctx.running_mean, ctx.running_var
weight, bias, eps = ctx.weight, ctx.bias, ctx.eps

View File

@ -17,6 +17,7 @@ from .expanded_weights_utils import (
@implements_per_sample_grads(F.layer_norm)
class LayerNormPerSampleGrad(torch.autograd.Function):
@staticmethod
# pyrefly: ignore # bad-override
def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs):
expanded_args, expanded_kwargs = standard_kwargs(
kwarg_names, expanded_args_and_kwargs
@ -42,6 +43,7 @@ class LayerNormPerSampleGrad(torch.autograd.Function):
return output
@staticmethod
# pyrefly: ignore # bad-override
def backward(ctx, grad_output):
def weight_per_sample_grad(weight):
return sum_over_all_but_batch_and_last_n(

View File

@ -16,6 +16,7 @@ from .expanded_weights_utils import (
@implements_per_sample_grads(F.linear)
class LinearPerSampleGrad(torch.autograd.Function):
@staticmethod
# pyrefly: ignore # bad-override
def forward(ctx, _, __, *expanded_args_and_kwargs):
if len(expanded_args_and_kwargs[0].shape) <= 1:
raise RuntimeError(
@ -35,6 +36,7 @@ class LinearPerSampleGrad(torch.autograd.Function):
return output
@staticmethod
# pyrefly: ignore # bad-override
def backward(ctx, grad_output):
input, weight = ctx.args
bias = ctx.kwargs["bias"]

View File

@ -77,6 +77,7 @@ def swap_tensor(
setattr(module, name, tensor)
elif hasattr(module, name):
delattr(module, name)
# pyrefly: ignore # bad-return
return orig_tensor

View File

@ -41,9 +41,11 @@ def _no_grad(func: Callable[_P, _R]) -> Callable[_P, _R]:
def _no_grad_wrapper(*args, **kwargs):
with torch.no_grad():
# pyrefly: ignore # invalid-param-spec
return func(*args, **kwargs)
functools.update_wrapper(_no_grad_wrapper, func)
# pyrefly: ignore # bad-return
return _no_grad_wrapper

View File

@ -84,6 +84,7 @@ def convert_conv2d_weight_memory_format(
)
for child in module.children():
convert_conv2d_weight_memory_format(child, memory_format)
# pyrefly: ignore # bad-return
return module
@ -163,6 +164,7 @@ def convert_conv3d_weight_memory_format(
)
for child in module.children():
convert_conv3d_weight_memory_format(child, memory_format)
# pyrefly: ignore # bad-return
return module

View File

@ -98,6 +98,7 @@ class _Orthogonal(Module):
)
# Q is now orthogonal (or unitary) of size (..., n, n)
if n != k:
# pyrefly: ignore # unbound-name
Q = Q[..., :k]
# Q is now the size of the X (albeit perhaps transposed)
else:

View File

@ -179,23 +179,28 @@ class ParametrizationList(ModuleList):
# Register the tensor(s)
if self.is_tensor:
# pyrefly: ignore # missing-attribute
if original.dtype != new.dtype:
raise ValueError(
"When `right_inverse` outputs one tensor, it may not change the dtype.\n"
f"original.dtype: {original.dtype}\n"
# pyrefly: ignore # missing-attribute
f"right_inverse(original).dtype: {new.dtype}"
)
# pyrefly: ignore # missing-attribute
if original.device != new.device:
raise ValueError(
"When `right_inverse` outputs one tensor, it may not change the device.\n"
f"original.device: {original.device}\n"
# pyrefly: ignore # missing-attribute
f"right_inverse(original).device: {new.device}"
)
# Set the original to original so that the user does not need to re-register the parameter
# manually in the optimiser
with torch.no_grad():
# pyrefly: ignore # bad-argument-type
_maybe_set(original, new)
_register_parameter_or_buffer(self, "original", original)
else:
@ -396,6 +401,7 @@ def _inject_property(module: Module, tensor_name: str) -> None:
if torch.jit.is_scripting():
raise RuntimeError("Parametrization is not working with scripting.")
parametrization = self.parametrizations[tensor_name]
# pyrefly: ignore # redundant-condition
if _cache_enabled:
if torch.jit.is_scripting():
# Scripting
@ -695,6 +701,7 @@ def remove_parametrizations(
# Fetch the original tensor
assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy
parametrizations = module.parametrizations[tensor_name]
# pyrefly: ignore # invalid-argument
if parametrizations.is_tensor:
original = parametrizations.original
assert isinstance(original, torch.Tensor), "is_tensor promised us a Tensor"

View File

@ -274,8 +274,11 @@ class PruningContainer(BasePruningMethod):
if not isinstance(args, Iterable): # only 1 item
self._tensor_name = args._tensor_name
self.add_pruning_method(args)
# pyrefly: ignore # bad-argument-type
elif len(args) == 1: # only 1 item in a tuple
# pyrefly: ignore # index-error
self._tensor_name = args[0]._tensor_name
# pyrefly: ignore # index-error
self.add_pruning_method(args[0])
else: # manual construction from list or other iterable (or no args)
for method in args:
@ -1097,6 +1100,7 @@ def global_unstructured(parameters, pruning_method, importance_scores=None, **kw
# flatten importance scores to consider them all at once in global pruning
relevant_importance_scores = torch.nn.utils.parameters_to_vector(
# pyrefly: ignore # bad-argument-type
[
importance_scores.get((module, name), getattr(module, name))
for (module, name) in parameters

View File

@ -332,6 +332,7 @@ def spectral_norm(
else:
dim = 0
SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
# pyrefly: ignore # bad-return
return module

View File

@ -41,10 +41,7 @@ if not python_pytree._cxx_pytree_dynamo_traceable:
)
# pyrefly: ignore # import-error
import optree
# pyrefly: ignore # import-error
from optree import PyTreeSpec as TreeSpec # direct import for type annotations

View File

@ -679,7 +679,7 @@ class FlopCounterMode:
import tabulate
# pyrefly: ignore # bad-assignment
tabulate.PRESERVE_WHITESPACE = True
header = ["Module", "FLOP", "% Total"]
values = []

View File

@ -9,7 +9,7 @@ from typing import Any, Optional
import torch
import numpy as np
# pyrefly: ignore # import-error
from google.protobuf import struct_pb2
from tensorboard.compat.proto.summary_pb2 import (

View File

@ -956,7 +956,7 @@ class SummaryWriter:
)
self._projector_config.embeddings.extend([embedding_info])
# pyrefly: ignore # import-error
from google.protobuf import text_format
config_pbtxt = text_format.MessageToString(self._projector_config)