Files
pytorch/torch/amp
Yeounoh Chung e2e9d15726 Unblock float16 dtype for xla autocasting (#109554)
`torch.autocast` with `xla` backend has been restricted to `torch.bfloat16`. This shouldn't be the case anymore.

This works with `xla::cast( ..., type=f16)`
```
IR {
  %0 = f32[] prim::Constant(), xla_shape=f32[], value=1
  %1 = f32[3,2]{1,0} aten::expand(%0), xla_shape=f32[3,2]{1,0}, size=(3, 2), dynamic_dims=(0, 0)
  %2 = f16[3,2]{1,0} xla::cast(%1), xla_shape=f16[3,2]{1,0}, type=f16, dtype=Half, stype=Float
  %3 = f32[] prim::Constant(), xla_shape=f32[], value=1
  %4 = f32[2,3]{1,0} aten::expand(%3), xla_shape=f32[2,3]{1,0}, size=(2, 3), dynamic_dims=(0, 0)
  %5 = f16[2,3]{1,0} xla::cast(%4), xla_shape=f16[2,3]{1,0}, type=f16, dtype=Half, stype=Float
  %6 = f16[2,2]{1,0} aten::mm(%5, %2), xla_shape=f16[2,2]{1,0}, ROOT=0
}
```

This will allow PyTorch/XLA to extend its autocast implementation to use `xla` backend for `float16` type as well.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109554
Approved by: https://github.com/JackCaoG, https://github.com/bdhirsh
2023-09-21 03:19:44 +00:00
..