mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add scaled_mm python API, test (#164142)
Summary: * Add `torch.nn.functional.scaled_mm` as an abstraction around the C++ methods * Wraps `torch._scaled_mm_v2` API by default, but user can force use of the older `torch._scaled_mm` interface. * Scaled MM tests now run on the new API Test Plan: `pytest test/test_scaled_matmul_cuda.py` Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Simon Layton <simonlaytonmeta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/164142 Approved by: https://github.com/drisspg ghstack dependencies: #164141
This commit is contained in:
committed by
PyTorch MergeBot
parent
512b6b59f0
commit
6a7f5c0d21
@ -249,6 +249,7 @@ def get_ignored_functions() -> set[Callable]:
|
||||
torch.nn.functional.has_torch_function_unary,
|
||||
torch.nn.functional.has_torch_function_variadic,
|
||||
torch.nn.functional.handle_torch_function,
|
||||
torch.nn.functional.scaled_mm,
|
||||
torch.nn.functional.sigmoid,
|
||||
torch.nn.functional.hardsigmoid,
|
||||
torch.nn.functional.tanh,
|
||||
|
Reference in New Issue
Block a user