Pyrefly suppressions 7/n (#164913)

Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

Almost there!

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: delete lines in the pyrefly.toml file from the project-excludes field
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199

after:
 INFO 0 errors (6,884 ignored)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164913
Approved by: https://github.com/oulgen
This commit is contained in:
Maggie Moss
2025-10-08 07:27:14 +00:00
committed by PyTorch MergeBot
parent 12d2ef557f
commit c855f8632e
89 changed files with 626 additions and 67 deletions

View File

@ -22,14 +22,16 @@ project-excludes = [
# ==== to test Pyrefly on a specific directory, simply comment it out ==== # ==== to test Pyrefly on a specific directory, simply comment it out ====
"torch/_inductor/**", "torch/_inductor/**",
"torch/distributed/**", "torch/distributed/**",
"torch/nn/**",
"torch/_dynamo/**",
# formatting issues # formatting issues
"torch/linalg/__init__.py", "torch/linalg/__init__.py",
"torch/package/importer.py", "torch/package/importer.py",
"torch/package/_package_pickler.py", "torch/package/_package_pickler.py",
"torch/jit/annotations.py", "torch/jit/annotations.py",
"torch/utils/data/datapipes/_typing.py", "torch/utils/data/datapipes/_typing.py",
"torch/nn/functional.py",
"torch/_export/utils.py",
"torch/fx/experimental/unification/multipledispatch/__init__.py",
"torch/nn/modules/__init__.py",
# ==== # ====
"benchmarks/instruction_counts/main.py", "benchmarks/instruction_counts/main.py",
"benchmarks/instruction_counts/definitions/setup.py", "benchmarks/instruction_counts/definitions/setup.py",

View File

@ -59,7 +59,6 @@ class TestBundledInputs(TestCase):
# despite having nominally large bundled inputs. # despite having nominally large bundled inputs.
augmented_size = model_size(sm) augmented_size = model_size(sm)
# pyrefly: ignore # missing-attribute
self.assertLess(augmented_size, original_size + (1 << 12)) self.assertLess(augmented_size, original_size + (1 << 12))
loaded = save_and_load(sm) loaded = save_and_load(sm)
@ -67,15 +66,12 @@ class TestBundledInputs(TestCase):
self.assertEqual(loaded.get_num_bundled_inputs(), len(samples)) self.assertEqual(loaded.get_num_bundled_inputs(), len(samples))
self.assertEqual(len(inflated), len(samples)) self.assertEqual(len(inflated), len(samples))
# pyrefly: ignore # missing-attribute
self.assertTrue(loaded(*inflated[0]) is inflated[0][0]) self.assertTrue(loaded(*inflated[0]) is inflated[0][0])
for idx, inp in enumerate(inflated): for idx, inp in enumerate(inflated):
# pyrefly: ignore # missing-attribute
self.assertIsInstance(inp, tuple) self.assertIsInstance(inp, tuple)
self.assertEqual(len(inp), 1) self.assertEqual(len(inp), 1)
# pyrefly: ignore # missing-attribute
self.assertIsInstance(inp[0], torch.Tensor) self.assertIsInstance(inp[0], torch.Tensor)
if idx != 5: if idx != 5:
# Strides might be important for benchmarking. # Strides might be important for benchmarking.
@ -144,7 +140,6 @@ class TestBundledInputs(TestCase):
inflated = loaded.get_all_bundled_inputs() inflated = loaded.get_all_bundled_inputs()
self.assertEqual(inflated, samples) self.assertEqual(inflated, samples)
# pyrefly: ignore # missing-attribute
self.assertTrue(loaded(*inflated[0]) == "first 1") self.assertTrue(loaded(*inflated[0]) == "first 1")
def test_multiple_methods_with_inputs(self): def test_multiple_methods_with_inputs(self):
@ -192,7 +187,6 @@ class TestBundledInputs(TestCase):
# Check running and size helpers # Check running and size helpers
# pyrefly: ignore # missing-attribute
self.assertTrue(loaded(*inflated[0]) is inflated[0][0]) self.assertTrue(loaded(*inflated[0]) is inflated[0][0])
self.assertEqual(loaded.get_num_bundled_inputs(), len(samples)) self.assertEqual(loaded.get_num_bundled_inputs(), len(samples))
@ -426,7 +420,6 @@ class TestBundledInputs(TestCase):
augmented_size = model_size(sm) augmented_size = model_size(sm)
# assert the size has not increased more than 8KB # assert the size has not increased more than 8KB
# pyrefly: ignore # missing-attribute
self.assertLess(augmented_size, original_size + (1 << 13)) self.assertLess(augmented_size, original_size + (1 << 13))
loaded = save_and_load(sm) loaded = save_and_load(sm)

View File

@ -48,7 +48,7 @@ class TestComplexTensor(TestCase):
def test_all(self, device, dtype): def test_all(self, device, dtype):
# issue: https://github.com/pytorch/pytorch/issues/120875 # issue: https://github.com/pytorch/pytorch/issues/120875
x = torch.tensor([1 + 2j, 3 - 4j, 5j, 6], device=device, dtype=dtype) x = torch.tensor([1 + 2j, 3 - 4j, 5j, 6], device=device, dtype=dtype)
# pyrefly: ignore # missing-attribute
self.assertTrue(torch.all(x)) self.assertTrue(torch.all(x))
@dtypes(*complex_types()) @dtypes(*complex_types())
@ -57,7 +57,7 @@ class TestComplexTensor(TestCase):
x = torch.tensor( x = torch.tensor(
[0, 0j, -0 + 0j, -0 - 0j, 0 + 0j, 0 - 0j], device=device, dtype=dtype [0, 0j, -0 + 0j, -0 - 0j, 0 + 0j, 0 - 0j], device=device, dtype=dtype
) )
# pyrefly: ignore # missing-attribute
self.assertFalse(torch.any(x)) self.assertFalse(torch.any(x))
@onlyCPU @onlyCPU

View File

@ -142,7 +142,6 @@ class TestTypeHints(TestCase):
] ]
) )
if result != 0: if result != 0:
# pyrefly: ignore # missing-attribute
self.fail(f"mypy failed:\n{stderr}\n{stdout}") self.fail(f"mypy failed:\n{stderr}\n{stdout}")

View File

@ -125,7 +125,7 @@ class TestDTypeInfo(TestCase):
# Regression test for https://github.com/pytorch/pytorch/issues/124868 # Regression test for https://github.com/pytorch/pytorch/issues/124868
# If reference count is leaked this would be a set of 10 elements # If reference count is leaked this would be a set of 10 elements
ref_cnt = {sys.getrefcount(torch.float32.to_complex()) for _ in range(10)} ref_cnt = {sys.getrefcount(torch.float32.to_complex()) for _ in range(10)}
# pyrefly: ignore # missing-attribute
self.assertLess(len(ref_cnt), 3) self.assertLess(len(ref_cnt), 3)
self.assertEqual(torch.float64.to_complex(), torch.complex128) self.assertEqual(torch.float64.to_complex(), torch.complex128)
@ -136,7 +136,7 @@ class TestDTypeInfo(TestCase):
# Regression test for https://github.com/pytorch/pytorch/issues/124868 # Regression test for https://github.com/pytorch/pytorch/issues/124868
# If reference count is leaked this would be a set of 10 elements # If reference count is leaked this would be a set of 10 elements
ref_cnt = {sys.getrefcount(torch.cfloat.to_real()) for _ in range(10)} ref_cnt = {sys.getrefcount(torch.cfloat.to_real()) for _ in range(10)}
# pyrefly: ignore # missing-attribute
self.assertLess(len(ref_cnt), 3) self.assertLess(len(ref_cnt), 3)
self.assertEqual(torch.complex128.to_real(), torch.double) self.assertEqual(torch.complex128.to_real(), torch.double)

View File

@ -53,6 +53,8 @@ from .eval_frame import (
OptimizedModule, OptimizedModule,
reset_code, reset_code,
) )
# pyrefly: ignore # deprecated
from .external_utils import is_compiling from .external_utils import is_compiling
from .mutation_guard import GenerationTracker from .mutation_guard import GenerationTracker
from .pgo import reset_code_state from .pgo import reset_code_state

View File

@ -95,6 +95,7 @@ class ModIndex(torch.autograd.Function):
generate_vmap_rule = True generate_vmap_rule = True
@staticmethod @staticmethod
# pyrefly: ignore # bad-override
def forward(x: Tensor, indices: list[Tensor]) -> Tensor: def forward(x: Tensor, indices: list[Tensor]) -> Tensor:
return torch.ops.aten.index(x, indices) return torch.ops.aten.index(x, indices)
@ -242,6 +243,7 @@ def _trace_wrapped_functionalized(ctx: Any, *args: Any, **kwargs: Any) -> Any:
def autograd_function_backward_rewritten(original_backward: Any) -> Any: def autograd_function_backward_rewritten(original_backward: Any) -> Any:
def new_backward(ctx: Any, *grads: Any) -> Any: def new_backward(ctx: Any, *grads: Any) -> Any:
# pyrefly: ignore # bad-assignment
grads = [g.contiguous() for g in grads] grads = [g.contiguous() for g in grads]
return original_backward(ctx, *grads) return original_backward(ctx, *grads)

View File

@ -89,6 +89,7 @@ class AOTCompiledFunction:
**import_sources, **import_sources,
self._artifacts.backend_id: self._artifacts.compiled_fn, self._artifacts.backend_id: self._artifacts.compiled_fn,
} }
# pyrefly: ignore # read-only
self.fn = types.FunctionType( self.fn = types.FunctionType(
self._artifacts.bytecode, f_globals, closure=self._artifacts.closure self._artifacts.bytecode, f_globals, closure=self._artifacts.closure
) )

View File

@ -206,6 +206,7 @@ def cudagraphs(dynamo_model: torch.fx.GraphModule, dynamo_inputs: Sequence[Any])
assert manager is not None assert manager is not None
def fn(inputs: list[Any]) -> Any: def fn(inputs: list[Any]) -> Any:
# pyrefly: ignore # missing-attribute
manager.set_to_running_backward() manager.set_to_running_backward()
return aot_model(inputs) return aot_model(inputs)

View File

