mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/60784 This pr refactors the ps benchmark for modular trainers. Test Plan: Imported from OSS Reviewed By: zou3519 Differential Revision: D29697291 Pulled By: gcramer23 fbshipit-source-id: 64579a1f5326d3cd9f32936dcf53bc243d54b71d
13 lines
346 B
Python
13 lines
346 B
Python
def preprocess_dummy_data(rank, data):
|
|
r"""
|
|
A function that moves the data from CPU to GPU
|
|
for DummyData class.
|
|
Args:
|
|
rank (int): worker rank
|
|
data (list): training examples
|
|
"""
|
|
for i in range(len(data)):
|
|
data[i][0] = data[i][0].cuda(rank)
|
|
data[i][1] = data[i][1].cuda(rank)
|
|
return data
|