mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Pyrefly suppressions 4/n (#164615)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283 Test plan: dmypy restart && python3 scripts/lintrunner.py -a pyrefly check step 1: uncomment lines in the pyrefly.toml file step 2: run pyrefly check step 3: add suppressions, clean up unused suppressions before: https://gist.github.com/maggiemoss/356645cf8cfe33123d9a27f23b30f7b1 after: 0 errors (2,753 ignored) Pull Request resolved: https://github.com/pytorch/pytorch/pull/164615 Approved by: https://github.com/oulgen
This commit is contained in:
committed by
PyTorch MergeBot
parent
4bd1505f84
commit
4ab847bbc7
@ -78,6 +78,7 @@ def _adjust_lr(
|
||||
A, B = param_shape[:2]
|
||||
|
||||
if adjust_lr_fn is None or adjust_lr_fn == "original":
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
adjusted_ratio = math.sqrt(max(1, A / B))
|
||||
elif adjust_lr_fn == "match_rms_adamw":
|
||||
adjusted_ratio = 0.2 * math.sqrt(max(A, B))
|
||||
|
@ -415,6 +415,7 @@ def _single_tensor_adam(
|
||||
if weight_decay.requires_grad:
|
||||
grad = grad.addcmul_(param.clone(), weight_decay)
|
||||
else:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
grad = grad.add(param, alpha=weight_decay)
|
||||
else:
|
||||
grad = grad.add(param, alpha=weight_decay)
|
||||
@ -444,6 +445,7 @@ def _single_tensor_adam(
|
||||
device_beta1 = beta1
|
||||
|
||||
# Decay the first and second moment running average coefficient
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
exp_avg.lerp_(grad, 1 - device_beta1)
|
||||
|
||||
# Nested if is necessary to bypass jitscript rules
|
||||
@ -692,6 +694,7 @@ def _multi_tensor_adam(
|
||||
device_exp_avgs, device_grads, cast(float, 1 - device_beta1)
|
||||
)
|
||||
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
torch._foreach_mul_(device_exp_avg_sqs, beta2)
|
||||
|
||||
# Due to the strictness of the _foreach_addcmul API, we can't have a single
|
||||
|
@ -263,6 +263,7 @@ def _single_tensor_asgd(
|
||||
ax.copy_(param)
|
||||
|
||||
if capturable:
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
eta.copy_(lr / ((1 + lambd * lr * step_t) ** alpha))
|
||||
mu.copy_(1 / torch.maximum(step_t - t0, torch.ones_like(step_t)))
|
||||
else:
|
||||
|
@ -113,9 +113,11 @@ def _strong_wolfe(
|
||||
|
||||
# compute new trial value
|
||||
t = _cubic_interpolate(
|
||||
# pyrefly: ignore # index-error
|
||||
bracket[0],
|
||||
bracket_f[0],
|
||||
bracket_gtd[0], # type: ignore[possibly-undefined]
|
||||
# pyrefly: ignore # index-error
|
||||
bracket[1],
|
||||
bracket_f[1],
|
||||
bracket_gtd[1],
|
||||
@ -151,6 +153,7 @@ def _strong_wolfe(
|
||||
|
||||
if f_new > (f + c1 * t * gtd) or f_new >= bracket_f[low_pos]:
|
||||
# Armijo condition not satisfied or not lower than lowest point
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
bracket[high_pos] = t
|
||||
bracket_f[high_pos] = f_new
|
||||
bracket_g[high_pos] = g_new.clone(memory_format=torch.contiguous_format) # type: ignore[possibly-undefined]
|
||||
@ -160,14 +163,17 @@ def _strong_wolfe(
|
||||
if abs(gtd_new) <= -c2 * gtd:
|
||||
# Wolfe conditions satisfied
|
||||
done = True
|
||||
# pyrefly: ignore # index-error
|
||||
elif gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0:
|
||||
# old high becomes new low
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
bracket[high_pos] = bracket[low_pos]
|
||||
bracket_f[high_pos] = bracket_f[low_pos]
|
||||
bracket_g[high_pos] = bracket_g[low_pos] # type: ignore[possibly-undefined]
|
||||
bracket_gtd[high_pos] = bracket_gtd[low_pos]
|
||||
|
||||
# new point becomes new low
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
bracket[low_pos] = t
|
||||
bracket_f[low_pos] = f_new
|
||||
bracket_g[low_pos] = g_new.clone(memory_format=torch.contiguous_format) # type: ignore[possibly-undefined]
|
||||
@ -252,6 +258,7 @@ class LBFGS(Optimizer):
|
||||
|
||||
def _numel(self):
|
||||
if self._numel_cache is None:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self._numel_cache = sum(
|
||||
2 * p.numel() if torch.is_complex(p) else p.numel()
|
||||
for p in self._params
|
||||
|
@ -1665,6 +1665,7 @@ class ReduceLROnPlateau(LRScheduler):
|
||||
self.default_min_lr = None
|
||||
self.min_lrs = list(min_lr)
|
||||
else:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self.default_min_lr = min_lr
|
||||
self.min_lrs = [min_lr] * len(optimizer.param_groups)
|
||||
|
||||
@ -1724,6 +1725,7 @@ class ReduceLROnPlateau(LRScheduler):
|
||||
"of the `optimizer` param groups."
|
||||
)
|
||||
else:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self.min_lrs = [self.default_min_lr] * len(self.optimizer.param_groups)
|
||||
|
||||
for i, param_group in enumerate(self.optimizer.param_groups):
|
||||
@ -1903,10 +1905,13 @@ class CyclicLR(LRScheduler):
|
||||
|
||||
self.max_lrs = _format_param("max_lr", optimizer, max_lr)
|
||||
|
||||
# pyrefly: ignore # bad-assignment
|
||||
step_size_up = float(step_size_up)
|
||||
step_size_down = (
|
||||
# pyrefly: ignore # bad-assignment
|
||||
float(step_size_down) if step_size_down is not None else step_size_up
|
||||
)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
self.total_size = step_size_up + step_size_down
|
||||
self.step_ratio = step_size_up / self.total_size
|
||||
|
||||
|
@ -62,6 +62,7 @@ def _use_grad_for_differentiable(func: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||
def _use_grad(*args: _P.args, **kwargs: _P.kwargs) -> _T:
|
||||
import torch._dynamo
|
||||
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
self = cast(Optimizer, args[0]) # assume first positional arg is `self`
|
||||
prev_grad = torch.is_grad_enabled()
|
||||
try:
|
||||
@ -135,11 +136,13 @@ def _disable_dynamo_if_unsupported(
|
||||
if torch.compiler.is_compiling() and (
|
||||
not kwargs.get("capturable", False)
|
||||
and has_state_steps
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
and (arg := args[state_steps_ind])
|
||||
and isinstance(arg, Sequence)
|
||||
and arg[0].is_cuda
|
||||
or (
|
||||
"state_steps" in kwargs
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
and (kwarg := kwargs["state_steps"])
|
||||
and isinstance(kwarg, Sequence)
|
||||
and kwarg[0].is_cuda
|
||||
@ -359,14 +362,18 @@ class Optimizer:
|
||||
|
||||
_optimizer_step_pre_hooks: dict[int, OptimizerPreHook]
|
||||
_optimizer_step_post_hooks: dict[int, OptimizerPostHook]
|
||||
# pyrefly: ignore # not-a-type
|
||||
_optimizer_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]'
|
||||
_optimizer_state_dict_post_hooks: (
|
||||
# pyrefly: ignore # not-a-type
|
||||
'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]'
|
||||
)
|
||||
_optimizer_load_state_dict_pre_hooks: (
|
||||
# pyrefly: ignore # not-a-type
|
||||
'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]'
|
||||
)
|
||||
_optimizer_load_state_dict_post_hooks: (
|
||||
# pyrefly: ignore # not-a-type
|
||||
'OrderedDict[int, Callable[["Optimizer"], None]]'
|
||||
)
|
||||
|
||||
@ -391,6 +398,7 @@ class Optimizer:
|
||||
self.state: defaultdict[torch.Tensor, Any] = defaultdict(dict)
|
||||
self.param_groups: list[dict[str, Any]] = []
|
||||
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
param_groups = list(params)
|
||||
if len(param_groups) == 0:
|
||||
raise ValueError("optimizer got an empty parameter list")
|
||||
@ -514,6 +522,7 @@ class Optimizer:
|
||||
f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}."
|
||||
)
|
||||
|
||||
# pyrefly: ignore # invalid-param-spec
|
||||
out = func(*args, **kwargs)
|
||||
self._optimizer_step_code()
|
||||
|
||||
@ -949,7 +958,14 @@ class Optimizer:
|
||||
r"""Make a deep copy of value, casting all tensors to device of param."""
|
||||
if isinstance(value, torch.Tensor):
|
||||
return Optimizer._process_value_according_to_param_policy(
|
||||
param, value, param_id, param_groups, key
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
param,
|
||||
value,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
param_id,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
param_groups,
|
||||
key,
|
||||
)
|
||||
elif isinstance(value, dict):
|
||||
return {
|
||||
@ -960,6 +976,7 @@ class Optimizer:
|
||||
}
|
||||
elif isinstance(value, Iterable):
|
||||
return type(value)(
|
||||
# pyrefly: ignore # bad-argument-count
|
||||
_cast(param, v, param_id=param_id, param_groups=param_groups)
|
||||
for v in value
|
||||
) # type: ignore[call-arg]
|
||||
|
@ -322,6 +322,7 @@ def _single_tensor_radam(
|
||||
rho_t = rho_inf - 2 * step * (beta2**step) / bias_correction2
|
||||
|
||||
def _compute_rect():
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
return (
|
||||
(rho_t - 4)
|
||||
* (rho_t - 2)
|
||||
@ -336,6 +337,7 @@ def _single_tensor_radam(
|
||||
else:
|
||||
exp_avg_sq_sqrt = exp_avg_sq_sqrt.add_(eps)
|
||||
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
return (bias_correction2**0.5) / exp_avg_sq_sqrt
|
||||
|
||||
# Compute the variance rectification term and update parameters accordingly
|
||||
|
@ -337,6 +337,7 @@ def _single_tensor_sgd(
|
||||
if not torch.jit.is_scripting():
|
||||
lr = _to_scalar(lr)
|
||||
|
||||
# pyrefly: ignore # bad-assignment
|
||||
for i, param in enumerate(params):
|
||||
grad = grads[i] if not maximize else -grads[i]
|
||||
|
||||
@ -347,6 +348,7 @@ def _single_tensor_sgd(
|
||||
# usually this is the differentiable path, which is why the param.clone() is needed
|
||||
grad = grad.addcmul_(param.clone(), weight_decay)
|
||||
else:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
grad = grad.add(param, alpha=weight_decay)
|
||||
else:
|
||||
grad = grad.add(param, alpha=weight_decay)
|
||||
@ -370,6 +372,7 @@ def _single_tensor_sgd(
|
||||
if lr.requires_grad:
|
||||
param.addcmul_(grad, lr, value=-1)
|
||||
else:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
param.add_(grad, alpha=-lr)
|
||||
else:
|
||||
param.add_(grad, alpha=-lr)
|
||||
@ -430,10 +433,12 @@ def _multi_tensor_sgd(
|
||||
|
||||
all_states_with_momentum_buffer = True
|
||||
for i in range(len(device_momentum_buffer_list)):
|
||||
# pyrefly: ignore # index-error
|
||||
if device_momentum_buffer_list[i] is None:
|
||||
all_states_with_momentum_buffer = False
|
||||
break
|
||||
else:
|
||||
# pyrefly: ignore # index-error
|
||||
bufs.append(cast(Tensor, device_momentum_buffer_list[i]))
|
||||
|
||||
if all_states_with_momentum_buffer:
|
||||
@ -441,12 +446,15 @@ def _multi_tensor_sgd(
|
||||
torch._foreach_add_(bufs, device_grads, alpha=1 - dampening)
|
||||
else:
|
||||
bufs = []
|
||||
# pyrefly: ignore # bad-assignment
|
||||
for i in range(len(device_momentum_buffer_list)):
|
||||
# pyrefly: ignore # index-error
|
||||
if device_momentum_buffer_list[i] is None:
|
||||
buf = device_momentum_buffer_list[i] = momentum_buffer_list[
|
||||
indices[i]
|
||||
] = device_grads[i].detach().clone()
|
||||
else:
|
||||
# pyrefly: ignore # index-error
|
||||
buf = cast(Tensor, device_momentum_buffer_list[i])
|
||||
buf.mul_(momentum).add_(device_grads[i], alpha=1 - dampening)
|
||||
|
||||
|
@ -249,11 +249,13 @@ class AveragedModel(Module):
|
||||
def update_parameters(self, model: Module):
|
||||
"""Update model parameters."""
|
||||
self_param = (
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
itertools.chain(self.module.parameters(), self.module.buffers())
|
||||
if self.use_buffers
|
||||
else self.parameters()
|
||||
)
|
||||
model_param = (
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
itertools.chain(model.parameters(), model.buffers())
|
||||
if self.use_buffers
|
||||
else model.parameters()
|
||||
@ -300,8 +302,11 @@ class AveragedModel(Module):
|
||||
for p_averaged, p_model in zip( # type: ignore[assignment]
|
||||
self_param_detached, model_param_detached
|
||||
):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
n_averaged = self.n_averaged.to(p_averaged.device)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
p_averaged.detach().copy_(
|
||||
# pyrefly: ignore # missing-attribute, bad-argument-type
|
||||
self.avg_fn(p_averaged.detach(), p_model, n_averaged)
|
||||
)
|
||||
|
||||
@ -489,12 +494,14 @@ class SWALR(LRScheduler):
|
||||
step = self._step_count - 1
|
||||
if self.anneal_epochs == 0:
|
||||
step = max(1, step)
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
prev_t = max(0, min(1, (step - 1) / max(1, self.anneal_epochs)))
|
||||
prev_alpha = self.anneal_func(prev_t)
|
||||
prev_lrs = [
|
||||
self._get_initial_lr(group["lr"], group["swa_lr"], prev_alpha)
|
||||
for group in self.optimizer.param_groups
|
||||
]
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
t = max(0, min(1, step / max(1, self.anneal_epochs)))
|
||||
alpha = self.anneal_func(t)
|
||||
return [
|
||||
|
Reference in New Issue
Block a user