mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Expose enablement of TensorExpr fuser as env variable (#35341)
Summary: This commit allows one to use an environment variable to enable the fuser in torch/csrc/jit/tensorexpr/ ``` PYTORCH_TENSOREXPR=1 python benchmark.py ``` This commit also changes the registration to happen by default, removing the requirement for the python exposed "_jit_register_tensorexpr_fuser" Pull Request resolved: https://github.com/pytorch/pytorch/pull/35341 Reviewed By: ZolotukhinM Differential Revision: D20676348 Pulled By: bwasti fbshipit-source-id: 4c997cdc310e7567c03905ebff72b3e8a4c2f464
This commit is contained in:
committed by
Facebook GitHub Bot
parent
4d39aeec27
commit
a3e10d2a17
@ -90,7 +90,7 @@ Works only with Python3.\n A few examples:
|
||||
if args.cuda_fuser == "te":
|
||||
import torch
|
||||
|
||||
torch._C._jit_register_tensorexpr_fuser()
|
||||
torch._C._jit_set_texpr_fuser_enabled(True)
|
||||
|
||||
def set_global_threads(num_threads):
|
||||
os.environ["OMP_NUM_THREADS"] = str(num_threads)
|
||||
|
@ -55,7 +55,6 @@ class TestFuser(JitTestCase):
|
||||
self.old_gpu_fuser_state = torch._C._jit_can_fuse_on_gpu()
|
||||
torch._C._jit_override_can_fuse_on_cpu(False)
|
||||
torch._C._jit_override_can_fuse_on_gpu(False)
|
||||
torch._C._jit_register_tensorexpr_fuser()
|
||||
torch._C._jit_set_texpr_fuser_enabled(True)
|
||||
|
||||
self.old_profiling_executor = torch._C._jit_set_profiling_executor(True)
|
||||
|
@ -23,9 +23,10 @@ class BaseTestClass(unittest.TestCase):
|
||||
# TODO: read the old value and restore it rather than always set to True
|
||||
# on exit
|
||||
torch._C._jit_override_can_fuse_on_gpu(False)
|
||||
torch._C._jit_register_tensorexpr_fuser()
|
||||
torch._C._jit_set_texpr_fuser_enabled(True)
|
||||
|
||||
def tearDown(self):
|
||||
torch._C._jit_set_texpr_fuser_enabled(False)
|
||||
torch._C._jit_override_can_fuse_on_gpu(True)
|
||||
|
||||
class TestTensorExprFuser(BaseTestClass):
|
||||
|
@ -13,9 +13,20 @@
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
static bool texpr_fuser_enabled = true;
|
||||
static bool texpr_fuser_enabled_ = false;
|
||||
void setTensorExprFuserEnabled(bool val) {
|
||||
texpr_fuser_enabled = val;
|
||||
texpr_fuser_enabled_ = val;
|
||||
}
|
||||
|
||||
static bool tensorExprFuserEnabled() {
|
||||
static const char* enable_c_str = std::getenv("PYTORCH_TENSOREXPR");
|
||||
if (!enable_c_str) {
|
||||
return texpr_fuser_enabled_;
|
||||
}
|
||||
if (std::string(enable_c_str) == "0") {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
const Symbol& getTensorExprSymbol() {
|
||||
@ -255,7 +266,7 @@ std::pair<graph_node_list::iterator, bool> scanNode(
|
||||
}
|
||||
|
||||
void fuseTensorExprs(std::shared_ptr<Graph>& graph) {
|
||||
if (!texpr_fuser_enabled) {
|
||||
if (!tensorExprFuserEnabled()) {
|
||||
return;
|
||||
}
|
||||
GRAPH_DUMP("Before TExprFuser: ", graph);
|
||||
@ -331,12 +342,7 @@ RegisterOperators TensorExprOps({
|
||||
AliasAnalysisKind::PURE_FUNCTION),
|
||||
});
|
||||
|
||||
void registerTensorExprFuser() {
|
||||
static bool already_registered = false;
|
||||
if (!already_registered) {
|
||||
RegisterPass pass(fuseTensorExprs);
|
||||
already_registered = true;
|
||||
}
|
||||
}
|
||||
static RegisterPass pass(fuseTensorExprs);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -354,7 +354,6 @@ void initJITBindings(PyObject* module) {
|
||||
.def("_jit_override_can_fuse_on_gpu", &overrideCanFuseOnGPU)
|
||||
.def("_jit_can_fuse_on_cpu", &canFuseOnCPU)
|
||||
.def("_jit_can_fuse_on_gpu", &canFuseOnGPU)
|
||||
.def("_jit_register_tensorexpr_fuser", ®isterTensorExprFuser)
|
||||
.def(
|
||||
"_jit_differentiate",
|
||||
[](Graph& g) {
|
||||
|
Reference in New Issue
Block a user