mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Disable autocast in aot autograd (#86515)
Fix for https://github.com/pytorch/torchdynamo/issues/1368 From comment: > When we invoke a Composite Implicit autograd operator that has an autocast rule, such as Einsum, autocast is disabled during its invocation. When we trace out the operators in an implicit op, re-applying on autocast rules on those operators might yield divergence from what was executed at runtime. This pass checks for divergence. If divergence is found, we will disable autocast. We would like to avoid disabling autocast if possible because accessing TLS is slow. Concretely, the problem found was when invoked `sum` in `einsum`: As seen by the following divergence: ``` >>> with torch.cuda.amp.autocast(enabled=True): ... print(torch.ops.aten.sum.dim_IntList(torch.rand([2, 2, 2], device="cuda", dtype=torch.half), [1, 2]).dtype) ... torch.float32 >>> print(torch.ops.aten.sum.dim_IntList(torch.rand([2, 2, 2], device="cuda", dtype=torch.half), [1, 2]).dtype) torch.float16 ``` Edit: we've decided to accept the overhead of universally disabling autocast instead Pull Request resolved: https://github.com/pytorch/pytorch/pull/86515 Approved by: https://github.com/bdhirsh, https://github.com/Chillee
This commit is contained in:
committed by
PyTorch MergeBot
parent
d598290baa
commit
c4f0b93f86
@ -243,22 +243,28 @@ def make_boxed_compiler(compiler):
|
||||
return f
|
||||
|
||||
|
||||
def call_func_with_args(f, args, steal_args=False):
|
||||
def call_func_with_args(f, args, steal_args=False, disable_amp=False):
|
||||
if not steal_args:
|
||||
args = list(args)
|
||||
assert isinstance(args, list)
|
||||
|
||||
if hasattr(f, "_boxed_call"):
|
||||
out = normalize_as_list(f(args))
|
||||
else:
|
||||
# TODO: Please remove soon
|
||||
# https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670
|
||||
warnings.warn(
|
||||
"Your compiler for AOTAutograd is returning a a function that doesn't take boxed arguments. "
|
||||
"Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. "
|
||||
"See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale."
|
||||
)
|
||||
out = normalize_as_list(f(*args))
|
||||
if disable_amp:
|
||||
guard = torch._C._DisableAutocast()
|
||||
try:
|
||||
if hasattr(f, "_boxed_call"):
|
||||
out = normalize_as_list(f(args))
|
||||
else:
|
||||
# TODO: Please remove soon
|
||||
# https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670
|
||||
warnings.warn(
|
||||
"Your compiler for AOTAutograd is returning a a function that doesn't take boxed arguments. "
|
||||
"Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. "
|
||||
"See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale."
|
||||
)
|
||||
out = normalize_as_list(f(*args))
|
||||
finally:
|
||||
if disable_amp:
|
||||
del guard
|
||||
return out
|
||||
|
||||
|
||||
@ -279,17 +285,31 @@ def aot_dispatch_base(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig):
|
||||
if config.debug_graphs:
|
||||
print("====== Forward (only) graph ======")
|
||||
fw_module.print_readable()
|
||||
with track_graph_compiling("inference"):
|
||||
|
||||
|
||||
disable_amp = torch._C._is_any_autocast_enabled()
|
||||
context = disable_autocast_manager if disable_amp else nullcontext
|
||||
|
||||
with context(), track_graph_compiling("inference"):
|
||||
compiled_fw = aot_config.fw_compiler(fw_module, flat_args)
|
||||
|
||||
@wraps(compiled_fw)
|
||||
def new_fn(args):
|
||||
fw_outs = call_func_with_args(compiled_fw, args)
|
||||
fw_outs = call_func_with_args(compiled_fw, args, disable_amp=disable_amp)
|
||||
return fw_outs
|
||||
|
||||
return new_fn
|
||||
|
||||
|
||||
@contextmanager
|
||||
def disable_autocast_manager():
|
||||
guard = torch._C._DisableAutocast()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
del guard
|
||||
|
||||
|
||||
def aot_dispatch_autograd(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig):
|
||||
# Deduplicate inputs. Suppose you have:
|
||||
#
|
||||
@ -360,12 +380,16 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Tensor], aot_config: AOTConfi
|
||||
|
||||
joint_inputs = (deduped_flat_args, out)
|
||||
|
||||
disable_amp = torch._C._is_any_autocast_enabled()
|
||||
|
||||
if config.use_functionalize:
|
||||
# Trace once without decompositions, into a graph of ATen ops.
|
||||
# NB: tracing_mode is real, as it's assumed the calling context setup
|
||||
# fake tensor mode / symbolic shapes if that is needed
|
||||
fx_g = make_fx(joint_forward_backward)(*joint_inputs)
|
||||
|
||||
context = disable_autocast_manager if disable_amp else nullcontext
|
||||
|
||||
def fake_fn(primals, tangents):
|
||||
with torch.fx.traceback.override_stack_trace():
|
||||
return torch.fx.Interpreter(fx_g).run(primals, tangents)
|
||||
@ -375,7 +399,8 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Tensor], aot_config: AOTConfi
|
||||
# view and inplace ops that come from primtorch.
|
||||
# Eventually, functionalization should support primtorch view/inplace ops,
|
||||
# which will make it ok to run decompositions before functionalization.
|
||||
fx_g = make_fx(functionalize(fake_fn), aot_config.decompositions)(*joint_inputs)
|
||||
with context():
|
||||
fx_g = make_fx(functionalize(fake_fn), aot_config.decompositions)(*joint_inputs)
|
||||
fx_g.graph.eliminate_dead_code()
|
||||
fx_g.recompile()
|
||||
else:
|
||||
@ -414,7 +439,7 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Tensor], aot_config: AOTConfi
|
||||
@disable_torchdynamo
|
||||
def forward(ctx, *deduped_flat_tensor_args):
|
||||
fw_outs = call_func_with_args(
|
||||
CompiledFunction.compiled_fw, deduped_flat_tensor_args
|
||||
CompiledFunction.compiled_fw, deduped_flat_tensor_args, disable_amp=disable_amp
|
||||
)
|
||||
num_outs = CompiledFunction.num_outs
|
||||
num_symints = CompiledFunction.num_symints
|
||||
@ -433,15 +458,15 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Tensor], aot_config: AOTConfi
|
||||
contiguous_args = [t.contiguous() if torch.is_tensor(t) else t for t in flat_args]
|
||||
all_args = list(ctx.symints) + list(ctx.saved_tensors) + list(contiguous_args)
|
||||
if CompiledFunction.compiled_bw is None:
|
||||
with track_graph_compiling("backward", True):
|
||||
context = disable_autocast_manager if disable_amp else nullcontext
|
||||
with context(), track_graph_compiling("backward", True):
|
||||
CompiledFunction.compiled_bw = aot_config.bw_compiler(
|
||||
bw_module, all_args
|
||||
)
|
||||
ctx.maybe_clear_saved_tensors()
|
||||
out = call_func_with_args(
|
||||
CompiledFunction.compiled_bw, all_args, steal_args=True
|
||||
CompiledFunction.compiled_bw, all_args, steal_args=True, disable_amp=disable_amp
|
||||
)
|
||||
|
||||
return tuple(out)
|
||||
|
||||
@wraps(CompiledFunction.apply)
|
||||
|
@ -402,6 +402,31 @@ class TestAOTAutograd(AOTTestCase):
|
||||
|
||||
self.assertEqual(ref_out, test_out)
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
|
||||
def test_autocast_disable_guard(self):
|
||||
guard = torch._C._DisableAutocast()
|
||||
try:
|
||||
x = torch.rand([4, 4]).cuda()
|
||||
y = x @ x
|
||||
self.assertEqual(y.dtype, torch.float32)
|
||||
finally:
|
||||
del guard
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
|
||||
def test_nonidempotent_amp(self):
|
||||
def f(self_s_emb, add_3):
|
||||
einsum_2 = torch.functional.einsum('ah,th->t', self_s_emb, add_3)
|
||||
log_softmax_2 = einsum_2.log_softmax(-1)
|
||||
return (log_softmax_2,)
|
||||
|
||||
args = [torch.rand((1, 256), dtype=torch.float32, device='cuda'), torch.rand((30, 256), dtype=torch.float16, device='cuda')]
|
||||
with torch.cuda.amp.autocast(enabled=True):
|
||||
self.verify_aot_autograd(f, args)
|
||||
|
||||
args = [e.requires_grad_(True) for e in args]
|
||||
with torch.cuda.amp.autocast(enabled=True):
|
||||
self.verify_aot_autograd(f, args)
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
|
||||
def test_batch_norm_amp(self):
|
||||
device = "cuda"
|
||||
|
@ -54,6 +54,10 @@ struct MultithreadingEnabled {
|
||||
bool old_;
|
||||
};
|
||||
|
||||
struct DisableAutocast {
|
||||
c10::impl::ExcludeDispatchKeyGuard guard_{c10::autocast_dispatch_keyset};
|
||||
};
|
||||
|
||||
struct EnableTorchFunction {
|
||||
EnableTorchFunction()
|
||||
: old_(at::impl::PythonTorchFunctionTLS::is_disabled()) {
|
||||
@ -367,7 +371,7 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
|
||||
py::class_<DisableFuncTorch>(_C_m, "_DisableFuncTorch").def(py::init<>());
|
||||
py::class_<MultithreadingEnabled>(_C_m, "_MultithreadingEnabled")
|
||||
.def(py::init<bool>());
|
||||
|
||||
py::class_<DisableAutocast>(_C_m, "_DisableAutocast").def(py::init<>());
|
||||
py::class_<torch::autograd::SavedVariable>(m, "SavedTensor")
|
||||
.def(py::init([]() -> torch::autograd::SavedVariable {
|
||||
TORCH_CHECK(
|
||||
@ -413,6 +417,17 @@ static PyObject* is_autocast_enabled(PyObject* _unused, PyObject* arg) {
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* is_any_autocast_enabled(PyObject* _unused, PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (at::autocast::is_enabled() || at::autocast::is_cpu_enabled() ||
|
||||
at::autocast::is_xpu_enabled()) {
|
||||
Py_RETURN_TRUE;
|
||||
} else {
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* set_autocast_cpu_enabled(PyObject* _unused, PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (!PyBool_Check(arg)) {
|
||||
@ -757,6 +772,7 @@ static PyMethodDef methods[] = { // NOLINT
|
||||
nullptr},
|
||||
{"set_autocast_enabled", set_autocast_enabled, METH_O, nullptr},
|
||||
{"is_autocast_enabled", is_autocast_enabled, METH_NOARGS, nullptr},
|
||||
{"_is_any_autocast_enabled", is_any_autocast_enabled, METH_NOARGS, nullptr},
|
||||
{"clear_autocast_cache", clear_autocast_cache, METH_NOARGS, nullptr},
|
||||
{"set_autocast_cpu_enabled", set_autocast_cpu_enabled, METH_O, nullptr},
|
||||
{"is_autocast_cpu_enabled", is_autocast_cpu_enabled, METH_NOARGS, nullptr},
|
||||
|
Reference in New Issue
Block a user