mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE][PYFMT] migrate PYFMT for {torch,test}/{nn,optim}/**
to ruff format
(#144548)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144548 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
3e38feb05f
commit
596b418391
@ -1,5 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
r"""Implementation for Stochastic Weight Averaging implementation."""
|
||||
|
||||
import itertools
|
||||
import math
|
||||
import warnings
|
||||
@ -225,9 +226,9 @@ class AveragedModel(Module):
|
||||
use_buffers=False,
|
||||
): # noqa: D107
|
||||
super().__init__()
|
||||
assert (
|
||||
avg_fn is None or multi_avg_fn is None
|
||||
), "Only one of avg_fn and multi_avg_fn should be provided"
|
||||
assert avg_fn is None or multi_avg_fn is None, (
|
||||
"Only one of avg_fn and multi_avg_fn should be provided"
|
||||
)
|
||||
self.module = deepcopy(model)
|
||||
if device is not None:
|
||||
self.module = self.module.to(device)
|
||||
@ -274,7 +275,9 @@ class AveragedModel(Module):
|
||||
) in grouped_tensors.items():
|
||||
if self.multi_avg_fn:
|
||||
self.multi_avg_fn(
|
||||
self_params, model_params, self.n_averaged.to(device) # type: ignore[arg-type]
|
||||
self_params, # type: ignore[arg-type]
|
||||
model_params, # type: ignore[arg-type]
|
||||
self.n_averaged.to(device),
|
||||
)
|
||||
elif (
|
||||
device is not None
|
||||
|
Reference in New Issue
Block a user