mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
make torch.amp.autocast more generic (#125103)
# Motivation As discussed in [#124479](https://github.com/pytorch/pytorch/pull/124479), `torch.amp.autocast` can NOT be completely equivalent to `torch.cuda.amp.autocast` and `torch.cpu.amp.autocast` since `torch.amp.autocast` has NOT the default `dtype` for CPU (`torch.bfloat16` by default) and CUDA (`torch.float16` by default) respectively. We would like `torch.amp.autocast` to be more generic to help the developer/customer write the device-agnostic code. Because there are not enough reasons to add device-specific autocast `torch.xxx.amp.autocast` for each device backend. # Solution When `None` is passed to `dtype`, we should use `torch.get_autocast_dtype` to get the related dtype for each backend. Meanwhile, `torch.get_autocast_dtype` is necessary to be supported in JIT path for BC. # Additional Context With this PR, `torch.amp.autocast(device_type='cuda')` is equivalent to `torch.cuda.amp.autocast`. Add two new UTs to cover this change in eager and jit path respectively. Pull Request resolved: https://github.com/pytorch/pytorch/pull/125103 Approved by: https://github.com/albanD, https://github.com/jgong5, https://github.com/gujinghui
This commit is contained in:
committed by
PyTorch MergeBot
parent
320af5eaa6
commit
d17be10df1
@ -227,6 +227,7 @@ namespace c10 {
|
||||
_(aten, is_autocast_enabled) \
|
||||
_(aten, is_autocast_cpu_enabled) \
|
||||
_(aten, is_autocast_xla_enabled) \
|
||||
_(aten, get_autocast_dtype) \
|
||||
FORALL_ATEN_BASE_SYMBOLS(_) \
|
||||
_(onnx, Add) \
|
||||
_(onnx, Concat) \
|
||||
|
@ -2086,7 +2086,6 @@ class BenchmarkRunner:
|
||||
|
||||
devices = [current_device] if current_device else self.args.devices
|
||||
if self.args.amp:
|
||||
if devices == ["cuda"]:
|
||||
# AMP training can lead to small loss values which can undeflow
|
||||
# gradient values returning in zero gradients. To solve this
|
||||
# problem, PyTorch introduces GradScaler. GradScaler is a stateful
|
||||
@ -2108,9 +2107,9 @@ class BenchmarkRunner:
|
||||
# factor between eager and dynamo run, making accuracy check
|
||||
# harder.
|
||||
# self.grad_scaler = torch.cuda.amp.GradScaler(init_scale=2.0)
|
||||
self.autocast = torch.cuda.amp.autocast
|
||||
if devices == ["cpu"]:
|
||||
self.autocast = torch.cpu.amp.autocast
|
||||
self.autocast = functools.partial(
|
||||
torch.amp.autocast, device_type=devices[0]
|
||||
)
|
||||
if self.args.amp_dtype:
|
||||
amp_dtype = (
|
||||
torch.float16
|
||||
|
@ -244,6 +244,15 @@ class TestAutocastCPU(TestCase):
|
||||
with torch.autocast(device_type="cpu", dtype=torch.float32, enabled=False):
|
||||
_ = torch.ones(10)
|
||||
|
||||
def test_generic_autocast(self):
|
||||
for op_with_args in self.autocast_lists.torch_16:
|
||||
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
|
||||
with torch.amp.autocast(device_type="cpu"):
|
||||
generic_autocast_output = getattr(torch, op)(*args, **maybe_kwargs)
|
||||
with torch.cpu.amp.autocast():
|
||||
cpu_autocast_output = getattr(torch, op)(*args, **maybe_kwargs)
|
||||
self.assertEqual(generic_autocast_output, cpu_autocast_output)
|
||||
|
||||
|
||||
class CustomLinear(torch.autograd.Function):
|
||||
@staticmethod
|
||||
|
@ -33,6 +33,23 @@ class TestAutocast(JitTestCase):
|
||||
torch._C._jit_set_autocast_mode(self.old_value)
|
||||
super().tearDown()
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
||||
def test_jit_generic_autocast(self):
|
||||
@torch.jit.script
|
||||
def fn_cuda_autocast(a, b):
|
||||
with autocast():
|
||||
x = torch.mm(a, b)
|
||||
y = torch.sum(x)
|
||||
return x, y
|
||||
|
||||
@torch.jit.script
|
||||
def fn_generic_autocast(a, b):
|
||||
with torch.amp.autocast(device_type='cuda'):
|
||||
x = torch.mm(a, b)
|
||||
y = torch.sum(x)
|
||||
return x, y
|
||||
self.assertEqual(fn_cuda_autocast(self.a_fp32, self.b_fp32), fn_generic_autocast(self.a_fp32, self.b_fp32))
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
||||
def test_minimal(self):
|
||||
@torch.jit.script
|
||||
|
@ -599,28 +599,20 @@ class OutputGraph:
|
||||
)
|
||||
global_state["grad_enabled"] = (torch.set_grad_enabled, torch.is_grad_enabled())
|
||||
|
||||
def autocast_specific_backend(
|
||||
device_type: str, func: Callable[[str, Any], None]
|
||||
):
|
||||
def decorator(value):
|
||||
return func(device_type, value)
|
||||
|
||||
return decorator
|
||||
|
||||
global_state["autocast_enabled"] = (
|
||||
autocast_specific_backend("cuda", torch.set_autocast_enabled),
|
||||
functools.partial(torch.set_autocast_enabled, "cuda"),
|
||||
torch.is_autocast_enabled("cuda"),
|
||||
)
|
||||
global_state["autocast_cpu_enabled"] = (
|
||||
autocast_specific_backend("cpu", torch.set_autocast_enabled),
|
||||
functools.partial(torch.set_autocast_enabled, "cpu"),
|
||||
torch.is_autocast_enabled("cpu"),
|
||||
)
|
||||
global_state["autocast_gpu_dtype"] = (
|
||||
autocast_specific_backend("cuda", torch.set_autocast_dtype),
|
||||
functools.partial(torch.set_autocast_dtype, "cuda"),
|
||||
torch.get_autocast_dtype("cuda"),
|
||||
)
|
||||
global_state["autocast_cpu_dtype"] = (
|
||||
autocast_specific_backend("cpu", torch.set_autocast_dtype),
|
||||
functools.partial(torch.set_autocast_dtype, "cpu"),
|
||||
torch.get_autocast_dtype("cpu"),
|
||||
)
|
||||
global_state["autocast_cache_enabled"] = (
|
||||
|
@ -191,7 +191,10 @@ class autocast:
|
||||
Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
|
||||
enabled(bool, optional): Whether autocasting should be enabled in the region.
|
||||
Default: ``True``
|
||||
dtype(torch_dtype, optional): Whether to use torch.float16 or torch.bfloat16.
|
||||
dtype(torch_dtype, optional): Data type for ops run in autocast. It uses the default value
|
||||
(``torch.float16`` for CUDA and ``torch.bfloat16`` for CPU), given by
|
||||
:func:`~torch.get_autocast_dtype`, if :attr:`dtype` is ``None``.
|
||||
Default: ``None``
|
||||
cache_enabled(bool, optional): Whether the weight cache inside autocast should be enabled.
|
||||
Default: ``True``
|
||||
"""
|
||||
@ -207,11 +210,12 @@ class autocast:
|
||||
raise ValueError(
|
||||
f"Expected `device_type` of type `str`, got: `{type(device_type)}`"
|
||||
)
|
||||
if dtype is None:
|
||||
dtype = torch.get_autocast_dtype(device_type)
|
||||
if torch._jit_internal.is_scripting():
|
||||
self._enabled = enabled
|
||||
self.device = device_type
|
||||
self.fast_dtype = dtype
|
||||
# TODO: support get_autocast_gpu/cpu_dtype
|
||||
assert dtype is not None
|
||||
return
|
||||
self.device = device_type
|
||||
|
@ -96,17 +96,19 @@ c10::optional<AutocastScope> parseAutocast(
|
||||
use.user->s(attr::name) == "fast_dtype") {
|
||||
// Search for `prim::SetAttr[name="fast_dtype"]`
|
||||
auto ret = constant_as<c10::ScalarType>(use.user->input(1));
|
||||
TORCH_CHECK(
|
||||
ret.has_value() && ret.value() != c10::ScalarType::Undefined,
|
||||
"Autocast dtype argument must be a constant and defined");
|
||||
if (ret.has_value()) {
|
||||
dtype = ret.value();
|
||||
}
|
||||
}
|
||||
}
|
||||
TORCH_CHECK(enabled.has_value(), "Autocast missing _enabled attribute");
|
||||
TORCH_CHECK(!device.empty(), "Autocast missing device attribute");
|
||||
if (dtype == c10::ScalarType::Undefined) {
|
||||
dtype = at::autocast::get_autocast_dtype(c10::Device(device).type());
|
||||
}
|
||||
TORCH_CHECK(
|
||||
dtype != c10::ScalarType::Undefined,
|
||||
"Autocast missing fast_dtype attribute");
|
||||
TORCH_CHECK(!device.empty(), "Autocast missing device attribute");
|
||||
"Autocast has invalid fast_dtype attribute");
|
||||
if (device == "cuda") {
|
||||
scope.context.gpu_enabled = enabled.value();
|
||||
scope.context.gpu_scalar_type = dtype;
|
||||
|
@ -815,6 +815,21 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
|
||||
push(stack, enabled);
|
||||
},
|
||||
aliasAnalysisConservative()),
|
||||
OperatorGeneratorArgs(
|
||||
TORCH_SELECTIVE_SCHEMA(
|
||||
"aten::get_autocast_dtype(str device_type) -> ScalarType"),
|
||||
[](Stack& stack) {
|
||||
#if defined BUILD_LITE_INTERPRETER || defined C10_MOBILE
|
||||
// autocast is not supported.
|
||||
at::ScalarType dtype = at::ScalarType::Undefined;
|
||||
#else
|
||||
at::DeviceType device_type =
|
||||
at::Device(pop(stack).toStringRef()).type();
|
||||
at::ScalarType dtype = at::autocast::get_autocast_dtype(device_type);
|
||||
#endif
|
||||
push(stack, dtype);
|
||||
},
|
||||
aliasAnalysisConservative()),
|
||||
OperatorGeneratorArgs(
|
||||
TORCH_SELECTIVE_SCHEMA("prim::Uninitialized() -> Any"),
|
||||
unInitialized,
|
||||
|
@ -288,11 +288,10 @@ class CheckpointFunction(torch.autograd.Function):
|
||||
set_device_states(ctx.fwd_devices, ctx.fwd_device_states)
|
||||
detached_inputs = detach_variable(tuple(inputs))
|
||||
|
||||
device_autocast_ctx = device_module.amp.autocast(
|
||||
**ctx.device_autocast_kwargs
|
||||
device_autocast_ctx = torch.amp.autocast(
|
||||
device_type=ctx.device, **ctx.device_autocast_kwargs
|
||||
) if torch.amp.is_autocast_available(ctx.device) else contextlib.nullcontext()
|
||||
with torch.enable_grad(), device_autocast_ctx, \
|
||||
torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
|
||||
with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): # type: ignore[attr-defined]
|
||||
outputs = ctx.run_function(*detached_inputs)
|
||||
|
||||
if isinstance(outputs, torch.Tensor):
|
||||
@ -1395,11 +1394,10 @@ def _checkpoint_without_reentrant_generator(
|
||||
if had_device_in_fwd:
|
||||
set_device_states(fwd_devices, fwd_device_states)
|
||||
|
||||
device_autocast_ctx = device_module.amp.autocast(
|
||||
**device_autocast_kwargs
|
||||
device_autocast_ctx = torch.amp.autocast(
|
||||
device_type=device, **device_autocast_kwargs
|
||||
) if torch.amp.is_autocast_available(device) else contextlib.nullcontext()
|
||||
with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), \
|
||||
recompute_context:
|
||||
with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
|
||||
fn(*args, **kwargs)
|
||||
|
||||
new_frame = _CheckpointFrame(
|
||||
|
Reference in New Issue
Block a user