mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
`torch.autocast` with `xla` backend has been restricted to `torch.bfloat16`. This shouldn't be the case anymore. This works with `xla::cast( ..., type=f16)` ``` IR { %0 = f32[] prim::Constant(), xla_shape=f32[], value=1 %1 = f32[3,2]{1,0} aten::expand(%0), xla_shape=f32[3,2]{1,0}, size=(3, 2), dynamic_dims=(0, 0) %2 = f16[3,2]{1,0} xla::cast(%1), xla_shape=f16[3,2]{1,0}, type=f16, dtype=Half, stype=Float %3 = f32[] prim::Constant(), xla_shape=f32[], value=1 %4 = f32[2,3]{1,0} aten::expand(%3), xla_shape=f32[2,3]{1,0}, size=(2, 3), dynamic_dims=(0, 0) %5 = f16[2,3]{1,0} xla::cast(%4), xla_shape=f16[2,3]{1,0}, type=f16, dtype=Half, stype=Float %6 = f16[2,2]{1,0} aten::mm(%5, %2), xla_shape=f16[2,2]{1,0}, ROOT=0 } ``` This will allow PyTorch/XLA to extend its autocast implementation to use `xla` backend for `float16` type as well. Pull Request resolved: https://github.com/pytorch/pytorch/pull/109554 Approved by: https://github.com/JackCaoG, https://github.com/bdhirsh