mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
add deepspeed exec and dataloader
This commit is contained in:
1
bin/deepspeed
Symbolic link
1
bin/deepspeed
Symbolic link
@ -0,0 +1 @@
|
||||
ds
|
6
bin/ds
Executable file
6
bin/ds
Executable file
@ -0,0 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from deepspeed.pt.deepspeed_run import main
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
92
deepspeed/pt/deepspeed_dataloader.py
Executable file
92
deepspeed/pt/deepspeed_dataloader.py
Executable file
@ -0,0 +1,92 @@
|
||||
'''
|
||||
Copyright 2019 The Microsoft DeepSpeed Team
|
||||
'''
|
||||
|
||||
import torch
|
||||
import logging
|
||||
from torch.utils.data import DataLoader, RandomSampler
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class DeepSpeedDataSource(object):
|
||||
def __init__(self, filenames):
|
||||
all_lines = []
|
||||
for filename in filenames:
|
||||
logging.info("Start reading file %s" % filename)
|
||||
with open(filename, "r") as f:
|
||||
for i, line in enumerate(tqdm(f)):
|
||||
all_lines.append(line.strip())
|
||||
|
||||
self.all_lines = all_lines
|
||||
self.len = len(self.all_lines)
|
||||
|
||||
def __len__(self):
|
||||
return self.len
|
||||
|
||||
|
||||
class DeepSpeedDataLoader(object):
|
||||
def __init__(self,
|
||||
dataset,
|
||||
batch_size,
|
||||
pin_memory,
|
||||
local_rank,
|
||||
tput_timer,
|
||||
collate_fn=None,
|
||||
num_local_io_workers=None,
|
||||
data_sampler=None):
|
||||
self.tput_timer = tput_timer
|
||||
self.batch_size = batch_size
|
||||
|
||||
if local_rank >= 0:
|
||||
if data_sampler is None:
|
||||
data_sampler = DistributedSampler(dataset)
|
||||
device_count = 1
|
||||
else:
|
||||
if data_sampler is None:
|
||||
data_sampler = RandomSampler(dataset)
|
||||
device_count = torch.cuda.device_count()
|
||||
batch_size *= device_count
|
||||
|
||||
if num_local_io_workers is None:
|
||||
num_local_io_workers = 2 * device_count
|
||||
|
||||
self.num_local_io_workers = num_local_io_workers
|
||||
self.data_sampler = data_sampler
|
||||
self.dataset = dataset
|
||||
self.collate_fn = collate_fn
|
||||
self.device_count = device_count
|
||||
self.batch_size = batch_size
|
||||
self.pin_memory = pin_memory
|
||||
self.len = len(self.data_sampler)
|
||||
self.data = None
|
||||
|
||||
def __iter__(self):
|
||||
self._create_dataloader()
|
||||
return self
|
||||
|
||||
def __len__(self):
|
||||
return self.len
|
||||
|
||||
def __next__(self):
|
||||
if self.tput_timer:
|
||||
self.tput_timer.start()
|
||||
return next(self.data)
|
||||
|
||||
def _create_dataloader(self):
|
||||
if self.collate_fn is None:
|
||||
self.dataloader = DataLoader(self.dataset,
|
||||
batch_size=self.batch_size,
|
||||
pin_memory=self.pin_memory,
|
||||
sampler=self.data_sampler,
|
||||
num_workers=self.num_local_io_workers)
|
||||
else:
|
||||
self.dataloader = DataLoader(self.dataset,
|
||||
batch_size=self.batch_size,
|
||||
pin_memory=self.pin_memory,
|
||||
sampler=self.data_sampler,
|
||||
collate_fn=self.collate_fn,
|
||||
num_workers=self.num_local_io_workers)
|
||||
self.data = (x for x in self.dataloader)
|
||||
|
||||
return self.dataloader
|
Reference in New Issue
Block a user