Disable autocast in aot autograd (#86515)

Fix for https://github.com/pytorch/torchdynamo/issues/1368

From comment:
> When we invoke a Composite Implicit autograd operator that has an autocast rule, such as Einsum,
autocast is disabled during its invocation. When we trace out the operators in an implicit op,
re-applying on autocast rules on those operators might yield divergence from what was executed at runtime.
This pass checks for divergence. If divergence is found, we will disable autocast.
We would like to avoid disabling autocast if possible because accessing TLS is slow.

Concretely, the problem found was when invoked `sum` in `einsum`:

As seen by the following divergence:
```
>>> with torch.cuda.amp.autocast(enabled=True):
...     print(torch.ops.aten.sum.dim_IntList(torch.rand([2, 2, 2], device="cuda", dtype=torch.half), [1, 2]).dtype)
...
torch.float32
>>> print(torch.ops.aten.sum.dim_IntList(torch.rand([2, 2, 2], device="cuda", dtype=torch.half), [1, 2]).dtype)
torch.float16
```

Edit: we've decided to accept the overhead of universally disabling autocast instead
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86515
Approved by: https://github.com/bdhirsh, https://github.com/Chillee
This commit is contained in:
Elias Ellison
2022-10-11 01:24:48 +00:00
committed by PyTorch MergeBot
parent d598290baa
commit c4f0b93f86
3 changed files with 86 additions and 20 deletions

View File

@ -243,22 +243,28 @@ def make_boxed_compiler(compiler):
return f
def call_func_with_args(f, args, steal_args=False):
def call_func_with_args(f, args, steal_args=False, disable_amp=False):
if not steal_args:
args = list(args)
assert isinstance(args, list)
if hasattr(f, "_boxed_call"):
out = normalize_as_list(f(args))
else:
# TODO: Please remove soon
# https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670
warnings.warn(
"Your compiler for AOTAutograd is returning a a function that doesn't take boxed arguments. "
"Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. "
"See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale."
)
out = normalize_as_list(f(*args))
if disable_amp:
guard = torch._C._DisableAutocast()
try:
if hasattr(f, "_boxed_call"):
out = normalize_as_list(f(args))
else:
# TODO: Please remove soon
# https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670
warnings.warn(
"Your compiler for AOTAutograd is returning a a function that doesn't take boxed arguments. "
"Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. "
"See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale."
)
out = normalize_as_list(f(*args))
finally:
if disable_amp:
del guard
return out
@ -279,17 +285,31 @@ def aot_dispatch_base(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig):
if config.debug_graphs:
print("====== Forward (only) graph ======")
fw_module.print_readable()
with track_graph_compiling("inference"):
disable_amp = torch._C._is_any_autocast_enabled()
context = disable_autocast_manager if disable_amp else nullcontext
with context(), track_graph_compiling("inference"):
compiled_fw = aot_config.fw_compiler(fw_module, flat_args)
@wraps(compiled_fw)
def new_fn(args):
fw_outs = call_func_with_args(compiled_fw, args)
fw_outs = call_func_with_args(compiled_fw, args, disable_amp=disable_amp)
return fw_outs
return new_fn
@contextmanager
def disable_autocast_manager():
guard = torch._C._DisableAutocast()
try:
yield
finally:
del guard
def aot_dispatch_autograd(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig):
# Deduplicate inputs. Suppose you have:
#
@ -360,12 +380,16 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Tensor], aot_config: AOTConfi
joint_inputs = (deduped_flat_args, out)
disable_amp = torch._C._is_any_autocast_enabled()
if config.use_functionalize:
# Trace once without decompositions, into a graph of ATen ops.
# NB: tracing_mode is real, as it's assumed the calling context setup
# fake tensor mode / symbolic shapes if that is needed
fx_g = make_fx(joint_forward_backward)(*joint_inputs)
context = disable_autocast_manager if disable_amp else nullcontext
def fake_fn(primals, tangents):
with torch.fx.traceback.override_stack_trace():
return torch.fx.Interpreter(fx_g).run(primals, tangents)
@ -375,7 +399,8 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Tensor], aot_config: AOTConfi
# view and inplace ops that come from primtorch.
# Eventually, functionalization should support primtorch view/inplace ops,
# which will make it ok to run decompositions before functionalization.
fx_g = make_fx(functionalize(fake_fn), aot_config.decompositions)(*joint_inputs)
with context():
fx_g = make_fx(functionalize(fake_fn), aot_config.decompositions)(*joint_inputs)
fx_g.graph.eliminate_dead_code()
fx_g.recompile()
else:
@ -414,7 +439,7 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Tensor], aot_config: AOTConfi
@disable_torchdynamo
def forward(ctx, *deduped_flat_tensor_args):
fw_outs = call_func_with_args(
CompiledFunction.compiled_fw, deduped_flat_tensor_args
CompiledFunction.compiled_fw, deduped_flat_tensor_args, disable_amp=disable_amp
)
num_outs = CompiledFunction.num_outs
num_symints = CompiledFunction.num_symints
@ -433,15 +458,15 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Tensor], aot_config: AOTConfi
contiguous_args = [t.contiguous() if torch.is_tensor(t) else t for t in flat_args]
all_args = list(ctx.symints) + list(ctx.saved_tensors) + list(contiguous_args)
if CompiledFunction.compiled_bw is None:
with track_graph_compiling("backward", True):
context = disable_autocast_manager if disable_amp else nullcontext
with context(), track_graph_compiling("backward", True):
CompiledFunction.compiled_bw = aot_config.bw_compiler(
bw_module, all_args
)
ctx.maybe_clear_saved_tensors()
out = call_func_with_args(
CompiledFunction.compiled_bw, all_args, steal_args=True
CompiledFunction.compiled_bw, all_args, steal_args=True, disable_amp=disable_amp
)
return tuple(out)
@wraps(CompiledFunction.apply)