@ -77,16 +77,19 @@ def tvm(
opt_level = options.get("opt_level", 3) opt_level = options.get("opt_level", 3)
if scheduler == "auto_scheduler": if scheduler == "auto_scheduler":
# pyrefly: ignore # import-error
from tvm import auto_scheduler from tvm import auto_scheduler
log_file = tempfile.NamedTemporaryFile() log_file = tempfile.NamedTemporaryFile()
# pyrefly: ignore # bad-argument-type
if not os.path.exists(log_file): if not os.path.exists(log_file):
tasks, task_weights = auto_scheduler.extract_tasks( tasks, task_weights = auto_scheduler.extract_tasks(
mod["main"], params, target mod["main"], params, target
) )
if len(tasks) != 0: if len(tasks) != 0:
tuner = auto_scheduler.TaskScheduler(tasks, task_weights) tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
# pyrefly: ignore # bad-argument-type
if not os.path.exists(log_file): if not os.path.exists(log_file):
assert trials > 0 assert trials > 0
tune_option = auto_scheduler.TuningOptions( tune_option = auto_scheduler.TuningOptions(
@ -97,7 +100,9 @@ def tvm(
try: try:
tuner.tune(tune_option) tuner.tune(tune_option)
except Exception: except Exception:
# pyrefly: ignore # bad-argument-type
if os.path.exists(log_file): if os.path.exists(log_file):
# pyrefly: ignore # bad-argument-type
os.unlink(log_file) os.unlink(log_file)
raise raise
@ -107,6 +112,7 @@ def tvm(
): ):
lib = relay.build(mod, target=target, params=params) lib = relay.build(mod, target=target, params=params)
elif scheduler == "meta_schedule": elif scheduler == "meta_schedule":
# pyrefly: ignore # import-error
from tvm import meta_schedule as ms from tvm import meta_schedule as ms
with tempfile.TemporaryDirectory() as work_dir: with tempfile.TemporaryDirectory() as work_dir:

View File

@ -37,6 +37,7 @@ if sys.version_info >= (3, 11):
TERMINAL_OPCODES.add(dis.opmap["JUMP_FORWARD"]) TERMINAL_OPCODES.add(dis.opmap["JUMP_FORWARD"])
else: else:
TERMINAL_OPCODES.add(dis.opmap["JUMP_ABSOLUTE"]) TERMINAL_OPCODES.add(dis.opmap["JUMP_ABSOLUTE"])
# pyrefly: ignore # unsupported-operation
if (3, 12) <= sys.version_info < (3, 14): if (3, 12) <= sys.version_info < (3, 14):
TERMINAL_OPCODES.add(dis.opmap["RETURN_CONST"]) TERMINAL_OPCODES.add(dis.opmap["RETURN_CONST"])
if sys.version_info >= (3, 13): if sys.version_info >= (3, 13):

View File

@ -903,6 +903,7 @@ def devirtualize_jumps(instructions: list[Instruction]) -> None:
inst.arg = abs( inst.arg = abs(
int(target.offset - inst.offset - instruction_size(inst)) int(target.offset - inst.offset - instruction_size(inst))
) )
# pyrefly: ignore # unsupported-operation
inst.arg //= 2 inst.arg //= 2
inst.argval = target.offset inst.argval = target.offset
inst.argrepr = f"to {target.offset}" inst.argrepr = f"to {target.offset}"
@ -1354,6 +1355,7 @@ def update_offsets(instructions: Sequence[Instruction]) -> None:
offset = 0 offset = 0
for inst in instructions: for inst in instructions:
inst.offset = offset inst.offset = offset
# pyrefly: ignore # unsupported-operation
offset += instruction_size(inst) offset += instruction_size(inst)

View File

@ -464,6 +464,7 @@ def cprofile_wrapper(func: Callable[_P, _T]) -> Callable[_P, _T]:
try: try:
prof.enable() prof.enable()
start_ts = time.time() start_ts = time.time()
# pyrefly: ignore # bad-argument-type
retval = prof.runcall(func, *args, **kwargs) retval = prof.runcall(func, *args, **kwargs)
profile_latency = time.time() - start_ts profile_latency = time.time() - start_ts
prof.disable() prof.disable()
@ -957,6 +958,7 @@ def get_traced_fn(mod: Any) -> tuple[FunctionType, Optional[object]]:
if isinstance(mod, torch.nn.Module): if isinstance(mod, torch.nn.Module):
mod = mod.forward mod = mod.forward
if hasattr(mod, "__self__"): if hasattr(mod, "__self__"):
# pyrefly: ignore # missing-attribute
return mod.__func__, mod.__self__ return mod.__func__, mod.__self__
elif inspect.isfunction(mod): elif inspect.isfunction(mod):
return mod, None return mod, None
@ -1096,6 +1098,7 @@ def _fullgraph_capture_frame(
while cur_exn.__cause__ is not None: while cur_exn.__cause__ is not None:
cur_exn.__cause__.with_traceback(None) cur_exn.__cause__.with_traceback(None)
cur_exn = cur_exn.__cause__ cur_exn = cur_exn.__cause__
# pyrefly: ignore # invalid-inheritance
raise e.with_traceback(None) from e.__cause__ # User compiler error raise e.with_traceback(None) from e.__cause__ # User compiler error
return CaptureOutput( return CaptureOutput(
@ -1119,6 +1122,7 @@ def compile_frame( # type: ignore[return]
frame_state: Optional[dict[str, Union[int, FrameStateSizeEntry]]] = None, frame_state: Optional[dict[str, Union[int, FrameStateSizeEntry]]] = None,
distributed_state: Optional[DistributedState] = None, distributed_state: Optional[DistributedState] = None,
package: Optional[CompilePackage] = None, package: Optional[CompilePackage] = None,
# pyrefly: ignore # bad-return
) -> DynamoOutput: ) -> DynamoOutput:
""" """
A helper function taking a frame and backend, then return the generated bytecode A helper function taking a frame and backend, then return the generated bytecode

View File

@ -20,6 +20,7 @@ allowed to compute gradients on).
class TracableCreateParameter(torch.autograd.Function): class TracableCreateParameter(torch.autograd.Function):
@staticmethod @staticmethod
# pyrefly: ignore # bad-override
def forward(ctx: Any, tensor: Any, placeholder: Any) -> torch.nn.Parameter: def forward(ctx: Any, tensor: Any, placeholder: Any) -> torch.nn.Parameter:
assert not tensor.requires_grad assert not tensor.requires_grad
return placeholder.set_(tensor) return placeholder.set_(tensor)

View File

@ -879,6 +879,7 @@ def aot_graph_input_parser(
data_type, shape_str = match.groups() data_type, shape_str = match.groups()
shape = tuple(shape_str.split(",")) shape = tuple(shape_str.split(","))
dtype = dtype_map[data_type] dtype = dtype_map[data_type]
# pyrefly: ignore # bad-argument-type
kwargs[param] = gen_tensor(shape, dtype) kwargs[param] = gen_tensor(shape, dtype)
match = re.search(sym_shape_regex, annotation) match = re.search(sym_shape_regex, annotation)
@ -892,6 +893,7 @@ def aot_graph_input_parser(
attr_name, data_type, shape_str, _ = match.groups() attr_name, data_type, shape_str, _ = match.groups()
shape = tuple(shape_str.split(",")) shape = tuple(shape_str.split(","))
dtype = dtype_map[data_type] dtype = dtype_map[data_type]
# pyrefly: ignore # bad-argument-type
setattr(container, attr_name, gen_tensor(shape, dtype)) setattr(container, attr_name, gen_tensor(shape, dtype))
return kwargs return kwargs

View File

@ -95,6 +95,7 @@ def disable(fn=None, recursive=True, *, reason=None, wrapping=True): # type: ig
nonrecursive_disable_wrapper._torchdynamo_disable = True # type: ignore[attr-defined] nonrecursive_disable_wrapper._torchdynamo_disable = True # type: ignore[attr-defined]
nonrecursive_disable_wrapper._torchdynamo_disable_msg = reason # type: ignore[attr-defined] nonrecursive_disable_wrapper._torchdynamo_disable_msg = reason # type: ignore[attr-defined]
nonrecursive_disable_wrapper._torchdynamo_orig_callable = fn # type: ignore[attr-defined] nonrecursive_disable_wrapper._torchdynamo_orig_callable = fn # type: ignore[attr-defined]
# pyrefly: ignore # bad-return
return nonrecursive_disable_wrapper return nonrecursive_disable_wrapper
if fn is None: if fn is None:
@ -306,6 +307,7 @@ def forbid_in_graph(fn: Any) -> Any:
if isinstance(fn, (list, tuple)): if isinstance(fn, (list, tuple)):
return [forbid_in_graph(x) for x in fn] return [forbid_in_graph(x) for x in fn]
assert callable(fn), "forbid_in_graph applies only to callables" assert callable(fn), "forbid_in_graph applies only to callables"
# pyrefly: ignore # missing-attribute
fn._dynamo_forbidden = True fn._dynamo_forbidden = True
return fn return fn
@ -653,21 +655,28 @@ def mark_dynamic(
if isinstance(index, int): if isinstance(index, int):
if not hasattr(t, "_dynamo_dynamic_indices"): if not hasattr(t, "_dynamo_dynamic_indices"):
# pyrefly: ignore # missing-attribute
t._dynamo_dynamic_indices = set() t._dynamo_dynamic_indices = set()
# pyrefly: ignore # missing-attribute
t._dynamo_dynamic_range = set() t._dynamo_dynamic_range = set()
# pyrefly: ignore # missing-attribute
t._dynamo_hint_overrides = {} t._dynamo_hint_overrides = {}
if not hasattr(t, "_specialize_on"): if not hasattr(t, "_specialize_on"):
# pyrefly: ignore # missing-attribute
t._specialize_on = {} t._specialize_on = {}
if hint_override: if hint_override:
# pyrefly: ignore # missing-attribute
t._dynamo_hint_overrides[index] = hint_override t._dynamo_hint_overrides[index] = hint_override
# TODO(voz): Should we bounds check? # TODO(voz): Should we bounds check?
# pyrefly: ignore # missing-attribute
t._dynamo_dynamic_indices.add(index) t._dynamo_dynamic_indices.add(index)
t._dynamo_dynamic_range.add(_DimRange(index, min, max)) # type: ignore[arg-type] t._dynamo_dynamic_range.add(_DimRange(index, min, max)) # type: ignore[arg-type]
# FX tracers don't respect @forbid_in_graph and choke on the following error since it passes in proxies: # FX tracers don't respect @forbid_in_graph and choke on the following error since it passes in proxies:
# TypeError: 'Attribute' object does not support item assignment # TypeError: 'Attribute' object does not support item assignment
# pyrefly: ignore # missing-attribute
if isinstance(t._specialize_on, dict): if isinstance(t._specialize_on, dict):
t._specialize_on[index] = specialize_on if specialize_on is not None else [] t._specialize_on[index] = specialize_on if specialize_on is not None else []
@ -692,8 +701,10 @@ def maybe_mark_dynamic(t: Any, index: Union[int, list[Any], tuple[Any]]) -> None
if isinstance(index, int): if isinstance(index, int):
if not hasattr(t, "_dynamo_weak_dynamic_indices"): if not hasattr(t, "_dynamo_weak_dynamic_indices"):
# pyrefly: ignore # missing-attribute
t._dynamo_weak_dynamic_indices = set() t._dynamo_weak_dynamic_indices = set()
# TODO(voz): Should we bounds check? # TODO(voz): Should we bounds check?
# pyrefly: ignore # missing-attribute
t._dynamo_weak_dynamic_indices.add(index) t._dynamo_weak_dynamic_indices.add(index)
return return
@ -745,8 +756,11 @@ def mark_static(
# TODO: Make this configurable via a supported public API # TODO: Make this configurable via a supported public API
_apply_func_to_inner_tensors_of_same_dim(mark_static, t, index) _apply_func_to_inner_tensors_of_same_dim(mark_static, t, index)
# pyrefly: ignore # bad-argument-type
if not isinstance(t, torch.Tensor) and issubclass(t, torch.nn.Module): if not isinstance(t, torch.Tensor) and issubclass(t, torch.nn.Module):
# pyrefly: ignore # missing-attribute
t._dynamo_marked_static = True t._dynamo_marked_static = True
# pyrefly: ignore # bad-return
return t return t
if not isinstance(t, torch.Tensor): if not isinstance(t, torch.Tensor):

View File

@ -205,6 +205,7 @@ class CudaInterface(DeviceInterface):
Event = torch.cuda.Event # type: ignore[assignment] Event = torch.cuda.Event # type: ignore[assignment]
Stream = torch.cuda.Stream # type: ignore[assignment] Stream = torch.cuda.Stream # type: ignore[assignment]
# pyrefly: ignore # bad-override
class Worker: class Worker:
@staticmethod @staticmethod
def set_device(device: int) -> None: def set_device(device: int) -> None:
@ -240,6 +241,7 @@ class CudaInterface(DeviceInterface):
set_device = staticmethod(torch.cuda.set_device) set_device = staticmethod(torch.cuda.set_device)
device_count = staticmethod(torch.cuda.device_count) device_count = staticmethod(torch.cuda.device_count)
stream = staticmethod(torch.cuda.stream) # type: ignore[assignment] stream = staticmethod(torch.cuda.stream) # type: ignore[assignment]
# pyrefly: ignore # bad-override
current_stream = staticmethod(torch.cuda.current_stream) current_stream = staticmethod(torch.cuda.current_stream)
set_stream = staticmethod(torch.cuda.set_stream) # type: ignore[assignment] set_stream = staticmethod(torch.cuda.set_stream) # type: ignore[assignment]
_set_stream_by_id = staticmethod(torch.cuda._set_stream_by_id) # type: ignore[assignment] _set_stream_by_id = staticmethod(torch.cuda._set_stream_by_id) # type: ignore[assignment]
@ -300,6 +302,7 @@ class MtiaInterface(DeviceInterface):
Event = torch.mtia.Event # type: ignore[assignment] Event = torch.mtia.Event # type: ignore[assignment]
Stream = torch.mtia.Stream # type: ignore[assignment] Stream = torch.mtia.Stream # type: ignore[assignment]
# pyrefly: ignore # bad-override
class Worker: class Worker:
@staticmethod @staticmethod
def set_device(device: int) -> None: def set_device(device: int) -> None:
@ -335,6 +338,7 @@ class MtiaInterface(DeviceInterface):
set_device = staticmethod(torch.mtia.set_device) # type: ignore[assignment] set_device = staticmethod(torch.mtia.set_device) # type: ignore[assignment]
device_count = staticmethod(torch.mtia.device_count) device_count = staticmethod(torch.mtia.device_count)
stream = staticmethod(torch.mtia.stream) # type: ignore[assignment] stream = staticmethod(torch.mtia.stream) # type: ignore[assignment]
# pyrefly: ignore # bad-override
current_stream = staticmethod(torch.mtia.current_stream) current_stream = staticmethod(torch.mtia.current_stream)
set_stream = staticmethod(torch.mtia.set_stream) # type: ignore[assignment] set_stream = staticmethod(torch.mtia.set_stream) # type: ignore[assignment]
_set_stream_by_id = staticmethod(torch.mtia._set_stream_by_id) # type: ignore[assignment] _set_stream_by_id = staticmethod(torch.mtia._set_stream_by_id) # type: ignore[assignment]
@ -381,6 +385,7 @@ class XpuInterface(DeviceInterface):
Event = torch.xpu.Event # type: ignore[assignment] Event = torch.xpu.Event # type: ignore[assignment]
Stream = torch.xpu.Stream # type: ignore[assignment] Stream = torch.xpu.Stream # type: ignore[assignment]
# pyrefly: ignore # bad-override
class Worker: class Worker:
@staticmethod @staticmethod
def set_device(device: int) -> None: def set_device(device: int) -> None:
@ -416,6 +421,7 @@ class XpuInterface(DeviceInterface):
set_device = staticmethod(torch.xpu.set_device) set_device = staticmethod(torch.xpu.set_device)
device_count = staticmethod(torch.xpu.device_count) # type: ignore[has-type] device_count = staticmethod(torch.xpu.device_count) # type: ignore[has-type]
stream = staticmethod(torch.xpu.stream) # type: ignore[assignment] stream = staticmethod(torch.xpu.stream) # type: ignore[assignment]
# pyrefly: ignore # bad-override
current_stream = staticmethod(torch.xpu.current_stream) current_stream = staticmethod(torch.xpu.current_stream)
set_stream = staticmethod(torch.xpu.set_stream) # type: ignore[assignment] set_stream = staticmethod(torch.xpu.set_stream) # type: ignore[assignment]
_set_stream_by_id = staticmethod(torch.xpu._set_stream_by_id) # type: ignore[assignment] _set_stream_by_id = staticmethod(torch.xpu._set_stream_by_id) # type: ignore[assignment]
@ -458,6 +464,7 @@ class CpuDeviceProperties:
class CpuInterface(DeviceInterface): class CpuInterface(DeviceInterface):
# pyrefly: ignore # bad-override
class Event(torch.Event): class Event(torch.Event):
def __init__(self, enable_timing: bool = True) -> None: def __init__(self, enable_timing: bool = True) -> None:
self.time = 0.0 self.time = 0.0
@ -468,6 +475,7 @@ class CpuInterface(DeviceInterface):
def record(self, stream: Any = None) -> None: def record(self, stream: Any = None) -> None:
self.time = time.perf_counter() self.time = time.perf_counter()
# pyrefly: ignore # bad-override
class Worker: class Worker:
@staticmethod @staticmethod
def get_device_properties( def get_device_properties(
@ -543,6 +551,7 @@ class MpsInterface(DeviceInterface):
def synchronize(device: torch.types.Device = None) -> None: def synchronize(device: torch.types.Device = None) -> None:
torch.mps.synchronize() torch.mps.synchronize()
# pyrefly: ignore # bad-override
class Worker: class Worker:
@staticmethod @staticmethod
def get_device_properties(device: torch.types.Device = None) -> Any: def get_device_properties(device: torch.types.Device = None) -> Any:

View File

@ -484,6 +484,7 @@ class OptimizedModule(torch.nn.Module):
self._initialize() self._initialize()
@property @property
# pyrefly: ignore # bad-override
def training(self) -> bool: def training(self) -> bool:
return self._orig_mod.training return self._orig_mod.training
@ -892,6 +893,7 @@ class _TorchDynamoContext:
while cur_exn.__cause__ is not None: while cur_exn.__cause__ is not None:
cur_exn.__cause__.with_traceback(None) cur_exn.__cause__.with_traceback(None)
cur_exn = cur_exn.__cause__ cur_exn = cur_exn.__cause__
# pyrefly: ignore # invalid-inheritance
raise e.with_traceback(None) from e.__cause__ # User compiler error raise e.with_traceback(None) from e.__cause__ # User compiler error
except ShortenTraceback as e: except ShortenTraceback as e:
# Failures in the backend likely don't have useful # Failures in the backend likely don't have useful
@ -1020,7 +1022,10 @@ class OptimizeContext(_TorchDynamoContext):
assert rebuild_ctx is not None assert rebuild_ctx is not None
compiler_fn = rebuild_ctx() compiler_fn = rebuild_ctx()
ctx = torch._dynamo.compiled_autograd._enable( ctx = torch._dynamo.compiled_autograd._enable(
compiler_fn, dynamic=_dynamic, ignore_active_disable_ctx=False compiler_fn,
# pyrefly: ignore # bad-argument-type
dynamic=_dynamic,
ignore_active_disable_ctx=False,
) )
ctx.__enter__() ctx.__enter__()
return functools.partial(ctx.__exit__, None, None, None) return functools.partial(ctx.__exit__, None, None, None)
@ -1083,6 +1088,7 @@ class DisableContext(_TorchDynamoContext):
cls_obj.__call__ = self(cls_obj.__call__) cls_obj.__call__ = self(cls_obj.__call__)
if issubclass(cls_obj, torch.nn.Module): if issubclass(cls_obj, torch.nn.Module):
# NN module variable tracker directly inlines the _call_impl. Disable it. # NN module variable tracker directly inlines the _call_impl. Disable it.
# pyrefly: ignore # missing-attribute
cls_obj._call_impl = self(cls_obj._call_impl) cls_obj._call_impl = self(cls_obj._call_impl)
return cls_obj return cls_obj
@ -1988,6 +1994,7 @@ def export(
path: KeyPath, t: Union[torch.Tensor, _IntWrapper, Any] path: KeyPath, t: Union[torch.Tensor, _IntWrapper, Any]
) -> Any: ) -> Any:
if isinstance(t, torch.Tensor): if isinstance(t, torch.Tensor):
# pyrefly: ignore # missing-attribute
return ambient_fake_mode.from_tensor(t, static_shapes=True) return ambient_fake_mode.from_tensor(t, static_shapes=True)
elif isinstance(t, _IntWrapper): elif isinstance(t, _IntWrapper):
if ( if (
@ -2068,8 +2075,11 @@ def export(
) )
and not trace_rules.check(call_to_inspect) and not trace_rules.check(call_to_inspect)
): ):
# pyrefly: ignore # unbound-name
dim_constraints.solve() dim_constraints.solve()
# pyrefly: ignore # unbound-name
forced_specializations = dim_constraints.forced_specializations() forced_specializations = dim_constraints.forced_specializations()
# pyrefly: ignore # unbound-name
msg = dim_constraints.prettify_results( msg = dim_constraints.prettify_results(
original_signature, original_signature,
dynamic_shapes, dynamic_shapes,
@ -2090,9 +2100,11 @@ def export(
) )
# Error if we have any constraints on static values # Error if we have any constraints on static values
# pyrefly: ignore # unbound-name
for k in shape_env.var_to_range.keys(): for k in shape_env.var_to_range.keys():
if isinstance(k, sympy.Integer): if isinstance(k, sympy.Integer):
constraint_violation_error = ConstraintViolationError( constraint_violation_error = ConstraintViolationError(
# pyrefly: ignore # unbound-name
f"{''.join(traceback.format_list(shape_env.var_to_stack[k]))}\n" f"{''.join(traceback.format_list(shape_env.var_to_stack[k]))}\n"
"It appears that you're trying to set a constraint on a " "It appears that you're trying to set a constraint on a "
f"value which we evaluated to have a static value of {k}. " f"value which we evaluated to have a static value of {k}. "

View File

@ -369,6 +369,7 @@ def get_dynamo_observed_exception(exc_type: type[Exception]) -> type[ObservedExc
observed_exception_map[exc_type] = type( # type: ignore[assignment] observed_exception_map[exc_type] = type( # type: ignore[assignment]
f"Observed{name}Error", (ObservedException,), {} f"Observed{name}Error", (ObservedException,), {}
) )
# pyrefly: ignore # index-error
return observed_exception_map[exc_type] return observed_exception_map[exc_type]

View File

@ -96,7 +96,9 @@ def wrap_numpy(f: Callable[_P, _R]) -> Callable[_P, _R]:
args, kwargs = pytree.tree_map_only( args, kwargs = pytree.tree_map_only(
torch.Tensor, lambda x: x.numpy(), (args, kwargs) torch.Tensor, lambda x: x.numpy(), (args, kwargs)
) )
# pyrefly: ignore # invalid-param-spec
out = f(*args, **kwargs) out = f(*args, **kwargs)
# pyrefly: ignore # missing-attribute
return pytree.tree_map_only(np.ndarray, lambda x: torch.as_tensor(x), out) return pytree.tree_map_only(np.ndarray, lambda x: torch.as_tensor(x), out)
return wrap return wrap

View File

@ -250,6 +250,7 @@ class DynamoGraphTransformer(torch.fx.Transformer):
else: else:
placeholder.node.meta["val"] = self.flat_inputs[i] placeholder.node.meta["val"] = self.flat_inputs[i]
# pyrefly: ignore # unsupported-operation
self.new_input_nodes[i] = placeholder self.new_input_nodes[i] = placeholder
def _create_placeholder_mapping(self) -> None: def _create_placeholder_mapping(self) -> None:
@ -324,12 +325,18 @@ class DynamoGraphTransformer(torch.fx.Transformer):
# Copy module metadata like the original implementation # Copy module metadata like the original implementation
if hasattr(self.module, "meta"): if hasattr(self.module, "meta"):
# pyrefly: ignore # unsupported-operation
if "dynamo_flat_name_to_original_fqn" in self.module.meta: if "dynamo_flat_name_to_original_fqn" in self.module.meta:
# pyrefly: ignore # index-error
result_gm.meta["dynamo_flat_name_to_original_fqn"] = self.module.meta[ result_gm.meta["dynamo_flat_name_to_original_fqn"] = self.module.meta[
# pyrefly: ignore # index-error
"dynamo_flat_name_to_original_fqn" "dynamo_flat_name_to_original_fqn"
] ]
# pyrefly: ignore # unsupported-operation
if "dynamo_compile_id" in self.module.meta: if "dynamo_compile_id" in self.module.meta:
# pyrefly: ignore # index-error
result_gm.meta["dynamo_compile_id"] = self.module.meta[ result_gm.meta["dynamo_compile_id"] = self.module.meta[
# pyrefly: ignore # index-error
"dynamo_compile_id" "dynamo_compile_id"
] ]
@ -361,8 +368,11 @@ def _suggest_or_raise_constraint_violation(
torch._ops.OpOverloadPacket | torch._ops.OpOverload, torch._ops.OpOverloadPacket | torch._ops.OpOverload,
) )
): ):
# pyrefly: ignore # unbound-name
dim_constraints.solve() dim_constraints.solve()
# pyrefly: ignore # unbound-name
forced_specializations = dim_constraints.forced_specializations() forced_specializations = dim_constraints.forced_specializations()
# pyrefly: ignore # unbound-name
msg = dim_constraints.prettify_results( msg = dim_constraints.prettify_results(
inspect.signature(orig_callable), # type: ignore[attr-defined] inspect.signature(orig_callable), # type: ignore[attr-defined]
dynamic_shapes, dynamic_shapes,
@ -383,9 +393,11 @@ def _suggest_or_raise_constraint_violation(
) )
# Error if we have any constraints on static values # Error if we have any constraints on static values
# pyrefly: ignore # unbound-name
for k in shape_env.var_to_range.keys(): for k in shape_env.var_to_range.keys():
if isinstance(k, sympy.Integer): if isinstance(k, sympy.Integer):
constraint_violation_error = ConstraintViolationError( constraint_violation_error = ConstraintViolationError(
# pyrefly: ignore # unbound-name
f"{''.join(traceback.format_list(shape_env.var_to_stack[k]))}\n" f"{''.join(traceback.format_list(shape_env.var_to_stack[k]))}\n"
"It appears that you're trying to set a constraint on a " "It appears that you're trying to set a constraint on a "
f"value which we evaluated to have a static value of {k}. " f"value which we evaluated to have a static value of {k}. "

View File

@ -320,6 +320,7 @@ class GraphRegionTracker:
if len(group) > 1: if len(group) > 1:
region_group = [] region_group = []
min_rank = math.inf min_rank = math.inf
# pyrefly: ignore # bad-assignment
for node in group: for node in group:
# some nodes aren't in the topo ranking? # some nodes aren't in the topo ranking?
if node in topological_ranking: if node in topological_ranking:

View File

@ -640,6 +640,7 @@ class GuardManagerWrapper:
if isinstance(guard, RelationalGuard): if isinstance(guard, RelationalGuard):
if guard not in self.printed_relational_guards: if guard not in self.printed_relational_guards:
self.printed_relational_guards.add(guard) self.printed_relational_guards.add(guard)
# pyrefly: ignore # bad-argument-type
body.writelines(self.get_guard_lines(guard)) body.writelines(self.get_guard_lines(guard))
else: else:
body.writelines( body.writelines(
@ -700,6 +701,7 @@ class GuardManagerWrapper:
for guard in mgr.get_leaf_guards(): for guard in mgr.get_leaf_guards():
if isinstance(guard, RelationalGuard): if isinstance(guard, RelationalGuard):
if guard not in relational_guards_seen: if guard not in relational_guards_seen:
# pyrefly: ignore # bad-argument-type
self.code_parts.extend(get_code_parts(guard)) self.code_parts.extend(get_code_parts(guard))
relational_guards_seen.add(guard) relational_guards_seen.add(guard)
else: else:
@ -716,6 +718,7 @@ def from_numpy(a: Any) -> torch.Tensor:
# Re-enable torch function since we disable it on leaf guards # Re-enable torch function since we disable it on leaf guards
# we need it to properly construct the tensor if a default device is set # we need it to properly construct the tensor if a default device is set
with torch.overrides._enable_torch_function(): with torch.overrides._enable_torch_function():
# pyrefly: ignore # missing-attribute
return torch.as_tensor(a) if isinstance(a, (np.generic, np.ndarray)) else a return torch.as_tensor(a) if isinstance(a, (np.generic, np.ndarray)) else a
@ -729,6 +732,7 @@ def uninteresting_files() -> set[str]:
from torch._dynamo.polyfills.loader import POLYFILLED_MODULES from torch._dynamo.polyfills.loader import POLYFILLED_MODULES
# pyrefly: ignore # bad-argument-type
mods.extend(POLYFILLED_MODULES) mods.extend(POLYFILLED_MODULES)
return {inspect.getfile(m) for m in mods} return {inspect.getfile(m) for m in mods}
@ -2205,6 +2209,7 @@ class GuardBuilder(GuardBuilderBase):
return return
# Python math library doesn't support complex nan, so we need to use numpy # Python math library doesn't support complex nan, so we need to use numpy
# pyrefly: ignore # missing-attribute
if istype(val, complex) and np.isnan(val): if istype(val, complex) and np.isnan(val):
code = [f"(type({ref}) is complex and __numpy_isnan({ref}))"] code = [f"(type({ref}) is complex and __numpy_isnan({ref}))"]
self._set_guard_export_info(guard, code) self._set_guard_export_info(guard, code)
@ -2495,6 +2500,7 @@ class GuardBuilder(GuardBuilderBase):
# sources for the corresponding tensor dimension. # sources for the corresponding tensor dimension.
return [ return [
TensorPropertySource(source, TensorProperty.SIZE, dim) TensorPropertySource(source, TensorProperty.SIZE, dim)
# pyrefly: ignore # missing-attribute
for source in output_graph.tracked_fakes_id_to_source[t_id] for source in output_graph.tracked_fakes_id_to_source[t_id]
] ]
@ -2531,6 +2537,7 @@ class GuardBuilder(GuardBuilderBase):
equalities_inputs = None equalities_inputs = None
def _get_code_parts(langs: tuple[str, ...]) -> list[_ShapeGuardsHelper]: def _get_code_parts(langs: tuple[str, ...]) -> list[_ShapeGuardsHelper]:
# pyrefly: ignore # missing-attribute
return output_graph.shape_env.produce_guards_verbose( return output_graph.shape_env.produce_guards_verbose(
[a.fake for a in fs], # type: ignore[misc] [a.fake for a in fs], # type: ignore[misc]
[a.source for a in fs], [a.source for a in fs],
@ -2538,6 +2545,7 @@ class GuardBuilder(GuardBuilderBase):
equalities_inputs=equalities_inputs, equalities_inputs=equalities_inputs,
source_ref=self.source_ref, source_ref=self.source_ref,
# Export keeps static. # Export keeps static.
# pyrefly: ignore # missing-attribute
ignore_static=(not output_graph.export), ignore_static=(not output_graph.export),
langs=langs, langs=langs,
) )
@ -2599,7 +2607,9 @@ class GuardBuilder(GuardBuilderBase):
if not python_fallback: if not python_fallback:
assert cpp_code_parts # type: ignore[possibly-undefined] assert cpp_code_parts # type: ignore[possibly-undefined]
code_parts, source_to_symbol = ( code_parts, source_to_symbol = (
# pyrefly: ignore # unbound-name
cpp_code_parts.exprs, cpp_code_parts.exprs,
# pyrefly: ignore # unbound-name, missing-attribute
cpp_code_parts.source_to_symbol, cpp_code_parts.source_to_symbol,
) )
@ -2630,7 +2640,9 @@ class GuardBuilder(GuardBuilderBase):
assert cpp_code_parts # type: ignore[possibly-undefined] assert cpp_code_parts # type: ignore[possibly-undefined]
code_parts, source_to_symbol = ( code_parts, source_to_symbol = (
# pyrefly: ignore # unbound-name
cpp_code_parts.exprs, cpp_code_parts.exprs,
# pyrefly: ignore # unbound-name, missing-attribute
cpp_code_parts.source_to_symbol, cpp_code_parts.source_to_symbol,
) )
@ -3240,6 +3252,7 @@ class GuardsStatePickler(pickle.Pickler):
assert _.__closure__ is not None assert _.__closure__ is not None
return _.__closure__[0] return _.__closure__[0]
# pyrefly: ignore # bad-override
def reducer_override( def reducer_override(
self, obj: Any self, obj: Any
) -> Union[tuple[Callable[..., Any], tuple[Any, ...]], Any]: ) -> Union[tuple[Callable[..., Any], tuple[Any, ...]], Any]:
@ -4072,10 +4085,13 @@ class CheckFunctionManager:
and (cache_entry := self.guard_manager.cache_entry) is not None and (cache_entry := self.guard_manager.cache_entry) is not None
and (extra_state := self.guard_manager.extra_state) is not None and (extra_state := self.guard_manager.extra_state) is not None
): ):
# pyrefly: ignore # unbound-name
assert isinstance(cache_entry, CacheEntry) assert isinstance(cache_entry, CacheEntry)
# pyrefly: ignore # unbound-name
assert isinstance(extra_state, ExtraState) assert isinstance(extra_state, ExtraState)
reason = f"Cache line invalidated because {obj_str} got deallocated" reason = f"Cache line invalidated because {obj_str} got deallocated"
deleted_guard_manager = DeletedGuardManagerWrapper(reason) deleted_guard_manager = DeletedGuardManagerWrapper(reason)
# pyrefly: ignore # unbound-name
extra_state.invalidate(cache_entry, deleted_guard_manager) extra_state.invalidate(cache_entry, deleted_guard_manager)
self.guard_manager = deleted_guard_manager self.guard_manager = deleted_guard_manager

View File

@ -707,6 +707,7 @@ class OutputGraph(OutputGraphCommon):
self.backward_state_proxy: Optional[torch.fx.Proxy] = None self.backward_state_proxy: Optional[torch.fx.Proxy] = None
self.backward_state_var: Optional[str] = None self.backward_state_var: Optional[str] = None
# pyrefly: ignore # bad-override
self.name_of_builtins_dict_key_in_fglobals: str = ( self.name_of_builtins_dict_key_in_fglobals: str = (
self.install_builtins_dict_in_fglobals() self.install_builtins_dict_in_fglobals()
) )
@ -1146,6 +1147,7 @@ class OutputGraph(OutputGraphCommon):
vt = self.root_tx.output.side_effects.track_object_existing(target, vt) vt = self.root_tx.output.side_effects.track_object_existing(target, vt)
assert "tensor_dict" not in vt.as_proxy().node.meta assert "tensor_dict" not in vt.as_proxy().node.meta
# pyrefly: ignore # bad-argument-type
vt.as_proxy().node.meta["tensor_dict"] = _extract_tensor_dict(target) vt.as_proxy().node.meta["tensor_dict"] = _extract_tensor_dict(target)
return vt return vt
@ -1157,6 +1159,7 @@ class OutputGraph(OutputGraphCommon):
install_guard(source.make_guard(GuardBuilder.NN_MODULE)) install_guard(source.make_guard(GuardBuilder.NN_MODULE))
def wrap_name(module_key: str) -> VariableTracker: def wrap_name(module_key: str) -> VariableTracker:
# pyrefly: ignore # bad-argument-type
return NNModuleVariable(type(target), module_key, target, **options) return NNModuleVariable(type(target), module_key, target, **options)
else: else:
@ -1970,7 +1973,9 @@ class OutputGraph(OutputGraphCommon):
tx = self.root_tx tx = self.root_tx
assert tx is not None assert tx is not None
if (ds := tx.distributed_state) is not None and ds.all_states is None: if (ds := tx.distributed_state) is not None and ds.all_states is None:
# pyrefly: ignore # unbound-name
compile_pg = ds.compile_pg compile_pg = ds.compile_pg
# pyrefly: ignore # unbound-name
log.info("compiler_collective %s", ds.local_state) log.info("compiler_collective %s", ds.local_state)
torch._logging.trace_structured( torch._logging.trace_structured(
"artifact", "artifact",
@ -1978,6 +1983,7 @@ class OutputGraph(OutputGraphCommon):
"name": "compiler_collective", "name": "compiler_collective",
"encoding": "string", "encoding": "string",
}, },
# pyrefly: ignore # unbound-name
payload_fn=lambda: ds.local_state.render(), payload_fn=lambda: ds.local_state.render(),
) )
device_types = compile_pg._device_types device_types = compile_pg._device_types
@ -1991,7 +1997,9 @@ class OutputGraph(OutputGraphCommon):
dynamo_timed("compiler_collective", log_pt2_compile_event=True), dynamo_timed("compiler_collective", log_pt2_compile_event=True),
): ):
all_states: list[Any] = [None] * compile_pg.size() all_states: list[Any] = [None] * compile_pg.size()
# pyrefly: ignore # unbound-name
dist.all_gather_object(all_states, ds.local_state, group=compile_pg) dist.all_gather_object(all_states, ds.local_state, group=compile_pg)
# pyrefly: ignore # unbound-name
ds.all_states = all_states ds.all_states = all_states
# Clear speculation log, because are tracing may diverge due to # Clear speculation log, because are tracing may diverge due to
# this information from the compiler collective # this information from the compiler collective
@ -2321,6 +2329,7 @@ class OutputGraph(OutputGraphCommon):
}, },
) )
# pyrefly: ignore # unbound-name
return compiled_fn return compiled_fn
def dedup_pass(self) -> dict[str, torch.fx.GraphModule]: def dedup_pass(self) -> dict[str, torch.fx.GraphModule]:
@ -2375,6 +2384,7 @@ class OutputGraph(OutputGraphCommon):
isinstance(b, torch.SymBool) isinstance(b, torch.SymBool)
and (r := b.node.maybe_as_bool()) is not None and (r := b.node.maybe_as_bool()) is not None
): ):
# pyrefly: ignore # unbound-name
return r return r
# TODO: We can also technically remove all cases when the input # TODO: We can also technically remove all cases when the input
# doesn't have unbacked inputs, since it's all in the ShapeEnv # doesn't have unbacked inputs, since it's all in the ShapeEnv
@ -2740,6 +2750,7 @@ def check_pt2_compliant_op(
hints=[], hints=[],
) )
# pyrefly: ignore # unbound-name
op = getattr(target, overload) op = getattr(target, overload)
if torch.Tag.pt2_compliant_tag in op.tags: if torch.Tag.pt2_compliant_tag in op.tags:
encountered_compliant_op(op) encountered_compliant_op(op)
@ -2747,6 +2758,7 @@ def check_pt2_compliant_op(
encountered_non_compliant_op( encountered_non_compliant_op(
op, op,
f"Encountered the torch.ops.OpOverloadPacket {target} " f"Encountered the torch.ops.OpOverloadPacket {target} "
# pyrefly: ignore # unbound-name
f"which resolves to the overload ({overload}) that is " f"which resolves to the overload ({overload}) that is "
f"not PT2 compliant.", f"not PT2 compliant.",
) )
@ -2767,6 +2779,7 @@ class LazyProxy:
**kwargs: P.kwargs, **kwargs: P.kwargs,
) -> None: ) -> None:
self.tracer = tracer self.tracer = tracer
# pyrefly: ignore # invalid-type-var
self.fn = fn self.fn = fn
self.args = args self.args = args
self.kwargs = kwargs self.kwargs = kwargs

View File

@ -319,6 +319,7 @@ def _get_code_source(code: types.CodeType) -> tuple[str, str]:
code_source = _find_code_source(toplevel) code_source = _find_code_source(toplevel)
if code_source is None: if code_source is None:
_raise_resolution_error(code, toplevel) _raise_resolution_error(code, toplevel)
# pyrefly: ignore # missing-attribute
return toplevel.__qualname__, code_source.strip(".") return toplevel.__qualname__, code_source.strip(".")
@ -593,9 +594,11 @@ class CompilePackage:
f"Source code changes detected for {code.module} (line {code.firstlineno} - line {code.lastlineno})" f"Source code changes detected for {code.module} (line {code.firstlineno} - line {code.lastlineno})"
) )
# pyrefly: ignore # bad-assignment
self._source_info = dynamo.source_info self._source_info = dynamo.source_info
main, *codes = dynamo.codes main, *codes = dynamo.codes
# pyrefly: ignore # bad-assignment
self._codes = {self._innermost_fn.__code__: main} self._codes = {self._innermost_fn.__code__: main}
for code in codes: for code in codes:
self._codes[SerializedCode.to_code_object(code.python_code)] = code self._codes[SerializedCode.to_code_object(code.python_code)] = code
@ -603,6 +606,7 @@ class CompilePackage:
self._add_function( self._add_function(
self._innermost_fn.__code__, self._innermost_fn.__module__ self._innermost_fn.__code__, self._innermost_fn.__module__
) )
# pyrefly: ignore # bad-assignment
self._initialized = True self._initialized = True
def _add_function( def _add_function(
@ -746,6 +750,7 @@ class CompilePackage:
for name in names: for name in names:
module.__dict__.pop(name) module.__dict__.pop(name)
# pyrefly: ignore # bad-assignment
self._installed_globals = {} self._installed_globals = {}
_reset_precompile_entries(self._innermost_fn.__code__) _reset_precompile_entries(self._innermost_fn.__code__)

View File

@ -167,6 +167,7 @@ class CodeId:
@dataclasses.dataclass @dataclasses.dataclass
class CodeState: class CodeState:
automatic_dynamic: defaultdict[str, FrameStateSizeEntry] = dataclasses.field( automatic_dynamic: defaultdict[str, FrameStateSizeEntry] = dataclasses.field(
# pyrefly: ignore # unbound-name
default_factory=lambda: defaultdict(FrameStateSizeEntry) default_factory=lambda: defaultdict(FrameStateSizeEntry)
) )
@ -851,6 +852,7 @@ def get_code_state() -> defaultdict[CodeId, CodeState]:
not _CODE_STATE not _CODE_STATE
and (sticky_read := torch.compiler.config.pgo_extra_read_key) is not None and (sticky_read := torch.compiler.config.pgo_extra_read_key) is not None
): ):
# pyrefly: ignore # unbound-name
extra_read_key = get_extra_cache_key(sticky_read) extra_read_key = get_extra_cache_key(sticky_read)
if extra_read_key is not None: if extra_read_key is not None:
get_extra_remote_code_state(extra_read_key) get_extra_remote_code_state(extra_read_key)

