mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/34230 This PR adds some benchmarks that we used to assess tensor expressions performance. Differential Revision: D20251830 Test Plan: Imported from OSS Pulled By: ZolotukhinM fbshipit-source-id: bafd66ce32f63077e3733112d854f5c750d5b1af
43 lines
1.1 KiB
Python
43 lines
1.1 KiB
Python
tensor_engine = None
|
|
|
|
def unsupported(func):
|
|
def wrapper(self):
|
|
return func(self)
|
|
|
|
wrapper.is_supported = False
|
|
return wrapper
|
|
|
|
|
|
def is_supported(method):
|
|
if hasattr(method, 'is_supported'):
|
|
return method.is_supported
|
|
return True
|
|
|
|
|
|
def set_engine_mode(mode):
|
|
global tensor_engine
|
|
if mode == 'tf':
|
|
import tf_engine
|
|
tensor_engine = tf_engine.TensorFlowEngine()
|
|
elif mode == 'pt':
|
|
import pt_engine
|
|
tensor_engine = pt_engine.TorchTensorEngine()
|
|
elif mode == 'topi':
|
|
import topi_engine
|
|
tensor_engine = topi_engine.TopiEngine()
|
|
elif mode == 'relay':
|
|
import relay_engine
|
|
tensor_engine = relay_engine.RelayEngine()
|
|
elif mode == 'nnc':
|
|
import nnc_engine
|
|
tensor_engine = nnc_engine.NncEngine()
|
|
else:
|
|
raise ValueError('invalid tensor engine mode: %s' % (mode))
|
|
tensor_engine.mode = mode
|
|
|
|
|
|
def get_engine():
|
|
if tensor_engine is None:
|
|
raise ValueError('use of get_engine, before calling set_engine_mode is illegal')
|
|
return tensor_engine
|