mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ca] make torch.compile API respect ambient disable contexts (#155473)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155473 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
be124a61a4
commit
87b002b6fb
@ -99,6 +99,7 @@ def reset():
|
||||
torch._logging.set_logs(compiled_autograd_verbose=False)
|
||||
config.compiled_autograd = False
|
||||
compiled_autograd.reset()
|
||||
torch._dynamo.utils.counters.clear()
|
||||
|
||||
|
||||
class TestCompiledAutograd(TestCase):
|
||||
@ -706,6 +707,44 @@ main()
|
||||
self.assertEqual(expected, actual)
|
||||
self.assertEqual(counters["compiled_autograd"]["captures"], 2)
|
||||
|
||||
@parametrize("api", ("compile", "optimize"))
|
||||
@parametrize("backend", ("eager", "aot_eager", "inductor"))
|
||||
def test_compile_api_disable(self, api, backend):
|
||||
def wrap(fn, backend):
|
||||
if api == "compile":
|
||||
return torch.compile(fn, backend=backend)
|
||||
elif api == "optimize":
|
||||
return torch._dynamo.optimize(backend)(fn)
|
||||
|
||||
def fn(model, inputs):
|
||||
res = []
|
||||
for inp in inputs:
|
||||
result = model(inp).sum()
|
||||
result.backward()
|
||||
res.append(model[0].weight.grad)
|
||||
res.append(model[0].bias.grad)
|
||||
model.zero_grad()
|
||||
return res
|
||||
|
||||
torch.manual_seed(123)
|
||||
model = torch.nn.Sequential(
|
||||
torch.nn.Linear(4, 4),
|
||||
torch.nn.Sigmoid(),
|
||||
)
|
||||
inputs = [
|
||||
torch.randn([1, 4]),
|
||||
torch.randn([2, 4]),
|
||||
torch.randn([3, 4]),
|
||||
]
|
||||
|
||||
expected = fn(model, inputs)
|
||||
with config.patch(compiled_autograd=True):
|
||||
compiled_fn = wrap(fn, backend)
|
||||
with torch._dynamo.compiled_autograd._disable():
|
||||
actual = compiled_fn(model, inputs)
|
||||
self.assertEqual(expected, actual)
|
||||
self.assertTrue("compiled_autograd" not in counters)
|
||||
|
||||
@parametrize("backend", ("eager", "aot_eager", "inductor"))
|
||||
def test_optimize_assert(self, backend):
|
||||
# can be merged into the test above once we support
|
||||
|
@ -1373,9 +1373,11 @@ compiled_autograd_enabled_force_eager = False
|
||||
# global flag to check if we are processing graphs produced from a compiled autograd graph
|
||||
in_compiled_autograd_region = False
|
||||
|
||||
active_disable_ctx = False
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _enable(compiler_fn, dynamic: bool = True):
|
||||
def _enable(compiler_fn, dynamic: bool = True, ignore_active_disable_ctx=True):
|
||||
# The entrypoint to enable CA.
|
||||
# It is recommended to enable via `torch._dynamo.config.compiled_autograd = True` rather
|
||||
# than using this context manager directly. If you are torch.compiling the corresponding
|
||||
@ -1396,6 +1398,9 @@ def _enable(compiler_fn, dynamic: bool = True):
|
||||
# - dynamic: Whether compiled autograd will treat tensors in the autograd graph (params, activations) as dynamic.
|
||||
# This doesn't affect the dynamic configuration of the compilation wrapper.
|
||||
|
||||
if not ignore_active_disable_ctx and active_disable_ctx:
|
||||
yield
|
||||
else:
|
||||
if dynamic:
|
||||
assert type(dynamic) is bool
|
||||
|
||||
@ -1444,11 +1449,15 @@ def _disable():
|
||||
) = torch._C._dynamo.compiled_autograd.set_autograd_compiler(None, False)
|
||||
global compiled_autograd_enabled
|
||||
compiled_autograd_enabled = False
|
||||
global active_disable_ctx
|
||||
if not active_disable_ctx:
|
||||
active_disable_ctx = True
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if prior_compiler:
|
||||
compiled_autograd_enabled = True
|
||||
active_disable_ctx = False
|
||||
torch._C._dynamo.compiled_autograd.set_autograd_compiler(
|
||||
prior_compiler, prior_dynamic
|
||||
)
|
||||
|
@ -816,7 +816,7 @@ class OptimizeContext(_TorchDynamoContext):
|
||||
assert rebuild_ctx is not None
|
||||
compiler_fn = rebuild_ctx()
|
||||
ctx = torch._dynamo.compiled_autograd._enable(
|
||||
compiler_fn, dynamic=_dynamic
|
||||
compiler_fn, dynamic=_dynamic, ignore_active_disable_ctx=False
|
||||
)
|
||||
ctx.__enter__()
|
||||
return functools.partial(ctx.__exit__, None, None, None)
|
||||
|
Reference in New Issue
Block a user