Disable autocast cache in torch.cuda.make_graphed_callables (#84289)

There there are conflicts between `torch.clear_autocast_cache()` and `cudaMallocAsync` from #82682.
Moreover, the use of autocast caching is not reasonable during training which is the main target of `make_graphed_callables`.

cc @eqy @ptrblck
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84289
Approved by: https://github.com/ngimel
This commit is contained in:
Aidyn-A
2022-09-01 21:34:51 +00:00
committed by PyTorch MergeBot
parent d39490a711
commit ce1b727e77
2 changed files with 10 additions and 8 deletions

View File

@ -27,7 +27,7 @@ from torch.testing._internal.common_methods_invocations import tri_tests_args, t
from torch.testing._internal.common_utils import TestCase, freeze_rng_state, run_tests, \
NO_MULTIPROCESSING_SPAWN, skipIfRocm, load_tests, IS_REMOTE_GPU, IS_SANDCASTLE, IS_WINDOWS, \
slowTest, skipCUDANonDefaultStreamIf, skipCUDAMemoryLeakCheckIf, TEST_WITH_ROCM, TEST_NUMPY, \
get_cycles_per_ms, parametrize, instantiate_parametrized_tests
get_cycles_per_ms, parametrize, instantiate_parametrized_tests, subtest
from torch.testing._internal.autocast_test_lists import AutocastTestLists
# load_tests from common_utils is used to automatically filter tests for
@ -3752,7 +3752,8 @@ torch.cuda.synchronize()
@unittest.skipIf((not TEST_CUDA) or
TEST_WITH_ROCM or
int(torch.version.cuda.split(".")[0]) < 11, "CUDA >= 11.0 required for graphs")
@parametrize('with_amp,cache_enabled', [(False, False), (True, False), (True, True)],
@parametrize('with_amp,cache_enabled', [(False, False), (True, False), subtest((True, True),
decorators=[unittest.expectedFailure])],
name_fn=lambda x, y: '{}{}'.format({True: "with_amp", False: "without_amp"}[x],
{True: "_cache_enabled", False: "_cache_disabled"}[y] if x else ''))
def test_graph_make_graphed_callables(self, with_amp, cache_enabled):

View File

@ -223,9 +223,16 @@ def make_graphed_callables(callables, sample_args, num_warmup_iters=3):
When running a graphed callable, you must pass its arguments in the same order and format
they appeared in that callable's ``sample_args``.
.. warning::
The automatic mixed precision is supported in :func:`~torch.cuda.make_graphed_callables` only with disabled
caching. The context manager `torch.cuda.amp.autocast()` must have `cache_enabled=False`.
.. warning::
All Tensor outputs of graphed callables must require grad.
"""
if torch.is_autocast_enabled() and torch.is_autocast_cache_enabled():
raise RuntimeError("make_graphed_callables does not support the autocast caching. Please set `cache_enabled=False`.")
just_one_callable = False
if not isinstance(callables, tuple):
@ -281,9 +288,6 @@ def make_graphed_callables(callables, sample_args, num_warmup_iters=3):
# the safest approach is to capture all passes in the same order they'll run:
# fwd 1, fwd 2, ... fwd N, then bwd N, bwd N-1, ... bwd 1.
# Clear AMP autocast cache before capturing the graphs
torch.clear_autocast_cache()
# Capture forward graphs
per_callable_static_outputs = []
per_callable_output_was_tensor = []
@ -343,9 +347,6 @@ def make_graphed_callables(callables, sample_args, num_warmup_iters=3):
per_callable_static_grad_inputs = list(reversed(per_callable_static_grad_inputs))
# Now for every per_callable list, per_callable_*[i] holds the stuff for the ith callable.
# Clear AMP autocast cache after both forward and backward graphs are captured
torch.clear_autocast_cache()
def make_graphed_autograd_function(fwd_graph,
bwd_graph,
module_params,