[BE][PYFMT] migrate PYFMT for {torch,test}/{nn,optim}/** to ruff format (#144548)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144548
Approved by: https://github.com/ezyang
This commit is contained in:
Xuehai Pan
2025-06-14 00:48:12 +08:00
committed by PyTorch MergeBot
parent 3e38feb05f
commit 596b418391
65 changed files with 640 additions and 475 deletions

View File

@ -347,9 +347,9 @@ def _single_tensor_adafactor(
maximize: bool,
has_complex: bool,
):
assert (
grad_scale is None and found_inf is None
), "Grad scaling should occur outside of optimizer.step()"
assert grad_scale is None and found_inf is None, (
"Grad scaling should occur outside of optimizer.step()"
)
if torch.jit.is_scripting():
# this assert is due to JIT being dumb and not realizing that the ops below
@ -381,9 +381,9 @@ def _single_tensor_adafactor(
param.mul_(1 - lr * weight_decay)
if grad.dim() > 1:
assert (
row_var is not None and col_var is not None
), "row_var and col_var should be defined when grad is multidimensional"
assert row_var is not None and col_var is not None, (
"row_var and col_var should be defined when grad is multidimensional"
)
# same as (g * g).mean(dim=-1) w/o materializing an intermediate size g
row_mean = (
torch.norm(grad, dim=-1, keepdim=True).square_().div_(grad.size(-1))
@ -397,9 +397,9 @@ def _single_tensor_adafactor(
var_estimate = row_var @ col_var
var_estimate.div_(row_var.mean(dim=-2, keepdim=True).clamp_(min=eps1))
else:
assert (
variance is not None
), "variance should be defined when grad is a vector"
assert variance is not None, (
"variance should be defined when grad is a vector"
)
grad_squared = grad * grad
variance.lerp_(grad_squared, one_minus_beta2_t)
# avoid writing into variance during update
@ -472,9 +472,9 @@ def _multi_tensor_adafactor(
if len(params) == 0:
return
assert (
grad_scale is None and found_inf is None
), "Grad scaling should occur outside of optimizer.step()"
assert grad_scale is None and found_inf is None, (
"Grad scaling should occur outside of optimizer.step()"
)
lr = _to_scalar(lr)
@ -495,9 +495,9 @@ def _multi_tensor_adafactor(
device_grads = cast(list[Tensor], device_grads_)
device_state_steps = cast(list[Tensor], device_state_steps_)
if eps1 is None:
assert (
dtype is not None
), "dtype is needed to compute eps1 when eps1 is unset"
assert dtype is not None, (
"dtype is needed to compute eps1 when eps1 is unset"
)
eps1 = torch.finfo(dtype).eps
if TYPE_CHECKING:
@ -537,9 +537,9 @@ def _multi_tensor_adafactor(
if is_multidim:
device_row_vars = cast(list[Tensor], device_row_vars_)
device_col_vars = cast(list[Tensor], device_col_vars_)
assert (
device_row_vars[0] is not None and device_col_vars[0] is not None
), "row_var and col_var should be defined when grad is multidimensional"
assert device_row_vars[0] is not None and device_col_vars[0] is not None, (
"row_var and col_var should be defined when grad is multidimensional"
)
# same as (g * g).mean(dim=-1) w/o materializing an intermediate size g
row_means = [
torch.norm(grad, dim=-1, keepdim=True) for grad in device_grads
@ -570,9 +570,9 @@ def _multi_tensor_adafactor(
del row_var_means
else:
device_variances = cast(list[Tensor], device_variances_)
assert (
device_variances[0] is not None
), "variance should be defined when grad is a vector"
assert device_variances[0] is not None, (
"variance should be defined when grad is a vector"
)
grads_squared = torch._foreach_mul(device_grads, device_grads)
torch._foreach_lerp_(device_variances, grads_squared, one_minus_beta2_ts)

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
r"""Functional interface."""
import math
from torch import Tensor

View File

@ -5,6 +5,7 @@ Most commonly used methods are already supported, and the interface is general
enough, so that more sophisticated ones can be also easily integrated in the
future.
"""
from functools import partialmethod
from torch import optim

View File

@ -267,7 +267,9 @@ def _single_tensor_adadelta(
p.device.type == step.device.type
and p.device.type in capturable_supported_devices
for p, step in zip(params, state_steps)
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
), (
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
)
if not torch.jit.is_scripting():
lr = _to_scalar(lr)
@ -326,7 +328,9 @@ def _multi_tensor_adadelta(
p.device.type == step.device.type
and p.device.type in capturable_supported_devices
for p, step in zip(params, state_steps)
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
), (
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
)
if len(params) == 0:
return

View File

@ -398,7 +398,9 @@ def _single_tensor_adam(
assert (
param.device.type == step_t.device.type
and param.device.type in capturable_supported_devices
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
), (
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
)
# update step
step_t += 1
@ -433,7 +435,9 @@ def _single_tensor_adam(
# cast to workaround https://github.com/pytorch/pytorch/issues/140601
key = (device, dtype)
if key not in beta1_dict:
beta1_dict[key] = beta1.to(device=device, dtype=dtype, non_blocking=True) # type: ignore[union-attr]
beta1_dict[key] = beta1.to( # type: ignore[union-attr]
device=device, dtype=dtype, non_blocking=True
)
device_beta1: Union[float, Tensor] = beta1_dict[key]
else:
@ -593,7 +597,9 @@ def _multi_tensor_adam(
p.device.type == step.device.type
and p.device.type in capturable_supported_devices
for p, step in zip(params, state_steps)
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
), (
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
)
assert grad_scale is None and found_inf is None
@ -769,7 +775,10 @@ def _multi_tensor_adam(
torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt)
torch._foreach_add_(exp_avg_sq_sqrt, eps)
torch._foreach_addcdiv_(
device_params, device_exp_avgs, exp_avg_sq_sqrt, step_size # type: ignore[arg-type]
device_params,
device_exp_avgs,
exp_avg_sq_sqrt,
step_size, # type: ignore[arg-type]
)

View File

@ -256,7 +256,9 @@ def _single_tensor_adamax(
assert (
param.device.type == step_t.device.type
and param.device.type in capturable_supported_devices
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
), (
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
)
# update step
step_t += 1
@ -331,7 +333,9 @@ def _multi_tensor_adamax(
p.device.type == step.device.type
and p.device.type in capturable_supported_devices
for p, step in zip(params, state_steps)
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
), (
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
)
lr = _to_scalar(lr)

View File

@ -305,7 +305,9 @@ def _multi_tensor_asgd(
p.device.type == mu.device.type == eta.device.type == step.device.type
and p.device.type in capturable_supported_devices
for p, mu, eta, step in zip(params, mus, etas, state_steps)
), f"If capturable=True, params, mus, etas, and state_steps must be on supported devices: {capturable_supported_devices}."
), (
f"If capturable=True, params, mus, etas, and state_steps must be on supported devices: {capturable_supported_devices}."
)
lr = _to_scalar(lr)

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
r"""Learning Rate Scheduler."""
from __future__ import annotations
import math
@ -280,7 +281,7 @@ class LambdaLR(LRScheduler):
>>> # Assuming optimizer has two groups.
>>> num_epochs = 100
>>> lambda1 = lambda epoch: epoch // 30
>>> lambda2 = lambda epoch: 0.95 ** epoch
>>> lambda2 = lambda epoch: 0.95**epoch
>>> scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])
>>> for epoch in range(num_epochs):
>>> train(...)
@ -548,7 +549,7 @@ class MultiStepLR(LRScheduler):
>>> # lr = 0.05 if epoch < 30
>>> # lr = 0.005 if 30 <= epoch < 80
>>> # lr = 0.0005 if epoch >= 80
>>> scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)
>>> scheduler = MultiStepLR(optimizer, milestones=[30, 80], gamma=0.1)
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
@ -827,7 +828,11 @@ class SequentialLR(LRScheduler):
>>> # lr = 0.0405 if epoch == 22
>>> scheduler1 = ConstantLR(optimizer, factor=0.1, total_iters=20)
>>> scheduler2 = ExponentialLR(optimizer, gamma=0.9)
>>> scheduler = SequentialLR(optimizer, schedulers=[scheduler1, scheduler2], milestones=[20])
>>> scheduler = SequentialLR(
... optimizer,
... schedulers=[scheduler1, scheduler2],
... milestones=[20],
... )
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
@ -1271,11 +1276,11 @@ class ReduceLROnPlateau(LRScheduler):
Example:
>>> # xdoctest: +SKIP
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> scheduler = ReduceLROnPlateau(optimizer, 'min')
>>> scheduler = ReduceLROnPlateau(optimizer, "min")
>>> for epoch in range(10):
>>> train(...)
>>> val_loss = validate(...)
>>> # Note that step should be called after validate()
>>> # Note that step should be called after validate()
>>> scheduler.step(val_loss)
.. image:: ../scripts/lr_scheduler_images/ReduceLROnPlateau.png
@ -1502,7 +1507,12 @@ class CyclicLR(LRScheduler):
Example:
>>> # xdoctest: +SKIP
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.01, max_lr=0.1, step_size_up=10)
>>> scheduler = torch.optim.lr_scheduler.CyclicLR(
... optimizer,
... base_lr=0.01,
... max_lr=0.1,
... step_size_up=10,
... )
>>> data_loader = torch.utils.data.DataLoader(...)
>>> for epoch in range(10):
>>> for batch in data_loader:
@ -1729,7 +1739,9 @@ class CosineAnnealingWarmRestarts(LRScheduler):
Example:
>>> # xdoctest: +SKIP
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.05)
>>> scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=20)
>>> scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
... optimizer, T_0=20
... )
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
@ -1800,7 +1812,7 @@ class CosineAnnealingWarmRestarts(LRScheduler):
>>> for epoch in range(20):
>>> scheduler.step()
>>> scheduler.step(26)
>>> scheduler.step() # scheduler.step(27), instead of scheduler(20)
>>> scheduler.step() # scheduler.step(27), instead of scheduler(20)
"""
if epoch is None and self.last_epoch < 0:
epoch = 0
@ -1936,7 +1948,9 @@ class OneCycleLR(LRScheduler):
>>> # xdoctest: +SKIP
>>> data_loader = torch.utils.data.DataLoader(...)
>>> optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
>>> scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10)
>>> scheduler = torch.optim.lr_scheduler.OneCycleLR(
... optimizer, max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10
... )
>>> for epoch in range(10):
>>> for batch in data_loader:
>>> train_batch(...)
@ -2141,8 +2155,6 @@ class OneCycleLR(LRScheduler):
if self.use_beta1:
group["betas"] = (computed_momentum, *group["betas"][1:]) # type: ignore[possibly-undefined]
else:
group[
"momentum"
] = computed_momentum # type: ignore[possibly-undefined]
group["momentum"] = computed_momentum # type: ignore[possibly-undefined]
return lrs

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
r"""Implementation for the NAdam algorithm."""
from typing import cast, Optional, Union
import torch
@ -408,7 +409,11 @@ def _multi_tensor_nadam(
p.device.type == mp.device.type == step.device.type
and p.device.type in capturable_supported_devices
for p, mp, step in zip(params, mu_products, state_steps)
), f"If capturable=True, params, mu_products, and state_steps must be on supported devices: {capturable_supported_devices}."
), (
"If capturable=True, "
"params, mu_products, and state_steps must be on supported devices: "
f"{capturable_supported_devices}."
)
lr = _to_scalar(lr)
@ -576,10 +581,16 @@ def _multi_tensor_nadam(
)
torch._foreach_addcdiv_(
grouped_params, grouped_grads, exp_avg_sq_sqrt, step_size_grads # type: ignore[arg-type]
grouped_params,
grouped_grads,
exp_avg_sq_sqrt,
step_size_grads, # type: ignore[arg-type]
)
torch._foreach_addcdiv_(
grouped_params, grouped_exp_avgs, exp_avg_sq_sqrt, step_size_expavg # type: ignore[arg-type]
grouped_params,
grouped_exp_avgs,
exp_avg_sq_sqrt,
step_size_expavg, # type: ignore[arg-type]
)

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
"""Base optimizer."""
import functools
import warnings
from collections import defaultdict, OrderedDict
@ -103,7 +104,7 @@ def _stack_if_compiling(x):
def _disable_dynamo_if_unsupported(
single_tensor_fn: Optional[Callable[..., object]] = None
single_tensor_fn: Optional[Callable[..., object]] = None,
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
# workaround for torchscript BC
# it requires all called functions to be in the
@ -349,15 +350,24 @@ class Optimizer:
options (used when a parameter group doesn't specify them).
"""
OptimizerPreHook: TypeAlias = Callable[[Self, Args, Kwargs], Optional[tuple[Args, Kwargs]]] # type: ignore[misc]
OptimizerPreHook: TypeAlias = Callable[
[Self, Args, Kwargs], # type: ignore[misc]
Optional[tuple[Args, Kwargs]],
]
OptimizerPostHook: TypeAlias = Callable[[Self, Args, Kwargs], None] # type: ignore[misc]
_optimizer_step_pre_hooks: dict[int, OptimizerPreHook]
_optimizer_step_post_hooks: dict[int, OptimizerPostHook]
_optimizer_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]'
_optimizer_state_dict_post_hooks: 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]'
_optimizer_load_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]'
_optimizer_load_state_dict_post_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]'
_optimizer_state_dict_post_hooks: (
'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]'
)
_optimizer_load_state_dict_pre_hooks: (
'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]'
)
_optimizer_load_state_dict_post_hooks: (
'OrderedDict[int, Callable[["Optimizer"], None]]'
)
def __init__(self, params: ParamsT, defaults: dict[str, Any]) -> None: # noqa: D107
torch._C._log_api_usage_once("python.optimizer")
@ -847,7 +857,9 @@ class Optimizer:
handle = hooks.RemovableHandle(self._optimizer_load_state_dict_post_hooks)
self._optimizer_load_state_dict_post_hooks[handle.id] = hook
if prepend:
self._optimizer_load_state_dict_post_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined]
self._optimizer_load_state_dict_post_hooks.move_to_end(
handle.id, last=False
) # type: ignore[attr-defined]
return handle
@torch._disable_dynamo
@ -877,12 +889,25 @@ class Optimizer:
>>> # xdoctest: +SKIP
>>> model = torch.nn.Linear(10, 10)
>>> optim = torch.optim.SGD(model.parameters(), lr=3e-4)
>>> scheduler1 = torch.optim.lr_scheduler.LinearLR(optim, start_factor=0.1, end_factor=1, total_iters=20)
>>> scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=80, eta_min=3e-5)
>>> lr = torch.optim.lr_scheduler.SequentialLR(optim, schedulers=[scheduler1, scheduler2], milestones=[20])
>>> lr.load_state_dict(torch.load('./save_seq.pt'))
>>> scheduler1 = torch.optim.lr_scheduler.LinearLR(
... optim,
... start_factor=0.1,
... end_factor=1,
... total_iters=20,
... )
>>> scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(
... optim,
... T_max=80,
... eta_min=3e-5,
... )
>>> lr = torch.optim.lr_scheduler.SequentialLR(
... optim,
... schedulers=[scheduler1, scheduler2],
... milestones=[20],
... )
>>> lr.load_state_dict(torch.load("./save_seq.pt"))
>>> # now load the optimizer checkpoint after loading the LRScheduler
>>> optim.load_state_dict(torch.load('./save_optim.pt'))
>>> optim.load_state_dict(torch.load("./save_optim.pt"))
"""
# shallow copy, to be consistent with module API
@ -933,7 +958,10 @@ class Optimizer:
for k, v in value.items()
}
elif isinstance(value, Iterable):
return type(value)(_cast(param, v, param_id=param_id, param_groups=param_groups) for v in value) # type: ignore[call-arg]
return type(value)(
_cast(param, v, param_id=param_id, param_groups=param_groups)
for v in value
) # type: ignore[call-arg]
else:
return value
@ -1021,12 +1049,10 @@ class Optimizer:
torch._foreach_zero_(grads)
@overload
def step(self, closure: None = None) -> None:
...
def step(self, closure: None = None) -> None: ...
@overload
def step(self, closure: Callable[[], float]) -> float:
...
def step(self, closure: Callable[[], float]) -> float: ...
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
r"""Perform a single optimization step to update parameter.

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
r"""Implementation for the RAdam algorithm."""
from typing import cast, Optional, Union
import torch
@ -285,7 +286,9 @@ def _single_tensor_radam(
assert (
param.device.type == step_t.device.type
and param.device.type in capturable_supported_devices
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
), (
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
)
if torch.is_complex(param):
param = torch.view_as_real(param)
@ -386,7 +389,9 @@ def _multi_tensor_radam(
p.device.type == step.device.type
and p.device.type in capturable_supported_devices
for p, step in zip(params, state_steps)
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
), (
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
)
lr = _to_scalar(lr)

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
r"""Implementation for the RMSprop algorithm."""
from typing import cast, Optional, Union
import torch
@ -292,7 +293,9 @@ def _single_tensor_rmsprop(
assert (
param.device.type == step.device.type
and param.device.type in capturable_supported_devices
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
), (
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
)
grad = grads[i]
grad = grad if not maximize else -grad
@ -366,7 +369,9 @@ def _multi_tensor_rmsprop(
p.device.type == step.device.type
and p.device.type in capturable_supported_devices
for p, step in zip(params, state_steps)
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
), (
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
)
lr = _to_scalar(lr)

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
r"""Implementation for the Resilient backpropagation."""
from typing import cast, Optional, Union
import torch
@ -248,7 +249,9 @@ def _single_tensor_rprop(
assert (
param.device.type == step.device.type
and param.device.type in capturable_supported_devices
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
), (
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
)
step += 1
@ -315,7 +318,9 @@ def _multi_tensor_rprop(
p.device.type == step.device.type
and p.device.type in capturable_supported_devices
for p, step in zip(params, state_steps)
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
), (
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
)
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
[params, grads, prevs, step_sizes, state_steps] # type: ignore[list-item]

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
r"""Implementation for Stochastic Gradient Descent optimizer."""
from typing import cast, Optional, Union
import torch
@ -397,7 +398,8 @@ def _multi_tensor_sgd(
lr = _to_scalar(lr)
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
[params, grads, momentum_buffer_list], with_indices=True # type: ignore[list-item]
[params, grads, momentum_buffer_list], # type: ignore[list-item]
with_indices=True,
)
for (
device_params_,
@ -502,7 +504,8 @@ def _fused_sgd(
for i, g in enumerate(grads):
momentum_buffer_list[i] = torch.empty_like(g)
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
[params, grads, momentum_buffer_list], with_indices=False # type: ignore[list-item]
[params, grads, momentum_buffer_list], # type: ignore[list-item]
with_indices=False,
)
for (device, _), (
(device_params_, device_grads_, device_momentum_buffer_list),

View File

@ -37,9 +37,9 @@ class SparseAdam(Optimizer):
sparse_params = []
complex_params = []
for index, param_group in enumerate(self.param_groups):
assert isinstance(
param_group, dict
), f"param_groups must be a list of dicts, but got {type(param_group)}"
assert isinstance(param_group, dict), (
f"param_groups must be a list of dicts, but got {type(param_group)}"
)
# given param group, convert given params to a list first before iterating
for d_index, d_param in enumerate(param_group["params"]):
if d_param.is_sparse:

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
r"""Implementation for Stochastic Weight Averaging implementation."""
import itertools
import math
import warnings
@ -225,9 +226,9 @@ class AveragedModel(Module):
use_buffers=False,
): # noqa: D107
super().__init__()
assert (
avg_fn is None or multi_avg_fn is None
), "Only one of avg_fn and multi_avg_fn should be provided"
assert avg_fn is None or multi_avg_fn is None, (
"Only one of avg_fn and multi_avg_fn should be provided"
)
self.module = deepcopy(model)
if device is not None:
self.module = self.module.to(device)
@ -274,7 +275,9 @@ class AveragedModel(Module):
) in grouped_tensors.items():
if self.multi_avg_fn:
self.multi_avg_fn(
self_params, model_params, self.n_averaged.to(device) # type: ignore[arg-type]
self_params, # type: ignore[arg-type]
model_params, # type: ignore[arg-type]
self.n_averaged.to(device),
)
elif (
device is not None