mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136359 Approved by: https://github.com/albanD
144 lines
5.3 KiB
Python
144 lines
5.3 KiB
Python
import time
|
|
|
|
import numpy as np
|
|
from agent import AgentBase
|
|
from observer import ObserverBase
|
|
|
|
import torch
|
|
import torch.distributed.rpc as rpc
|
|
|
|
|
|
COORDINATOR_NAME = "coordinator"
|
|
AGENT_NAME = "agent"
|
|
OBSERVER_NAME = "observer{}"
|
|
|
|
EPISODE_STEPS = 100
|
|
|
|
|
|
class CoordinatorBase:
|
|
def __init__(self, batch_size, batch, state_size, nlayers, out_features):
|
|
r"""
|
|
Coordinator object to run on worker. Only one coordinator exists. Responsible
|
|
for facilitating communication between agent and observers and recording benchmark
|
|
throughput and latency data.
|
|
Args:
|
|
batch_size (int): Number of observer requests to process in a batch
|
|
batch (bool): Whether to process and respond to observer requests as a batch or 1 at a time
|
|
state_size (list): List of ints dictating the dimensions of the state
|
|
nlayers (int): Number of layers in the model
|
|
out_features (int): Number of out features in the model
|
|
"""
|
|
self.batch_size = batch_size
|
|
self.batch = batch
|
|
|
|
self.agent_rref = None # Agent RRef
|
|
self.ob_rrefs = [] # Observer RRef
|
|
|
|
agent_info = rpc.get_worker_info(AGENT_NAME)
|
|
self.agent_rref = rpc.remote(agent_info, AgentBase)
|
|
|
|
for rank in range(batch_size):
|
|
ob_info = rpc.get_worker_info(OBSERVER_NAME.format(rank + 2))
|
|
ob_ref = rpc.remote(ob_info, ObserverBase)
|
|
self.ob_rrefs.append(ob_ref)
|
|
|
|
ob_ref.rpc_sync().set_state(state_size, batch)
|
|
|
|
self.agent_rref.rpc_sync().set_world(
|
|
batch_size, state_size, nlayers, out_features, self.batch
|
|
)
|
|
|
|
def run_coordinator(self, episodes, episode_steps, queue):
|
|
r"""
|
|
Runs n benchmark episodes. Each episode is started by coordinator telling each
|
|
observer to contact the agent. Each episode is concluded by coordinator telling agent
|
|
to finish the episode, and then the coordinator records benchmark data
|
|
Args:
|
|
episodes (int): Number of episodes to run
|
|
episode_steps (int): Number steps to be run in each episdoe by each observer
|
|
queue (SimpleQueue): SimpleQueue from torch.multiprocessing.get_context() for
|
|
saving benchmark run results to
|
|
"""
|
|
|
|
agent_latency_final = []
|
|
agent_throughput_final = []
|
|
|
|
observer_latency_final = []
|
|
observer_throughput_final = []
|
|
|
|
for ep in range(episodes):
|
|
ep_start_time = time.time()
|
|
|
|
print(f"Episode {ep} - ", end="")
|
|
|
|
n_steps = episode_steps
|
|
|
|
futs = []
|
|
for ob_rref in self.ob_rrefs:
|
|
futs.append(
|
|
ob_rref.rpc_async().run_ob_episode(self.agent_rref, n_steps)
|
|
)
|
|
|
|
rets = torch.futures.wait_all(futs)
|
|
agent_latency, agent_throughput = self.agent_rref.rpc_sync().finish_episode(
|
|
rets
|
|
)
|
|
|
|
self.agent_rref.rpc_sync().reset_metrics()
|
|
|
|
agent_latency_final += agent_latency
|
|
agent_throughput_final += agent_throughput
|
|
|
|
observer_latency_final += [ret[2] for ret in rets]
|
|
observer_throughput_final += [ret[3] for ret in rets]
|
|
|
|
ep_end_time = time.time()
|
|
episode_time = ep_end_time - ep_start_time
|
|
print(round(episode_time, 3))
|
|
|
|
observer_latency_final = [t for s in observer_latency_final for t in s]
|
|
observer_throughput_final = [t for s in observer_throughput_final for t in s]
|
|
|
|
benchmark_metrics = {
|
|
"agent latency (seconds)": {},
|
|
"agent throughput": {},
|
|
"observer latency (seconds)": {},
|
|
"observer throughput": {},
|
|
}
|
|
|
|
print(f"For batch size {self.batch_size}")
|
|
print("\nAgent Latency - ", len(agent_latency_final))
|
|
agent_latency_final = sorted(agent_latency_final)
|
|
for p in [50, 75, 90, 95]:
|
|
v = np.percentile(agent_latency_final, p)
|
|
print("p" + str(p) + ":", round(v, 3))
|
|
p = f"p{p}"
|
|
benchmark_metrics["agent latency (seconds)"][p] = round(v, 3)
|
|
|
|
print("\nAgent Throughput - ", len(agent_throughput_final))
|
|
agent_throughput_final = sorted(agent_throughput_final)
|
|
for p in [50, 75, 90, 95]:
|
|
v = np.percentile(agent_throughput_final, p)
|
|
print("p" + str(p) + ":", int(v))
|
|
p = f"p{p}"
|
|
benchmark_metrics["agent throughput"][p] = int(v)
|
|
|
|
print("\nObserver Latency - ", len(observer_latency_final))
|
|
observer_latency_final = sorted(observer_latency_final)
|
|
for p in [50, 75, 90, 95]:
|
|
v = np.percentile(observer_latency_final, p)
|
|
print("p" + str(p) + ":", round(v, 3))
|
|
p = f"p{p}"
|
|
benchmark_metrics["observer latency (seconds)"][p] = round(v, 3)
|
|
|
|
print("\nObserver Throughput - ", len(observer_throughput_final))
|
|
observer_throughput_final = sorted(observer_throughput_final)
|
|
for p in [50, 75, 90, 95]:
|
|
v = np.percentile(observer_throughput_final, p)
|
|
print("p" + str(p) + ":", int(v))
|
|
p = f"p{p}"
|
|
benchmark_metrics["observer throughput"][p] = int(v)
|
|
|
|
if queue:
|
|
queue.put(benchmark_metrics)
|