Pyrefly suppressions 4/n (#164615)

Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: uncomment lines in the pyrefly.toml file
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/356645cf8cfe33123d9a27f23b30f7b1

after:

0 errors (2,753 ignored)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164615
Approved by: https://github.com/oulgen
This commit is contained in:
Maggie Moss
2025-10-06 16:14:36 +00:00
committed by PyTorch MergeBot
parent 4bd1505f84
commit 4ab847bbc7
52 changed files with 293 additions and 21 deletions

View File

@ -78,6 +78,7 @@ def _adjust_lr(
A, B = param_shape[:2]
if adjust_lr_fn is None or adjust_lr_fn == "original":
# pyrefly: ignore # no-matching-overload
adjusted_ratio = math.sqrt(max(1, A / B))
elif adjust_lr_fn == "match_rms_adamw":
adjusted_ratio = 0.2 * math.sqrt(max(A, B))

View File

@ -415,6 +415,7 @@ def _single_tensor_adam(
if weight_decay.requires_grad:
grad = grad.addcmul_(param.clone(), weight_decay)
else:
# pyrefly: ignore # bad-argument-type
grad = grad.add(param, alpha=weight_decay)
else:
grad = grad.add(param, alpha=weight_decay)
@ -444,6 +445,7 @@ def _single_tensor_adam(
device_beta1 = beta1
# Decay the first and second moment running average coefficient
# pyrefly: ignore # no-matching-overload
exp_avg.lerp_(grad, 1 - device_beta1)
# Nested if is necessary to bypass jitscript rules
@ -692,6 +694,7 @@ def _multi_tensor_adam(
device_exp_avgs, device_grads, cast(float, 1 - device_beta1)
)
# pyrefly: ignore # no-matching-overload
torch._foreach_mul_(device_exp_avg_sqs, beta2)
# Due to the strictness of the _foreach_addcmul API, we can't have a single

View File

@ -263,6 +263,7 @@ def _single_tensor_asgd(
ax.copy_(param)
if capturable:
# pyrefly: ignore # unsupported-operation
eta.copy_(lr / ((1 + lambd * lr * step_t) ** alpha))
mu.copy_(1 / torch.maximum(step_t - t0, torch.ones_like(step_t)))
else:

View File

@ -113,9 +113,11 @@ def _strong_wolfe(
# compute new trial value
t = _cubic_interpolate(
# pyrefly: ignore # index-error
bracket[0],
bracket_f[0],
bracket_gtd[0], # type: ignore[possibly-undefined]
# pyrefly: ignore # index-error
bracket[1],
bracket_f[1],
bracket_gtd[1],
@ -151,6 +153,7 @@ def _strong_wolfe(
if f_new > (f + c1 * t * gtd) or f_new >= bracket_f[low_pos]:
# Armijo condition not satisfied or not lower than lowest point
# pyrefly: ignore # unsupported-operation
bracket[high_pos] = t
bracket_f[high_pos] = f_new
bracket_g[high_pos] = g_new.clone(memory_format=torch.contiguous_format) # type: ignore[possibly-undefined]
@ -160,14 +163,17 @@ def _strong_wolfe(
if abs(gtd_new) <= -c2 * gtd:
# Wolfe conditions satisfied
done = True
# pyrefly: ignore # index-error
elif gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0:
# old high becomes new low
# pyrefly: ignore # unsupported-operation
bracket[high_pos] = bracket[low_pos]
bracket_f[high_pos] = bracket_f[low_pos]
bracket_g[high_pos] = bracket_g[low_pos] # type: ignore[possibly-undefined]
bracket_gtd[high_pos] = bracket_gtd[low_pos]
# new point becomes new low
# pyrefly: ignore # unsupported-operation
bracket[low_pos] = t
bracket_f[low_pos] = f_new
bracket_g[low_pos] = g_new.clone(memory_format=torch.contiguous_format) # type: ignore[possibly-undefined]
@ -252,6 +258,7 @@ class LBFGS(Optimizer):
def _numel(self):
if self._numel_cache is None:
# pyrefly: ignore # bad-assignment
self._numel_cache = sum(
2 * p.numel() if torch.is_complex(p) else p.numel()
for p in self._params

View File

@ -1665,6 +1665,7 @@ class ReduceLROnPlateau(LRScheduler):
self.default_min_lr = None
self.min_lrs = list(min_lr)
else:
# pyrefly: ignore # bad-assignment
self.default_min_lr = min_lr
self.min_lrs = [min_lr] * len(optimizer.param_groups)
@ -1724,6 +1725,7 @@ class ReduceLROnPlateau(LRScheduler):
"of the `optimizer` param groups."
)
else:
# pyrefly: ignore # bad-assignment
self.min_lrs = [self.default_min_lr] * len(self.optimizer.param_groups)
for i, param_group in enumerate(self.optimizer.param_groups):
@ -1903,10 +1905,13 @@ class CyclicLR(LRScheduler):
self.max_lrs = _format_param("max_lr", optimizer, max_lr)
# pyrefly: ignore # bad-assignment
step_size_up = float(step_size_up)
step_size_down = (
# pyrefly: ignore # bad-assignment
float(step_size_down) if step_size_down is not None else step_size_up
)
# pyrefly: ignore # unsupported-operation
self.total_size = step_size_up + step_size_down
self.step_ratio = step_size_up / self.total_size

View File

@ -62,6 +62,7 @@ def _use_grad_for_differentiable(func: Callable[_P, _T]) -> Callable[_P, _T]:
def _use_grad(*args: _P.args, **kwargs: _P.kwargs) -> _T:
import torch._dynamo
# pyrefly: ignore # unsupported-operation
self = cast(Optimizer, args[0]) # assume first positional arg is `self`
prev_grad = torch.is_grad_enabled()
try:
@ -135,11 +136,13 @@ def _disable_dynamo_if_unsupported(
if torch.compiler.is_compiling() and (
not kwargs.get("capturable", False)
and has_state_steps
# pyrefly: ignore # unsupported-operation
and (arg := args[state_steps_ind])
and isinstance(arg, Sequence)
and arg[0].is_cuda
or (
"state_steps" in kwargs
# pyrefly: ignore # unsupported-operation
and (kwarg := kwargs["state_steps"])
and isinstance(kwarg, Sequence)
and kwarg[0].is_cuda
@ -359,14 +362,18 @@ class Optimizer:
_optimizer_step_pre_hooks: dict[int, OptimizerPreHook]
_optimizer_step_post_hooks: dict[int, OptimizerPostHook]
# pyrefly: ignore # not-a-type
_optimizer_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]'
_optimizer_state_dict_post_hooks: (
# pyrefly: ignore # not-a-type
'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]'
)
_optimizer_load_state_dict_pre_hooks: (
# pyrefly: ignore # not-a-type
'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]'
)
_optimizer_load_state_dict_post_hooks: (
# pyrefly: ignore # not-a-type
'OrderedDict[int, Callable[["Optimizer"], None]]'
)
@ -391,6 +398,7 @@ class Optimizer:
self.state: defaultdict[torch.Tensor, Any] = defaultdict(dict)
self.param_groups: list[dict[str, Any]] = []
# pyrefly: ignore # no-matching-overload
param_groups = list(params)
if len(param_groups) == 0:
raise ValueError("optimizer got an empty parameter list")
@ -514,6 +522,7 @@ class Optimizer:
f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}."
)
# pyrefly: ignore # invalid-param-spec
out = func(*args, **kwargs)
self._optimizer_step_code()
@ -949,7 +958,14 @@ class Optimizer:
r"""Make a deep copy of value, casting all tensors to device of param."""
if isinstance(value, torch.Tensor):
return Optimizer._process_value_according_to_param_policy(
param, value, param_id, param_groups, key
# pyrefly: ignore # bad-argument-type
param,
value,
# pyrefly: ignore # bad-argument-type
param_id,
# pyrefly: ignore # bad-argument-type
param_groups,
key,
)
elif isinstance(value, dict):
return {
@ -960,6 +976,7 @@ class Optimizer:
}
elif isinstance(value, Iterable):
return type(value)(
# pyrefly: ignore # bad-argument-count
_cast(param, v, param_id=param_id, param_groups=param_groups)
for v in value
) # type: ignore[call-arg]

View File

@ -322,6 +322,7 @@ def _single_tensor_radam(
rho_t = rho_inf - 2 * step * (beta2**step) / bias_correction2
def _compute_rect():
# pyrefly: ignore # unsupported-operation
return (
(rho_t - 4)
* (rho_t - 2)
@ -336,6 +337,7 @@ def _single_tensor_radam(
else:
exp_avg_sq_sqrt = exp_avg_sq_sqrt.add_(eps)
# pyrefly: ignore # unsupported-operation
return (bias_correction2**0.5) / exp_avg_sq_sqrt
# Compute the variance rectification term and update parameters accordingly

View File

@ -337,6 +337,7 @@ def _single_tensor_sgd(
if not torch.jit.is_scripting():
lr = _to_scalar(lr)
# pyrefly: ignore # bad-assignment
for i, param in enumerate(params):
grad = grads[i] if not maximize else -grads[i]
@ -347,6 +348,7 @@ def _single_tensor_sgd(
# usually this is the differentiable path, which is why the param.clone() is needed
grad = grad.addcmul_(param.clone(), weight_decay)
else:
# pyrefly: ignore # bad-argument-type
grad = grad.add(param, alpha=weight_decay)
else:
grad = grad.add(param, alpha=weight_decay)
@ -370,6 +372,7 @@ def _single_tensor_sgd(
if lr.requires_grad:
param.addcmul_(grad, lr, value=-1)
else:
# pyrefly: ignore # bad-argument-type
param.add_(grad, alpha=-lr)
else:
param.add_(grad, alpha=-lr)
@ -430,10 +433,12 @@ def _multi_tensor_sgd(
all_states_with_momentum_buffer = True
for i in range(len(device_momentum_buffer_list)):
# pyrefly: ignore # index-error
if device_momentum_buffer_list[i] is None:
all_states_with_momentum_buffer = False
break
else:
# pyrefly: ignore # index-error
bufs.append(cast(Tensor, device_momentum_buffer_list[i]))
if all_states_with_momentum_buffer:
@ -441,12 +446,15 @@ def _multi_tensor_sgd(
torch._foreach_add_(bufs, device_grads, alpha=1 - dampening)
else:
bufs = []
# pyrefly: ignore # bad-assignment
for i in range(len(device_momentum_buffer_list)):
# pyrefly: ignore # index-error
if device_momentum_buffer_list[i] is None:
buf = device_momentum_buffer_list[i] = momentum_buffer_list[
indices[i]
] = device_grads[i].detach().clone()
else:
# pyrefly: ignore # index-error
buf = cast(Tensor, device_momentum_buffer_list[i])
buf.mul_(momentum).add_(device_grads[i], alpha=1 - dampening)

View File

@ -249,11 +249,13 @@ class AveragedModel(Module):
def update_parameters(self, model: Module):
"""Update model parameters."""
self_param = (
# pyrefly: ignore # bad-argument-type
itertools.chain(self.module.parameters(), self.module.buffers())
if self.use_buffers
else self.parameters()
)
model_param = (
# pyrefly: ignore # bad-argument-type
itertools.chain(model.parameters(), model.buffers())
if self.use_buffers
else model.parameters()
@ -300,8 +302,11 @@ class AveragedModel(Module):
for p_averaged, p_model in zip( # type: ignore[assignment]
self_param_detached, model_param_detached
):
# pyrefly: ignore # missing-attribute
n_averaged = self.n_averaged.to(p_averaged.device)
# pyrefly: ignore # missing-attribute
p_averaged.detach().copy_(
# pyrefly: ignore # missing-attribute, bad-argument-type
self.avg_fn(p_averaged.detach(), p_model, n_averaged)
)
@ -489,12 +494,14 @@ class SWALR(LRScheduler):
step = self._step_count - 1
if self.anneal_epochs == 0:
step = max(1, step)
# pyrefly: ignore # no-matching-overload
prev_t = max(0, min(1, (step - 1) / max(1, self.anneal_epochs)))
prev_alpha = self.anneal_func(prev_t)
prev_lrs = [
self._get_initial_lr(group["lr"], group["swa_lr"], prev_alpha)
for group in self.optimizer.param_groups
]
# pyrefly: ignore # no-matching-overload
t = max(0, min(1, step / max(1, self.anneal_epochs)))
alpha = self.anneal_func(t)
return [