View File

@ -196,6 +196,7 @@ def tee(iterable: Iterable[_T], n: int = 2, /) -> tuple[Iterator[_T], ...]:
@overload @overload
# pyrefly: ignore # inconsistent-overload
def zip_longest( def zip_longest(
iter1: Iterable[_T1], iter1: Iterable[_T1],
/, /,
@ -205,6 +206,7 @@ def zip_longest(
@overload @overload
# pyrefly: ignore # inconsistent-overload
def zip_longest( def zip_longest(
iter1: Iterable[_T1], iter1: Iterable[_T1],
iter2: Iterable[_T2], iter2: Iterable[_T2],
@ -213,6 +215,7 @@ def zip_longest(
@overload @overload
# pyrefly: ignore # inconsistent-overload
def zip_longest( def zip_longest(
iter1: Iterable[_T1], iter1: Iterable[_T1],
iter2: Iterable[_T2], iter2: Iterable[_T2],
@ -223,6 +226,7 @@ def zip_longest(
@overload @overload
# pyrefly: ignore # inconsistent-overload
def zip_longest( def zip_longest(
iter1: Iterable[_T], iter1: Iterable[_T],
iter2: Iterable[_T], iter2: Iterable[_T],
@ -233,6 +237,7 @@ def zip_longest(
@overload @overload
# pyrefly: ignore # inconsistent-overload
def zip_longest( def zip_longest(
iter1: Iterable[_T], iter1: Iterable[_T],
iter2: Iterable[_T], iter2: Iterable[_T],

View File

@ -30,10 +30,12 @@ _Us = TypeVarTuple("_Us")
@overload @overload
# pyrefly: ignore # inconsistent-overload
def attrgetter(attr: str, /) -> Callable[[Any], _U]: ... def attrgetter(attr: str, /) -> Callable[[Any], _U]: ...
@overload @overload
# pyrefly: ignore # inconsistent-overload
def attrgetter( def attrgetter(
attr1: str, attr2: str, /, *attrs: str attr1: str, attr2: str, /, *attrs: str
) -> Callable[[Any], tuple[_U1, _U2, Unpack[_Us]]]: ... ) -> Callable[[Any], tuple[_U1, _U2, Unpack[_Us]]]: ...
@ -68,10 +70,12 @@ def attrgetter(*attrs: str) -> Callable[[Any], Any | tuple[Any, ...]]:
@overload @overload
# pyrefly: ignore # inconsistent-overload
def itemgetter(item: _T, /) -> Callable[[Any], _U]: ... def itemgetter(item: _T, /) -> Callable[[Any], _U]: ...
@overload @overload
# pyrefly: ignore # inconsistent-overload
def itemgetter( def itemgetter(
item1: _T1, item2: _T2, /, *items: Unpack[_Ts] item1: _T1, item2: _T2, /, *items: Unpack[_Ts]
) -> Callable[[Any], tuple[_U1, _U2, Unpack[_Us]]]: ... ) -> Callable[[Any], tuple[_U1, _U2, Unpack[_Us]]]: ...

View File

@ -17,6 +17,7 @@ __all__ = ["fspath"]
@substitute_in_graph(os.fspath, can_constant_fold_through=True) @substitute_in_graph(os.fspath, can_constant_fold_through=True)
def fspath(path: AnyStr | os.PathLike[AnyStr]) -> AnyStr: def fspath(path: AnyStr | os.PathLike[AnyStr]) -> AnyStr:
if isinstance(path, (str, bytes)): if isinstance(path, (str, bytes)):
# pyrefly: ignore # bad-return
return path return path
path_type = type(path) path_type = type(path)

View File

@ -171,6 +171,7 @@ if python_pytree._cxx_pytree_dynamo_traceable:
or optree.is_namedtuple_class(treespec.type) or optree.is_namedtuple_class(treespec.type)
or optree.is_structseq_class(treespec.type) or optree.is_structseq_class(treespec.type)
): ):
# pyrefly: ignore # bad-return
return treespec._unflatten_func( return treespec._unflatten_func(
treespec._metadata, treespec._metadata,
children_representations, children_representations,

View File

@ -49,8 +49,11 @@ class ProfileMetrics:
if isinstance(other, int): if isinstance(other, int):
other = ProfileMetrics(other, other, other) other = ProfileMetrics(other, other, other)
return ProfileMetrics( return ProfileMetrics(
# pyrefly: ignore # no-matching-overload
self.microseconds / max(1, other.microseconds), self.microseconds / max(1, other.microseconds),
# pyrefly: ignore # bad-argument-type
self.operators / max(1, other.operators), self.operators / max(1, other.operators),
# pyrefly: ignore # bad-argument-type
self.fusions / max(1, other.fusions), self.fusions / max(1, other.fusions),
) )

View File

@ -370,13 +370,16 @@ isolate_fails_code_str = None
try: try:
if isinstance(kernel, Autotuner): if isinstance(kernel, Autotuner):
# pyrefly: ignore # missing-attribute
if isinstance(kernel.fn, Heuristics): if isinstance(kernel.fn, Heuristics):
model_str += "ERROR: Repro will not work as intended, " model_str += "ERROR: Repro will not work as intended, "
model_str += "triton.runtime.autotuner.Heuristics is not currently supported\n" model_str += "triton.runtime.autotuner.Heuristics is not currently supported\n"
break break
config_strs = [] config_strs = []
# pyrefly: ignore # missing-attribute
for kernel_config in kernel.configs: for kernel_config in kernel.configs:
# pyrefly: ignore # bad-argument-type
config_strs.append(f"""triton.Config( config_strs.append(f"""triton.Config(
{str(kernel_config.kwargs)}, {str(kernel_config.kwargs)},
num_warps={kernel_config.num_warps}, num_warps={kernel_config.num_warps},
@ -394,8 +397,10 @@ isolate_fails_code_str = None
""").strip() """).strip()
model_str += "\n@triton.jit\n" model_str += "\n@triton.jit\n"
# pyrefly: ignore # missing-attribute
src_code = kernel.src if isinstance(kernel, JITFunction) else kernel.fn.src src_code = kernel.src if isinstance(kernel, JITFunction) else kernel.fn.src
fn_name = ( fn_name = (
# pyrefly: ignore # missing-attribute
kernel._fn_name kernel._fn_name
if isinstance(kernel, JITFunction) if isinstance(kernel, JITFunction)
else kernel.fn._fn_name else kernel.fn._fn_name
@ -409,7 +414,9 @@ isolate_fails_code_str = None
model_str += "ERROR: Repro will not work as intended, " model_str += "ERROR: Repro will not work as intended, "
model_str += f"User defined triton kernel exception: {e}\n" model_str += f"User defined triton kernel exception: {e}\n"
# pyrefly: ignore # unbound-name
if len(kernel_side_table.constant_args) > 0: if len(kernel_side_table.constant_args) > 0:
# pyrefly: ignore # unbound-name
model_str += f"{kernel_side_table_prefix}.constant_args={kernel_side_table.constant_args}\n" model_str += f"{kernel_side_table_prefix}.constant_args={kernel_side_table.constant_args}\n"
model_str += NNModuleToString.convert(gm) model_str += NNModuleToString.convert(gm)
@ -420,8 +427,10 @@ isolate_fails_code_str = None
# Extract from graph placeholders and their corresponding arguments # Extract from graph placeholders and their corresponding arguments
placeholder_targets = fx_placeholder_targets(gm) placeholder_targets = fx_placeholder_targets(gm)
for placeholder, arg in zip(placeholder_targets, args): for placeholder, arg in zip(placeholder_targets, args):
# pyrefly: ignore # unbound-name
if isinstance(arg, (int, torch.SymInt)): if isinstance(arg, (int, torch.SymInt)):
writer.symint(placeholder, arg) writer.symint(placeholder, arg)
# pyrefly: ignore # unbound-name
elif isinstance(arg, torch.Tensor): elif isinstance(arg, torch.Tensor):
# TODO: improve these names with FQN # TODO: improve these names with FQN
writer.tensor(placeholder, arg) writer.tensor(placeholder, arg)
@ -431,16 +440,20 @@ isolate_fails_code_str = None
writer.unsupported(placeholder, arg) writer.unsupported(placeholder, arg)
# Extract symbolic variables from the same arguments # Extract symbolic variables from the same arguments
# pyrefly: ignore # unbound-name
if isinstance(arg, torch.SymInt): if isinstance(arg, torch.SymInt):
sym_name = str(arg.node) sym_name = str(arg.node)
if arg.node.hint is not None: if arg.node.hint is not None:
used_syms[sym_name] = arg.node.hint used_syms[sym_name] = arg.node.hint
# pyrefly: ignore # unbound-name
elif isinstance(arg, torch.Tensor): elif isinstance(arg, torch.Tensor):
# Extract symbolic variables from tensor shapes and strides # Extract symbolic variables from tensor shapes and strides
for dim in arg.shape: for dim in arg.shape:
# pyrefly: ignore # unbound-name
if isinstance(dim, torch.SymInt) and dim.node.hint is not None: if isinstance(dim, torch.SymInt) and dim.node.hint is not None:
used_syms[str(dim.node)] = dim.node.hint used_syms[str(dim.node)] = dim.node.hint
for stride in arg.stride(): for stride in arg.stride():
# pyrefly: ignore # unbound-name
if isinstance(stride, torch.SymInt) and stride.node.hint is not None: if isinstance(stride, torch.SymInt) and stride.node.hint is not None:
used_syms[str(stride.node)] = stride.node.hint used_syms[str(stride.node)] = stride.node.hint
@ -758,6 +771,7 @@ def repro_common(
# TODO: speed this up # TODO: speed this up
mod = make_fx(mod, tracing_mode=options.tracing_mode)(*args) mod = make_fx(mod, tracing_mode=options.tracing_mode)(*args)
# pyrefly: ignore # bad-assignment
torch._inductor.config.generate_intermediate_hooks = True torch._inductor.config.generate_intermediate_hooks = True
return mod, args return mod, args

View File

@ -301,6 +301,7 @@ def repro_load_args(load_args: Any, save_dir: Optional[str]) -> tuple[Any]:
def repro_common( def repro_common(
options: Any, exported_program: ExportedProgram options: Any, exported_program: ExportedProgram
) -> tuple[torch.fx.GraphModule, Any, Any]: ) -> tuple[torch.fx.GraphModule, Any, Any]:
# pyrefly: ignore # bad-assignment
torch._inductor.config.generate_intermediate_hooks = True torch._inductor.config.generate_intermediate_hooks = True
mod = exported_program.module(check_guards=False) mod = exported_program.module(check_guards=False)
args, kwargs = exported_program.example_inputs args, kwargs = exported_program.example_inputs
@ -422,6 +423,7 @@ def repro_minify(
) -> bool: ) -> bool:
# Need to export first so the in_spec and out_spec are populated # Need to export first so the in_spec and out_spec are populated
tuple_inputs = tuple(flat_example_inputs) tuple_inputs = tuple(flat_example_inputs)
# pyrefly: ignore # bad-assignment
gm = export_for_aoti_minifier( gm = export_for_aoti_minifier(
gm, tuple_inputs, strict=strict, skip_export_error=skip_export_error gm, tuple_inputs, strict=strict, skip_export_error=skip_export_error
) )

View File

@ -102,6 +102,7 @@ def _bytecode_from_template_with_split(
def _try_except_tf_mode_template(dummy: Any, stack_var_name: Any) -> None: def _try_except_tf_mode_template(dummy: Any, stack_var_name: Any) -> None:
# NOTE: Make sure this name matches what is generated by symbolic_convert:import_source # NOTE: Make sure this name matches what is generated by symbolic_convert:import_source
# on torch._dynamo.utils. # on torch._dynamo.utils.
# pyrefly: ignore # unknown-name
global __import_torch_dot__dynamo_dot_utils global __import_torch_dot__dynamo_dot_utils
try: try:
dummy dummy
@ -555,6 +556,7 @@ class ContinueExecutionCache:
# remap original instructions' exception table entries # remap original instructions' exception table entries
if old_hook_target_remap: if old_hook_target_remap:
# pyrefly: ignore # unbound-name
assert is_py311_plus assert is_py311_plus
for inst in instructions: for inst in instructions:
if ( if (

View File

@ -696,6 +696,7 @@ class SideEffects:
cg.add_cache(var) cg.add_cache(var)
var.source = LocalSource(cg.tempvars[var]) # type: ignore[attr-defined] var.source = LocalSource(cg.tempvars[var]) # type: ignore[attr-defined]
elif var.source is None: elif var.source is None:
# pyrefly: ignore # bad-assignment
var.source = LocalCellSource(var.local_name) var.source = LocalCellSource(var.local_name)
elif isinstance(var, variables.TensorVariable): elif isinstance(var, variables.TensorVariable):
# NOTE: for historical reasons we never assigned local sources # NOTE: for historical reasons we never assigned local sources
@ -732,6 +733,7 @@ class SideEffects:
if isinstance(var, variables.UserDefinedObjectVariable): if isinstance(var, variables.UserDefinedObjectVariable):
def load_new_method() -> None: def load_new_method() -> None:
# pyrefly: ignore # missing-attribute
assert var.base_cls_vt is not None assert var.base_cls_vt is not None
cg(var.base_cls_vt) # type: ignore[attr-defined] cg(var.base_cls_vt) # type: ignore[attr-defined]
cg.extend_output([cg.create_load_attr("__new__")]) cg.extend_output([cg.create_load_attr("__new__")])
@ -978,7 +980,9 @@ class SideEffects:
elif self.is_attribute_mutation(var): elif self.is_attribute_mutation(var):
if isinstance( if isinstance(
var, variables.UserDefinedDictVariable var,
variables.UserDefinedDictVariable,
# pyrefly: ignore # bad-argument-type
) and self.is_modified(var._dict_vt): ) and self.is_modified(var._dict_vt):
# Do dict related update manually here. The store_attr # Do dict related update manually here. The store_attr
# mutations will be applied later. # mutations will be applied later.
@ -1011,6 +1015,7 @@ class SideEffects:
] ]
) )
# pyrefly: ignore # bad-argument-type
cg(var._dict_vt, allow_cache=False) # Don't codegen via source cg(var._dict_vt, allow_cache=False) # Don't codegen via source
cg.extend_output( cg.extend_output(
[ [
@ -1031,7 +1036,9 @@ class SideEffects:
] ]
) )
elif isinstance( elif isinstance(
var, variables.UserDefinedListVariable var,
variables.UserDefinedListVariable,
# pyrefly: ignore # bad-argument-type
) and self.is_modified(var._list_vt): ) and self.is_modified(var._list_vt):
# Update the list to the updated items. Be careful in # Update the list to the updated items. Be careful in
# calling the list methods and not the overridden methods. # calling the list methods and not the overridden methods.
@ -1048,6 +1055,7 @@ class SideEffects:
] ]
) )
# pyrefly: ignore # bad-argument-type
cg(var._list_vt, allow_cache=False) # Don't codegen via source cg(var._list_vt, allow_cache=False) # Don't codegen via source
cg.extend_output( cg.extend_output(
[ [

View File

@ -563,6 +563,7 @@ def log_graph_break(
) )
else: else:
user_stack = get_stack_above_dynamo() + user_stack # type: ignore[assignment] user_stack = get_stack_above_dynamo() + user_stack # type: ignore[assignment]
# pyrefly: ignore # bad-argument-type
user_stack = collapse_resume_frames(user_stack) user_stack = collapse_resume_frames(user_stack)
user_stack_formatted = "".join(traceback.format_list(user_stack)) user_stack_formatted = "".join(traceback.format_list(user_stack))
user_stack_trace = ( user_stack_trace = (
@ -1040,6 +1041,7 @@ class BytecodeDispatchTableMeta(type):
op: getattr(cls, opname, functools.partial(_missing, opname)) op: getattr(cls, opname, functools.partial(_missing, opname))
for opname, op in dis.opmap.items() for opname, op in dis.opmap.items()
} }
# pyrefly: ignore # missing-attribute
cls.dispatch_table = [dispatch_table.get(i) for i in range(2**8)] cls.dispatch_table = [dispatch_table.get(i) for i in range(2**8)]
@ -1788,13 +1790,17 @@ class InstructionTranslatorBase(
source = self.import_source(module_name) source = self.import_source(module_name)
if self.exec_recorder: if self.exec_recorder:
# pyrefly: ignore # unbound-name
self.exec_recorder.add_local_mod(recorded_name, value) self.exec_recorder.add_local_mod(recorded_name, value)
# pyrefly: ignore # unbound-name
if istype(value, (types.ModuleType, DummyModule)): if istype(value, (types.ModuleType, DummyModule)):
# pyrefly: ignore # unbound-name
self.push(PythonModuleVariable(value, source=source)) self.push(PythonModuleVariable(value, source=source))
else: else:
unimplemented_v2( unimplemented_v2(
gb_type="Bad import result", gb_type="Bad import result",
# pyrefly: ignore # unbound-name
context=typestr(value), context=typestr(value),
explanation="Import result is not a Python module.", explanation="Import result is not a Python module.",
hints=[], hints=[],
@ -1873,6 +1879,7 @@ class InstructionTranslatorBase(
exit, exc = self.popn(2) exit, exc = self.popn(2)
assert exc is None assert exc is None
self.push(exc) self.push(exc)
# pyrefly: ignore # bad-argument-type
self.push(exit.call_function(self, [ConstantVariable.create(None)] * 3, {})) self.push(exit.call_function(self, [ConstantVariable.create(None)] * 3, {}))
def WITH_CLEANUP_FINISH(self, inst: Instruction) -> None: def WITH_CLEANUP_FINISH(self, inst: Instruction) -> None:
@ -2294,7 +2301,9 @@ class InstructionTranslatorBase(
): ):
return True return True
elif isinstance(exc_instance, variables.BuiltinVariable) and issubclass( elif isinstance(exc_instance, variables.BuiltinVariable) and issubclass(
exc_instance.fn, expected_type.fn exc_instance.fn,
# pyrefly: ignore # missing-attribute
expected_type.fn,
): ):
return True return True
@ -2354,26 +2363,37 @@ class InstructionTranslatorBase(
assert isinstance(null, NullVariable) assert isinstance(null, NullVariable)
if not isinstance( if not isinstance(
argsvars, BaseListVariable # pyrefly: ignore # unbound-name
argsvars,
BaseListVariable,
# pyrefly: ignore # unbound-name
) and argsvars.has_force_unpack_var_sequence(self): ) and argsvars.has_force_unpack_var_sequence(self):
# pyrefly: ignore # unbound-name
argsvars = TupleVariable(argsvars.force_unpack_var_sequence(self)) argsvars = TupleVariable(argsvars.force_unpack_var_sequence(self))
# Unpack for cases like fn(**obj) where obj is a map # Unpack for cases like fn(**obj) where obj is a map
# pyrefly: ignore # unbound-name
if isinstance(kwargsvars, UserDefinedObjectVariable): if isinstance(kwargsvars, UserDefinedObjectVariable):
kwargsvars = BuiltinVariable.call_custom_dict(self, dict, kwargsvars) # type: ignore[arg-type] kwargsvars = BuiltinVariable.call_custom_dict(self, dict, kwargsvars) # type: ignore[arg-type]
# pyrefly: ignore # unbound-name
if not isinstance(argsvars, BaseListVariable) or not isinstance( if not isinstance(argsvars, BaseListVariable) or not isinstance(
kwargsvars, ConstDictVariable # pyrefly: ignore # unbound-name
kwargsvars,
ConstDictVariable,
): ):
unimplemented_v2( unimplemented_v2(
gb_type="Variadic function call with bad args/kwargs type", gb_type="Variadic function call with bad args/kwargs type",
# pyrefly: ignore # unbound-name
context=f"args type: {typestr(argsvars)}, kwargs type: {typestr(kwargsvars)}", context=f"args type: {typestr(argsvars)}, kwargs type: {typestr(kwargsvars)}",
explanation="Expected args to be a list and kwargs to be a dict", explanation="Expected args to be a list and kwargs to be a dict",
hints=[*graph_break_hints.USER_ERROR], hints=[*graph_break_hints.USER_ERROR],
) )
# Map to a dictionary of str -> VariableTracker # Map to a dictionary of str -> VariableTracker
# pyrefly: ignore # unbound-name, missing-attribute
kwargsvars = kwargsvars.keys_as_python_constant() kwargsvars = kwargsvars.keys_as_python_constant()
# pyrefly: ignore # unbound-name, missing-attribute
self.call_function(fn, argsvars.items, kwargsvars) self.call_function(fn, argsvars.items, kwargsvars)
@break_graph_if_unsupported(push=1) @break_graph_if_unsupported(push=1)
@ -2437,6 +2457,7 @@ class InstructionTranslatorBase(
def LOAD_ATTR(self, inst: Instruction) -> None: def LOAD_ATTR(self, inst: Instruction) -> None:
if sys.version_info >= (3, 12): if sys.version_info >= (3, 12):
# pyrefly: ignore # unsupported-operation
if inst.arg % 2: if inst.arg % 2:
self.LOAD_METHOD(inst) self.LOAD_METHOD(inst)
return return
@ -3029,14 +3050,17 @@ class InstructionTranslatorBase(
"(i.e. `a, b, c = d`).", "(i.e. `a, b, c = d`).",
hints=[*graph_break_hints.USER_ERROR], hints=[*graph_break_hints.USER_ERROR],
) )
# pyrefly: ignore # unbound-name
if len(val) != inst.argval: if len(val) != inst.argval:
unimplemented_v2( unimplemented_v2(
gb_type="Length mismatch when unpacking object for UNPACK_SEQUENCE", gb_type="Length mismatch when unpacking object for UNPACK_SEQUENCE",
# pyrefly: ignore # unbound-name
context=f"expected length: {inst.argval}, actual: {len(val)}", context=f"expected length: {inst.argval}, actual: {len(val)}",
explanation=f"{seq} unpacked to a list for the UNPACK_SEQUENCE bytecode " explanation=f"{seq} unpacked to a list for the UNPACK_SEQUENCE bytecode "
"(i.e. `a, b, c = d`) with unexpected length.", "(i.e. `a, b, c = d`) with unexpected length.",
hints=[*graph_break_hints.DYNAMO_BUG], hints=[*graph_break_hints.DYNAMO_BUG],
) )
# pyrefly: ignore # unbound-name
for i in reversed(val): for i in reversed(val):
self.push(i) self.push(i)
@ -3409,9 +3433,13 @@ class InstructionTranslatorBase(
args = [contents[1]] args = [contents[1]]
if kw_names: if kw_names:
# pyrefly: ignore # bad-argument-type
args = args + contents[2 : -len(kw_names)] args = args + contents[2 : -len(kw_names)]
# pyrefly: ignore # bad-argument-type
kwargs_list = contents[-len(kw_names) :] kwargs_list = contents[-len(kw_names) :]
# pyrefly: ignore # no-matching-overload
kwargs = dict(zip(kw_names, kwargs_list)) kwargs = dict(zip(kw_names, kwargs_list))
# pyrefly: ignore # bad-argument-type
assert len(kwargs) == len(kw_names) assert len(kwargs) == len(kw_names)
else: else:
args = args + contents[2:] args = args + contents[2:]
@ -4118,6 +4146,7 @@ class InstructionTranslator(InstructionTranslatorBase):
and isinstance(tos, LocalGeneratorObjectVariable) and isinstance(tos, LocalGeneratorObjectVariable)
): ):
self.stack[-1] = ListIteratorVariable( self.stack[-1] = ListIteratorVariable(
# pyrefly: ignore # unbound-name
tos.force_unpack_var_sequence(self), tos.force_unpack_var_sequence(self),
mutation_type=ValueMutationNew(), mutation_type=ValueMutationNew(),
) )
@ -4188,6 +4217,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
"""Trace and inline a called method""" """Trace and inline a called method"""
symbolic_result: Optional[VariableTracker] symbolic_result: Optional[VariableTracker]
# pyrefly: ignore # bad-override
parent: InstructionTranslatorBase parent: InstructionTranslatorBase
@classmethod @classmethod
@ -4231,6 +4261,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
# trace through. # trace through.
if ( if (
hasattr(getattr(func, "fn", None), "_origin") hasattr(getattr(func, "fn", None), "_origin")
# pyrefly: ignore # missing-attribute
and func.fn._origin is produce_trampoline_autograd_apply and func.fn._origin is produce_trampoline_autograd_apply
): ):
# Known sound # Known sound
@ -4305,12 +4336,14 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
tracing_ctx.previously_inlined_functions[code] = result tracing_ctx.previously_inlined_functions[code] = result
try: try:
# pyrefly: ignore # missing-attribute
sub_locals = func.bind_args(parent, args, kwargs) sub_locals = func.bind_args(parent, args, kwargs)
except TypeError as e: except TypeError as e:
# Wrap the general TypeError during bind_args() to the internal ArgsMismatchError with detailed info # Wrap the general TypeError during bind_args() to the internal ArgsMismatchError with detailed info
raise ArgsMismatchError( # noqa: B904 raise ArgsMismatchError( # noqa: B904
"{reason}.\n func = {func}, args = {args}, kwargs = {kwargs}".format( "{reason}.\n func = {func}, args = {args}, kwargs = {kwargs}".format(
reason=str(e), reason=str(e),
# pyrefly: ignore # missing-attribute
func=f"'{func.get_name()}' {func.get_filename()}:{func.get_code().co_firstlineno}", func=f"'{func.get_name()}' {func.get_filename()}:{func.get_code().co_firstlineno}",
args=[arg.python_type() for arg in args], args=[arg.python_type() for arg in args],
kwargs=kwargs, kwargs=kwargs,
@ -4394,6 +4427,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
sub_locals, sub_locals,
parent.symbolic_globals, parent.symbolic_globals,
parent.symbolic_torch_function_state, parent.symbolic_torch_function_state,
# pyrefly: ignore # bad-argument-type
func, func,
) )
return tracer return tracer

View File

@ -153,7 +153,9 @@ class CPythonTestCase(TestCase):
assertTupleEqual = unittest.TestCase.assertTupleEqual assertTupleEqual = unittest.TestCase.assertTupleEqual
assertSetEqual = unittest.TestCase.assertSetEqual assertSetEqual = unittest.TestCase.assertSetEqual
assertDictEqual = polyfills.assert_dict_equal assertDictEqual = polyfills.assert_dict_equal
# pyrefly: ignore # bad-override
assertRaises = unittest.TestCase.assertRaises assertRaises = unittest.TestCase.assertRaises
# pyrefly: ignore # bad-override
assertRaisesRegex = unittest.TestCase.assertRaisesRegex assertRaisesRegex = unittest.TestCase.assertRaisesRegex
assertWarns = unittest.TestCase.assertWarns assertWarns = unittest.TestCase.assertWarns
assertWarnsRegex = unittest.TestCase.assertWarnsRegex assertWarnsRegex = unittest.TestCase.assertWarnsRegex
@ -169,8 +171,10 @@ class CPythonTestCase(TestCase):
) -> Callable[..., Any]: ) -> Callable[..., Any]:
# We want to compile only the test function, excluding any setup code # We want to compile only the test function, excluding any setup code
# from unittest # from unittest
method = getattr(self, self._testMethodName) method = getattr(self, self._testMethodName)
method = torch._dynamo.optimize(backend, error_on_graph_break=nopython)(method) method = torch._dynamo.optimize(backend, error_on_graph_break=nopython)(method)
setattr(self, self._testMethodName, method) setattr(self, self._testMethodName, method)
return fn return fn

View File

@ -207,6 +207,7 @@ torch._inductor.config.{"cpp" if device == "cpu" else "triton"}.inject_relu_bug_
launch_file = _as_posix_path(os.path.join(repro_dir, "minifier_launcher.py")) launch_file = _as_posix_path(os.path.join(repro_dir, "minifier_launcher.py"))
with open(launch_file) as f: with open(launch_file) as f:
launch_code = f.read() launch_code = f.read()
self.assertTrue(os.path.exists(launch_file)) self.assertTrue(os.path.exists(launch_file))
args = ["python3", launch_file, "minify", *minifier_args] args = ["python3", launch_file, "minify", *minifier_args]
@ -218,6 +219,7 @@ torch._inductor.config.{"cpp" if device == "cpu" else "triton"}.inject_relu_bug_
print("minifier stdout:", launch_proc.stdout.decode("utf-8")) print("minifier stdout:", launch_proc.stdout.decode("utf-8"))
stderr = launch_proc.stderr.decode("utf-8") stderr = launch_proc.stderr.decode("utf-8")
print("minifier stderr:", stderr) print("minifier stderr:", stderr)
self.assertNotIn("Input graph did not fail the tester", stderr) self.assertNotIn("Input graph did not fail the tester", stderr)
return launch_proc, launch_code return launch_proc, launch_code
@ -230,6 +232,7 @@ torch._inductor.config.{"cpp" if device == "cpu" else "triton"}.inject_relu_bug_
repro_file = _as_posix_path(os.path.join(repro_dir, "repro.py")) repro_file = _as_posix_path(os.path.join(repro_dir, "repro.py"))
with open(repro_file) as f: with open(repro_file) as f:
repro_code = f.read() repro_code = f.read()
self.assertTrue(os.path.exists(repro_file)) self.assertTrue(os.path.exists(repro_file))
repro_proc = self._maybe_subprocess_run( repro_proc = self._maybe_subprocess_run(
@ -296,11 +299,14 @@ torch._dynamo.config.debug_dir_root = "{_as_posix_path(self.DEBUG_DIR)}"
if expected_error is None: if expected_error is None:
# Just check that there was no error # Just check that there was no error
self.assertEqual(test_proc.returncode, 0) self.assertEqual(test_proc.returncode, 0)
self.assertIsNone(repro_dir) self.assertIsNone(repro_dir)
return None return None
# NB: Intentionally do not test return code; we only care about # NB: Intentionally do not test return code; we only care about
# actually generating the repro, we don't have to crash # actually generating the repro, we don't have to crash
self.assertIn(expected_error, test_proc.stderr.decode("utf-8")) self.assertIn(expected_error, test_proc.stderr.decode("utf-8"))
self.assertIsNotNone(repro_dir) self.assertIsNotNone(repro_dir)
print("running minifier", file=sys.stderr) print("running minifier", file=sys.stderr)
_minifier_proc, minifier_code = self._run_minifier_launcher( _minifier_proc, minifier_code = self._run_minifier_launcher(
@ -311,6 +317,7 @@ torch._dynamo.config.debug_dir_root = "{_as_posix_path(self.DEBUG_DIR)}"
) )
print("running repro", file=sys.stderr) print("running repro", file=sys.stderr)
repro_proc, repro_code = self._run_repro(repro_dir, isolate=isolate) repro_proc, repro_code = self._run_repro(repro_dir, isolate=isolate)
self.assertIn(expected_error, repro_proc.stderr.decode("utf-8")) self.assertIn(expected_error, repro_proc.stderr.decode("utf-8"))
self.assertNotEqual(repro_proc.returncode, 0) self.assertNotEqual(repro_proc.returncode, 0)
return MinifierTestResult(minifier_code=minifier_code, repro_code=repro_code) return MinifierTestResult(minifier_code=minifier_code, repro_code=repro_code)

View File

@ -496,6 +496,7 @@ def make_test_cls_with_patches(
def skipIfNotPy311(fn: Callable[_P, _T]) -> Callable[_P, _T]: def skipIfNotPy311(fn: Callable[_P, _T]) -> Callable[_P, _T]:
if sys.version_info >= (3, 11): if sys.version_info >= (3, 11):
return fn return fn
# pyrefly: ignore # bad-return, bad-argument-type
return unittest.skip(fn) return unittest.skip(fn)

View File

@ -3005,6 +3005,7 @@ def get_torch_obj_rule_map() -> dict[Any, type["VariableTracker"]]:
obj = torch_dir + k[len("torch/") :] obj = torch_dir + k[len("torch/") :]
if obj is not None: if obj is not None:
if is_annotate_wrapped_function(obj): if is_annotate_wrapped_function(obj):
# pyrefly: ignore # missing-attribute
obj = obj.__wrapped__ obj = obj.__wrapped__
if is_lru_cache_wrapped_function(obj): if is_lru_cache_wrapped_function(obj):
obj = obj.__wrapped__ obj = obj.__wrapped__

View File

@ -295,11 +295,13 @@ def increment_op_count(cnt: int) -> None:
def calculate_time_spent() -> dict[str, float]: def calculate_time_spent() -> dict[str, float]:
total_by_key = {} total_by_key = {}
for phase, timing in cumulative_time_spent_ns.items(): for phase, timing in cumulative_time_spent_ns.items():
# pyrefly: ignore # unsupported-operation
total_by_key[phase] = timing / 1e9 total_by_key[phase] = timing / 1e9
total_by_key["total_wall_time"] = total_by_key.get( total_by_key["total_wall_time"] = total_by_key.get(
"entire_frame_compile", 0 "entire_frame_compile", 0
) + total_by_key.get("entire_backward_compile", 0) ) + total_by_key.get("entire_backward_compile", 0)
# pyrefly: ignore # bad-return
return total_by_key return total_by_key
@ -798,6 +800,7 @@ def compile_times(repr: Literal["str"], aggregate: bool = False) -> str: ...
@overload @overload
# pyrefly: ignore # inconsistent-overload
def compile_times( def compile_times(
repr: Literal["csv"], aggregate: bool = False repr: Literal["csv"], aggregate: bool = False
) -> tuple[list[str], list[object]]: ... ) -> tuple[list[str], list[object]]: ...
@ -1463,6 +1466,7 @@ class CompilationMetrics:
compile_id = all_metrics.get("compile_id") compile_id = all_metrics.get("compile_id")
all_metrics["compile_id"] = str(compile_id) if compile_id else None all_metrics["compile_id"] = str(compile_id) if compile_id else None
# pyrefly: ignore # bad-argument-type
return cls(**all_metrics) return cls(**all_metrics)
@ -2253,6 +2257,7 @@ def is_jit_model(
Union[ Union[
torch.jit._trace.TopLevelTracedModule, torch.jit._trace.TopLevelTracedModule,
torch.jit._script.RecursiveScriptModule, torch.jit._script.RecursiveScriptModule,
# pyrefly: ignore # invalid-param-spec
torch.jit.ScriptFunction[Any, Any], torch.jit.ScriptFunction[Any, Any],
torch.jit.ScriptModule, torch.jit.ScriptModule,
] ]
@ -2361,6 +2366,7 @@ def checkpoint_params(gm: torch.fx.GraphModule) -> Callable[[], None]:
cuda_rng_state = torch.clone(torch.cuda.get_rng_state()) cuda_rng_state = torch.clone(torch.cuda.get_rng_state())
saved_state = [ saved_state = [
(param, param._version, torch.clone(param)) (param, param._version, torch.clone(param))
# pyrefly: ignore # bad-argument-type
for param in itertools.chain(gm.parameters(), gm.buffers()) for param in itertools.chain(gm.parameters(), gm.buffers())
] ]
@ -2626,13 +2632,16 @@ def get_items_from_dict(obj: dict[K, V]) -> Iterable[tuple[K, Union[V, Any]]]:
if istype(obj, (dict, OrderedDict)): if istype(obj, (dict, OrderedDict)):
return obj.items() return obj.items()
elif isinstance(obj, OrderedDict): elif isinstance(obj, OrderedDict):
# pyrefly: ignore # bad-argument-type
return [(k, OrderedDict.__getitem__(obj, k)) for k in OrderedDict.keys(obj)] return [(k, OrderedDict.__getitem__(obj, k)) for k in OrderedDict.keys(obj)]
else: else:
# pyrefly: ignore # bad-argument-type
return [(k, dict.__getitem__(obj, k)) for k in dict.keys(obj)] return [(k, dict.__getitem__(obj, k)) for k in dict.keys(obj)]
def nn_module_new(cls: Any) -> Any: def nn_module_new(cls: Any) -> Any:
obj = object_new(cls) obj = object_new(cls)
# pyrefly: ignore # bad-argument-type
torch.nn.Module.__init__(obj) torch.nn.Module.__init__(obj)
return obj return obj
@ -2679,6 +2688,7 @@ def dict_keys_getitem(d: dict[Any, Any], n: int) -> Any:
dict_class = dict dict_class = dict
if isinstance(d, OrderedDict): if isinstance(d, OrderedDict):
dict_class = OrderedDict dict_class = OrderedDict
# pyrefly: ignore # bad-argument-type
return next(itertools.islice(dict_class.keys(d), n, n + 1)) return next(itertools.islice(dict_class.keys(d), n, n + 1))
@ -3222,8 +3232,10 @@ def format_func_info(code: CodeType) -> str:
@contextlib.contextmanager @contextlib.contextmanager
def disable_cache_limit() -> Generator[None, None, None]: def disable_cache_limit() -> Generator[None, None, None]:
prior = config.recompile_limit prior = config.recompile_limit
# pyrefly: ignore # bad-assignment
config.recompile_limit = sys.maxsize config.recompile_limit = sys.maxsize
prior_acc_limit = config.accumulated_recompile_limit prior_acc_limit = config.accumulated_recompile_limit
# pyrefly: ignore # bad-assignment
config.accumulated_recompile_limit = sys.maxsize config.accumulated_recompile_limit = sys.maxsize
try: try:
@ -3958,6 +3970,7 @@ class numpy_operator_wrapper(Generic[_P, R]):
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> Any: def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> Any:
assert not kwargs assert not kwargs
# pyrefly: ignore # bad-assignment
args = ( args = (
tnp.ndarray(arg) if isinstance(arg, torch.Tensor) else arg for arg in args tnp.ndarray(arg) if isinstance(arg, torch.Tensor) else arg for arg in args
) )
@ -4157,6 +4170,7 @@ def _extract_anchors_from_expr(segment: str) -> Optional[_Anchors]:
# (x) + (y) # (x) + (y)
# ~~^~~~~~~ # ~~^~~~~~~
while (ch := lines[cur_lineno][cur_col]).isspace() or ch in ")\\#": while (ch := lines[cur_lineno][cur_col]).isspace() or ch in ")\\#":
# pyrefly: ignore # unbound-name
if ch in "\\#": if ch in "\\#":
cur_lineno, cur_col = nextline(cur_lineno, cur_col) cur_lineno, cur_col = nextline(cur_lineno, cur_col)
else: else:
@ -4507,6 +4521,7 @@ class GmWrapper(torch.nn.Module):
self.unflatten_fn = unflatten_fn self.unflatten_fn = unflatten_fn
def forward(self, *args: Any) -> Any: def forward(self, *args: Any) -> Any:
# pyrefly: ignore # annotation-mismatch
args: list[Any] = list(args) args: list[Any] = list(args)
return self.gm(*self.unflatten_fn(args)) return self.gm(*self.unflatten_fn(args))

View File

@ -1028,6 +1028,7 @@ class BuiltinVariable(VariableTracker):
def call_self_handler(tx: "InstructionTranslator", args, kwargs): def call_self_handler(tx: "InstructionTranslator", args, kwargs):
try: try:
# pyrefly: ignore # not-callable
result = self_handler(tx, *args, **kwargs) result = self_handler(tx, *args, **kwargs)
if result is not None: if result is not None:
return result return result
@ -1035,6 +1036,7 @@ class BuiltinVariable(VariableTracker):
# Check if binding is bad. inspect signature bind is expensive. # Check if binding is bad. inspect signature bind is expensive.
# So check only when handler call fails. # So check only when handler call fails.
try: try:
# pyrefly: ignore # bad-argument-type
inspect.signature(self_handler).bind(tx, *args, **kwargs) inspect.signature(self_handler).bind(tx, *args, **kwargs)
except TypeError as e: except TypeError as e:
has_constant_handler = obj.has_constant_handler(args, kwargs) has_constant_handler = obj.has_constant_handler(args, kwargs)
@ -1087,6 +1089,7 @@ class BuiltinVariable(VariableTracker):
hints=[*graph_break_hints.DYNAMO_BUG], hints=[*graph_break_hints.DYNAMO_BUG],
from_exc=exc, from_exc=exc,
) )
# pyrefly: ignore # unbound-name
return VariableTracker.build(tx, res) return VariableTracker.build(tx, res)
else: else:
@ -1115,6 +1118,7 @@ class BuiltinVariable(VariableTracker):
tx, tx,
args=list(map(ConstantVariable.create, exc.args)), args=list(map(ConstantVariable.create, exc.args)),
) )
# pyrefly: ignore # unbound-name
return VariableTracker.build(tx, res) return VariableTracker.build(tx, res)
handlers.append(constant_fold_handler) handlers.append(constant_fold_handler)
@ -1437,6 +1441,7 @@ class BuiltinVariable(VariableTracker):
resolved_fn = getattr(self.fn, name) resolved_fn = getattr(self.fn, name)
if resolved_fn in dict_methods: if resolved_fn in dict_methods:
if isinstance(args[0], variables.UserDefinedDictVariable): if isinstance(args[0], variables.UserDefinedDictVariable):
# pyrefly: ignore # missing-attribute
return args[0]._dict_vt.call_method(tx, name, args[1:], kwargs) return args[0]._dict_vt.call_method(tx, name, args[1:], kwargs)
elif isinstance(args[0], variables.ConstDictVariable): elif isinstance(args[0], variables.ConstDictVariable):
return args[0].call_method(tx, name, args[1:], kwargs) return args[0].call_method(tx, name, args[1:], kwargs)
@ -1445,6 +1450,7 @@ class BuiltinVariable(VariableTracker):
resolved_fn = getattr(self.fn, name) resolved_fn = getattr(self.fn, name)
if resolved_fn in set_methods: if resolved_fn in set_methods:
if isinstance(args[0], variables.UserDefinedSetVariable): if isinstance(args[0], variables.UserDefinedSetVariable):
# pyrefly: ignore # missing-attribute
return args[0]._set_vt.call_method(tx, name, args[1:], kwargs) return args[0]._set_vt.call_method(tx, name, args[1:], kwargs)
elif isinstance(args[0], variables.SetVariable): elif isinstance(args[0], variables.SetVariable):
return args[0].call_method(tx, name, args[1:], kwargs) return args[0].call_method(tx, name, args[1:], kwargs)
@ -1533,10 +1539,12 @@ class BuiltinVariable(VariableTracker):
if type(arg.value).__str__ is object.__str__: if type(arg.value).__str__ is object.__str__:
# Rely on the object str method # Rely on the object str method
try: try:
# pyrefly: ignore # unbound-name
return variables.ConstantVariable.create(value=str_method()) return variables.ConstantVariable.create(value=str_method())
except AttributeError: except AttributeError:
# Graph break # Graph break
return return
# pyrefly: ignore # unbound-name
elif is_wrapper_or_member_descriptor(str_method): elif is_wrapper_or_member_descriptor(str_method):
unimplemented_v2( unimplemented_v2(
gb_type="Attempted to a str() method implemented in C/C++", gb_type="Attempted to a str() method implemented in C/C++",
@ -1653,8 +1661,10 @@ class BuiltinVariable(VariableTracker):
else: else:
raw_b = b.raw_value raw_b = b.raw_value
if self.fn is max: if self.fn is max:
# pyrefly: ignore # missing-attribute
raw_res = max(a.raw_value, raw_b) raw_res = max(a.raw_value, raw_b)
else: else:
# pyrefly: ignore # missing-attribute
raw_res = min(a.raw_value, raw_b) raw_res = min(a.raw_value, raw_b)
need_unwrap = any( need_unwrap = any(
@ -2106,6 +2116,7 @@ class BuiltinVariable(VariableTracker):
) )
if isinstance(arg, variables.UserDefinedExceptionClassVariable): if isinstance(arg, variables.UserDefinedExceptionClassVariable):
# pyrefly: ignore # unbound-name
return ConstantVariable.create(isinstance(arg_type, isinstance_type)) return ConstantVariable.create(isinstance(arg_type, isinstance_type))
isinstance_type_tuple: tuple[type, ...] isinstance_type_tuple: tuple[type, ...]
@ -2138,8 +2149,10 @@ class BuiltinVariable(VariableTracker):
# through it. This is a limitation of the current implementation. # through it. This is a limitation of the current implementation.
# Usually `__subclasscheck__` and `__instancecheck__` can be constant fold through, it # Usually `__subclasscheck__` and `__instancecheck__` can be constant fold through, it
# might not be a big issue and we trade off it for performance. # might not be a big issue and we trade off it for performance.
# pyrefly: ignore # unbound-name
val = issubclass(arg_type, isinstance_type_tuple) val = issubclass(arg_type, isinstance_type_tuple)
except TypeError: except TypeError:
# pyrefly: ignore # unbound-name
val = arg_type in isinstance_type_tuple val = arg_type in isinstance_type_tuple
return variables.ConstantVariable.create(val) return variables.ConstantVariable.create(val)
@ -2161,6 +2174,7 @@ class BuiltinVariable(VariableTracker):
# WARNING: This might run arbitrary user code `__subclasscheck__`. # WARNING: This might run arbitrary user code `__subclasscheck__`.
# See the comment in call_isinstance above. # See the comment in call_isinstance above.
# pyrefly: ignore # unbound-name
return variables.ConstantVariable(issubclass(left_ty_py, right_ty_py)) return variables.ConstantVariable(issubclass(left_ty_py, right_ty_py))
def call_super(self, tx: "InstructionTranslator", a, b): def call_super(self, tx: "InstructionTranslator", a, b):
@ -2206,7 +2220,9 @@ class BuiltinVariable(VariableTracker):
value = getattr(self.fn, name) value = getattr(self.fn, name)
except AttributeError: except AttributeError:
raise_observed_exception(AttributeError, tx) raise_observed_exception(AttributeError, tx)
# pyrefly: ignore # unbound-name
if not callable(value): if not callable(value):
# pyrefly: ignore # unbound-name
return VariableTracker.build(tx, value, source) return VariableTracker.build(tx, value, source)
return variables.GetAttrVariable(self, name, source=source) return variables.GetAttrVariable(self, name, source=source)

View File

@ -651,6 +651,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
def handle_use_deterministic_algorithms( def handle_use_deterministic_algorithms(
self, tx: "InstructionTranslator", mode, warn_only=False self, tx: "InstructionTranslator", mode, warn_only=False
): ):
# pyrefly: ignore # missing-attribute
if warn_only and warn_only.as_python_constant(): if warn_only and warn_only.as_python_constant():
unimplemented_v2( unimplemented_v2(
gb_type="Attempted to use torch.use_deterministic_algorithms(warn_only=True)", gb_type="Attempted to use torch.use_deterministic_algorithms(warn_only=True)",
@ -1035,6 +1036,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
else: else:
raise torch._dynamo.exc.Unsupported("branch not supported") raise torch._dynamo.exc.Unsupported("branch not supported")
return variables.ConstantVariable.create( return variables.ConstantVariable.create(
# pyrefly: ignore # bad-argument-type
torch.fx.experimental.symbolic_shapes.guard_scalar(val) torch.fx.experimental.symbolic_shapes.guard_scalar(val)
) )
@ -1081,6 +1083,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
return return
return variables.ConstantVariable.create( return variables.ConstantVariable.create(
# pyrefly: ignore # bad-argument-type
torch.fx.experimental.symbolic_shapes.has_static_value(val) torch.fx.experimental.symbolic_shapes.has_static_value(val)
) )
@ -1212,13 +1215,17 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
) )
# need to guard only on no-arg get_device_module # need to guard only on no-arg get_device_module
# pyrefly: ignore # unbound-name
if device is None: if device is None:
source = CallFunctionNoArgsSource(self.source) source = CallFunctionNoArgsSource(self.source)
install_guard(source.make_guard(GuardBuilder.ID_MATCH)) install_guard(source.make_guard(GuardBuilder.ID_MATCH))
# assumes `module` is in the form `torch.xyz` # assumes `module` is in the form `torch.xyz`
new_source = AttrSource( new_source = AttrSource(
TorchSource(), module.__name__.rsplit(".", maxsplit=1)[-1] TorchSource(),
# pyrefly: ignore # unbound-name
module.__name__.rsplit(".", maxsplit=1)[-1],
) )
# pyrefly: ignore # unbound-name
return VariableTracker.build(tx, module, new_source) return VariableTracker.build(tx, module, new_source)
@register(torch.set_default_device) @register(torch.set_default_device)
@ -1373,9 +1380,12 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
f"{fn.__name__}_spec", f_spec f"{fn.__name__}_spec", f_spec
) )
input_spec_proxy = tx.output.register_static_attr_and_return_proxy( input_spec_proxy = tx.output.register_static_attr_and_return_proxy(
fn.__name__ + "_input_spec", input_spec fn.__name__ + "_input_spec",
# pyrefly: ignore # unbound-name
input_spec,
) )
f_spec_proxy.node.type = type(f_spec) f_spec_proxy.node.type = type(f_spec)
# pyrefly: ignore # unbound-name
input_spec_proxy.node.type = type(input_spec) input_spec_proxy.node.type = type(input_spec)
all_args = (f_spec_proxy, input_spec_proxy, *proxified_flat_args) all_args = (f_spec_proxy, input_spec_proxy, *proxified_flat_args)
@ -1716,6 +1726,7 @@ For now, dynamo will explicitly graph break when it encounters user code with th
) )
# this results in cleaner graphs, but only works for inputs # this results in cleaner graphs, but only works for inputs
# pyrefly: ignore # missing-attribute
if data.source: if data.source:
return cls._nn_param_via_prefix_insert(tx, data, requires_grad) return cls._nn_param_via_prefix_insert(tx, data, requires_grad)
@ -1734,7 +1745,9 @@ For now, dynamo will explicitly graph break when it encounters user code with th
# TODO[@lucaskabela]: Remove the behavior below since it is deprecated # TODO[@lucaskabela]: Remove the behavior below since it is deprecated
if isinstance( if isinstance(
data, TensorWithTFOverrideVariable data,
TensorWithTFOverrideVariable,
# pyrefly: ignore # missing-attribute
) or is_traceable_wrapper_subclass_type(data.class_type): ) or is_traceable_wrapper_subclass_type(data.class_type):
unimplemented_v2( unimplemented_v2(
gb_type="Attempted to use torch.nn.Parameter constructor with tensor subclass", gb_type="Attempted to use torch.nn.Parameter constructor with tensor subclass",
@ -1757,8 +1770,11 @@ For now, dynamo will explicitly graph break when it encounters user code with th
) )
try: try:
# pyrefly: ignore # missing-attribute
shape = tuple(data.var_getattr(tx, "shape").as_python_constant()) shape = tuple(data.var_getattr(tx, "shape").as_python_constant())
# pyrefly: ignore # missing-attribute
dtype = data.var_getattr(tx, "dtype").as_python_constant() dtype = data.var_getattr(tx, "dtype").as_python_constant()
# pyrefly: ignore # missing-attribute
device = data.var_getattr(tx, "device").as_python_constant() device = data.var_getattr(tx, "device").as_python_constant()
except NotImplementedError as e: except NotImplementedError as e:
unimplemented_v2( unimplemented_v2(
@ -1773,9 +1789,13 @@ For now, dynamo will explicitly graph break when it encounters user code with th
) )
placeholder = tx.output.synthetic_graph_input( placeholder = tx.output.synthetic_graph_input(
new_parameter_placeholder, [shape, dtype, device, requires_grad] new_parameter_placeholder,
# pyrefly: ignore # unbound-name
[shape, dtype, device, requires_grad],
) )
# pyrefly: ignore # missing-attribute
if data.requires_grad: if data.requires_grad:
# pyrefly: ignore # missing-attribute
data = data.call_method(tx, "detach", [], {}) data = data.call_method(tx, "detach", [], {})
from .builder import wrap_fx_proxy from .builder import wrap_fx_proxy
@ -1785,6 +1805,7 @@ For now, dynamo will explicitly graph break when it encounters user code with th
tx.output.create_proxy( tx.output.create_proxy(
"call_function", "call_function",
tracable_create_parameter, tracable_create_parameter,
# pyrefly: ignore # missing-attribute
(data.as_proxy(), placeholder.as_proxy()), (data.as_proxy(), placeholder.as_proxy()),
{}, {},
), ),

View File

@ -646,7 +646,6 @@ def update_schema():
assert thrift_content[1].startswith("// checksum<<") assert thrift_content[1].startswith("// checksum<<")
thrift_checksum_real = _hash_content("\n".join(thrift_content[2:])) thrift_checksum_real = _hash_content("\n".join(thrift_content[2:]))
# pyrefly: ignore # import-error
from yaml import load, Loader from yaml import load, Loader
dst = load(content, Loader=Loader) dst = load(content, Loader=Loader)

View File

@ -34,7 +34,7 @@ from torch.fx._pytree import (
_deregister_pytree_flatten_spec, _deregister_pytree_flatten_spec,
register_pytree_flatten_spec, register_pytree_flatten_spec,
) )
from torch.utils._pytree import ( # pyrefly: ignore # deprecated from torch.utils._pytree import (
_deregister_pytree_node, _deregister_pytree_node,
_register_pytree_node, _register_pytree_node,
Context, Context,

View File

@ -198,7 +198,7 @@ def generate_yaml_from_profiles(op_profiles: dict[str, set[OpProfile]]) -> str:
to a file. The yaml string can be loaded back into an operator profile to a file. The yaml string can be loaded back into an operator profile
structure using `read_profiles_from_yaml`. structure using `read_profiles_from_yaml`.
""" """
# pyrefly: ignore # import-error
import yaml import yaml
from torch._export.serde.serialize import ( from torch._export.serde.serialize import (
@ -263,7 +263,7 @@ def read_profiles_from_yaml(yaml_str: str) -> dict[str, set[OpProfile]]:
""" """
Reads the yaml saved by `save_op_profiles` and returns the operator profiles. Reads the yaml saved by `save_op_profiles` and returns the operator profiles.
""" """
# pyrefly: ignore # import-error
import yaml import yaml
from torch._export.serde.serialize import ( from torch._export.serde.serialize import (

View File

@ -1,5 +1,5 @@
from .core import dispatch from .core import dispatch
from .dispatcher import ( # pyrefly: ignore # deprecated from .dispatcher import (
Dispatcher, Dispatcher,
halt_ordering, halt_ordering,
MDNotImplementedError, MDNotImplementedError,

View File

@ -153,6 +153,7 @@ class CausalBias(torch.Tensor):
diagonal=diagonal_offset, diagonal=diagonal_offset,
) )
# pyrefly: ignore # bad-return
def _materialize(self, device: Optional[torch.device] = None) -> torch.Tensor: def _materialize(self, device: Optional[torch.device] = None) -> torch.Tensor:
""" """
Materializes the causal bias into a tensor form. Materializes the causal bias into a tensor form.

View File

@ -84,6 +84,7 @@ _score_mod_signature = Callable[[Tensor, Tensor, Tensor, Tensor, Tensor], Tensor
_mask_mod_signature = Callable[[Tensor, Tensor, Tensor, Tensor], Tensor] _mask_mod_signature = Callable[[Tensor, Tensor, Tensor, Tensor], Tensor]
# pyrefly: ignore # invalid-inheritance
class FlexKernelOptions(TypedDict, total=False): class FlexKernelOptions(TypedDict, total=False):
"""Options for controlling the behavior of FlexAttention kernels. """Options for controlling the behavior of FlexAttention kernels.
@ -127,76 +128,93 @@ class FlexKernelOptions(TypedDict, total=False):
""" """
# Performance tuning options # Performance tuning options
# pyrefly: ignore # invalid-annotation
num_warps: NotRequired[int] num_warps: NotRequired[int]
"""Number of warps to use in the CUDA kernel. Higher values may improve performance """Number of warps to use in the CUDA kernel. Higher values may improve performance
but increase register pressure. Default is determined by autotuning.""" but increase register pressure. Default is determined by autotuning."""
# pyrefly: ignore # invalid-annotation
num_stages: NotRequired[int] num_stages: NotRequired[int]
"""Number of pipeline stages in the CUDA kernel. Higher values may improve performance """Number of pipeline stages in the CUDA kernel. Higher values may improve performance
but increase shared memory usage. Default is determined by autotuning.""" but increase shared memory usage. Default is determined by autotuning."""
# pyrefly: ignore # invalid-annotation
BLOCK_M: NotRequired[int] BLOCK_M: NotRequired[int]
"""Thread block size for the sequence length dimension of Q in forward pass. """Thread block size for the sequence length dimension of Q in forward pass.
Must be a power of 2. Common values: 16, 32, 64, 128. Default is determined by autotuning.""" Must be a power of 2. Common values: 16, 32, 64, 128. Default is determined by autotuning."""
# pyrefly: ignore # invalid-annotation
BLOCK_N: NotRequired[int] BLOCK_N: NotRequired[int]
"""Thread block size for the sequence length dimension of K/V in forward pass. """Thread block size for the sequence length dimension of K/V in forward pass.
Must be a power of 2. Common values: 16, 32, 64, 128. Default is determined by autotuning.""" Must be a power of 2. Common values: 16, 32, 64, 128. Default is determined by autotuning."""
# Backward-specific block sizes (when prefixed with 'bwd_') # Backward-specific block sizes (when prefixed with 'bwd_')
# pyrefly: ignore # invalid-annotation
BLOCK_M1: NotRequired[int] BLOCK_M1: NotRequired[int]
"""Thread block size for Q dimension in backward pass. Use as 'bwd_BLOCK_M1'. """Thread block size for Q dimension in backward pass. Use as 'bwd_BLOCK_M1'.
Default is determined by autotuning.""" Default is determined by autotuning."""
# pyrefly: ignore # invalid-annotation
BLOCK_N1: NotRequired[int] BLOCK_N1: NotRequired[int]
"""Thread block size for K/V dimension in backward pass. Use as 'bwd_BLOCK_N1'. """Thread block size for K/V dimension in backward pass. Use as 'bwd_BLOCK_N1'.
Default is determined by autotuning.""" Default is determined by autotuning."""
# pyrefly: ignore # invalid-annotation
BLOCK_M2: NotRequired[int] BLOCK_M2: NotRequired[int]
"""Thread block size for second Q dimension in backward pass. Use as 'bwd_BLOCK_M2'. """Thread block size for second Q dimension in backward pass. Use as 'bwd_BLOCK_M2'.
Default is determined by autotuning.""" Default is determined by autotuning."""
# pyrefly: ignore # invalid-annotation
BLOCK_N2: NotRequired[int] BLOCK_N2: NotRequired[int]
"""Thread block size for second K/V dimension in backward pass. Use as 'bwd_BLOCK_N2'. """Thread block size for second K/V dimension in backward pass. Use as 'bwd_BLOCK_N2'.
Default is determined by autotuning.""" Default is determined by autotuning."""
# pyrefly: ignore # invalid-annotation
PRESCALE_QK: NotRequired[bool] PRESCALE_QK: NotRequired[bool]
"""Whether to pre-scale QK by 1/sqrt(d) and change of base. This is slightly faster but """Whether to pre-scale QK by 1/sqrt(d) and change of base. This is slightly faster but
may have more numerical error. Default: False.""" may have more numerical error. Default: False."""
# pyrefly: ignore # invalid-annotation
ROWS_GUARANTEED_SAFE: NotRequired[bool] ROWS_GUARANTEED_SAFE: NotRequired[bool]
"""If True, guarantees that at least one value in each row is not masked out. """If True, guarantees that at least one value in each row is not masked out.
Allows skipping safety checks for better performance. Only set this if you are certain Allows skipping safety checks for better performance. Only set this if you are certain
your mask guarantees this property. For example, causal attention is guaranteed safe your mask guarantees this property. For example, causal attention is guaranteed safe
because each query has at least 1 key-value to attend to. Default: False.""" because each query has at least 1 key-value to attend to. Default: False."""
# pyrefly: ignore # invalid-annotation
BLOCKS_ARE_CONTIGUOUS: NotRequired[bool] BLOCKS_ARE_CONTIGUOUS: NotRequired[bool]
"""If True, guarantees that all blocks in the mask are contiguous. """If True, guarantees that all blocks in the mask are contiguous.
Allows optimizing block traversal. For example, causal masks would satisfy this, Allows optimizing block traversal. For example, causal masks would satisfy this,
but prefix_lm + sliding window would not. Default: False.""" but prefix_lm + sliding window would not. Default: False."""
# pyrefly: ignore # invalid-annotation
WRITE_DQ: NotRequired[bool] WRITE_DQ: NotRequired[bool]
"""Controls whether gradient scatters are done in the DQ iteration loop of the backward pass. """Controls whether gradient scatters are done in the DQ iteration loop of the backward pass.
Setting this to False will force this to happen in the DK loop which depending on your Setting this to False will force this to happen in the DK loop which depending on your
specific score_mod and mask_mod might be faster. Default: True.""" specific score_mod and mask_mod might be faster. Default: True."""
# pyrefly: ignore # invalid-annotation
FORCE_USE_FLEX_ATTENTION: NotRequired[bool] FORCE_USE_FLEX_ATTENTION: NotRequired[bool]
"""If True, forces the use of the flex attention kernel instead of potentially using """If True, forces the use of the flex attention kernel instead of potentially using
the more optimized flex-decoding kernel for short sequences. This can be a helpful the more optimized flex-decoding kernel for short sequences. This can be a helpful
option for debugging. Default: False.""" option for debugging. Default: False."""
# pyrefly: ignore # invalid-annotation
USE_TMA: NotRequired[bool] USE_TMA: NotRequired[bool]
"""Whether to use Tensor Memory Accelerator (TMA) on supported hardware. """Whether to use Tensor Memory Accelerator (TMA) on supported hardware.
This is experimental and may not work on all hardware, currently specific This is experimental and may not work on all hardware, currently specific
to NVIDIA GPUs Hopper+. Default: False.""" to NVIDIA GPUs Hopper+. Default: False."""
# ROCm-specific options # ROCm-specific options
# pyrefly: ignore # invalid-annotation
kpack: NotRequired[int] kpack: NotRequired[int]
"""ROCm-specific kernel packing parameter.""" """ROCm-specific kernel packing parameter."""
# pyrefly: ignore # invalid-annotation
matrix_instr_nonkdim: NotRequired[int] matrix_instr_nonkdim: NotRequired[int]
"""ROCm-specific matrix instruction non-K dimension.""" """ROCm-specific matrix instruction non-K dimension."""
# pyrefly: ignore # invalid-annotation
waves_per_eu: NotRequired[int] waves_per_eu: NotRequired[int]
"""ROCm-specific waves per execution unit.""" """ROCm-specific waves per execution unit."""
@ -581,6 +599,7 @@ class BlockMask:
block_size = (self.BLOCK_SIZE,) # type: ignore[assignment] block_size = (self.BLOCK_SIZE,) # type: ignore[assignment]
seq_lengths = (self.seq_lengths,) # type: ignore[assignment] seq_lengths = (self.seq_lengths,) # type: ignore[assignment]
# pyrefly: ignore # not-iterable
return ( return (
*seq_lengths, *seq_lengths,
self.kv_num_blocks, self.kv_num_blocks,
@ -753,6 +772,7 @@ class BlockMask:
partial_dense = _ordered_to_dense(self.kv_num_blocks, self.kv_indices) partial_dense = _ordered_to_dense(self.kv_num_blocks, self.kv_indices)
if self.full_kv_num_blocks is not None: if self.full_kv_num_blocks is not None:
assert self.full_kv_indices is not None assert self.full_kv_indices is not None
# pyrefly: ignore # bad-return
return partial_dense | _ordered_to_dense( return partial_dense | _ordered_to_dense(
self.full_kv_num_blocks, self.full_kv_indices self.full_kv_num_blocks, self.full_kv_indices
) )

View File

@ -78,6 +78,7 @@ class ModuleWrapper(nn.Module):
# nn.Module defines training as a boolean # nn.Module defines training as a boolean
@property # type: ignore[override] @property # type: ignore[override]
# pyrefly: ignore # bad-override
def training(self): def training(self):
return self.cpp_module.training return self.cpp_module.training

View File

@ -1266,6 +1266,7 @@ def adaptive_max_pool2d_with_indices(
output_size, output_size,
return_indices=return_indices, return_indices=return_indices,
) )
# pyrefly: ignore # bad-argument-type
output_size = _list_with_default(output_size, input.size()) output_size = _list_with_default(output_size, input.size())
return torch._C._nn.adaptive_max_pool2d(input, output_size) return torch._C._nn.adaptive_max_pool2d(input, output_size)
@ -1323,6 +1324,7 @@ def adaptive_max_pool3d_with_indices(
output_size, output_size,
return_indices=return_indices, return_indices=return_indices,
) )
# pyrefly: ignore # bad-argument-type
output_size = _list_with_default(output_size, input.size()) output_size = _list_with_default(output_size, input.size())
return torch._C._nn.adaptive_max_pool3d(input, output_size) return torch._C._nn.adaptive_max_pool3d(input, output_size)
@ -1381,6 +1383,7 @@ def adaptive_avg_pool2d(input: Tensor, output_size: BroadcastingList2[int]) -> T
""" """
if has_torch_function_unary(input): if has_torch_function_unary(input):
return handle_torch_function(adaptive_avg_pool2d, (input,), input, output_size) return handle_torch_function(adaptive_avg_pool2d, (input,), input, output_size)
# pyrefly: ignore # bad-argument-type
_output_size = _list_with_default(output_size, input.size()) _output_size = _list_with_default(output_size, input.size())
return torch._C._nn.adaptive_avg_pool2d(input, _output_size) return torch._C._nn.adaptive_avg_pool2d(input, _output_size)
@ -1396,6 +1399,7 @@ def adaptive_avg_pool3d(input: Tensor, output_size: BroadcastingList3[int]) -> T
""" """
if has_torch_function_unary(input): if has_torch_function_unary(input):
return handle_torch_function(adaptive_avg_pool3d, (input,), input, output_size) return handle_torch_function(adaptive_avg_pool3d, (input,), input, output_size)
# pyrefly: ignore # bad-argument-type
_output_size = _list_with_default(output_size, input.size()) _output_size = _list_with_default(output_size, input.size())
return torch._C._nn.adaptive_avg_pool3d(input, _output_size) return torch._C._nn.adaptive_avg_pool3d(input, _output_size)
@ -2431,6 +2435,7 @@ def _no_grad_embedding_renorm_(
input: Tensor, input: Tensor,
max_norm: float, max_norm: float,
norm_type: float, norm_type: float,
# pyrefly: ignore # bad-return
) -> tuple[Tensor, Tensor]: ) -> tuple[Tensor, Tensor]:
torch.embedding_renorm_(weight.detach(), input, max_norm, norm_type) torch.embedding_renorm_(weight.detach(), input, max_norm, norm_type)
@ -2684,6 +2689,7 @@ def embedding_bag(
if not torch.jit.is_scripting() and input.dim() == 2 and input.is_nested: if not torch.jit.is_scripting() and input.dim() == 2 and input.is_nested:
include_last_offset = True include_last_offset = True
# pyrefly: ignore # missing-attribute
offsets = input.offsets() offsets = input.offsets()
input = input.values().reshape(-1) input = input.values().reshape(-1)
if per_sample_weights is not None: if per_sample_weights is not None:
@ -2818,6 +2824,7 @@ def batch_norm(
eps=eps, eps=eps,
) )
if training: if training:
# pyrefly: ignore # bad-argument-type
_verify_batch_size(input.size()) _verify_batch_size(input.size())
return torch.batch_norm( return torch.batch_norm(
@ -2873,6 +2880,7 @@ def instance_norm(
eps=eps, eps=eps,
) )
if use_input_stats: if use_input_stats:
# pyrefly: ignore # bad-argument-type
_verify_spatial_size(input.size()) _verify_spatial_size(input.size())
return torch.instance_norm( return torch.instance_norm(
input, input,
@ -2998,11 +3006,13 @@ def local_response_norm(
div = input.mul(input) div = input.mul(input)
if dim == 3: if dim == 3:
div = div.unsqueeze(1) div = div.unsqueeze(1)
# pyrefly: ignore # bad-argument-type
div = pad(div, (0, 0, size // 2, (size - 1) // 2)) div = pad(div, (0, 0, size // 2, (size - 1) // 2))
div = avg_pool2d(div, (size, 1), stride=1).squeeze(1) div = avg_pool2d(div, (size, 1), stride=1).squeeze(1)
else: else:
sizes = input.size() sizes = input.size()
div = div.view(sizes[0], 1, sizes[1], sizes[2], -1) div = div.view(sizes[0], 1, sizes[1], sizes[2], -1)
# pyrefly: ignore # bad-argument-type
div = pad(div, (0, 0, 0, 0, size // 2, (size - 1) // 2)) div = pad(div, (0, 0, 0, 0, size // 2, (size - 1) // 2))
div = avg_pool3d(div, (size, 1, 1), stride=1).squeeze(1) div = avg_pool3d(div, (size, 1, 1), stride=1).squeeze(1)
div = div.view(sizes) div = div.view(sizes)
@ -3151,7 +3161,12 @@ def nll_loss(
if size_average is not None or reduce is not None: if size_average is not None or reduce is not None:
reduction = _Reduction.legacy_get_string(size_average, reduce) reduction = _Reduction.legacy_get_string(size_average, reduce)
return torch._C._nn.nll_loss_nd( return torch._C._nn.nll_loss_nd(
input, target, weight, _Reduction.get_enum(reduction), ignore_index input,
target,
weight,
# pyrefly: ignore # bad-argument-type
_Reduction.get_enum(reduction),
ignore_index,
) )
@ -3296,6 +3311,7 @@ def gaussian_nll_loss(
var.clamp_(min=eps) var.clamp_(min=eps)
# Calculate the loss # Calculate the loss
# pyrefly: ignore # unsupported-operation
loss = 0.5 * (torch.log(var) + (input - target) ** 2 / var) loss = 0.5 * (torch.log(var) + (input - target) ** 2 / var)
if full: if full:
loss += 0.5 * math.log(2 * math.pi) loss += 0.5 * math.log(2 * math.pi)
@ -3471,6 +3487,7 @@ def cross_entropy(
input, input,
target, target,
weight, weight,
# pyrefly: ignore # bad-argument-type
_Reduction.get_enum(reduction), _Reduction.get_enum(reduction),
ignore_index, ignore_index,
label_smoothing, label_smoothing,
@ -3535,6 +3552,7 @@ def binary_cross_entropy(
new_size = _infer_size(target.size(), weight.size()) new_size = _infer_size(target.size(), weight.size())
weight = weight.expand(new_size) weight = weight.expand(new_size)
# pyrefly: ignore # bad-argument-type
return torch._C._nn.binary_cross_entropy(input, target, weight, reduction_enum) return torch._C._nn.binary_cross_entropy(input, target, weight, reduction_enum)
@ -3663,11 +3681,18 @@ def smooth_l1_loss(
if beta == 0.0: if beta == 0.0:
return torch._C._nn.l1_loss( return torch._C._nn.l1_loss(
expanded_input, expanded_target, _Reduction.get_enum(reduction) expanded_input,
expanded_target,
# pyrefly: ignore # bad-argument-type
_Reduction.get_enum(reduction),
) )
else: else:
return torch._C._nn.smooth_l1_loss( return torch._C._nn.smooth_l1_loss(
expanded_input, expanded_target, _Reduction.get_enum(reduction), beta expanded_input,
expanded_target,
# pyrefly: ignore # bad-argument-type
_Reduction.get_enum(reduction),
beta,
) )
@ -3725,7 +3750,11 @@ def huber_loss(
if weight is None: if weight is None:
# Use the optimized C++ backend for standard Huber loss # Use the optimized C++ backend for standard Huber loss
return torch._C._nn.huber_loss( return torch._C._nn.huber_loss(
expanded_input, expanded_target, _Reduction.get_enum(reduction), delta expanded_input,
expanded_target,
# pyrefly: ignore # bad-argument-type
_Reduction.get_enum(reduction),
delta,
) )
else: else:
if weight.size() != input.size(): if weight.size() != input.size():
@ -3733,7 +3762,11 @@ def huber_loss(
# Calculate the unweighted loss first # Calculate the unweighted loss first
unweighted_loss = torch._C._nn.huber_loss( unweighted_loss = torch._C._nn.huber_loss(
expanded_input, expanded_target, _Reduction.get_enum("none"), delta expanded_input,
expanded_target,
# pyrefly: ignore # bad-argument-type
_Reduction.get_enum("none"),
delta,
) )
# Apply weight to the unweighted loss # Apply weight to the unweighted loss
@ -3820,7 +3853,10 @@ def l1_loss(
) )
else: else:
return torch._C._nn.l1_loss( return torch._C._nn.l1_loss(
expanded_input, expanded_target, _Reduction.get_enum(reduction) expanded_input,
expanded_target,
# pyrefly: ignore # bad-argument-type
_Reduction.get_enum(reduction),
) )
@ -3895,7 +3931,10 @@ def mse_loss(
) )
else: else:
return torch._C._nn.mse_loss( return torch._C._nn.mse_loss(
expanded_input, expanded_target, _Reduction.get_enum(reduction) expanded_input,
expanded_target,
# pyrefly: ignore # bad-argument-type
_Reduction.get_enum(reduction),
) )
@ -4032,6 +4071,7 @@ def multilabel_margin_loss(
reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
else: else:
reduction_enum = _Reduction.get_enum(reduction) reduction_enum = _Reduction.get_enum(reduction)
# pyrefly: ignore # bad-argument-type
return torch._C._nn.multilabel_margin_loss(input, target, reduction_enum) return torch._C._nn.multilabel_margin_loss(input, target, reduction_enum)
@ -4073,6 +4113,7 @@ def soft_margin_loss(
reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
else: else:
reduction_enum = _Reduction.get_enum(reduction) reduction_enum = _Reduction.get_enum(reduction)
# pyrefly: ignore # bad-argument-type
return torch._C._nn.soft_margin_loss(input, target, reduction_enum) return torch._C._nn.soft_margin_loss(input, target, reduction_enum)
@ -4237,7 +4278,13 @@ def multi_margin_loss(
raise ValueError("weight must be one-dimensional") raise ValueError("weight must be one-dimensional")
return torch._C._nn.multi_margin_loss( return torch._C._nn.multi_margin_loss(
input, target, p, margin, weight, reduction_enum input,
target,
p,
margin,
weight,
# pyrefly: ignore # bad-argument-type
reduction_enum,
) )
@ -4383,6 +4430,7 @@ def upsample( # noqa: F811
scale_factor: Optional[float] = None, scale_factor: Optional[float] = None,
mode: str = "nearest", mode: str = "nearest",
align_corners: Optional[bool] = None, align_corners: Optional[bool] = None,
# pyrefly: ignore # bad-return
) -> Tensor: # noqa: B950 ) -> Tensor: # noqa: B950
pass pass
@ -4394,6 +4442,7 @@ def upsample( # noqa: F811
scale_factor: Optional[float] = None, scale_factor: Optional[float] = None,
mode: str = "nearest", mode: str = "nearest",
align_corners: Optional[bool] = None, align_corners: Optional[bool] = None,
# pyrefly: ignore # bad-return
) -> Tensor: # noqa: B950 ) -> Tensor: # noqa: B950
pass pass
@ -4496,6 +4545,7 @@ def interpolate( # noqa: F811
align_corners: Optional[bool] = None, align_corners: Optional[bool] = None,
recompute_scale_factor: Optional[bool] = None, recompute_scale_factor: Optional[bool] = None,
antialias: bool = False, antialias: bool = False,
# pyrefly: ignore # bad-return
) -> Tensor: # noqa: B950 ) -> Tensor: # noqa: B950
pass pass
@ -4509,6 +4559,7 @@ def interpolate( # noqa: F811
align_corners: Optional[bool] = None, align_corners: Optional[bool] = None,
recompute_scale_factor: Optional[bool] = None, recompute_scale_factor: Optional[bool] = None,
antialias: bool = False, antialias: bool = False,
# pyrefly: ignore # bad-return
) -> Tensor: # noqa: B950 ) -> Tensor: # noqa: B950
pass pass
@ -4522,6 +4573,7 @@ def interpolate( # noqa: F811
align_corners: Optional[bool] = None, align_corners: Optional[bool] = None,
recompute_scale_factor: Optional[bool] = None, recompute_scale_factor: Optional[bool] = None,
antialias: bool = False, antialias: bool = False,
# pyrefly: ignore # bad-return
) -> Tensor: # noqa: B950 ) -> Tensor: # noqa: B950
pass pass
@ -4535,6 +4587,7 @@ def interpolate( # noqa: F811
align_corners: Optional[bool] = None, align_corners: Optional[bool] = None,
recompute_scale_factor: Optional[bool] = None, recompute_scale_factor: Optional[bool] = None,
antialias: bool = False, antialias: bool = False,
# pyrefly: ignore # bad-return
) -> Tensor: ) -> Tensor:
pass pass
@ -4709,6 +4762,7 @@ def interpolate( # noqa: F811
( (
torch.floor( torch.floor(
( (
# pyrefly: ignore # missing-attribute
input.size(i + 2).float() input.size(i + 2).float()
* torch.tensor(scale_factors[i], dtype=torch.float32) * torch.tensor(scale_factors[i], dtype=torch.float32)
).float() ).float()
@ -4733,21 +4787,28 @@ def interpolate( # noqa: F811
) )
if input.dim() == 3 and mode == "nearest": if input.dim() == 3 and mode == "nearest":
# pyrefly: ignore # bad-argument-type
return torch._C._nn.upsample_nearest1d(input, output_size, scale_factors) return torch._C._nn.upsample_nearest1d(input, output_size, scale_factors)
if input.dim() == 4 and mode == "nearest": if input.dim() == 4 and mode == "nearest":
# pyrefly: ignore # bad-argument-type
return torch._C._nn.upsample_nearest2d(input, output_size, scale_factors) return torch._C._nn.upsample_nearest2d(input, output_size, scale_factors)
if input.dim() == 5 and mode == "nearest": if input.dim() == 5 and mode == "nearest":
# pyrefly: ignore # bad-argument-type
return torch._C._nn.upsample_nearest3d(input, output_size, scale_factors) return torch._C._nn.upsample_nearest3d(input, output_size, scale_factors)
if input.dim() == 3 and mode == "nearest-exact": if input.dim() == 3 and mode == "nearest-exact":
# pyrefly: ignore # bad-argument-type
return torch._C._nn._upsample_nearest_exact1d(input, output_size, scale_factors) return torch._C._nn._upsample_nearest_exact1d(input, output_size, scale_factors)
if input.dim() == 4 and mode == "nearest-exact": if input.dim() == 4 and mode == "nearest-exact":
# pyrefly: ignore # bad-argument-type
return torch._C._nn._upsample_nearest_exact2d(input, output_size, scale_factors) return torch._C._nn._upsample_nearest_exact2d(input, output_size, scale_factors)
if input.dim() == 5 and mode == "nearest-exact": if input.dim() == 5 and mode == "nearest-exact":
# pyrefly: ignore # bad-argument-type
return torch._C._nn._upsample_nearest_exact3d(input, output_size, scale_factors) return torch._C._nn._upsample_nearest_exact3d(input, output_size, scale_factors)
if input.dim() == 3 and mode == "area": if input.dim() == 3 and mode == "area":
assert output_size is not None assert output_size is not None
# pyrefly: ignore # bad-argument-type
return adaptive_avg_pool1d(input, output_size) return adaptive_avg_pool1d(input, output_size)
if input.dim() == 4 and mode == "area": if input.dim() == 4 and mode == "area":
assert output_size is not None assert output_size is not None
@ -4759,13 +4820,21 @@ def interpolate( # noqa: F811
if input.dim() == 3 and mode == "linear": if input.dim() == 3 and mode == "linear":
assert align_corners is not None assert align_corners is not None
return torch._C._nn.upsample_linear1d( return torch._C._nn.upsample_linear1d(
input, output_size, align_corners, scale_factors input,
# pyrefly: ignore # bad-argument-type
output_size,
align_corners,
scale_factors,
) )
if input.dim() == 4 and mode == "bilinear": if input.dim() == 4 and mode == "bilinear":
assert align_corners is not None assert align_corners is not None
if antialias: if antialias:
return torch._C._nn._upsample_bilinear2d_aa( return torch._C._nn._upsample_bilinear2d_aa(
input, output_size, align_corners, scale_factors input,
# pyrefly: ignore # bad-argument-type
output_size,
align_corners,
scale_factors,
) )
# Two levels are necessary to prevent TorchScript from touching # Two levels are necessary to prevent TorchScript from touching
# are_deterministic_algorithms_enabled. # are_deterministic_algorithms_enabled.
@ -4778,7 +4847,11 @@ def interpolate( # noqa: F811
"torch._decomp.decompositions" "torch._decomp.decompositions"
)._upsample_linear_vec(input, output_size, align_corners, scale_factors) )._upsample_linear_vec(input, output_size, align_corners, scale_factors)
return torch._C._nn.upsample_bilinear2d( return torch._C._nn.upsample_bilinear2d(
input, output_size, align_corners, scale_factors input,
# pyrefly: ignore # bad-argument-type
output_size,
align_corners,
scale_factors,
) )
if input.dim() == 5 and mode == "trilinear": if input.dim() == 5 and mode == "trilinear":
assert align_corners is not None assert align_corners is not None
@ -4793,16 +4866,28 @@ def interpolate( # noqa: F811
"torch._decomp.decompositions" "torch._decomp.decompositions"
)._upsample_linear_vec(input, output_size, align_corners, scale_factors) )._upsample_linear_vec(input, output_size, align_corners, scale_factors)
return torch._C._nn.upsample_trilinear3d( return torch._C._nn.upsample_trilinear3d(
input, output_size, align_corners, scale_factors input,
# pyrefly: ignore # bad-argument-type
output_size,
align_corners,
scale_factors,
) )
if input.dim() == 4 and mode == "bicubic": if input.dim() == 4 and mode == "bicubic":
assert align_corners is not None assert align_corners is not None
if antialias: if antialias:
return torch._C._nn._upsample_bicubic2d_aa( return torch._C._nn._upsample_bicubic2d_aa(
input, output_size, align_corners, scale_factors input,
# pyrefly: ignore # bad-argument-type
output_size,
align_corners,
scale_factors,
) )
return torch._C._nn.upsample_bicubic2d( return torch._C._nn.upsample_bicubic2d(
input, output_size, align_corners, scale_factors input,
# pyrefly: ignore # bad-argument-type
output_size,
align_corners,
scale_factors,
) )
if input.dim() == 3 and mode == "bilinear": if input.dim() == 3 and mode == "bilinear":
@ -4834,6 +4919,7 @@ def upsample_nearest( # noqa: F811
input: Tensor, input: Tensor,
size: Optional[int] = None, size: Optional[int] = None,
scale_factor: Optional[float] = None, scale_factor: Optional[float] = None,
# pyrefly: ignore # bad-return
) -> Tensor: ) -> Tensor:
pass pass
@ -4843,6 +4929,7 @@ def upsample_nearest( # noqa: F811
input: Tensor, input: Tensor,
size: Optional[list[int]] = None, size: Optional[list[int]] = None,
scale_factor: Optional[float] = None, scale_factor: Optional[float] = None,
# pyrefly: ignore # bad-return
) -> Tensor: ) -> Tensor:
pass pass
@ -4884,6 +4971,7 @@ def upsample_bilinear( # noqa: F811
input: Tensor, input: Tensor,
size: Optional[int] = None, size: Optional[int] = None,
scale_factor: Optional[float] = None, scale_factor: Optional[float] = None,
# pyrefly: ignore # bad-return
) -> Tensor: ) -> Tensor:
pass pass
@ -4893,6 +4981,7 @@ def upsample_bilinear( # noqa: F811
input: Tensor, input: Tensor,
size: Optional[list[int]] = None, size: Optional[list[int]] = None,
scale_factor: Optional[float] = None, scale_factor: Optional[float] = None,
# pyrefly: ignore # bad-return
) -> Tensor: ) -> Tensor:
pass pass
@ -4902,6 +4991,7 @@ def upsample_bilinear( # noqa: F811
input: Tensor, input: Tensor,
size: Optional[int] = None, size: Optional[int] = None,
scale_factor: Optional[list[float]] = None, scale_factor: Optional[list[float]] = None,
# pyrefly: ignore # bad-return
) -> Tensor: ) -> Tensor:
pass pass
@ -4911,6 +5001,7 @@ def upsample_bilinear( # noqa: F811
input: Tensor, input: Tensor,
size: Optional[list[int]] = None, size: Optional[list[int]] = None,
scale_factor: Optional[list[float]] = None, scale_factor: Optional[list[float]] = None,
# pyrefly: ignore # bad-return
) -> Tensor: ) -> Tensor:
pass pass
@ -5717,6 +5808,7 @@ def _in_projection_packed(
.squeeze(-2) .squeeze(-2)
.contiguous() .contiguous()
) )
# pyrefly: ignore # bad-return
return proj[0], proj[1], proj[2] return proj[0], proj[1], proj[2]
else: else:
# encoder-decoder attention # encoder-decoder attention
@ -5735,6 +5827,7 @@ def _in_projection_packed(
.squeeze(-2) .squeeze(-2)
.contiguous() .contiguous()
) )
# pyrefly: ignore # bad-return
return (q_proj, kv_proj[0], kv_proj[1]) return (q_proj, kv_proj[0], kv_proj[1])
else: else:
w_q, w_k, w_v = w.chunk(3) w_q, w_k, w_v = w.chunk(3)
@ -5742,6 +5835,7 @@ def _in_projection_packed(
b_q = b_k = b_v = None b_q = b_k = b_v = None
else: else:
b_q, b_k, b_v = b.chunk(3) b_q, b_k, b_v = b.chunk(3)
# pyrefly: ignore # bad-return
return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v) return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
@ -6372,8 +6466,10 @@ def multi_head_attention_forward(
k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
if attn_mask is not None: if attn_mask is not None:
# pyrefly: ignore # bad-argument-type
attn_mask = pad(attn_mask, (0, 1)) attn_mask = pad(attn_mask, (0, 1))
if key_padding_mask is not None: if key_padding_mask is not None:
# pyrefly: ignore # bad-argument-type
key_padding_mask = pad(key_padding_mask, (0, 1)) key_padding_mask = pad(key_padding_mask, (0, 1))
else: else:
assert bias_k is None assert bias_k is None
@ -6382,8 +6478,10 @@ def multi_head_attention_forward(
# #
# reshape q, k, v for multihead attention and make them batch first # reshape q, k, v for multihead attention and make them batch first
# #
# pyrefly: ignore # no-matching-overload
q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
if static_k is None: if static_k is None:
# pyrefly: ignore # no-matching-overload
k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1) k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
else: else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed # TODO finish disentangling control flow so we don't do in-projections when statics are passed
@ -6395,6 +6493,7 @@ def multi_head_attention_forward(
) )
k = static_k k = static_k
if static_v is None: if static_v is None:
# pyrefly: ignore # no-matching-overload
v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1) v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
else: else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed # TODO finish disentangling control flow so we don't do in-projections when statics are passed
@ -6410,14 +6509,20 @@ def multi_head_attention_forward(
if add_zero_attn: if add_zero_attn:
zero_attn_shape = (bsz * num_heads, 1, head_dim) zero_attn_shape = (bsz * num_heads, 1, head_dim)
k = torch.cat( k = torch.cat(
[k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1 # pyrefly: ignore # no-matching-overload
[k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)],
dim=1,
) )
v = torch.cat( v = torch.cat(
[v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1 # pyrefly: ignore # no-matching-overload
[v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)],
dim=1,
) )
if attn_mask is not None: if attn_mask is not None:
# pyrefly: ignore # bad-argument-type
attn_mask = pad(attn_mask, (0, 1)) attn_mask = pad(attn_mask, (0, 1))
if key_padding_mask is not None: if key_padding_mask is not None:
# pyrefly: ignore # bad-argument-type
key_padding_mask = pad(key_padding_mask, (0, 1)) key_padding_mask = pad(key_padding_mask, (0, 1))
# update source sequence length after adjustments # update source sequence length after adjustments
@ -6467,6 +6572,7 @@ def multi_head_attention_forward(
attn_output = torch.bmm(attn_output_weights, v) attn_output = torch.bmm(attn_output_weights, v)
attn_output = ( attn_output = (
# pyrefly: ignore # no-matching-overload
attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim) attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
) )
attn_output = linear(attn_output, out_proj_weight, out_proj_bias) attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
@ -6493,13 +6599,16 @@ def multi_head_attention_forward(
attn_mask = attn_mask.view(bsz, num_heads, -1, src_len) attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)
q = q.view(bsz, num_heads, tgt_len, head_dim) q = q.view(bsz, num_heads, tgt_len, head_dim)
# pyrefly: ignore # no-matching-overload
k = k.view(bsz, num_heads, src_len, head_dim) k = k.view(bsz, num_heads, src_len, head_dim)
# pyrefly: ignore # no-matching-overload
v = v.view(bsz, num_heads, src_len, head_dim) v = v.view(bsz, num_heads, src_len, head_dim)
attn_output = scaled_dot_product_attention( attn_output = scaled_dot_product_attention(
q, k, v, attn_mask, dropout_p, is_causal q, k, v, attn_mask, dropout_p, is_causal
) )
attn_output = ( attn_output = (
# pyrefly: ignore # no-matching-overload
attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim) attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
) )

View File

@ -500,6 +500,7 @@ def xavier_normal_(
def _calculate_correct_fan(tensor: Tensor, mode: _FanMode) -> int: def _calculate_correct_fan(tensor: Tensor, mode: _FanMode) -> int:
# pyrefly: ignore # bad-assignment
mode = mode.lower() mode = mode.lower()
valid_modes = ["fan_in", "fan_out"] valid_modes = ["fan_in", "fan_out"]
if mode not in valid_modes: if mode not in valid_modes:

View File

@ -6,6 +6,7 @@ from torch.autograd.function import Function
class SyncBatchNorm(Function): class SyncBatchNorm(Function):
@staticmethod @staticmethod
# pyrefly: ignore # bad-override
def forward( def forward(
self, self,
input, input,
@ -210,6 +211,7 @@ class SyncBatchNorm(Function):
class CrossMapLRN2d(Function): class CrossMapLRN2d(Function):
@staticmethod @staticmethod
# pyrefly: ignore # bad-override
def forward(ctx, input, size, alpha=1e-4, beta=0.75, k=1): def forward(ctx, input, size, alpha=1e-4, beta=0.75, k=1):
ctx.size = size ctx.size = size
ctx.alpha = alpha ctx.alpha = alpha
@ -265,6 +267,7 @@ class CrossMapLRN2d(Function):
return output return output
@staticmethod @staticmethod
# pyrefly: ignore # bad-override
def backward(ctx, grad_output): def backward(ctx, grad_output):
input, output = ctx.saved_tensors input, output = ctx.saved_tensors
grad_input = grad_output.new() grad_input = grad_output.new()
@ -306,6 +309,7 @@ class CrossMapLRN2d(Function):
class BackwardHookFunction(torch.autograd.Function): class BackwardHookFunction(torch.autograd.Function):
@staticmethod @staticmethod
# pyrefly: ignore # bad-override
def forward(ctx, *args): def forward(ctx, *args):
ctx.mark_non_differentiable(*[arg for arg in args if not arg.requires_grad]) ctx.mark_non_differentiable(*[arg for arg in args if not arg.requires_grad])
return args return args

View File

@ -72,6 +72,7 @@ class _NormBase(Module):
torch.tensor( torch.tensor(
0, 0,
dtype=torch.long, dtype=torch.long,
# pyrefly: ignore # bad-argument-type
**{k: v for k, v in factory_kwargs.items() if k != "dtype"}, **{k: v for k, v in factory_kwargs.items() if k != "dtype"},
), ),
) )
@ -221,6 +222,7 @@ class _LazyNormBase(LazyModuleMixin, _NormBase):
dtype=None, dtype=None,
) -> None: ) -> None:
factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {"device": device, "dtype": dtype}
# pyrefly: ignore # bad-argument-type
super().__init__( super().__init__(
# affine and track_running_stats are hardcoded to False to # affine and track_running_stats are hardcoded to False to
# avoid creating tensors that will soon be overwritten. # avoid creating tensors that will soon be overwritten.
@ -234,22 +236,29 @@ class _LazyNormBase(LazyModuleMixin, _NormBase):
self.affine = affine self.affine = affine
self.track_running_stats = track_running_stats self.track_running_stats = track_running_stats
if self.affine: if self.affine:
# pyrefly: ignore # bad-argument-type
self.weight = UninitializedParameter(**factory_kwargs) self.weight = UninitializedParameter(**factory_kwargs)
# pyrefly: ignore # bad-argument-type
self.bias = UninitializedParameter(**factory_kwargs) self.bias = UninitializedParameter(**factory_kwargs)
if self.track_running_stats: if self.track_running_stats:
# pyrefly: ignore # bad-argument-type
self.running_mean = UninitializedBuffer(**factory_kwargs) self.running_mean = UninitializedBuffer(**factory_kwargs)
# pyrefly: ignore # bad-argument-type
self.running_var = UninitializedBuffer(**factory_kwargs) self.running_var = UninitializedBuffer(**factory_kwargs)
self.num_batches_tracked = torch.tensor( self.num_batches_tracked = torch.tensor(
0, 0,
dtype=torch.long, dtype=torch.long,
# pyrefly: ignore # bad-argument-type
**{k: v for k, v in factory_kwargs.items() if k != "dtype"}, **{k: v for k, v in factory_kwargs.items() if k != "dtype"},
) )
def reset_parameters(self) -> None: def reset_parameters(self) -> None:
# pyrefly: ignore # bad-argument-type
if not self.has_uninitialized_params() and self.num_features != 0: if not self.has_uninitialized_params() and self.num_features != 0:
super().reset_parameters() super().reset_parameters()
def initialize_parameters(self, input) -> None: # type: ignore[override] def initialize_parameters(self, input) -> None: # type: ignore[override]
# pyrefly: ignore # bad-argument-type
if self.has_uninitialized_params(): if self.has_uninitialized_params():
self.num_features = input.shape[1] self.num_features = input.shape[1]
if self.affine: if self.affine:

View File

@ -109,6 +109,7 @@ class Sequential(Module):
def __init__(self, *args: Module) -> None: ... def __init__(self, *args: Module) -> None: ...
@overload @overload
# pyrefly: ignore # inconsistent-overload
def __init__(self, arg: OrderedDict[str, Module]) -> None: ... def __init__(self, arg: OrderedDict[str, Module]) -> None: ...
def __init__(self, *args): def __init__(self, *args):
@ -472,6 +473,7 @@ class ModuleList(Module):
return self return self
def pop(self, key: Union[int, slice]) -> Module: def pop(self, key: Union[int, slice]) -> Module:
# pyrefly: ignore # index-error
v = self[key] v = self[key]
del self[key] del self[key]
return v return v
@ -623,9 +625,11 @@ class ModuleDict(Module):
"ModuleDict update sequence element " "ModuleDict update sequence element "
"#" + str(j) + " should be Iterable; is" + type(m).__name__ "#" + str(j) + " should be Iterable; is" + type(m).__name__
) )
# pyrefly: ignore # bad-argument-type
if not len(m) == 2: if not len(m) == 2:
raise ValueError( raise ValueError(
"ModuleDict update sequence element " "ModuleDict update sequence element "
# pyrefly: ignore # bad-argument-type
"#" + str(j) + " has length " + str(len(m)) + "; 2 is required" "#" + str(j) + " has length " + str(len(m)) + "; 2 is required"
) )
# modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)] # modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)]
@ -684,6 +688,7 @@ class ParameterList(Module):
def __getitem__(self, idx: int) -> Any: ... def __getitem__(self, idx: int) -> Any: ...
@overload @overload
# pyrefly: ignore # inconsistent-overload
def __getitem__(self: T, idx: slice) -> T: ... def __getitem__(self: T, idx: slice) -> T: ...
def __getitem__(self, idx): def __getitem__(self, idx):
@ -769,9 +774,11 @@ class ParameterList(Module):
size_str, size_str,
device_str, device_str,
) )
# pyrefly: ignore # bad-argument-type
child_lines.append(" (" + str(k) + "): " + parastr) child_lines.append(" (" + str(k) + "): " + parastr)
else: else:
child_lines.append( child_lines.append(
# pyrefly: ignore # bad-argument-type
" (" + str(k) + "): Object of type: " + type(p).__name__ " (" + str(k) + "): Object of type: " + type(p).__name__
) )
@ -979,9 +986,11 @@ class ParameterDict(Module):
"ParameterDict update sequence element " "ParameterDict update sequence element "
"#" + str(j) + " should be Iterable; is" + type(p).__name__ "#" + str(j) + " should be Iterable; is" + type(p).__name__
) )
# pyrefly: ignore # bad-argument-type
if not len(p) == 2: if not len(p) == 2:
raise ValueError( raise ValueError(
"ParameterDict update sequence element " "ParameterDict update sequence element "
# pyrefly: ignore # bad-argument-type
"#" + str(j) + " has length " + str(len(p)) + "; 2 is required" "#" + str(j) + " has length " + str(len(p)) + "; 2 is required"
) )
# parameters as length-2 list too cumbersome to type, see ModuleDict.update comment # parameters as length-2 list too cumbersome to type, see ModuleDict.update comment
@ -1002,9 +1011,11 @@ class ParameterDict(Module):
size_str, size_str,
device_str, device_str,
) )
# pyrefly: ignore # bad-argument-type
child_lines.append(" (" + str(k) + "): " + parastr) child_lines.append(" (" + str(k) + "): " + parastr)
else: else:
child_lines.append( child_lines.append(
# pyrefly: ignore # bad-argument-type
" (" + str(k) + "): Object of type: " + type(p).__name__ " (" + str(k) + "): Object of type: " + type(p).__name__
) )
tmpstr = "\n".join(child_lines) tmpstr = "\n".join(child_lines)

View File

@ -363,6 +363,7 @@ class Conv1d(_ConvNd):
self.dilation, self.dilation,
self.groups, self.groups,
) )
# pyrefly: ignore # no-matching-overload
return F.conv1d( return F.conv1d(
input, weight, bias, self.stride, self.padding, self.dilation, self.groups input, weight, bias, self.stride, self.padding, self.dilation, self.groups
) )
@ -540,6 +541,7 @@ class Conv2d(_ConvNd):
self.dilation, self.dilation,
self.groups, self.groups,
) )
# pyrefly: ignore # no-matching-overload
return F.conv2d( return F.conv2d(
input, weight, bias, self.stride, self.padding, self.dilation, self.groups input, weight, bias, self.stride, self.padding, self.dilation, self.groups
) )
@ -709,6 +711,7 @@ class Conv3d(_ConvNd):
self.dilation, self.dilation,
self.groups, self.groups,
) )
# pyrefly: ignore # no-matching-overload
return F.conv3d( return F.conv3d(
input, weight, bias, self.stride, self.padding, self.dilation, self.groups input, weight, bias, self.stride, self.padding, self.dilation, self.groups
) )
@ -1494,6 +1497,7 @@ class LazyConv1d(_LazyConvXdMixin, Conv1d): # type: ignore[misc]
dtype=None, dtype=None,
) -> None: ) -> None:
factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {"device": device, "dtype": dtype}
# pyrefly: ignore # bad-argument-type
super().__init__( super().__init__(
0, 0,
0, 0,
@ -1508,9 +1512,11 @@ class LazyConv1d(_LazyConvXdMixin, Conv1d): # type: ignore[misc]
padding_mode, padding_mode,
**factory_kwargs, **factory_kwargs,
) )
# pyrefly: ignore # bad-override, bad-argument-type
self.weight = UninitializedParameter(**factory_kwargs) self.weight = UninitializedParameter(**factory_kwargs)
self.out_channels = out_channels self.out_channels = out_channels
if bias: if bias:
# pyrefly: ignore # bad-override, bad-argument-type
self.bias = UninitializedParameter(**factory_kwargs) self.bias = UninitializedParameter(**factory_kwargs)
def _get_num_spatial_dims(self) -> int: def _get_num_spatial_dims(self) -> int:
@ -1563,6 +1569,7 @@ class LazyConv2d(_LazyConvXdMixin, Conv2d): # type: ignore[misc]
dtype=None, dtype=None,
) -> None: ) -> None:
factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {"device": device, "dtype": dtype}
# pyrefly: ignore # bad-argument-type
super().__init__( super().__init__(
0, 0,
0, 0,
@ -1577,9 +1584,11 @@ class LazyConv2d(_LazyConvXdMixin, Conv2d): # type: ignore[misc]
padding_mode, padding_mode,
**factory_kwargs, **factory_kwargs,
) )
# pyrefly: ignore # bad-override, bad-argument-type
self.weight = UninitializedParameter(**factory_kwargs) self.weight = UninitializedParameter(**factory_kwargs)
self.out_channels = out_channels self.out_channels = out_channels
if bias: if bias:
# pyrefly: ignore # bad-override, bad-argument-type
self.bias = UninitializedParameter(**factory_kwargs) self.bias = UninitializedParameter(**factory_kwargs)
def _get_num_spatial_dims(self) -> int: def _get_num_spatial_dims(self) -> int:
@ -1633,6 +1642,7 @@ class LazyConv3d(_LazyConvXdMixin, Conv3d): # type: ignore[misc]
dtype=None, dtype=None,
) -> None: ) -> None:
factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {"device": device, "dtype": dtype}
# pyrefly: ignore # bad-argument-type
super().__init__( super().__init__(
0, 0,
0, 0,
@ -1647,9 +1657,11 @@ class LazyConv3d(_LazyConvXdMixin, Conv3d): # type: ignore[misc]
padding_mode, padding_mode,
**factory_kwargs, **factory_kwargs,
) )
# pyrefly: ignore # bad-override, bad-argument-type
self.weight = UninitializedParameter(**factory_kwargs) self.weight = UninitializedParameter(**factory_kwargs)
self.out_channels = out_channels self.out_channels = out_channels
if bias: if bias:
# pyrefly: ignore # bad-override, bad-argument-type
self.bias = UninitializedParameter(**factory_kwargs) self.bias = UninitializedParameter(**factory_kwargs)
def _get_num_spatial_dims(self) -> int: def _get_num_spatial_dims(self) -> int:
@ -1701,6 +1713,7 @@ class LazyConvTranspose1d(_LazyConvXdMixin, ConvTranspose1d): # type: ignore[mi
dtype=None, dtype=None,
) -> None: ) -> None:
factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {"device": device, "dtype": dtype}
# pyrefly: ignore # bad-argument-type
super().__init__( super().__init__(
0, 0,
0, 0,
@ -1716,9 +1729,11 @@ class LazyConvTranspose1d(_LazyConvXdMixin, ConvTranspose1d): # type: ignore[mi
padding_mode, padding_mode,
**factory_kwargs, **factory_kwargs,
) )
# pyrefly: ignore # bad-override, bad-argument-type
self.weight = UninitializedParameter(**factory_kwargs) self.weight = UninitializedParameter(**factory_kwargs)
self.out_channels = out_channels self.out_channels = out_channels
if bias: if bias:
# pyrefly: ignore # bad-override, bad-argument-type
self.bias = UninitializedParameter(**factory_kwargs) self.bias = UninitializedParameter(**factory_kwargs)
def _get_num_spatial_dims(self) -> int: def _get_num_spatial_dims(self) -> int:
@ -1770,6 +1785,7 @@ class LazyConvTranspose2d(_LazyConvXdMixin, ConvTranspose2d): # type: ignore[mi
dtype=None, dtype=None,
) -> None: ) -> None:
factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {"device": device, "dtype": dtype}
# pyrefly: ignore # bad-argument-type
super().__init__( super().__init__(
0, 0,
0, 0,
@ -1785,9 +1801,11 @@ class LazyConvTranspose2d(_LazyConvXdMixin, ConvTranspose2d): # type: ignore[mi
padding_mode, padding_mode,
**factory_kwargs, **factory_kwargs,
) )
# pyrefly: ignore # bad-override, bad-argument-type
self.weight = UninitializedParameter(**factory_kwargs) self.weight = UninitializedParameter(**factory_kwargs)
self.out_channels = out_channels self.out_channels = out_channels
if bias: if bias:
# pyrefly: ignore # bad-override, bad-argument-type
self.bias = UninitializedParameter(**factory_kwargs) self.bias = UninitializedParameter(**factory_kwargs)
def _get_num_spatial_dims(self) -> int: def _get_num_spatial_dims(self) -> int:
@ -1839,6 +1857,7 @@ class LazyConvTranspose3d(_LazyConvXdMixin, ConvTranspose3d): # type: ignore[mi
dtype=None, dtype=None,
) -> None: ) -> None:
factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {"device": device, "dtype": dtype}
# pyrefly: ignore # bad-argument-type
super().__init__( super().__init__(
0, 0,
0, 0,
@ -1854,9 +1873,11 @@ class LazyConvTranspose3d(_LazyConvXdMixin, ConvTranspose3d): # type: ignore[mi
padding_mode, padding_mode,
**factory_kwargs, **factory_kwargs,
) )
# pyrefly: ignore # bad-override, bad-argument-type
self.weight = UninitializedParameter(**factory_kwargs) self.weight = UninitializedParameter(**factory_kwargs)
self.out_channels = out_channels self.out_channels = out_channels
if bias: if bias:
# pyrefly: ignore # bad-override, bad-argument-type
self.bias = UninitializedParameter(**factory_kwargs) self.bias = UninitializedParameter(**factory_kwargs)
def _get_num_spatial_dims(self) -> int: def _get_num_spatial_dims(self) -> int:

View File

@ -172,7 +172,9 @@ class LazyModuleMixin:
def __init__(self: _LazyProtocol, *args, **kwargs): def __init__(self: _LazyProtocol, *args, **kwargs):
# Mypy doesn't like this super call in a mixin # Mypy doesn't like this super call in a mixin
super().__init__(*args, **kwargs) # type: ignore[misc] super().__init__(*args, **kwargs) # type: ignore[misc]
# pyrefly: ignore # read-only
self._load_hook = self._register_load_state_dict_pre_hook(self._lazy_load_hook) self._load_hook = self._register_load_state_dict_pre_hook(self._lazy_load_hook)
# pyrefly: ignore # read-only
self._initialize_hook = self.register_forward_pre_hook( self._initialize_hook = self.register_forward_pre_hook(
self._infer_parameters, with_kwargs=True self._infer_parameters, with_kwargs=True
) )

View File

@ -286,6 +286,7 @@ class LazyLinear(LazyModuleMixin, Linear):
""" """
cls_to_become = Linear # type: ignore[assignment] cls_to_become = Linear # type: ignore[assignment]
# pyrefly: ignore # bad-override
weight: UninitializedParameter weight: UninitializedParameter
bias: UninitializedParameter # type: ignore[assignment] bias: UninitializedParameter # type: ignore[assignment]
@ -295,16 +296,20 @@ class LazyLinear(LazyModuleMixin, Linear):
factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {"device": device, "dtype": dtype}
# bias is hardcoded to False to avoid creating tensor # bias is hardcoded to False to avoid creating tensor
# that will soon be overwritten. # that will soon be overwritten.
# pyrefly: ignore # bad-argument-type
super().__init__(0, 0, False) super().__init__(0, 0, False)
# pyrefly: ignore # bad-argument-type
self.weight = UninitializedParameter(**factory_kwargs) self.weight = UninitializedParameter(**factory_kwargs)
self.out_features = out_features self.out_features = out_features
if bias: if bias:
# pyrefly: ignore # bad-argument-type
self.bias = UninitializedParameter(**factory_kwargs) self.bias = UninitializedParameter(**factory_kwargs)
def reset_parameters(self) -> None: def reset_parameters(self) -> None:
""" """
Resets parameters based on their initialization used in ``__init__``. Resets parameters based on their initialization used in ``__init__``.
""" """
# pyrefly: ignore # bad-argument-type
if not self.has_uninitialized_params() and self.in_features != 0: if not self.has_uninitialized_params() and self.in_features != 0:
super().reset_parameters() super().reset_parameters()
@ -312,6 +317,7 @@ class LazyLinear(LazyModuleMixin, Linear):
""" """
Infers ``in_features`` based on ``input`` and initializes parameters. Infers ``in_features`` based on ``input`` and initializes parameters.
""" """
# pyrefly: ignore # bad-argument-type
if self.has_uninitialized_params(): if self.has_uninitialized_params():
with torch.no_grad(): with torch.no_grad():
self.in_features = input.shape[-1] self.in_features = input.shape[-1]

View File

@ -38,11 +38,13 @@ T = TypeVar("T", bound="Module")
class _IncompatibleKeys( class _IncompatibleKeys(
# pyrefly: ignore # invalid-inheritance
namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"]), namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"]),
): ):
__slots__ = () __slots__ = ()
def __repr__(self) -> str: def __repr__(self) -> str:
# pyrefly: ignore # missing-attribute
if not self.missing_keys and not self.unexpected_keys: if not self.missing_keys and not self.unexpected_keys:
return "<All keys matched successfully>" return "<All keys matched successfully>"
return super().__repr__() return super().__repr__()
@ -91,6 +93,7 @@ class _WrappedHook:
def __getstate__(self) -> dict: def __getstate__(self) -> dict:
result = {"hook": self.hook, "with_module": self.with_module} result = {"hook": self.hook, "with_module": self.with_module}
if self.with_module: if self.with_module:
# pyrefly: ignore # unsupported-operation
result["module"] = self.module() result["module"] = self.module()
return result return result
@ -976,7 +979,9 @@ class Module:
# Decrement use count of the gradient by setting to None # Decrement use count of the gradient by setting to None
param.grad = None param.grad = None
param_applied = torch.nn.Parameter( param_applied = torch.nn.Parameter(
param_applied, requires_grad=param.requires_grad # pyrefly: ignore # bad-argument-type
param_applied,
requires_grad=param.requires_grad,
) )
torch.utils.swap_tensors(param, param_applied) torch.utils.swap_tensors(param, param_applied)
except Exception as e: except Exception as e:
@ -987,11 +992,13 @@ class Module:
) from e ) from e
out_param = param out_param = param
elif p_should_use_set_data: elif p_should_use_set_data:
# pyrefly: ignore # bad-assignment
param.data = param_applied param.data = param_applied
out_param = param out_param = param
else: else:
assert isinstance(param, Parameter) assert isinstance(param, Parameter)
assert param.is_leaf assert param.is_leaf
# pyrefly: ignore # bad-argument-type
out_param = Parameter(param_applied, param.requires_grad) out_param = Parameter(param_applied, param.requires_grad)
self._parameters[key] = out_param self._parameters[key] = out_param
@ -2253,6 +2260,7 @@ class Module:
if destination is None: if destination is None:
destination = OrderedDict() destination = OrderedDict()
# pyrefly: ignore # missing-attribute
destination._metadata = OrderedDict() destination._metadata = OrderedDict()
local_metadata = dict(version=self._version) local_metadata = dict(version=self._version)
@ -2402,7 +2410,9 @@ class Module:
if k not in self._non_persistent_buffers_set if k not in self._non_persistent_buffers_set
} }
local_name_params = itertools.chain( local_name_params = itertools.chain(
self._parameters.items(), persistent_buffers.items() self._parameters.items(),
# pyrefly: ignore # bad-argument-type
persistent_buffers.items(),
) )
local_state = {k: v for k, v in local_name_params if v is not None} local_state = {k: v for k, v in local_name_params if v is not None}
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False) assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)

