Files
pytorch/benchmarks/distributed/rpc/parameter_server/trainer/preprocess_data.py
Garrett Cramer 304c02ee44 refactor ps benchmark (#60784)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60784

This pr refactors the ps benchmark for modular trainers.

Test Plan: Imported from OSS

Reviewed By: zou3519

Differential Revision: D29697291

Pulled By: gcramer23

fbshipit-source-id: 64579a1f5326d3cd9f32936dcf53bc243d54b71d
2021-07-14 13:19:13 -07:00

13 lines
346 B
Python

def preprocess_dummy_data(rank, data):
r"""
A function that moves the data from CPU to GPU
for DummyData class.
Args:
rank (int): worker rank
data (list): training examples
"""
for i in range(len(data)):
data[i][0] = data[i][0].cuda(rank)
data[i][1] = data[i][1].cuda(rank)
return data