Files
pytorch/aten
Jagadish Krishnamoorthy 264e7f68a0 [ROCm] Fix mx fp8 and fp4 code after scaling refactor changes. (#163127)
PR #151360 added mx fp8 and fp4 support on ROCm.
1. However, on recent upstream, scaling function in Blas.cpp along with test_matmul_cuda changes triggered failures.
This patch corrects is_blockwise_1x32_scaling function code.

2. Fixes the m, n, k dimensions for ROCm mx case.

3.  Modify FP4E2M1FN_LARGEST_POW2 (largest power of 2 representable in `torch.float4_e2m1fn_x2`) to 2.
This resulted in higher SQNR value for mx fp4 test.

Testing result on gfx950 w/ ROCm7.0

PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py -k test_blockwise -v Ran 452 tests in 22.698s
OK passed 111
This is same as before. (when PR 151360 was merged)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163127
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2025-09-19 12:29:52 +00:00
..
2023-05-19 00:49:08 +00:00