View File

@ -402,6 +402,31 @@ class TestAOTAutograd(AOTTestCase):
self.assertEqual(ref_out, test_out)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
def test_autocast_disable_guard(self):
guard = torch._C._DisableAutocast()
try:
x = torch.rand([4, 4]).cuda()
y = x @ x
self.assertEqual(y.dtype, torch.float32)
finally:
del guard
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
def test_nonidempotent_amp(self):
def f(self_s_emb, add_3):
einsum_2 = torch.functional.einsum('ah,th->t', self_s_emb, add_3)
log_softmax_2 = einsum_2.log_softmax(-1)
return (log_softmax_2,)
args = [torch.rand((1, 256), dtype=torch.float32, device='cuda'), torch.rand((30, 256), dtype=torch.float16, device='cuda')]
with torch.cuda.amp.autocast(enabled=True):
self.verify_aot_autograd(f, args)
args = [e.requires_grad_(True) for e in args]
with torch.cuda.amp.autocast(enabled=True):
self.verify_aot_autograd(f, args)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
def test_batch_norm_amp(self):
device = "cuda"

View File

@ -54,6 +54,10 @@ struct MultithreadingEnabled {
bool old_;
};
struct DisableAutocast {
c10::impl::ExcludeDispatchKeyGuard guard_{c10::autocast_dispatch_keyset};
};
struct EnableTorchFunction {
EnableTorchFunction()
: old_(at::impl::PythonTorchFunctionTLS::is_disabled()) {
@ -367,7 +371,7 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
py::class_<DisableFuncTorch>(_C_m, "_DisableFuncTorch").def(py::init<>());
py::class_<MultithreadingEnabled>(_C_m, "_MultithreadingEnabled")
.def(py::init<bool>());
py::class_<DisableAutocast>(_C_m, "_DisableAutocast").def(py::init<>());
py::class_<torch::autograd::SavedVariable>(m, "SavedTensor")
.def(py::init([]() -> torch::autograd::SavedVariable {
TORCH_CHECK(
@ -413,6 +417,17 @@ static PyObject* is_autocast_enabled(PyObject* _unused, PyObject* arg) {
END_HANDLE_TH_ERRORS
}
static PyObject* is_any_autocast_enabled(PyObject* _unused, PyObject* arg) {
HANDLE_TH_ERRORS
if (at::autocast::is_enabled() || at::autocast::is_cpu_enabled() ||
at::autocast::is_xpu_enabled()) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
END_HANDLE_TH_ERRORS
}
static PyObject* set_autocast_cpu_enabled(PyObject* _unused, PyObject* arg) {
HANDLE_TH_ERRORS
if (!PyBool_Check(arg)) {
@ -757,6 +772,7 @@ static PyMethodDef methods[] = { // NOLINT
nullptr},
{"set_autocast_enabled", set_autocast_enabled, METH_O, nullptr},
{"is_autocast_enabled", is_autocast_enabled, METH_NOARGS, nullptr},
{"_is_any_autocast_enabled", is_any_autocast_enabled, METH_NOARGS, nullptr},
{"clear_autocast_cache", clear_autocast_cache, METH_NOARGS, nullptr},
{"set_autocast_cpu_enabled", set_autocast_cpu_enabled, METH_O, nullptr},
{"is_autocast_cpu_enabled", is_autocast_cpu_enabled, METH_NOARGS, nullptr},