Expose enablement of TensorExpr fuser as env variable (#35341)

Summary:
This commit allows one to use an environment variable to enable the fuser in torch/csrc/jit/tensorexpr/

```
PYTORCH_TENSOREXPR=1 python benchmark.py
```

This commit also changes the registration to happen by default, removing the requirement for the python exposed "_jit_register_tensorexpr_fuser"
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35341

Reviewed By: ZolotukhinM

Differential Revision: D20676348

Pulled By: bwasti

fbshipit-source-id: 4c997cdc310e7567c03905ebff72b3e8a4c2f464
This commit is contained in:
Bram Wasti
2020-03-26 14:28:46 -07:00
committed by Facebook GitHub Bot
parent 4d39aeec27
commit a3e10d2a17
5 changed files with 19 additions and 14 deletions

View File

@ -90,7 +90,7 @@ Works only with Python3.\n A few examples:
if args.cuda_fuser == "te":
import torch
torch._C._jit_register_tensorexpr_fuser()
torch._C._jit_set_texpr_fuser_enabled(True)
def set_global_threads(num_threads):
os.environ["OMP_NUM_THREADS"] = str(num_threads)

View File

@ -55,7 +55,6 @@ class TestFuser(JitTestCase):
self.old_gpu_fuser_state = torch._C._jit_can_fuse_on_gpu()
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_register_tensorexpr_fuser()
torch._C._jit_set_texpr_fuser_enabled(True)
self.old_profiling_executor = torch._C._jit_set_profiling_executor(True)

View File

@ -23,9 +23,10 @@ class BaseTestClass(unittest.TestCase):
# TODO: read the old value and restore it rather than always set to True
# on exit
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_register_tensorexpr_fuser()
torch._C._jit_set_texpr_fuser_enabled(True)
def tearDown(self):
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_override_can_fuse_on_gpu(True)
class TestTensorExprFuser(BaseTestClass):

View File

@ -13,9 +13,20 @@
namespace torch {
namespace jit {
static bool texpr_fuser_enabled = true;
static bool texpr_fuser_enabled_ = false;
void setTensorExprFuserEnabled(bool val) {
texpr_fuser_enabled = val;
texpr_fuser_enabled_ = val;
}
static bool tensorExprFuserEnabled() {
static const char* enable_c_str = std::getenv("PYTORCH_TENSOREXPR");
if (!enable_c_str) {
return texpr_fuser_enabled_;
}
if (std::string(enable_c_str) == "0") {
return false;
}
return true;
}
const Symbol& getTensorExprSymbol() {
@ -255,7 +266,7 @@ std::pair<graph_node_list::iterator, bool> scanNode(
}
void fuseTensorExprs(std::shared_ptr<Graph>& graph) {
if (!texpr_fuser_enabled) {
if (!tensorExprFuserEnabled()) {
return;
}
GRAPH_DUMP("Before TExprFuser: ", graph);
@ -331,12 +342,7 @@ RegisterOperators TensorExprOps({
AliasAnalysisKind::PURE_FUNCTION),
});
void registerTensorExprFuser() {
static bool already_registered = false;
if (!already_registered) {
RegisterPass pass(fuseTensorExprs);
already_registered = true;
}
}
static RegisterPass pass(fuseTensorExprs);
} // namespace jit
} // namespace torch

View File

@ -354,7 +354,6 @@ void initJITBindings(PyObject* module) {
.def("_jit_override_can_fuse_on_gpu", &overrideCanFuseOnGPU)
.def("_jit_can_fuse_on_cpu", &canFuseOnCPU)
.def("_jit_can_fuse_on_gpu", &canFuseOnGPU)
.def("_jit_register_tensorexpr_fuser", &registerTensorExprFuser)
.def(
"_jit_differentiate",
[](Graph& g) {