Files
pytorch/caffe2/python/numa_benchmark.py
Yudong Guang 265b55d028 Revert D13205604: Move numa.{h, cc} to c10/util
Differential Revision:
D13205604

Original commit changeset: 54166492d318

fbshipit-source-id: 89b6833518c0b554668c88ae38d97fbc47e2de17
2018-12-07 10:01:25 -08:00

70 lines
2.3 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from caffe2.python import core, workspace
from caffe2.proto import caffe2_pb2
import time
SHAPE_LEN = 4096
NUM_ITER = 1000
GB = 1024 * 1024 * 1024
NUM_REPLICAS = 48
def build_net(net_name, cross_socket):
init_net = core.Net(net_name + "_init")
init_net.Proto().type = "async_scheduling"
numa_device_option = caffe2_pb2.DeviceOption()
numa_device_option.device_type = caffe2_pb2.CPU
numa_device_option.numa_node_id = 0
for replica_id in range(NUM_REPLICAS):
init_net.XavierFill([], net_name + "/input_blob_" + str(replica_id),
shape=[SHAPE_LEN, SHAPE_LEN], device_option=numa_device_option)
net = core.Net(net_name)
net.Proto().type = "async_scheduling"
if cross_socket:
numa_device_option.numa_node_id = 1
for replica_id in range(NUM_REPLICAS):
net.Copy(net_name + "/input_blob_" + str(replica_id),
net_name + "/output_blob_" + str(replica_id),
device_option=numa_device_option)
return init_net, net
def main():
assert workspace.IsNUMAEnabled() and workspace.GetNumNUMANodes() >= 2
single_init, single_net = build_net("single_net", False)
cross_init, cross_net = build_net("cross_net", True)
workspace.CreateNet(single_init)
workspace.RunNet(single_init.Name())
workspace.CreateNet(cross_init)
workspace.RunNet(cross_init.Name())
workspace.CreateNet(single_net)
workspace.CreateNet(cross_net)
for _ in range(4):
t = time.time()
workspace.RunNet(single_net.Name(), NUM_ITER)
dt = time.time() - t
print("Single socket time:", dt)
single_bw = 4 * SHAPE_LEN * SHAPE_LEN * NUM_REPLICAS * NUM_ITER / dt / GB
print("Single socket BW: {} GB/s".format(single_bw))
t = time.time()
workspace.RunNet(cross_net.Name(), NUM_ITER)
dt = time.time() - t
print("Cross socket time:", dt)
cross_bw = 4 * SHAPE_LEN * SHAPE_LEN * NUM_REPLICAS * NUM_ITER / dt / GB
print("Cross socket BW: {} GB/s".format(cross_bw))
print("Single BW / Cross BW: {}".format(single_bw / cross_bw))
if __name__ == '__main__':
core.GlobalInit(["caffe2", "--caffe2_cpu_numa_enabled=1"])
main()