Files
pytorch/aten
Angel Li a856a17799 bf16 support for per_channel bwd (#165325)
Follow up to #165098 - adding bf16 support for the backward pass. To avoid BC breaking changes/losing precision, we upcast the parameters to fp32 after the op gets called, and downcast the gradients to bf16 before returning.

For testing, we upcast to fp32 before calling the reference function. We increase the tolerance to 1e-2 for bf16 inputs because of a difference in casting calculations between python's `x.to(torch.bfloat16)` and cpp's `x.to(at::kBFloat16)` (after comparing intermediate tensors, we found that the numerics diverge after the final casting). We don't explicitly cast in the CPP op but rather let autograd/optimizer handle it.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165325
Approved by: https://github.com/andrewor14
2025-10-14 05:47:32 +00:00
..
2023-05-19 00:49:08 +00:00