View File

@ -84,6 +84,7 @@ class CircularPad1d(_CircularPadNd):
[5., 6., 7., 4., 5., 6., 7., 4.]]]) [5., 6., 7., 4., 5., 6., 7., 4.]]])
""" """
# pyrefly: ignore # bad-override
padding: tuple[int, int] padding: tuple[int, int]
def __init__(self, padding: _size_2_t) -> None: def __init__(self, padding: _size_2_t) -> None:
@ -144,6 +145,7 @@ class CircularPad2d(_CircularPadNd):
[8., 6., 7., 8., 6.]]]]) [8., 6., 7., 8., 6.]]]])
""" """
# pyrefly: ignore # bad-override
padding: tuple[int, int, int, int] padding: tuple[int, int, int, int]
def __init__(self, padding: _size_4_t) -> None: def __init__(self, padding: _size_4_t) -> None:
@ -194,6 +196,7 @@ class CircularPad3d(_CircularPadNd):
>>> output = m(input) >>> output = m(input)
""" """
# pyrefly: ignore # bad-override
padding: tuple[int, int, int, int, int, int] padding: tuple[int, int, int, int, int, int]
def __init__(self, padding: _size_6_t) -> None: def __init__(self, padding: _size_6_t) -> None:
@ -265,6 +268,7 @@ class ConstantPad1d(_ConstantPadNd):
[ 3.5000, 3.5000, 3.5000, -3.6372, 0.1182, -1.8652, 3.5000]]]) [ 3.5000, 3.5000, 3.5000, -3.6372, 0.1182, -1.8652, 3.5000]]])
""" """
# pyrefly: ignore # bad-override
padding: tuple[int, int] padding: tuple[int, int]
def __init__(self, padding: _size_2_t, value: float) -> None: def __init__(self, padding: _size_2_t, value: float) -> None:
@ -316,6 +320,7 @@ class ConstantPad2d(_ConstantPadNd):
""" """
__constants__ = ["padding", "value"] __constants__ = ["padding", "value"]
# pyrefly: ignore # bad-override
padding: tuple[int, int, int, int] padding: tuple[int, int, int, int]
def __init__(self, padding: _size_4_t, value: float) -> None: def __init__(self, padding: _size_4_t, value: float) -> None:
@ -356,6 +361,7 @@ class ConstantPad3d(_ConstantPadNd):
>>> output = m(input) >>> output = m(input)
""" """
# pyrefly: ignore # bad-override
padding: tuple[int, int, int, int, int, int] padding: tuple[int, int, int, int, int, int]
def __init__(self, padding: _size_6_t, value: float) -> None: def __init__(self, padding: _size_6_t, value: float) -> None:
@ -409,6 +415,7 @@ class ReflectionPad1d(_ReflectionPadNd):
[7., 6., 5., 4., 5., 6., 7., 6.]]]) [7., 6., 5., 4., 5., 6., 7., 6.]]])
""" """
# pyrefly: ignore # bad-override
padding: tuple[int, int] padding: tuple[int, int]
def __init__(self, padding: _size_2_t) -> None: def __init__(self, padding: _size_2_t) -> None:
@ -462,6 +469,7 @@ class ReflectionPad2d(_ReflectionPadNd):
[7., 6., 7., 8., 7.]]]]) [7., 6., 7., 8., 7.]]]])
""" """
# pyrefly: ignore # bad-override
padding: tuple[int, int, int, int] padding: tuple[int, int, int, int]
def __init__(self, padding: _size_4_t) -> None: def __init__(self, padding: _size_4_t) -> None:
@ -517,6 +525,7 @@ class ReflectionPad3d(_ReflectionPadNd):
[1., 0., 1., 0.]]]]]) [1., 0., 1., 0.]]]]])
""" """
# pyrefly: ignore # bad-override
padding: tuple[int, int, int, int, int, int] padding: tuple[int, int, int, int, int, int]
def __init__(self, padding: _size_6_t) -> None: def __init__(self, padding: _size_6_t) -> None:
@ -570,6 +579,7 @@ class ReplicationPad1d(_ReplicationPadNd):
[4., 4., 4., 4., 5., 6., 7., 7.]]]) [4., 4., 4., 4., 5., 6., 7., 7.]]])
""" """
# pyrefly: ignore # bad-override
padding: tuple[int, int] padding: tuple[int, int]
def __init__(self, padding: _size_2_t) -> None: def __init__(self, padding: _size_2_t) -> None:
@ -623,6 +633,7 @@ class ReplicationPad2d(_ReplicationPadNd):
[6., 6., 7., 8., 8.]]]]) [6., 6., 7., 8., 8.]]]])
""" """
# pyrefly: ignore # bad-override
padding: tuple[int, int, int, int] padding: tuple[int, int, int, int]
def __init__(self, padding: _size_4_t) -> None: def __init__(self, padding: _size_4_t) -> None:
@ -665,6 +676,7 @@ class ReplicationPad3d(_ReplicationPadNd):
>>> output = m(input) >>> output = m(input)
""" """
# pyrefly: ignore # bad-override
padding: tuple[int, int, int, int, int, int] padding: tuple[int, int, int, int, int, int]
def __init__(self, padding: _size_6_t) -> None: def __init__(self, padding: _size_6_t) -> None:

