mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
This PR relaxes two restrictions on torch.autocast in the DeepSpeed engine: 1) Nesting torch.autocast Currently, we do not expect `torch.autocast` to be used outside the DeepSpeed engine. Here is the current behavior: - If `torch.autocast` is enabled in the DeepSpeed config and the engine detects it is also enabled outside, a warning is displayed. - If it is disabled in the config, the engine raises an error. This design prevents the following usage: ```python with torch.autocast(...): logits = deepspeed_model(...) loss = criteria_fn(logits) ``` In this case, we also want to apply autocast to `criteria_fn`. With the current behavior, we would need move `deepspeed_model(...)` outside the `torch.autocast` context, leading to inconsistent code between DeepSpeed and non-DeepSpeed setups. (cannot be handled with `enabled` arg of `torch.autocast`) Change in this PR: `torch.autocast` outside the DeepSpeed engine is ignored, and - If `torch_autocast` is enabled in the config, DeepSpeed will follow that setting. - If it is disabled, DeepSpeed falls back to its own mixed-precision support (or FP32). In these cases, DeepSpeed engine shows a message to explain the behavior. 2) Model’s dtype Previously, DeepSpeed assumed the model’s dtype must be FP32 when `torch.autocast` was enabled. However, models with lower-precision parameters (e.g., BF16) can also be used with autocast. For example, if both the model and `torch.autocast` use BF16, autocast will upcast precision-sensitive ops as needed. Change in this PR: Removed the assertion that restricted the model’s dtype to FP32. This PR also adds and updates tests to cover these new behaviors. --------- Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>