Add scaled_mm python API, test (#164142)

Summary:

* Add `torch.nn.functional.scaled_mm` as an abstraction around the C++
  methods
* Wraps `torch._scaled_mm_v2` API by default, but user can force use of
  the older `torch._scaled_mm` interface.
* Scaled MM tests now run on the new API

Test Plan:

`pytest test/test_scaled_matmul_cuda.py`

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlaytonmeta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164142
Approved by: https://github.com/drisspg
ghstack dependencies: #164141
This commit is contained in:
Simon Layton
2025-10-08 16:44:39 -07:00
committed by PyTorch MergeBot
parent 512b6b59f0
commit 6a7f5c0d21
11 changed files with 363 additions and 78 deletions

View File

@ -249,6 +249,7 @@ def get_ignored_functions() -> set[Callable]:
torch.nn.functional.has_torch_function_unary,
torch.nn.functional.has_torch_function_variadic,
torch.nn.functional.handle_torch_function,
torch.nn.functional.scaled_mm,
torch.nn.functional.sigmoid,
torch.nn.functional.hardsigmoid,
torch.nn.functional.tanh,