mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
move misc implementation out of jit/__init__.py
(#41154)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/41154 Test Plan: Imported from OSS Reviewed By: ailzhang Differential Revision: D22445213 Pulled By: suo fbshipit-source-id: 200545715c5ef13beb1437f49e01efb21498ddb7
This commit is contained in:
committed by
Facebook GitHub Bot
parent
6392713584
commit
ca1b8ebbcb
70
torch/jit/_fuser.py
Normal file
70
torch/jit/_fuser.py
Normal file
@ -0,0 +1,70 @@
|
||||
import contextlib
|
||||
|
||||
import torch
|
||||
|
||||
@contextlib.contextmanager
|
||||
def optimized_execution(should_optimize):
|
||||
"""
|
||||
A context manager that controls whether the JIT's executor will run
|
||||
optimizations before executing a function.
|
||||
"""
|
||||
stored_flag = torch._C._get_graph_executor_optimize()
|
||||
torch._C._set_graph_executor_optimize(should_optimize)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch._C._set_graph_executor_optimize(stored_flag)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def fuser(name):
|
||||
"""
|
||||
A context manager that facilitates switching between
|
||||
backend fusers.
|
||||
|
||||
Valid names:
|
||||
* ``fuser0`` - enables only legacy fuser
|
||||
* ``fuser1`` - enables only NNC
|
||||
* ``fuser2`` - enables only nvFuser
|
||||
"""
|
||||
old_cpu_fuse = torch._C._jit_can_fuse_on_cpu()
|
||||
old_gpu_fuse = torch._C._jit_can_fuse_on_gpu()
|
||||
old_texpr_fuser_state = torch._C._jit_texpr_fuser_enabled()
|
||||
old_nvfuser_state = torch._C._jit_nvfuser_enabled()
|
||||
if name == 'fuser0': # legacy fuser
|
||||
torch._C._jit_override_can_fuse_on_cpu(True)
|
||||
torch._C._jit_override_can_fuse_on_gpu(True)
|
||||
torch._C._jit_set_texpr_fuser_enabled(False)
|
||||
torch._C._jit_set_nvfuser_enabled(False)
|
||||
elif name == 'fuser1': # NNC
|
||||
old_profiling_executor = torch._C._jit_set_profiling_executor(True)
|
||||
old_profiling_mode = torch._C._jit_set_profiling_mode(True)
|
||||
torch._C._jit_override_can_fuse_on_cpu(False)
|
||||
torch._C._jit_override_can_fuse_on_gpu(False)
|
||||
torch._C._jit_set_texpr_fuser_enabled(True)
|
||||
torch._C._jit_set_nvfuser_enabled(False)
|
||||
elif name == 'fuser2': # nvFuser
|
||||
torch._C._jit_override_can_fuse_on_cpu(False)
|
||||
torch._C._jit_override_can_fuse_on_gpu(False)
|
||||
torch._C._jit_set_texpr_fuser_enabled(False)
|
||||
torch._C._jit_set_nvfuser_enabled(True)
|
||||
else:
|
||||
raise Exception("unrecognized fuser option")
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if name == 'fuser1': # NNC
|
||||
torch._C._jit_set_profiling_executor(old_profiling_executor)
|
||||
torch._C._jit_set_profiling_mode(old_profiling_mode)
|
||||
# recover the previous values
|
||||
torch._C._jit_override_can_fuse_on_cpu(old_cpu_fuse)
|
||||
torch._C._jit_override_can_fuse_on_gpu(old_gpu_fuse)
|
||||
torch._C._jit_set_texpr_fuser_enabled(old_texpr_fuser_state)
|
||||
torch._C._jit_set_nvfuser_enabled(old_nvfuser_state)
|
||||
|
||||
|
||||
last_executed_optimized_graph = torch._C._last_executed_optimized_graph
|
||||
|
||||
|
||||
def _graph_for(self, *args, **kwargs):
|
||||
self(*args, **kwargs)
|
||||
return last_executed_optimized_graph()
|
Reference in New Issue
Block a user