Files
pytorch/benchmarks/tensorexpr/tensor_engine.py
Mikhail Zolotukhin e93e7b2795 [TensorExpr] Add tensorexpr benchmarks. (#34230)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34230

This PR adds some benchmarks that we used to assess tensor expressions performance.

Differential Revision: D20251830

Test Plan: Imported from OSS

Pulled By: ZolotukhinM

fbshipit-source-id: bafd66ce32f63077e3733112d854f5c750d5b1af
2020-03-16 11:49:39 -07:00

43 lines
1.1 KiB
Python

tensor_engine = None
def unsupported(func):
def wrapper(self):
return func(self)
wrapper.is_supported = False
return wrapper
def is_supported(method):
if hasattr(method, 'is_supported'):
return method.is_supported
return True
def set_engine_mode(mode):
global tensor_engine
if mode == 'tf':
import tf_engine
tensor_engine = tf_engine.TensorFlowEngine()
elif mode == 'pt':
import pt_engine
tensor_engine = pt_engine.TorchTensorEngine()
elif mode == 'topi':
import topi_engine
tensor_engine = topi_engine.TopiEngine()
elif mode == 'relay':
import relay_engine
tensor_engine = relay_engine.RelayEngine()
elif mode == 'nnc':
import nnc_engine
tensor_engine = nnc_engine.NncEngine()
else:
raise ValueError('invalid tensor engine mode: %s' % (mode))
tensor_engine.mode = mode
def get_engine():
if tensor_engine is None:
raise ValueError('use of get_engine, before calling set_engine_mode is illegal')
return tensor_engine