mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Move tensor grouping to ATen (#100007)
rel: #94344 Pull Request resolved: https://github.com/pytorch/pytorch/pull/100007 Approved by: https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
7108c035bc
commit
74b7a6c75e
@ -194,7 +194,7 @@ class AveragedModel(Module):
|
||||
if self.n_averaged > 0:
|
||||
if self.multi_avg_fn is not None or self.avg_fn is None:
|
||||
grouped_tensors = _group_tensors_by_device_and_dtype([self_param_detached, model_param_detached])
|
||||
for ((device, _), [self_params, model_params]) in grouped_tensors.items():
|
||||
for ((device, _), ([self_params, model_params], _)) in grouped_tensors.items():
|
||||
if self.multi_avg_fn:
|
||||
self.multi_avg_fn(self_params, model_params, self.n_averaged.to(device))
|
||||
elif device.type == 'cuda':
|
||||
|
Reference in New Issue
Block a user