mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
add deepspeed exec and dataloader
This commit is contained in:
1
bin/deepspeed
Symbolic link
1
bin/deepspeed
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
ds
|
6
bin/ds
Executable file
6
bin/ds
Executable file
@ -0,0 +1,6 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
from deepspeed.pt.deepspeed_run import main
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
92
deepspeed/pt/deepspeed_dataloader.py
Executable file
92
deepspeed/pt/deepspeed_dataloader.py
Executable file
@ -0,0 +1,92 @@
|
|||||||
|
'''
|
||||||
|
Copyright 2019 The Microsoft DeepSpeed Team
|
||||||
|
'''
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import logging
|
||||||
|
from torch.utils.data import DataLoader, RandomSampler
|
||||||
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
class DeepSpeedDataSource(object):
|
||||||
|
def __init__(self, filenames):
|
||||||
|
all_lines = []
|
||||||
|
for filename in filenames:
|
||||||
|
logging.info("Start reading file %s" % filename)
|
||||||
|
with open(filename, "r") as f:
|
||||||
|
for i, line in enumerate(tqdm(f)):
|
||||||
|
all_lines.append(line.strip())
|
||||||
|
|
||||||
|
self.all_lines = all_lines
|
||||||
|
self.len = len(self.all_lines)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.len
|
||||||
|
|
||||||
|
|
||||||
|
class DeepSpeedDataLoader(object):
|
||||||
|
def __init__(self,
|
||||||
|
dataset,
|
||||||
|
batch_size,
|
||||||
|
pin_memory,
|
||||||
|
local_rank,
|
||||||
|
tput_timer,
|
||||||
|
collate_fn=None,
|
||||||
|
num_local_io_workers=None,
|
||||||
|
data_sampler=None):
|
||||||
|
self.tput_timer = tput_timer
|
||||||
|
self.batch_size = batch_size
|
||||||
|
|
||||||
|
if local_rank >= 0:
|
||||||
|
if data_sampler is None:
|
||||||
|
data_sampler = DistributedSampler(dataset)
|
||||||
|
device_count = 1
|
||||||
|
else:
|
||||||
|
if data_sampler is None:
|
||||||
|
data_sampler = RandomSampler(dataset)
|
||||||
|
device_count = torch.cuda.device_count()
|
||||||
|
batch_size *= device_count
|
||||||
|
|
||||||
|
if num_local_io_workers is None:
|
||||||
|
num_local_io_workers = 2 * device_count
|
||||||
|
|
||||||
|
self.num_local_io_workers = num_local_io_workers
|
||||||
|
self.data_sampler = data_sampler
|
||||||
|
self.dataset = dataset
|
||||||
|
self.collate_fn = collate_fn
|
||||||
|
self.device_count = device_count
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.pin_memory = pin_memory
|
||||||
|
self.len = len(self.data_sampler)
|
||||||
|
self.data = None
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
self._create_dataloader()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.len
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
if self.tput_timer:
|
||||||
|
self.tput_timer.start()
|
||||||
|
return next(self.data)
|
||||||
|
|
||||||
|
def _create_dataloader(self):
|
||||||
|
if self.collate_fn is None:
|
||||||
|
self.dataloader = DataLoader(self.dataset,
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
pin_memory=self.pin_memory,
|
||||||
|
sampler=self.data_sampler,
|
||||||
|
num_workers=self.num_local_io_workers)
|
||||||
|
else:
|
||||||
|
self.dataloader = DataLoader(self.dataset,
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
pin_memory=self.pin_memory,
|
||||||
|
sampler=self.data_sampler,
|
||||||
|
collate_fn=self.collate_fn,
|
||||||
|
num_workers=self.num_local_io_workers)
|
||||||
|
self.data = (x for x in self.dataloader)
|
||||||
|
|
||||||
|
return self.dataloader
|
Reference in New Issue
Block a user