Move tensor grouping to ATen (#100007)

rel: #94344
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100007
Approved by: https://github.com/janeyx99
This commit is contained in:
Masaki Kozuki
2023-06-09 15:44:46 +00:00
committed by PyTorch MergeBot
parent 7108c035bc
commit 74b7a6c75e
20 changed files with 271 additions and 57 deletions

View File

@ -194,7 +194,7 @@ class AveragedModel(Module):
if self.n_averaged > 0:
if self.multi_avg_fn is not None or self.avg_fn is None:
grouped_tensors = _group_tensors_by_device_and_dtype([self_param_detached, model_param_detached])
for ((device, _), [self_params, model_params]) in grouped_tensors.items():
for ((device, _), ([self_params, model_params], _)) in grouped_tensors.items():
if self.multi_avg_fn:
self.multi_avg_fn(self_params, model_params, self.n_averaged.to(device))
elif device.type == 'cuda':