mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
Disable autocast cache in torch.cuda.make_graphed_callables (#84289)
There there are conflicts between `torch.clear_autocast_cache()` and `cudaMallocAsync` from #82682. Moreover, the use of autocast caching is not reasonable during training which is the main target of `make_graphed_callables`. cc @eqy @ptrblck Pull Request resolved: https://github.com/pytorch/pytorch/pull/84289 Approved by: https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
d39490a711
commit
ce1b727e77
@ -27,7 +27,7 @@ from torch.testing._internal.common_methods_invocations import tri_tests_args, t
|
||||
from torch.testing._internal.common_utils import TestCase, freeze_rng_state, run_tests, \
|
||||
NO_MULTIPROCESSING_SPAWN, skipIfRocm, load_tests, IS_REMOTE_GPU, IS_SANDCASTLE, IS_WINDOWS, \
|
||||
slowTest, skipCUDANonDefaultStreamIf, skipCUDAMemoryLeakCheckIf, TEST_WITH_ROCM, TEST_NUMPY, \
|
||||
get_cycles_per_ms, parametrize, instantiate_parametrized_tests
|
||||
get_cycles_per_ms, parametrize, instantiate_parametrized_tests, subtest
|
||||
from torch.testing._internal.autocast_test_lists import AutocastTestLists
|
||||
|
||||
# load_tests from common_utils is used to automatically filter tests for
|
||||
@ -3752,7 +3752,8 @@ torch.cuda.synchronize()
|
||||
@unittest.skipIf((not TEST_CUDA) or
|
||||
TEST_WITH_ROCM or
|
||||
int(torch.version.cuda.split(".")[0]) < 11, "CUDA >= 11.0 required for graphs")
|
||||
@parametrize('with_amp,cache_enabled', [(False, False), (True, False), (True, True)],
|
||||
@parametrize('with_amp,cache_enabled', [(False, False), (True, False), subtest((True, True),
|
||||
decorators=[unittest.expectedFailure])],
|
||||
name_fn=lambda x, y: '{}{}'.format({True: "with_amp", False: "without_amp"}[x],
|
||||
{True: "_cache_enabled", False: "_cache_disabled"}[y] if x else ''))
|
||||
def test_graph_make_graphed_callables(self, with_amp, cache_enabled):
|
||||
|
||||
@ -223,9 +223,16 @@ def make_graphed_callables(callables, sample_args, num_warmup_iters=3):
|
||||
When running a graphed callable, you must pass its arguments in the same order and format
|
||||
they appeared in that callable's ``sample_args``.
|
||||
|
||||
.. warning::
|
||||
The automatic mixed precision is supported in :func:`~torch.cuda.make_graphed_callables` only with disabled
|
||||
caching. The context manager `torch.cuda.amp.autocast()` must have `cache_enabled=False`.
|
||||
|
||||
.. warning::
|
||||
All Tensor outputs of graphed callables must require grad.
|
||||
"""
|
||||
if torch.is_autocast_enabled() and torch.is_autocast_cache_enabled():
|
||||
raise RuntimeError("make_graphed_callables does not support the autocast caching. Please set `cache_enabled=False`.")
|
||||
|
||||
just_one_callable = False
|
||||
|
||||
if not isinstance(callables, tuple):
|
||||
@ -281,9 +288,6 @@ def make_graphed_callables(callables, sample_args, num_warmup_iters=3):
|
||||
# the safest approach is to capture all passes in the same order they'll run:
|
||||
# fwd 1, fwd 2, ... fwd N, then bwd N, bwd N-1, ... bwd 1.
|
||||
|
||||
# Clear AMP autocast cache before capturing the graphs
|
||||
torch.clear_autocast_cache()
|
||||
|
||||
# Capture forward graphs
|
||||
per_callable_static_outputs = []
|
||||
per_callable_output_was_tensor = []
|
||||
@ -343,9 +347,6 @@ def make_graphed_callables(callables, sample_args, num_warmup_iters=3):
|
||||
per_callable_static_grad_inputs = list(reversed(per_callable_static_grad_inputs))
|
||||
# Now for every per_callable list, per_callable_*[i] holds the stuff for the ith callable.
|
||||
|
||||
# Clear AMP autocast cache after both forward and backward graphs are captured
|
||||
torch.clear_autocast_cache()
|
||||
|
||||
def make_graphed_autograd_function(fwd_graph,
|
||||
bwd_graph,
|
||||
module_params,
|
||||
|
||||
Reference in New Issue
Block a user