mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
With https://github.com/pytorch/xla/pull/5148, https://github.com/pytorch/xla/pull/4740 With these changes XLA:GPU users should use `torch.cuda.amp.autocast()` for AMP with float16 XLA:TPU users should use `torch.amp.autocast('xla')` for AMP with bfloat16 Pull Request resolved: https://github.com/pytorch/pytorch/pull/96370 Approved by: https://github.com/bdhirsh, https://github.com/malfet