View File

@ -111,6 +111,7 @@ class RNNBase(Module):
if ( if (
not isinstance(dropout, numbers.Number) not isinstance(dropout, numbers.Number)
# pyrefly: ignore # unsupported-operation
or not 0 <= dropout <= 1 or not 0 <= dropout <= 1
or isinstance(dropout, bool) or isinstance(dropout, bool)
): ):
@ -119,6 +120,7 @@ class RNNBase(Module):
"representing the probability of an element being " "representing the probability of an element being "
"zeroed" "zeroed"
) )
# pyrefly: ignore # unsupported-operation
if dropout > 0 and num_layers == 1: if dropout > 0 and num_layers == 1:
warnings.warn( warnings.warn(
"dropout option adds dropout after all but last " "dropout option adds dropout after all but last "
@ -639,15 +641,22 @@ class RNN(RNNBase):
@overload @overload
@torch._jit_internal._overload_method # noqa: F811 @torch._jit_internal._overload_method # noqa: F811
# pyrefly: ignore # bad-override
def forward( def forward(
self, input: Tensor, hx: Optional[Tensor] = None self,
input: Tensor,
hx: Optional[Tensor] = None,
# pyrefly: ignore # bad-return
) -> tuple[Tensor, Tensor]: ) -> tuple[Tensor, Tensor]:
pass pass
@overload @overload
@torch._jit_internal._overload_method # noqa: F811 @torch._jit_internal._overload_method # noqa: F811
def forward( def forward(
self, input: PackedSequence, hx: Optional[Tensor] = None self,
input: PackedSequence,
hx: Optional[Tensor] = None,
# pyrefly: ignore # bad-return
) -> tuple[PackedSequence, Tensor]: ) -> tuple[PackedSequence, Tensor]:
pass pass
@ -772,7 +781,11 @@ class RNN(RNNBase):
if isinstance(orig_input, PackedSequence): if isinstance(orig_input, PackedSequence):
output_packed = PackedSequence( output_packed = PackedSequence(
output, batch_sizes, sorted_indices, unsorted_indices output,
# pyrefly: ignore # bad-argument-type
batch_sizes,
sorted_indices,
unsorted_indices,
) )
return output_packed, self.permute_hidden(hidden, unsorted_indices) return output_packed, self.permute_hidden(hidden, unsorted_indices)
@ -996,6 +1009,7 @@ class LSTM(RNNBase):
# In the future, we should prevent mypy from applying contravariance rules here. # In the future, we should prevent mypy from applying contravariance rules here.
# See torch/nn/modules/module.py::_forward_unimplemented # See torch/nn/modules/module.py::_forward_unimplemented
# pyrefly: ignore # bad-override
def check_forward_args( def check_forward_args(
self, self,
input: Tensor, input: Tensor,
@ -1029,8 +1043,12 @@ class LSTM(RNNBase):
# Same as above, see torch/nn/modules/module.py::_forward_unimplemented # Same as above, see torch/nn/modules/module.py::_forward_unimplemented
@overload # type: ignore[override] @overload # type: ignore[override]
@torch._jit_internal._overload_method # noqa: F811 @torch._jit_internal._overload_method # noqa: F811
# pyrefly: ignore # bad-override
def forward( def forward(
self, input: Tensor, hx: Optional[tuple[Tensor, Tensor]] = None self,
input: Tensor,
hx: Optional[tuple[Tensor, Tensor]] = None,
# pyrefly: ignore # bad-return
) -> tuple[Tensor, tuple[Tensor, Tensor]]: # noqa: F811 ) -> tuple[Tensor, tuple[Tensor, Tensor]]: # noqa: F811
pass pass
@ -1038,7 +1056,10 @@ class LSTM(RNNBase):
@overload @overload
@torch._jit_internal._overload_method # noqa: F811 @torch._jit_internal._overload_method # noqa: F811
def forward( def forward(
self, input: PackedSequence, hx: Optional[tuple[Tensor, Tensor]] = None self,
input: PackedSequence,
hx: Optional[tuple[Tensor, Tensor]] = None,
# pyrefly: ignore # bad-return
) -> tuple[PackedSequence, tuple[Tensor, Tensor]]: # noqa: F811 ) -> tuple[PackedSequence, tuple[Tensor, Tensor]]: # noqa: F811
pass pass
@ -1152,7 +1173,11 @@ class LSTM(RNNBase):
# xxx: isinstance check needs to be in conditional for TorchScript to compile # xxx: isinstance check needs to be in conditional for TorchScript to compile
if isinstance(orig_input, PackedSequence): if isinstance(orig_input, PackedSequence):
output_packed = PackedSequence( output_packed = PackedSequence(
output, batch_sizes, sorted_indices, unsorted_indices output,
# pyrefly: ignore # bad-argument-type
batch_sizes,
sorted_indices,
unsorted_indices,
) )
return output_packed, self.permute_hidden(hidden, unsorted_indices) return output_packed, self.permute_hidden(hidden, unsorted_indices)
else: else:
@ -1318,15 +1343,22 @@ class GRU(RNNBase):
@overload # type: ignore[override] @overload # type: ignore[override]
@torch._jit_internal._overload_method # noqa: F811 @torch._jit_internal._overload_method # noqa: F811
# pyrefly: ignore # bad-override
def forward( def forward(
self, input: Tensor, hx: Optional[Tensor] = None self,
input: Tensor,
hx: Optional[Tensor] = None,
# pyrefly: ignore # bad-return
) -> tuple[Tensor, Tensor]: # noqa: F811 ) -> tuple[Tensor, Tensor]: # noqa: F811
pass pass
@overload @overload
@torch._jit_internal._overload_method # noqa: F811 @torch._jit_internal._overload_method # noqa: F811
def forward( def forward(
self, input: PackedSequence, hx: Optional[Tensor] = None self,
input: PackedSequence,
hx: Optional[Tensor] = None,
# pyrefly: ignore # bad-return
) -> tuple[PackedSequence, Tensor]: # noqa: F811 ) -> tuple[PackedSequence, Tensor]: # noqa: F811
pass pass
@ -1420,7 +1452,11 @@ class GRU(RNNBase):
# xxx: isinstance check needs to be in conditional for TorchScript to compile # xxx: isinstance check needs to be in conditional for TorchScript to compile
if isinstance(orig_input, PackedSequence): if isinstance(orig_input, PackedSequence):
output_packed = PackedSequence( output_packed = PackedSequence(
output, batch_sizes, sorted_indices, unsorted_indices output,
# pyrefly: ignore # bad-argument-type
batch_sizes,
sorted_indices,
unsorted_indices,
) )
return output_packed, self.permute_hidden(hidden, unsorted_indices) return output_packed, self.permute_hidden(hidden, unsorted_indices)
else: else:

View File

@ -135,7 +135,11 @@ class Transformer(Module):
**factory_kwargs, **factory_kwargs,
) )
encoder_norm = LayerNorm( encoder_norm = LayerNorm(
d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs d_model,
eps=layer_norm_eps,
bias=bias,
# pyrefly: ignore # bad-argument-type
**factory_kwargs,
) )
self.encoder = TransformerEncoder( self.encoder = TransformerEncoder(
encoder_layer, num_encoder_layers, encoder_norm encoder_layer, num_encoder_layers, encoder_norm
@ -157,7 +161,11 @@ class Transformer(Module):
**factory_kwargs, **factory_kwargs,
) )
decoder_norm = LayerNorm( decoder_norm = LayerNorm(
d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs d_model,
eps=layer_norm_eps,
bias=bias,
# pyrefly: ignore # bad-argument-type
**factory_kwargs,
) )
self.decoder = TransformerDecoder( self.decoder = TransformerDecoder(
decoder_layer, num_decoder_layers, decoder_norm decoder_layer, num_decoder_layers, decoder_norm
@ -760,7 +768,9 @@ class TransformerEncoderLayer(Module):
self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs) self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs)
self.norm_first = norm_first self.norm_first = norm_first
# pyrefly: ignore # bad-argument-type
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
# pyrefly: ignore # bad-argument-type
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
self.dropout1 = Dropout(dropout) self.dropout1 = Dropout(dropout)
self.dropout2 = Dropout(dropout) self.dropout2 = Dropout(dropout)
@ -1052,8 +1062,11 @@ class TransformerDecoderLayer(Module):
self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs) self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs)
self.norm_first = norm_first self.norm_first = norm_first
# pyrefly: ignore # bad-argument-type
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
# pyrefly: ignore # bad-argument-type
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
# pyrefly: ignore # bad-argument-type
self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
self.dropout1 = Dropout(dropout) self.dropout1 = Dropout(dropout)
self.dropout2 = Dropout(dropout) self.dropout2 = Dropout(dropout)

View File

@ -36,6 +36,7 @@ def _list_with_default(out_size: list[int], defaults: list[int]) -> list[int]:
import torch import torch
if isinstance(out_size, (int, torch.SymInt)): if isinstance(out_size, (int, torch.SymInt)):
# pyrefly: ignore # bad-return
return out_size return out_size
if len(defaults) <= len(out_size): if len(defaults) <= len(out_size):
raise ValueError(f"Input dimension should be at least {len(out_size) + 1}") raise ValueError(f"Input dimension should be at least {len(out_size) + 1}")

View File

@ -43,6 +43,7 @@ def broadcast(tensor, devices=None, *, out=None):
devices = [_get_device_index(d) for d in devices] devices = [_get_device_index(d) for d in devices]
return torch._C._broadcast(tensor, devices) return torch._C._broadcast(tensor, devices)
else: else:
# pyrefly: ignore # bad-argument-type
return torch._C._broadcast_out(tensor, out) return torch._C._broadcast_out(tensor, out)
@ -200,6 +201,7 @@ def scatter(tensor, devices=None, chunk_sizes=None, dim=0, streams=None, *, out=
""" """
tensor = _handle_complex(tensor) tensor = _handle_complex(tensor)
if out is None: if out is None:
# pyrefly: ignore # not-iterable
devices = [_get_device_index(d) for d in devices] devices = [_get_device_index(d) for d in devices]
return tuple(torch._C._scatter(tensor, devices, chunk_sizes, dim, streams)) return tuple(torch._C._scatter(tensor, devices, chunk_sizes, dim, streams))
else: else:

View File

@ -160,6 +160,7 @@ class DataParallel(Module, Generic[T]):
self.module = module self.module = module
self.device_ids = [_get_device_index(x, True) for x in device_ids] self.device_ids = [_get_device_index(x, True) for x in device_ids]
self.output_device = _get_device_index(output_device, True) self.output_device = _get_device_index(output_device, True)
# pyrefly: ignore # read-only
self.src_device_obj = torch.device(device_type, self.device_ids[0]) self.src_device_obj = torch.device(device_type, self.device_ids[0])
if device_type == "cuda": if device_type == "cuda":
@ -173,6 +174,7 @@ class DataParallel(Module, Generic[T]):
if not self.device_ids: if not self.device_ids:
return self.module(*inputs, **kwargs) return self.module(*inputs, **kwargs)
# pyrefly: ignore # bad-argument-type
for t in chain(self.module.parameters(), self.module.buffers()): for t in chain(self.module.parameters(), self.module.buffers()):
if t.device != self.src_device_obj: if t.device != self.src_device_obj:
raise RuntimeError( raise RuntimeError(
@ -259,8 +261,10 @@ def data_parallel(
device_ids = [_get_device_index(x, True) for x in device_ids] device_ids = [_get_device_index(x, True) for x in device_ids]
output_device = _get_device_index(output_device, True) output_device = _get_device_index(output_device, True)
# pyrefly: ignore # no-matching-overload
src_device_obj = torch.device(device_type, device_ids[0]) src_device_obj = torch.device(device_type, device_ids[0])
# pyrefly: ignore # bad-argument-type
for t in chain(module.parameters(), module.buffers()): for t in chain(module.parameters(), module.buffers()):
if t.device != src_device_obj: if t.device != src_device_obj:
raise RuntimeError( raise RuntimeError(

View File

@ -241,6 +241,7 @@ class _BufferCommHook:
# is completed. # is completed.
class _DDPSink(Function): class _DDPSink(Function):
@staticmethod @staticmethod
# pyrefly: ignore # bad-override
def forward(ctx, ddp_weakref, *inputs): def forward(ctx, ddp_weakref, *inputs):
# set_materialize_grads(False) will ensure that None gradients stay as # set_materialize_grads(False) will ensure that None gradients stay as
# None and are not filled with zeros. # None and are not filled with zeros.
@ -691,6 +692,7 @@ class DistributedDataParallel(Module, Joinable):
elif process_group is None and device_mesh is None: elif process_group is None and device_mesh is None:
self.process_group = _get_default_group() self.process_group = _get_default_group()
elif device_mesh is None: elif device_mesh is None:
# pyrefly: ignore # bad-assignment
self.process_group = process_group self.process_group = process_group
else: else:
if device_mesh.ndim != 1: if device_mesh.ndim != 1:
@ -779,11 +781,13 @@ class DistributedDataParallel(Module, Joinable):
self.device_ids = None self.device_ids = None
self.output_device = None self.output_device = None
else: else:
# pyrefly: ignore # bad-assignment
self.device_ids = [_get_device_index(x, True) for x in device_ids] self.device_ids = [_get_device_index(x, True) for x in device_ids]
if output_device is None: if output_device is None:
output_device = device_ids[0] output_device = device_ids[0]
# pyrefly: ignore # bad-assignment
self.output_device = _get_device_index(output_device, True) self.output_device = _get_device_index(output_device, True)
self.static_graph = False self.static_graph = False
@ -933,6 +937,7 @@ class DistributedDataParallel(Module, Joinable):
# enabled. # enabled.
self._accum_grad_hooks: list[RemovableHandle] = [] self._accum_grad_hooks: list[RemovableHandle] = []
if self._use_python_reducer: if self._use_python_reducer:
# pyrefly: ignore # bad-assignment
torch._inductor.config._fuse_ddp_communication = True torch._inductor.config._fuse_ddp_communication = True
torch._inductor.config._fuse_ddp_bucket_size = bucket_cap_mb torch._inductor.config._fuse_ddp_bucket_size = bucket_cap_mb
# Directly adding this to the trace rule will disturb the users # Directly adding this to the trace rule will disturb the users

View File

@ -56,12 +56,16 @@ def scatter(inputs, target_gpus, dim=0):
if isinstance(obj, torch.Tensor): if isinstance(obj, torch.Tensor):
return Scatter.apply(target_gpus, None, dim, obj) return Scatter.apply(target_gpus, None, dim, obj)
if _is_namedtuple(obj): if _is_namedtuple(obj):
# pyrefly: ignore # no-matching-overload
return [type(obj)(*args) for args in zip(*map(scatter_map, obj))] return [type(obj)(*args) for args in zip(*map(scatter_map, obj))]
if isinstance(obj, tuple) and len(obj) > 0: if isinstance(obj, tuple) and len(obj) > 0:
# pyrefly: ignore # no-matching-overload
return list(zip(*map(scatter_map, obj))) return list(zip(*map(scatter_map, obj)))
if isinstance(obj, list) and len(obj) > 0: if isinstance(obj, list) and len(obj) > 0:
# pyrefly: ignore # no-matching-overload
return [list(i) for i in zip(*map(scatter_map, obj))] return [list(i) for i in zip(*map(scatter_map, obj))]
if isinstance(obj, dict) and len(obj) > 0: if isinstance(obj, dict) and len(obj) > 0:
# pyrefly: ignore # no-matching-overload
return [type(obj)(i) for i in zip(*map(scatter_map, obj.items()))] return [type(obj)(i) for i in zip(*map(scatter_map, obj.items()))]
return [obj for _ in target_gpus] return [obj for _ in target_gpus]
@ -123,9 +127,12 @@ def gather(outputs: Any, target_device: Union[int, torch.device], dim: int = 0)
if isinstance(out, dict): if isinstance(out, dict):
if not all(len(out) == len(d) for d in outputs): if not all(len(out) == len(d) for d in outputs):
raise ValueError("All dicts must have the same number of keys") raise ValueError("All dicts must have the same number of keys")
# pyrefly: ignore # not-callable
return type(out)((k, gather_map([d[k] for d in outputs])) for k in out) return type(out)((k, gather_map([d[k] for d in outputs])) for k in out)
if _is_namedtuple(out): if _is_namedtuple(out):
# pyrefly: ignore # no-matching-overload
return type(out)._make(map(gather_map, zip(*outputs))) return type(out)._make(map(gather_map, zip(*outputs)))
# pyrefly: ignore # no-matching-overload
return type(out)(map(gather_map, zip(*outputs))) return type(out)(map(gather_map, zip(*outputs)))
# Recursive function calls like this create reference cycles. # Recursive function calls like this create reference cycles.

View File

@ -81,6 +81,7 @@ class Parameter(torch.Tensor, metaclass=_ParameterMeta):
memo[id(self)] = result memo[id(self)] = result
return result return result
# pyrefly: ignore # bad-override
def __repr__(self): def __repr__(self):
return "Parameter containing:\n" + super().__repr__() return "Parameter containing:\n" + super().__repr__()
@ -143,6 +144,7 @@ class UninitializedTensorMixin:
if dtype is None: if dtype is None:
dtype = self.data.dtype dtype = self.data.dtype
self.data = torch.empty(shape, device=device, dtype=dtype) self.data = torch.empty(shape, device=device, dtype=dtype)
# pyrefly: ignore # bad-override, missing-attribute
self.__class__ = self.cls_to_become self.__class__ = self.cls_to_become
@property @property
@ -166,6 +168,7 @@ class UninitializedTensorMixin:
def __reduce_ex__(self, proto): def __reduce_ex__(self, proto):
# See Note [Don't serialize hooks] # See Note [Don't serialize hooks]
# pyrefly: ignore # missing-attribute
return (self.__class__, (self.requires_grad,)) return (self.__class__, (self.requires_grad,))
@classmethod @classmethod
@ -175,6 +178,7 @@ class UninitializedTensorMixin:
if func in cls._allowed_methods or func.__class__.__name__ == "method-wrapper": if func in cls._allowed_methods or func.__class__.__name__ == "method-wrapper":
if kwargs is None: if kwargs is None:
kwargs = {} kwargs = {}
# pyrefly: ignore # missing-attribute
return super().__torch_function__(func, types, args, kwargs) return super().__torch_function__(func, types, args, kwargs)
raise ValueError( raise ValueError(
f"Attempted to use an uninitialized parameter in {func}. " f"Attempted to use an uninitialized parameter in {func}. "
@ -216,6 +220,7 @@ class UninitializedParameter(UninitializedTensorMixin, Parameter):
def __new__(cls, requires_grad=True, device=None, dtype=None) -> None: def __new__(cls, requires_grad=True, device=None, dtype=None) -> None:
factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {"device": device, "dtype": dtype}
data = torch.empty(0, **factory_kwargs) data = torch.empty(0, **factory_kwargs)
# pyrefly: ignore # bad-return
return torch.Tensor._make_subclass(cls, data, requires_grad) return torch.Tensor._make_subclass(cls, data, requires_grad)
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
@ -261,7 +266,9 @@ class Buffer(torch.Tensor, metaclass=_BufferMeta):
data = torch.empty(0) data = torch.empty(0)
t = data.detach().requires_grad_(data.requires_grad) t = data.detach().requires_grad_(data.requires_grad)
# pyrefly: ignore # missing-attribute
t.persistent = persistent t.persistent = persistent
# pyrefly: ignore # missing-attribute
t._is_buffer = True t._is_buffer = True
return t return t
@ -292,6 +299,9 @@ class UninitializedBuffer(UninitializedTensorMixin, torch.Tensor):
factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {"device": device, "dtype": dtype}
data = torch.empty(0, **factory_kwargs) data = torch.empty(0, **factory_kwargs)
ret = torch.Tensor._make_subclass(cls, data, requires_grad) ret = torch.Tensor._make_subclass(cls, data, requires_grad)
# pyrefly: ignore # missing-attribute
ret.persistent = persistent ret.persistent = persistent
# pyrefly: ignore # missing-attribute
ret._is_buffer = True ret._is_buffer = True
# pyrefly: ignore # bad-return
return ret return ret

View File

@ -1,5 +1,5 @@
from . import parametrizations, parametrize, rnn, stateless from . import parametrizations, parametrize, rnn, stateless
from .clip_grad import ( from .clip_grad import ( # pyrefly: ignore # deprecated
_clip_grads_with_norm_ as clip_grads_with_norm_, _clip_grads_with_norm_ as clip_grads_with_norm_,
_get_total_norm as get_total_norm, _get_total_norm as get_total_norm,
clip_grad_norm, clip_grad_norm,

View File

@ -24,6 +24,7 @@ from .expanded_weights_utils import forward_helper
@implements_per_sample_grads(F.conv3d) @implements_per_sample_grads(F.conv3d)
class ConvPerSampleGrad(torch.autograd.Function): class ConvPerSampleGrad(torch.autograd.Function):
@staticmethod @staticmethod
# pyrefly: ignore # bad-override
def forward( def forward(
ctx: Any, ctx: Any,
kwarg_names: list[str], kwarg_names: list[str],
@ -56,6 +57,7 @@ class ConvPerSampleGrad(torch.autograd.Function):
f"unbatched input of dim {input.dim()}, expected input of dim {batched_dim_size}" f"unbatched input of dim {input.dim()}, expected input of dim {batched_dim_size}"
) )
# pyrefly: ignore # invalid-type-var
ctx.conv_fn = conv_fn ctx.conv_fn = conv_fn
ctx.batch_size = orig_input.shape[0] ctx.batch_size = orig_input.shape[0]

View File

@ -237,6 +237,7 @@ def conv_unfold_weight_grad_sample(
# n=batch_sz; o=num_out_channels; p=(num_in_channels/groups)*kernel_sz # n=batch_sz; o=num_out_channels; p=(num_in_channels/groups)*kernel_sz
weight_grad_sample = torch.einsum("noq,npq->nop", grad_output, input) weight_grad_sample = torch.einsum("noq,npq->nop", grad_output, input)
# rearrange the above tensor and extract diagonals. # rearrange the above tensor and extract diagonals.
# pyrefly: ignore # no-matching-overload
weight_grad_sample = weight_grad_sample.view( weight_grad_sample = weight_grad_sample.view(
n, n,
groups, groups,

View File

@ -14,6 +14,7 @@ from .expanded_weights_utils import (
@implements_per_sample_grads(F.embedding) @implements_per_sample_grads(F.embedding)
class EmbeddingPerSampleGrad(torch.autograd.Function): class EmbeddingPerSampleGrad(torch.autograd.Function):
@staticmethod @staticmethod
# pyrefly: ignore # bad-override
def forward( def forward(
ctx: Any, kwarg_names: list[str], _: Any, *expanded_args_and_kwargs: Any ctx: Any, kwarg_names: list[str], _: Any, *expanded_args_and_kwargs: Any
) -> torch.Tensor: ) -> torch.Tensor:
@ -34,6 +35,7 @@ class EmbeddingPerSampleGrad(torch.autograd.Function):
return output return output
@staticmethod @staticmethod
# pyrefly: ignore # bad-override
def backward( def backward(
ctx: Any, grad_output: torch.Tensor ctx: Any, grad_output: torch.Tensor
) -> tuple[Optional[torch.Tensor], ...]: ) -> tuple[Optional[torch.Tensor], ...]:

View File

@ -131,7 +131,9 @@ class ExpandedWeight(torch.Tensor):
# in aten, choosing the input or data variants is done by parsing logic. This mimics some of that # in aten, choosing the input or data variants is done by parsing logic. This mimics some of that
decomp_opts = expanded_weights_rnn_decomps[func] decomp_opts = expanded_weights_rnn_decomps[func]
use_input_variant = isinstance( use_input_variant = isinstance(
args[2], list # pyrefly: ignore # index-error
args[2],
list,
) # data variant uses a list here ) # data variant uses a list here
decomp = decomp_opts[0] if use_input_variant else decomp_opts[1] decomp = decomp_opts[0] if use_input_variant else decomp_opts[1]

View File

@ -8,6 +8,7 @@ from .expanded_weights_impl import ExpandedWeight
def is_batch_first(expanded_args_and_kwargs): def is_batch_first(expanded_args_and_kwargs):
batch_first = None batch_first = None
# pyrefly: ignore # bad-assignment
for arg in expanded_args_and_kwargs: for arg in expanded_args_and_kwargs:
if not isinstance(arg, ExpandedWeight): if not isinstance(arg, ExpandedWeight):
continue continue

View File

@ -18,6 +18,7 @@ from .expanded_weights_utils import (
@implements_per_sample_grads(F.group_norm) @implements_per_sample_grads(F.group_norm)
class GroupNormPerSampleGrad(torch.autograd.Function): class GroupNormPerSampleGrad(torch.autograd.Function):
@staticmethod @staticmethod
# pyrefly: ignore # bad-override
def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs): def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs):
expanded_args, expanded_kwargs = standard_kwargs( expanded_args, expanded_kwargs = standard_kwargs(
kwarg_names, expanded_args_and_kwargs kwarg_names, expanded_args_and_kwargs
@ -46,6 +47,7 @@ class GroupNormPerSampleGrad(torch.autograd.Function):
return output return output
@staticmethod @staticmethod
# pyrefly: ignore # bad-override
def backward(ctx, grad_output): def backward(ctx, grad_output):
input, num_groups = ctx.input, ctx.num_groups input, num_groups = ctx.input, ctx.num_groups
weight, bias, eps = ctx.weight, ctx.bias, ctx.eps weight, bias, eps = ctx.weight, ctx.bias, ctx.eps
@ -94,7 +96,9 @@ class GroupNormPerSampleGrad(torch.autograd.Function):
set_grad_sample_if_exists( set_grad_sample_if_exists(
weight, weight,
lambda _: torch.einsum( lambda _: torch.einsum(
"ni...->ni", F.group_norm(input, num_groups, eps=eps) * grad_output "ni...->ni",
# pyrefly: ignore # unsupported-operation
F.group_norm(input, num_groups, eps=eps) * grad_output,
), ),
) )
if hasattr(ctx, "bias"): if hasattr(ctx, "bias"):

View File

@ -17,6 +17,7 @@ from .expanded_weights_utils import (
@implements_per_sample_grads(F.instance_norm) @implements_per_sample_grads(F.instance_norm)
class InstanceNormPerSampleGrad(torch.autograd.Function): class InstanceNormPerSampleGrad(torch.autograd.Function):
@staticmethod @staticmethod
# pyrefly: ignore # bad-override
def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs): def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs):
instance_norm = partial(torch.instance_norm, cudnn_enabled=True) instance_norm = partial(torch.instance_norm, cudnn_enabled=True)
expanded_args, expanded_kwargs = standard_kwargs( expanded_args, expanded_kwargs = standard_kwargs(
@ -36,6 +37,7 @@ class InstanceNormPerSampleGrad(torch.autograd.Function):
return output return output
@staticmethod @staticmethod
# pyrefly: ignore # bad-override
def backward(ctx, grad_output): def backward(ctx, grad_output):
input, running_mean, running_var = ctx.input, ctx.running_mean, ctx.running_var input, running_mean, running_var = ctx.input, ctx.running_mean, ctx.running_var
weight, bias, eps = ctx.weight, ctx.bias, ctx.eps weight, bias, eps = ctx.weight, ctx.bias, ctx.eps

View File

@ -17,6 +17,7 @@ from .expanded_weights_utils import (
@implements_per_sample_grads(F.layer_norm) @implements_per_sample_grads(F.layer_norm)
class LayerNormPerSampleGrad(torch.autograd.Function): class LayerNormPerSampleGrad(torch.autograd.Function):
@staticmethod @staticmethod
# pyrefly: ignore # bad-override
def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs): def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs):
expanded_args, expanded_kwargs = standard_kwargs( expanded_args, expanded_kwargs = standard_kwargs(
kwarg_names, expanded_args_and_kwargs kwarg_names, expanded_args_and_kwargs
@ -42,6 +43,7 @@ class LayerNormPerSampleGrad(torch.autograd.Function):
return output return output
@staticmethod @staticmethod
# pyrefly: ignore # bad-override
def backward(ctx, grad_output): def backward(ctx, grad_output):
def weight_per_sample_grad(weight): def weight_per_sample_grad(weight):
return sum_over_all_but_batch_and_last_n( return sum_over_all_but_batch_and_last_n(

View File

@ -16,6 +16,7 @@ from .expanded_weights_utils import (
@implements_per_sample_grads(F.linear) @implements_per_sample_grads(F.linear)
class LinearPerSampleGrad(torch.autograd.Function): class LinearPerSampleGrad(torch.autograd.Function):
@staticmethod @staticmethod
# pyrefly: ignore # bad-override
def forward(ctx, _, __, *expanded_args_and_kwargs): def forward(ctx, _, __, *expanded_args_and_kwargs):
if len(expanded_args_and_kwargs[0].shape) <= 1: if len(expanded_args_and_kwargs[0].shape) <= 1:
raise RuntimeError( raise RuntimeError(
@ -35,6 +36,7 @@ class LinearPerSampleGrad(torch.autograd.Function):
return output return output
@staticmethod @staticmethod
# pyrefly: ignore # bad-override
def backward(ctx, grad_output): def backward(ctx, grad_output):
input, weight = ctx.args input, weight = ctx.args
bias = ctx.kwargs["bias"] bias = ctx.kwargs["bias"]

View File

@ -77,6 +77,7 @@ def swap_tensor(
setattr(module, name, tensor) setattr(module, name, tensor)
elif hasattr(module, name): elif hasattr(module, name):
delattr(module, name) delattr(module, name)
# pyrefly: ignore # bad-return
return orig_tensor return orig_tensor

View File

@ -41,9 +41,11 @@ def _no_grad(func: Callable[_P, _R]) -> Callable[_P, _R]:
def _no_grad_wrapper(*args, **kwargs): def _no_grad_wrapper(*args, **kwargs):
with torch.no_grad(): with torch.no_grad():
# pyrefly: ignore # invalid-param-spec
return func(*args, **kwargs) return func(*args, **kwargs)
functools.update_wrapper(_no_grad_wrapper, func) functools.update_wrapper(_no_grad_wrapper, func)
# pyrefly: ignore # bad-return
return _no_grad_wrapper return _no_grad_wrapper

View File

@ -84,6 +84,7 @@ def convert_conv2d_weight_memory_format(
) )
for child in module.children(): for child in module.children():
convert_conv2d_weight_memory_format(child, memory_format) convert_conv2d_weight_memory_format(child, memory_format)
# pyrefly: ignore # bad-return
return module return module
@ -163,6 +164,7 @@ def convert_conv3d_weight_memory_format(
) )
for child in module.children(): for child in module.children():
convert_conv3d_weight_memory_format(child, memory_format) convert_conv3d_weight_memory_format(child, memory_format)
# pyrefly: ignore # bad-return
return module return module

View File

@ -98,6 +98,7 @@ class _Orthogonal(Module):
) )
# Q is now orthogonal (or unitary) of size (..., n, n) # Q is now orthogonal (or unitary) of size (..., n, n)
if n != k: if n != k:
# pyrefly: ignore # unbound-name
Q = Q[..., :k] Q = Q[..., :k]
# Q is now the size of the X (albeit perhaps transposed) # Q is now the size of the X (albeit perhaps transposed)
else: else:

View File

@ -179,23 +179,28 @@ class ParametrizationList(ModuleList):
# Register the tensor(s) # Register the tensor(s)
if self.is_tensor: if self.is_tensor:
# pyrefly: ignore # missing-attribute
if original.dtype != new.dtype: if original.dtype != new.dtype:
raise ValueError( raise ValueError(
"When `right_inverse` outputs one tensor, it may not change the dtype.\n" "When `right_inverse` outputs one tensor, it may not change the dtype.\n"
f"original.dtype: {original.dtype}\n" f"original.dtype: {original.dtype}\n"
# pyrefly: ignore # missing-attribute
f"right_inverse(original).dtype: {new.dtype}" f"right_inverse(original).dtype: {new.dtype}"
) )
# pyrefly: ignore # missing-attribute
if original.device != new.device: if original.device != new.device:
raise ValueError( raise ValueError(
"When `right_inverse` outputs one tensor, it may not change the device.\n" "When `right_inverse` outputs one tensor, it may not change the device.\n"
f"original.device: {original.device}\n" f"original.device: {original.device}\n"
# pyrefly: ignore # missing-attribute
f"right_inverse(original).device: {new.device}" f"right_inverse(original).device: {new.device}"
) )
# Set the original to original so that the user does not need to re-register the parameter # Set the original to original so that the user does not need to re-register the parameter
# manually in the optimiser # manually in the optimiser
with torch.no_grad(): with torch.no_grad():
# pyrefly: ignore # bad-argument-type
_maybe_set(original, new) _maybe_set(original, new)
_register_parameter_or_buffer(self, "original", original) _register_parameter_or_buffer(self, "original", original)
else: else:
@ -396,6 +401,7 @@ def _inject_property(module: Module, tensor_name: str) -> None:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
raise RuntimeError("Parametrization is not working with scripting.") raise RuntimeError("Parametrization is not working with scripting.")
parametrization = self.parametrizations[tensor_name] parametrization = self.parametrizations[tensor_name]
# pyrefly: ignore # redundant-condition
if _cache_enabled: if _cache_enabled:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
# Scripting # Scripting
@ -695,6 +701,7 @@ def remove_parametrizations(
# Fetch the original tensor # Fetch the original tensor
assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy
parametrizations = module.parametrizations[tensor_name] parametrizations = module.parametrizations[tensor_name]
# pyrefly: ignore # invalid-argument
if parametrizations.is_tensor: if parametrizations.is_tensor:
original = parametrizations.original original = parametrizations.original
assert isinstance(original, torch.Tensor), "is_tensor promised us a Tensor" assert isinstance(original, torch.Tensor), "is_tensor promised us a Tensor"

View File

@ -274,8 +274,11 @@ class PruningContainer(BasePruningMethod):
if not isinstance(args, Iterable): # only 1 item if not isinstance(args, Iterable): # only 1 item
self._tensor_name = args._tensor_name self._tensor_name = args._tensor_name
self.add_pruning_method(args) self.add_pruning_method(args)
# pyrefly: ignore # bad-argument-type
elif len(args) == 1: # only 1 item in a tuple elif len(args) == 1: # only 1 item in a tuple
# pyrefly: ignore # index-error
self._tensor_name = args[0]._tensor_name self._tensor_name = args[0]._tensor_name
# pyrefly: ignore # index-error
self.add_pruning_method(args[0]) self.add_pruning_method(args[0])
else: # manual construction from list or other iterable (or no args) else: # manual construction from list or other iterable (or no args)
for method in args: for method in args:
@ -1097,6 +1100,7 @@ def global_unstructured(parameters, pruning_method, importance_scores=None, **kw
# flatten importance scores to consider them all at once in global pruning # flatten importance scores to consider them all at once in global pruning
relevant_importance_scores = torch.nn.utils.parameters_to_vector( relevant_importance_scores = torch.nn.utils.parameters_to_vector(
# pyrefly: ignore # bad-argument-type
[ [
importance_scores.get((module, name), getattr(module, name)) importance_scores.get((module, name), getattr(module, name))
for (module, name) in parameters for (module, name) in parameters

View File

@ -332,6 +332,7 @@ def spectral_norm(
else: else:
dim = 0 dim = 0
SpectralNorm.apply(module, name, n_power_iterations, dim, eps) SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
# pyrefly: ignore # bad-return
return module return module

View File

@ -41,10 +41,7 @@ if not python_pytree._cxx_pytree_dynamo_traceable:
) )
# pyrefly: ignore # import-error
import optree import optree
# pyrefly: ignore # import-error
from optree import PyTreeSpec as TreeSpec # direct import for type annotations from optree import PyTreeSpec as TreeSpec # direct import for type annotations

View File

@ -679,7 +679,7 @@ class FlopCounterMode:
import tabulate import tabulate
# pyrefly: ignore # bad-assignment
tabulate.PRESERVE_WHITESPACE = True tabulate.PRESERVE_WHITESPACE = True
header = ["Module", "FLOP", "% Total"] header = ["Module", "FLOP", "% Total"]
values = [] values = []

View File

@ -9,7 +9,7 @@ from typing import Any, Optional
import torch import torch
import numpy as np import numpy as np
# pyrefly: ignore # import-error
from google.protobuf import struct_pb2 from google.protobuf import struct_pb2
from tensorboard.compat.proto.summary_pb2 import ( from tensorboard.compat.proto.summary_pb2 import (

View File

@ -956,7 +956,7 @@ class SummaryWriter:
) )
self._projector_config.embeddings.extend([embedding_info]) self._projector_config.embeddings.extend([embedding_info])
# pyrefly: ignore # import-error
from google.protobuf import text_format from google.protobuf import text_format
config_pbtxt = text_format.MessageToString(self._projector_config) config_pbtxt = text_format.MessageToString(self._projector_config)