mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-25 16:14:55 +08:00 
			
		
		
		
	Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15024 Pull Request resolved: https://github.com/pytorch/pytorch/pull/14393 att Reviewed By: dzhulgakov Differential Revision: D13380559 fbshipit-source-id: abc3fc7321cf37323f756dfd614c7b41978734e4
		
			
				
	
	
		
			54 lines
		
	
	
		
			1.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			54 lines
		
	
	
		
			1.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from __future__ import absolute_import
 | |
| from __future__ import division
 | |
| from __future__ import print_function
 | |
| 
 | |
| from caffe2.python import core, workspace
 | |
| from caffe2.proto import caffe2_pb2
 | |
| from caffe2.python.test_util import TestCase
 | |
| import unittest
 | |
| 
 | |
| core.GlobalInit(["caffe2", "--caffe2_cpu_numa_enabled=1"])
 | |
| 
 | |
| def build_test_net(net_name):
 | |
|     net = core.Net(net_name)
 | |
|     net.Proto().type = "async_scheduling"
 | |
| 
 | |
|     numa_device_option = caffe2_pb2.DeviceOption()
 | |
|     numa_device_option.device_type = caffe2_pb2.CPU
 | |
|     numa_device_option.numa_node_id = 0
 | |
| 
 | |
|     net.ConstantFill([], "output_blob_0", shape=[1], value=3.14,
 | |
|                          device_option=numa_device_option)
 | |
| 
 | |
|     numa_device_option.numa_node_id = 1
 | |
|     net.ConstantFill([], "output_blob_1", shape=[1], value=3.14,
 | |
|                          device_option=numa_device_option)
 | |
| 
 | |
|     gpu_device_option = caffe2_pb2.DeviceOption()
 | |
|     gpu_device_option.device_type = caffe2_pb2.CUDA
 | |
|     gpu_device_option.device_id = 0
 | |
| 
 | |
|     net.CopyCPUToGPU("output_blob_0", "output_blob_0_gpu",
 | |
|                         device_option=gpu_device_option)
 | |
|     net.CopyCPUToGPU("output_blob_1", "output_blob_1_gpu",
 | |
|                         device_option=gpu_device_option)
 | |
| 
 | |
|     return net
 | |
| 
 | |
| 
 | |
| @unittest.skipIf(not workspace.IsNUMAEnabled(), "NUMA is not enabled")
 | |
| @unittest.skipIf(workspace.GetNumNUMANodes() < 2, "Not enough NUMA nodes")
 | |
| @unittest.skipIf(not workspace.has_gpu_support, "No GPU support")
 | |
| class NUMATest(TestCase):
 | |
|     def test_numa(self):
 | |
|         net = build_test_net("test_numa")
 | |
| 
 | |
|         workspace.RunNetOnce(net)
 | |
| 
 | |
|         self.assertEqual(workspace.GetBlobNUMANode("output_blob_0"), 0)
 | |
|         self.assertEqual(workspace.GetBlobNUMANode("output_blob_1"), 1)
 | |
| 
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|     unittest.main()
 |