mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Expose MKLMemory to the Python Feed and Fetch interface, and misc changes
Summary: This is #2 of a series of changes. It did the following: (1) a few refactor of the MKL memory interface (2) an initial MKLContext to deal with MKL specific computations (3) Provide MKLMemory access in Python with the blob feeder/fetcher registration. Reviewed By: dzhulgakov Differential Revision: D4210123 fbshipit-source-id: adea1f1ffbd0b9ffdd55092676468c16bec08992
This commit is contained in:
@ -2,6 +2,7 @@ import numpy as np
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from caffe2.proto import caffe2_pb2
|
||||
from caffe2.python import core, test_util, workspace
|
||||
|
||||
import caffe2.python.hypothesis_test_util as htu
|
||||
@ -294,7 +295,18 @@ class TestWorkspaceGPU(test_util.TestCase):
|
||||
self.assertEqual(pattern.shape[0], workspace.NumCudaDevices())
|
||||
|
||||
|
||||
class TestImmediate(test_util.TestCase):
|
||||
@unittest.skipIf(not workspace.C.has_mkldnn, "No MKLDNN support.")
|
||||
class TestWorkspaceMKLDNN(test_util.TestCase):
|
||||
|
||||
def testFeedFetchBlobMKLDNN(self):
|
||||
arr = np.random.randn(2, 3).astype(np.float32)
|
||||
workspace.FeedBlob(
|
||||
"testblob_mkldnn", arr, core.DeviceOption(caffe2_pb2.MKLDNN))
|
||||
fetched = workspace.FetchBlob("testblob_mkldnn")
|
||||
np.testing.assert_array_equal(arr, fetched)
|
||||
|
||||
|
||||
class TestImmedibate(test_util.TestCase):
|
||||
def testImmediateEnterExit(self):
|
||||
workspace.StartImmediate(i_know=True)
|
||||
self.assertTrue(workspace.IsImmediate())
|
||||
|
Reference in New Issue
Block a user