mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: This is basically a reborn version of https://github.com/pytorch/pytorch/issues/45254 . Ref: https://github.com/pytorch/pytorch/issues/42919 Pull Request resolved: https://github.com/pytorch/pytorch/pull/48077 Reviewed By: ngimel Differential Revision: D25687042 Pulled By: bugra fbshipit-source-id: 05f20a6f3c5212f73d0b1505b493b720e6cf74e5
295 lines
7.5 KiB
Python
295 lines
7.5 KiB
Python
# @package parallel_workers
|
|
# Module caffe2.python.parallel_workers
|
|
|
|
|
|
|
|
|
|
|
|
|
|
'''
|
|
This module provides a python-land multithreaded mechanism for executing work.
|
|
|
|
Basic usage is as follows:
|
|
coordinator = parallel_workers.init_workers(
|
|
my_worker_fun,
|
|
worker_name="train"
|
|
)
|
|
...
|
|
coordinator.start()
|
|
|
|
First argument is the function to run in a loop on potentially multiple threads.
|
|
It has the call signature
|
|
worker_fun(worker_id)
|
|
|
|
Argument 'worker_name' is used to distinguish different workers,
|
|
such as workers processing train data or workers processing test data.
|
|
|
|
Optionally, one can define an "init function" that is called once before
|
|
threads start, and has call signature:
|
|
my_init_fun(worker_coordinator, global_coordinator)
|
|
|
|
Note that for data_parallel_models, init_workers will be called
|
|
for each GPU. Note that the 'coordinator' returned by the function is same
|
|
each time.
|
|
'''
|
|
|
|
import logging
|
|
import threading
|
|
import atexit
|
|
import time
|
|
import collections
|
|
import traceback
|
|
|
|
from abc import ABCMeta, abstractmethod
|
|
|
|
log = logging.getLogger("parallel_workers")
|
|
log.setLevel(logging.INFO)
|
|
LOG_INT_SECS = 60
|
|
|
|
|
|
def init_workers(
|
|
worker_fun,
|
|
num_worker_threads=2,
|
|
worker_name="train",
|
|
init_fun=None,
|
|
external_loggers=None,
|
|
shutdown_fun=None,
|
|
):
|
|
global global_coordinator
|
|
|
|
metrics = Metrics(external_loggers)
|
|
|
|
worker_ids = [
|
|
global_coordinator.get_new_worker_id()
|
|
for i in range(num_worker_threads)
|
|
]
|
|
|
|
# Create coordinator object
|
|
coordinator = WorkerCoordinator(
|
|
worker_name, worker_ids, init_fun, shutdown_fun=shutdown_fun)
|
|
|
|
# Launch fetch worker threads
|
|
workers = [
|
|
threading.Thread(
|
|
target=run_worker,
|
|
name="parallel_workers worker id {}".format(worker_id),
|
|
args=[coordinator,
|
|
Worker(coordinator, worker_id, worker_fun, metrics)],
|
|
) for worker_id in worker_ids
|
|
]
|
|
|
|
coordinator._workers = workers
|
|
global_coordinator.add(coordinator)
|
|
|
|
return global_coordinator
|
|
|
|
|
|
class Metrics(object):
|
|
def __init__(self, external_loggers):
|
|
self._metrics = collections.defaultdict(lambda: 0)
|
|
self._external_loggers = external_loggers
|
|
|
|
def reset_metrics(self):
|
|
self._metrics = collections.defaultdict(lambda: 0)
|
|
|
|
def log_metrics(self):
|
|
if not self._external_loggers:
|
|
return
|
|
for logger in self._external_loggers:
|
|
try:
|
|
logger.log(self._metrics)
|
|
except Exception as e:
|
|
print("Failed to call ExternalLogger: {}".format(e))
|
|
|
|
def put_metric(self, key, value, count=True):
|
|
self._metrics[key] += value
|
|
if count:
|
|
count_key = '{}_count'.format(key)
|
|
self._metrics[count_key] += 1
|
|
|
|
|
|
class State():
|
|
__metaclass__ = ABCMeta
|
|
|
|
@abstractmethod
|
|
def start(self):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def stop(self):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def cleanup(self):
|
|
pass
|
|
|
|
|
|
class WorkerCoordinator(object):
|
|
def __init__(
|
|
self, worker_name, worker_ids, init_fun,
|
|
state=None, shutdown_fun=None
|
|
):
|
|
self._active = True
|
|
self._started = False
|
|
self._workers = []
|
|
self._worker_name = worker_name
|
|
self._worker_ids = worker_ids
|
|
self._init_fun = init_fun
|
|
self._state = state
|
|
self._shutdown_fun = shutdown_fun
|
|
|
|
def is_active(self):
|
|
return self._active
|
|
|
|
def init(self, global_coordinator):
|
|
if self._init_fun and not self._started:
|
|
data_coordinator = self
|
|
self._init_fun(data_coordinator, global_coordinator)
|
|
|
|
def _start(self):
|
|
if self._started:
|
|
return
|
|
self._active = True
|
|
self._started = True
|
|
if self._state:
|
|
self._state.start()
|
|
|
|
for w in self._workers:
|
|
w.daemon = True
|
|
w.start()
|
|
|
|
def _stop(self, reason=None):
|
|
self._active = False
|
|
if reason is not None:
|
|
log.error("Data input failed due to an error: {}".format(reason))
|
|
if self._shutdown_fun and self._started:
|
|
self._shutdown_fun()
|
|
if self._state:
|
|
self._state.stop()
|
|
|
|
self._started = False
|
|
|
|
def _wait_finish(self, cleanup=None):
|
|
print("Wait for workers to die: {}".format(self._worker_name))
|
|
for w in self._workers:
|
|
if w != threading.current_thread():
|
|
w.join(5.0) # don't wait forever, thread may be blocked in i/o
|
|
success = True
|
|
for w in self._workers:
|
|
if w.is_alive():
|
|
print("Worker {} failed to close while waiting".format(w))
|
|
success = False
|
|
|
|
# Release memory for the scratch blobs
|
|
if success and self._state:
|
|
self._state.cleanup()
|
|
|
|
print("All workers terminated: {}".format(success))
|
|
return success
|
|
|
|
def get_worker_ids(self):
|
|
return self._worker_ids
|
|
|
|
|
|
class GlobalWorkerCoordinator(object):
|
|
def __init__(self):
|
|
self._coordinators = []
|
|
self._fetcher_id_seq = 0
|
|
self._worker_ids = []
|
|
self.register_shutdown_handler()
|
|
|
|
def add(self, coordinator):
|
|
self._coordinators.append(coordinator)
|
|
|
|
def get_new_worker_id(self):
|
|
worker_id = self._fetcher_id_seq
|
|
self._worker_ids.append(worker_id)
|
|
self._fetcher_id_seq += 1
|
|
return worker_id
|
|
|
|
def get_worker_ids(self):
|
|
return self._worker_ids
|
|
|
|
def start(self):
|
|
# run init and start in separate for loop to
|
|
# ensure init happens serially before threads are spawn.
|
|
for c in self._coordinators:
|
|
c.init(self)
|
|
for c in self._coordinators:
|
|
c._start()
|
|
|
|
def stop(self):
|
|
all_success = True
|
|
for c in self._coordinators:
|
|
c._stop()
|
|
for c in self._coordinators:
|
|
success = c._wait_finish()
|
|
all_success = all_success and success
|
|
self._coordinators = []
|
|
return all_success
|
|
|
|
def stop_coordinator(self, worker_name):
|
|
'''
|
|
Stop a specific coordinator
|
|
'''
|
|
for c in self._coordinators:
|
|
if c._worker_name == worker_name:
|
|
c._stop()
|
|
c._wait_finish()
|
|
self._coordinators = [
|
|
c for c in self._coordinators
|
|
if c._worker_name != worker_name
|
|
]
|
|
|
|
def register_shutdown_handler(self):
|
|
def cleanup():
|
|
self.stop()
|
|
|
|
atexit.register(cleanup)
|
|
|
|
|
|
class Worker(object):
|
|
def __init__(
|
|
self,
|
|
coordinator,
|
|
worker_id,
|
|
worker_fun=None,
|
|
metrics=None
|
|
):
|
|
self._coordinator = coordinator
|
|
self._worker_id = worker_id
|
|
self._worker_fun = worker_fun
|
|
self._metrics = metrics
|
|
|
|
def start(self):
|
|
self._start_time = time.time()
|
|
|
|
def run(self):
|
|
self._worker_fun(self._worker_id)
|
|
|
|
def handle_exception(self, e):
|
|
traceback.print_exc()
|
|
logging.exception("Exception in worker", e)
|
|
self._coordinator._stop("Exception in worker {}: {}".format(
|
|
self._worker_id, e
|
|
))
|
|
|
|
def finish(self):
|
|
self._metrics.put_metric(
|
|
'worker_time', time.time() - self._start_time)
|
|
self._metrics.log_metrics()
|
|
|
|
|
|
global_coordinator = GlobalWorkerCoordinator()
|
|
|
|
|
|
def run_worker(coordinator, worker):
|
|
while coordinator.is_active():
|
|
worker.start()
|
|
try:
|
|
worker.run()
|
|
except Exception as e:
|
|
worker.handle_exception(e)
|
|
finally:
|
|
worker.finish()
|