mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: * Add `torch.nn.functional.scaled_mm` as an abstraction around the C++ methods * Wraps `torch._scaled_mm_v2` API by default, but user can force use of the older `torch._scaled_mm` interface. * Scaled MM tests now run on the new API Test Plan: `pytest test/test_scaled_matmul_cuda.py` Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Simon Layton <simonlaytonmeta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/164142 Approved by: https://github.com/drisspg ghstack dependencies: #164141
Please see the Writing documentation section of CONTRIBUTING.md for details on both writing and building the docs.