Compare commits

...

18 Commits

Author SHA1 Message Date
2cf45b6f69 Update on "[torchfuzz] move all torch_op_name defs to leaf classes in MatrixMultiplyOperator"
[ghstack-poisoned]
2025-10-25 21:15:35 -07:00
ddfc673eb0 Update base for Update on "[torchfuzz] move all torch_op_name defs to leaf classes in MatrixMultiplyOperator"
[ghstack-poisoned]
2025-10-25 21:15:35 -07:00
dc8a2bd616 Update on "[torchfuzz] move all torch_op_name defs to leaf classes in MatrixMultiplyOperator"
[ghstack-poisoned]
2025-10-25 01:15:33 -07:00
ae8a79da0f Update base for Update on "[torchfuzz] move all torch_op_name defs to leaf classes in MatrixMultiplyOperator"
[ghstack-poisoned]
2025-10-25 01:15:33 -07:00
1e78cd5b54 Update on "[torchfuzz] move all torch_op_name defs to leaf classes in MatrixMultiplyOperator"
[ghstack-poisoned]
2025-10-25 00:49:14 -07:00
17b6284c01 Update base for Update on "[torchfuzz] move all torch_op_name defs to leaf classes in MatrixMultiplyOperator"
[ghstack-poisoned]
2025-10-25 00:49:14 -07:00
5fd05d7de0 Update on "[torchfuzz] move all torch_op_name defs to leaf classes in MatrixMultiplyOperator"
[ghstack-poisoned]
2025-10-24 23:32:36 -07:00
6c92365875 Update base for Update on "[torchfuzz] move all torch_op_name defs to leaf classes in MatrixMultiplyOperator"
[ghstack-poisoned]
2025-10-24 23:32:36 -07:00
5671b95d14 [torchfuzz] move all torch_op_name defs to leaf classes in MatrixMultiplyOperator
[ghstack-poisoned]
2025-10-24 21:40:17 -07:00
3209c495e2 add split operator
[ghstack-poisoned]
2025-10-24 21:40:12 -07:00
42acfbad67 [torchfuzz] make pointwise subclasses defined torch_op_name
[ghstack-poisoned]
2025-10-24 21:40:07 -07:00
6c2dcc69d4 Update on "[torchfuzz] more regexes"
[ghstack-poisoned]
2025-10-24 20:53:20 -07:00
da8b8d0209 Update base for Update on "[torchfuzz] more regexes"
[ghstack-poisoned]
2025-10-24 20:53:20 -07:00
a55111a102 [torchfuzz] more regexes
[ghstack-poisoned]
2025-10-24 12:51:08 -07:00
f7d4f84e18 [torchfuzz] mhaf
[ghstack-poisoned]
2025-10-24 12:51:03 -07:00
ff6fcce8d5 [torchfuzz] add sdpa operator
[ghstack-poisoned]
2025-10-24 12:50:59 -07:00
7dec4f9b98 [torchfuzz] fix group norm operator
[ghstack-poisoned]
2025-10-24 12:50:54 -07:00
e2a51f4354 [torchfuzz] check in more ignore regexes
[ghstack-poisoned]
2025-10-24 12:50:49 -07:00
15 changed files with 350 additions and 477 deletions

View File

@ -41,7 +41,6 @@ torch.distributed.fsdp.fully_shard <distributed.fsdp.fully_shard>
torch.distributed.tensor.parallel <distributed.tensor.parallel>
torch.distributed.optim <distributed.optim>
torch.distributed.pipelining <distributed.pipelining>
torch.distributed._symmetric_memory <symmetric_memory>
torch.distributed.checkpoint <distributed.checkpoint>
torch.distributions <distributions>
torch.compiler <torch.compiler>

View File

@ -1,380 +0,0 @@
```{eval-rst}
.. role:: hidden
:class: hidden-section
```
# PyTorch Symmetric Memory
:::{note}
`torch.distributed._symmetric_memory` is currently in alpha state and under
development. API changes may be possible.
:::
## Why Symmetric Memory?
With rapidly evolving parallelization techniques, existing frameworks and
libraries often struggle to keep up, and developers increasingly rely on custom
implementations directly scheduling communications and computations. In recent
years weve witnessed a shift from primarily relying on one-dimensional
data-parallelism techniques to multi-dimensional parallelism ones. The latter
have different latency requirements for different types of communications and
thus require fine-grained overlapping of compute and communications.
To minimize compute interference, they also require the use of copy engines and
network interface cards (NICs) to drive communication. Network transport
protocols such as remote direct memory access (RDMA) enhance the performance by
enabling direct, high-speed, and low-latency communication between processors
and memory. This increase in variety indicates the need for finer-grained
communication primitives than are offered today by high-level collective APIs,
ones that would enable developers to implement specific algorithms tailored for
their use cases, such as low-latency collectives, fine-grained
compute-communications overlap, or custom fusions.
Furthermore, todays advanced AI systems connect GPUs with high-bandwidth links
(such as NVLinks, InfiniBand or RoCE), making GPU global memory directly
accessible to peers. Such connections present a great opportunity for
programmers to program the system as a single, gigantic GPU with vast accessible
memory, instead of programming singular “GPU islands.”
In this document, we will show how you can use PyTorch Symmetric Memory to
program modern GPU systems as a “single GPU” and achieve fine-grained remote
access.
## What PyTorch Symmetric Memory unlocks?
PyTorch Symmetric Memory unlocks three new capabilities:
- **Customized communication patterns**: Increased flexibility in kernel writing
allows developers to write custom kernels that implement their custom
computations and communications, directly tailored to the need of the
application. It will also be straightforward to add support for new data types
along with the special compute that those data types might require, even if its
not present yet in the standard libraries.
- **In-kernel compute-comm fusion**: Device-initiated communication capability
allows developers to write kernels with both computation and communication
instructions, allowing for the fusion of computation and data movement in the
smallest possible granularity.
- **Low-latency remote access**: Network transport protocols like RDMA enhance the
performance of symmetric memory in networked environments by enabling direct,
high-speed, and low-latency communication between processors and memory. RDMA
eliminates the overhead associated with the traditional network stack and CPU
involvement. It also offloads data transfer from the compute to the NICs,
freeing up compute resources for computational tasks.
Next, we will show you how PyTorch Symmetric Memory (SymmMem) enables new
applications with the above capabilities.
## A “Hello World” example
The PyTorch SymmMem programming model involves two key elements:
- creating symmetric tensors
- creating SymmMem kernels
To create symmetric tensors, one can use the
`torch.distributed._symmetric_memory` package:
```python
import torch.distributed._symmetric_memory as symm_mem
t = symm_mem.empty(128, device=torch.device("cuda", rank))
hdl = symm_mem.rendezvous(t, group)
```
The `symm_mem.empty` function creates a tensor that is backed by a symmetric
memory allocation. The `rendezvous` function establishes a rendezvous with peers
in the group, and returns a handle to the symmetric memory allocation. The
handle provides method to access information related to the symmetric memory
allocation, such as pointers to symmetric buffer on peer ranks, multicast
pointer (if supported), and signal pads.
The `empty` and `rendezvous` functions must be called in the same order on all
ranks in the group.
Then, collectives can be called on these tensors. For example, to perform a
one-shot all-reduce:
```python
# Most SymmMem ops are under the torch.ops.symm_mem namespace
torch.ops.symm_mem.one_shot_all_reduce(t, "sum", group)
```
Please note that `torch.ops.symm_mem` is an "op namespace" instead of a python
module. Therefore, you can't import it by `import torch.ops.symm_mem`, neither
can you import an op by `from torch.ops.symm_mem import one_shot_all_reduce`.
You can call the op directly as in the example above.
## Write your own kernel
To write your own kernel doing communications with symmetric memory, youll need
access to the addresses of mapped peer buffers and access to signal pads that
are required for synchronization. In the kernel youll also need to perform
correct synchronizations to make sure that peers are ready for communication,
and signal to them that this GPU is ready.
PyTorch Symmetric Memory provides CUDA Graph-compatible synchronization
primitives that operate on the signal pad accompanying each symmetric memory
allocation. Kernels using symmetric memory can be written both in CUDA and in
Triton. Heres an example allocating symmetric tensor and exchanging handles:
```python
import torch.distributed._symmetric_memory as symm_mem
dist.init_process_group()
rank = dist.get_rank()
# Allocate a tensor
t = symm_mem.empty(4096, device=f"cuda:{rank}")
# Establish symmetric memory and obtain the handle
hdl = symm_mem.rendezvous(t, dist.group.WORLD)
```
Access to buffer pointers, multimem pointer, and signal pads is provided via:
```python
hdl.buffer_ptrs
hdl.multicast_ptr
hdl.signal_pad_ptrs
```
Data pointed to by `buffer_ptrs` can be accessed just like regular local data,
and any necessary compute can also be performed in the usual ways. As with local
data, you can and should use vectorized accesses to improve efficiency.
Symmetric memory is especially convenient for writing kernels in Triton. While
previously Triton removed the barriers to writing efficient CUDA code, now
communications can be added easily to Triton kernels. The kernel below
demonstrates a low-latency, all-reduce kernel written in Triton.
```python
@triton.jit
def one_shot_all_reduce_kernel(
buf_tuple,
signal_pad_ptrs,
output_ptr,
numel: tl.constexpr,
rank: tl.constexpr,
world_size: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
ptx_utils.symm_mem_sync(
signal_pad_ptrs, None, rank, world_size, hasSubsequenceMemAccess=True
)
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
while block_start < numel:
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < numel
acc = tl.zeros((BLOCK_SIZE,), dtype=tl.bfloat16)
for i in tl.static_range(world_size):
buffer_rank = buf_tuple[i]
x = tl.load(buffer_rank + offsets, mask=mask)
acc += x
tl.store(output_ptr + offsets, acc, mask=mask)
block_start += tl.num_programs(axis=0) * BLOCK_SIZE
ptx_utils.symm_mem_sync(
signal_pad_ptrs, None, rank, world_size, hasPreviousMemAccess=True
)
```
Synchronizations at the beginning and the end of the kernel above guarantee that
all the processes see consistent data. The bulk of the kernel is recognizable
Triton code, and Triton will optimize it behind the scene, making sure memory
accesses are performed in an efficient way with vectorization and unrolling. As
with all Triton kernels, it is easily modifiable to add extra computations or
change the communication algorithm. Visit
https://github.com/meta-pytorch/kraken/blob/main/kraken to see additional
utilities and examples of using symmetric memory to implement common patterns in
Triton.
## Scale out
Large language models distribute experts onto more than 8 GPUs, hence requiring
multi-node access capability. NICs capable of RDMA come to help. In addition,
software libraries such as NVSHMEM or rocSHMEM abstract away the programming
difference between intra-node access and inter-node access with primitives that
are slightly higher level than pointer access, such as put and get.
PyTorch provides NVSHMEM plugins to augment Triton kernels cross-node
capabilities. As shown in the code snippet below, one can initiate a cross-node
put command within the kernel.
```python
import torch.distributed._symmetric_memory._nvshmem_triton as nvshmem
from torch.distributed._symmetric_memory._nvshmem_triton import requires_nvshmem
@requires_nvshmem
@triton.jit
def my_put_kernel(
dest,
src,
nelems,
pe,
):
nvshmem.put(dest, src, nelems, pe)
```
The `requires_nvshmem` decorator is used to indicate that the kernel requires
the NVSHMEM device library as an external dependency. When Triton compiles the
kernel, the decorator will search your system paths for the NVSHMEM device
library. If it is available, Triton will include the necessary device assembly
to use the NVSHMEM functions.
## API Reference
```{eval-rst}
.. currentmodule:: torch.distributed._symmetric_memory
```
```{eval-rst}
.. autofunction:: empty
```
```{eval-rst}
.. autofunction:: rendezvous
```
```{eval-rst}
.. autofunction:: is_nvshmem_available
```
```{eval-rst}
.. autofunction:: set_backend
```
```{eval-rst}
.. autofunction:: get_backend
```
## Op Reference
:::{note}
The following ops are hosted in the `torch.ops.symm_mem` namespace. You can call
them directly via `torch.ops.symm_mem.<op_name>`.
:::
```{eval-rst}
.. currentmodule:: torch.ops.symm_mem
```
```{eval-rst}
.. py:function:: multimem_all_reduce_(input: Tensor, reduce_op: str, group_name: str) -> Tensor
Performs a multimem all-reduce operation on the input tensor. This operation
requires hardware support for multimem operations. On NVIDIA GPUs, NVLink
SHARP is required.
:param Tensor input: Input tensor to perform all-reduce on. Must be symmetric.
:param str reduce_op: Reduction operation to perform. Currently only "sum" is supported.
:param str group_name: Name of the group to perform all-reduce on.
.. py:function:: multimem_all_gather_out(input: Tensor, group_name: str, out: Tensor) -> Tensor
Performs a multimem all-gather operation on the input tensor. This operation requires hardware support for multimem operations. On NVIDIA GPUs, NVLink SHARP is required.
:param Tensor input: Input tensor to perform all-gather on.
:param str group_name: Name of the group to perform all-gather on.
:param Tensor out: Output tensor to store the result of the all-gather operation. Must be symmetric.
.. py:function:: one_shot_all_reduce(input: Tensor, reduce_op: str, group_name: str) -> Tensor
Performs a one-shot all-reduce operation on the input tensor.
:param Tensor input: Input tensor to perform all-reduce on. Must be symmetric.
:param str reduce_op: Reduction operation to perform. Currently only "sum" is supported.
:param str group_name: Name of the group to perform all-reduce on.
.. py:function:: one_shot_all_reduce_out(input: Tensor, reduce_op: str, group_name: str, out: Tensor) -> Tensor
Performs a one-shot all-reduce operation based on the input tensor and writes the result to the output tensor.
:param Tensor input: Input tensor to perform all-reduce on. Must be symmetric.
:param str reduce_op: Reduction operation to perform. Currently only "sum" is supported.
:param str group_name: Name of the group to perform all-reduce on.
:param Tensor out: Output tensor to store the result of the all-reduce operation. Can be a regular tensor.
.. py:function:: two_shot_all_reduce_(input: Tensor, reduce_op: str, group_name: str) -> Tensor
Performs a two-shot all-reduce operation on the input tensor.
:param Tensor input: Input tensor to perform all-reduce on. Must be symmetric.
:param str reduce_op: Reduction operation to perform. Currently only "sum" is supported.
:param str group_name: Name of the group to perform all-reduce on.
.. py:function:: all_to_all_vdev(input: Tensor, out: Tensor, in_splits: Tensor, out_splits_offsets: Tensor, group_name: str) -> None
Performs an all-to-all-v operation using NVSHMEM, with split information provided on device.
:param Tensor input: Input tensor to perform all-to-all on. Must be symmetric.
:param Tensor out: Output tensor to store the result of the all-to-all operation. Must be symmetric.
:param Tensor in_splits: Tensor containing splits of data to send to each peer. Must be symmetric. Must be of size (group_size,). The splits are in the unit of elements in the 1st dimension.
:param Tensor out_splits_offsets: Tensor containing the splits and offsets of data received from each peer. Must be symmetric. Must be of size (2, group_size). The rows are (in order): output splits and output offsets.
:param str group_name: Name of the group to perform all-to-all on.
.. py:function:: all_to_all_vdev_2d(input: Tensor, out: Tensor, in_splits: Tensor, out_splits_offsets: Tensor, group_name: str, [major_align: int = None]) -> None
Perform a 2D all-to-all-v operation using NVSHMEM, with split information provided on device. In Mixture of Experts models, this operation can be used to dispatch tokens.
:param Tensor input: Input tensor to perform all-to-all on. Must be symmetric.
:param Tensor out: Output tensor to store the result of the all-to-all operation. Must be symmetric.
:param Tensor in_splits: Tensor containing the splits of data to send to each expert. Must be symmetric. Must be of size (group_size * ne,), where ne is the number of experts per rank. The splits are in the unit of elements in the 1st dimension.
:param Tensor out_splits_offsets: Tensor containing the splits and offsets of data received from each peer. Must be symmetric. Must be of size (2, group_size * ne). The rows are (in order): output splits and output offsets.
:param str group_name: Name of the group to perform all-to-all on.
:param int major_align: Optional alignment for the major dimension of the output chunk for each expert. If not provided, the alignment is assumed to be 1. Any alignment adjustment will be reflected in the output offsets.
A 2D AllToAllv shuffle is illustrated below:
(world_size = 2, ne = 2, total number of experts = 4)::
Source: | Rank 0 | Rank 1 |
| c0 | c1 | c2 | c3 | d0 | d1 | d2 | d3 |
Dest : | Rank 0 | Rank 1 |
| c0 | d0 | c1 | d1 | c2 | d2 | c3 | d3 |
where each `c_i` / `d_i` are slices of the `input` tensor, targeting expert
`i`, with length indicated by input splits. That is, the 2D AllToAllv
shuffle achieves a transpose from rank-major order at input to expert-major
order at output.
If `major_align` is not 1, the output offsets of c1, c2, c3 will be
up-aligned to this value. For example, if c0 has length 5 and d0 has
length 7 (making a total of 12), and if the `major_align` is set to 16,
the output offset of c1 will be 16. Similar for c2 and c3. This value has
no effect on the offset of the minor dimension, i.e. d0, d1, d2 and d3.
Note: since cutlass does not support empty bins, we set the aligned length
to `major_align` if it is 0. See
https://github.com/pytorch/pytorch/issues/152668.
.. py:function:: all_to_all_vdev_2d_offset(Tensor input, Tensor out, Tensor in_splits_offsets, Tensor out_splits_offsets, str group_name) -> None
Perform a 2D AllToAllv shuffle operation, with input split and offset
information provided on device. The input offsets are not required to be
exact prefix sum of the input splits, i.e. paddings are allowed between the
split chunks. The paddings, however, will not be transferred to peer
ranks.
In Mixture of Experts models, this operation can be used to combine tokens
processed by experts on parallel ranks. This operation can be viewed as an
"reverse" operation to the `all_to_all_vdev_2d` operation (which shuffles
tokens to experts).
:param Tensor input: Input tensor to perform all-to-all on. Must be symmetric.
:param Tensor out: Output tensor to store the result of the all-to-all operation. Must be symmetric.
:param Tensor in_splits_offsets: Tensor containing the splits and offsets of data to send to each expert. Must be symmetric. Must be of size (2, group_size * ne), where `ne` is the number of experts. The rows are (in order): input splits and input offsets. The splits are in the unit of elements in the 1st dimension.
:param Tensor out_splits_offsets: Tensor containing the splits and offsets of data received from each peer. Must be symmetric. Must be of size (2, group_size * ne). The rows are (in order): output splits and output offsets.
:param str group_name: Name of the group to perform all-to-all on.
```

View File

@ -55,7 +55,7 @@ torch.fx.node.Node.append(self, x: 'Node') -> None
torch.fx.node.Node.format_node(self, placeholder_names: Optional[List[str]] = None, maybe_return_typename: Optional[List[str]] = None, include_tensor_metadata: bool = False) -> Optional[str]
torch.fx.node.Node.insert_arg(self, idx: int, arg: torch.fx.node.Argument) -> None
torch.fx.node.Node.prepend(self, x: 'Node') -> None
torch.fx.node.Node.replace_all_uses_with(self, replace_with: 'Node', delete_user_cb: Optional[Callable[[Node], bool]] = None, propagate_meta: bool = False) -> List[Node]
torch.fx.node.Node.replace_all_uses_with(self, replace_with: 'Node', delete_user_cb: Callable[[Node], bool] = <function <lambda>>, propagate_meta: bool = False) -> List[Node]
torch.fx.node.Node.replace_input_with(self, old_input: 'Node', new_input: 'Node') -> None
torch.fx.node.Node.update_arg(self, idx: int, arg: torch.fx.node.Argument) -> None
torch.fx.node.Node.update_kwarg(self, key: str, arg: torch.fx.node.Argument) -> None

View File

@ -219,6 +219,8 @@ class DefaultFuzzTemplate(FuzzTemplate):
# Neural network operations
"torch.nn.functional.embedding",
"torch.nn.functional.linear",
"torch.nn.functional.scaled_dot_product_attention",
"torch.nn.functional.multi_head_attention_forward",
# Activation functions
"torch.nn.functional.relu",
"torch.nn.functional.leaky_relu",

View File

@ -72,8 +72,21 @@ IGNORE_PATTERNS: list[re.Pattern] = [
re.compile(
r"dimensionality of sizes \(0\) must match dimensionality of strides \(1\)"
), # https://github.com/pytorch/pytorch/issues/164814
re.compile(
r"self and mat2 must have the same dtype"
), # https://github.com/pytorch/pytorch/issues/165718
re.compile(
r"free\(\): invalid next size \(fast\)"
), # TODO: figure out why sometimes heap metadata gets corrupted on program exit (checks actually pass successfully)
re.compile(
r'assert "int" in str\(indices\.get_dtype\(\)\)'
), # https://github.com/pytorch/pytorch/issues/166042
re.compile(
r'self\.shape_env\.guard_or_defer_runtime_assert\(expr, "guard_equals"\)'
), # https://github.com/pytorch/pytorch/issues/166245
# Add more patterns here as needed, e.g.:
# re.compile(r"Some other error message"),
]

View File

@ -22,7 +22,9 @@ from torchfuzz.operators.nn_functional import (
EmbeddingOperator,
LayerNormOperator,
LinearOperator,
MultiHeadAttentionForwardOperator,
ReLUOperator,
ScaledDotProductAttentionOperator,
SoftmaxOperator,
)
from torchfuzz.operators.registry import (
@ -76,7 +78,9 @@ __all__ = [
"MatmulOperator",
"EmbeddingOperator",
"LinearOperator",
"MultiHeadAttentionForwardOperator",
"ReLUOperator",
"ScaledDotProductAttentionOperator",
"SoftmaxOperator",
"DropoutOperator",
"LayerNormOperator",

View File

@ -15,14 +15,8 @@ from torchfuzz.tensor_fuzzer import Spec, TensorSpec
class MatrixMultiplyOperator(Operator):
"""Base class for matrix multiplication operations."""
def __init__(self, name: str, torch_op: str):
def __init__(self, name: str):
super().__init__(name)
self._torch_op = torch_op
@property
def torch_op_name(self) -> Optional[str]:
"""Return the torch operation name."""
return self._torch_op
def can_produce(self, output_spec: Spec) -> bool:
"""Matrix multiply operations can produce float/complex tensors of dimension >= 2."""
@ -47,12 +41,6 @@ class MatrixMultiplyOperator(Operator):
def _get_compatible_dtype(self, output_dtype):
"""Get a compatible dtype for matrix multiplication."""
# For matrix multiplication, we need to be flexible with input dtypes
# since earlier operations may have performed type promotion.
# We'll let the fuzzer generate whatever dtypes result from earlier operations
# and rely on the operation graph to ensure compatibility.
# Return the output dtype as a starting point, but this may be overridden
# by the actual tensor specs generated by the fuzzer.
return [output_dtype, output_dtype]
@ -60,9 +48,14 @@ class MMOperator(MatrixMultiplyOperator):
"""Operator for matrix multiplication (torch.mm)."""
def __init__(self):
super().__init__("mm", "torch.mm")
super().__init__("mm")
self.weight = 5.0
@property
def torch_op_name(self) -> Optional[str]:
"""Return the torch operation name."""
return "torch.mm"
def can_produce(self, output_spec: Spec) -> bool:
"""MM requires exactly 2D tensors."""
if not isinstance(output_spec, TensorSpec):
@ -96,7 +89,6 @@ class MMOperator(MatrixMultiplyOperator):
# Choose a random inner dimension k
k = random.randint(1, 16)
# Get compatible dtypes
dtypes = self._get_compatible_dtype(output_spec.dtype)
# First tensor: [m, k]
@ -141,9 +133,14 @@ class AddmmOperator(MatrixMultiplyOperator):
"""Operator for additive matrix multiplication (torch.addmm)."""
def __init__(self):
super().__init__("addmm", "torch.addmm")
super().__init__("addmm")
self.weight = 5.0
@property
def torch_op_name(self) -> Optional[str]:
"""Return the torch operation name."""
return "torch.addmm"
def can_produce(self, output_spec: Spec) -> bool:
"""Addmm requires exactly 2D tensors."""
if not isinstance(output_spec, TensorSpec):
@ -177,7 +174,6 @@ class AddmmOperator(MatrixMultiplyOperator):
# Choose a random inner dimension k
k = random.randint(1, 16)
# Get compatible dtypes
dtypes = self._get_compatible_dtype(output_spec.dtype)
# Bias tensor: [m, n] (same shape as output)
@ -230,9 +226,14 @@ class BmmOperator(MatrixMultiplyOperator):
"""Operator for batch matrix multiplication (torch.bmm)."""
def __init__(self):
super().__init__("bmm", "torch.bmm")
super().__init__("bmm")
self.weight = 5.0
@property
def torch_op_name(self) -> Optional[str]:
"""Return the torch operation name."""
return "torch.bmm"
def can_produce(self, output_spec: Spec) -> bool:
"""Batch matrix multiply requires 3D tensors."""
if not isinstance(output_spec, TensorSpec):
@ -266,7 +267,6 @@ class BmmOperator(MatrixMultiplyOperator):
# Choose a random inner dimension k
k = random.randint(1, 16)
# Get compatible dtypes
dtypes = self._get_compatible_dtype(output_spec.dtype)
# First tensor: [b, m, k]
@ -311,9 +311,14 @@ class MatmulOperator(MatrixMultiplyOperator):
"""Operator for general matrix multiplication (torch.matmul)."""
def __init__(self):
super().__init__("matmul", "torch.matmul")
super().__init__("matmul")
self.weight = 500.0
@property
def torch_op_name(self) -> Optional[str]:
"""Return the torch operation name."""
return "torch.matmul"
def can_produce(self, output_spec: Spec) -> bool:
"""Matmul can handle various tensor dimensions >= 1."""
if not isinstance(output_spec, TensorSpec):
@ -343,7 +348,6 @@ class MatmulOperator(MatrixMultiplyOperator):
output_size = output_spec.size
output_dims = len(output_size)
# Get compatible dtypes
dtypes = self._get_compatible_dtype(output_spec.dtype)
if output_dims == 1:

View File

@ -1,5 +1,6 @@
"""Neural network functional operator implementations."""
import math
import random
from typing import Optional
@ -752,6 +753,17 @@ class GroupNormOperator(Operator):
# GroupNorm needs at least 2 dimensions (batch, channels)
if len(output_spec.size) < 2:
return False
# GroupNorm requires more than 1 value per channel
# For shape (N, C, *), num_values_per_channel = N * prod(*)
# We need N * prod(*) > 1
batch_size = output_spec.size[0]
spatial_size = math.prod(output_spec.size[2:])
num_values_per_channel = batch_size * spatial_size
if num_values_per_channel <= 1:
return False
return is_float_dtype(output_spec.dtype)
def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]:
@ -951,3 +963,221 @@ class SiLUOperator(Operator):
input_name = input_names[0]
return f"{output_name} = torch.nn.functional.silu({input_name})"
class ScaledDotProductAttentionOperator(Operator):
"""Operator for torch.nn.functional.scaled_dot_product_attention."""
def __init__(self):
super().__init__("torch.nn.functional.scaled_dot_product_attention")
@property
def torch_op_name(self) -> Optional[str]:
"""Return the torch operation name."""
return "torch.nn.functional.scaled_dot_product_attention"
def can_produce(self, output_spec: Spec) -> bool:
"""Scaled dot product attention can produce tensor outputs with floating point dtypes."""
if not isinstance(output_spec, TensorSpec):
return False
# SDPA needs at least 3 dimensions (batch, seq_len, embed_dim)
if len(output_spec.size) < 3:
return False
return is_float_dtype(output_spec.dtype)
def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]:
"""Generate input specs for scaled_dot_product_attention.
SDPA requires:
- query: (batch, seq_len, embed_dim) or (batch, num_heads, seq_len, head_dim)
- key: (batch, seq_len, embed_dim) or (batch, num_heads, seq_len_kv, head_dim)
- value: (batch, seq_len, embed_dim) or (batch, num_heads, seq_len_kv, head_dim)
Output shape matches query shape.
"""
if not isinstance(output_spec, TensorSpec):
raise ValueError(
"ScaledDotProductAttentionOperator can only produce TensorSpec outputs"
)
if len(output_spec.size) < 3:
raise ValueError("SDPA output must have at least 3 dimensions")
# Query has the same shape as output
query_spec = TensorSpec(
size=output_spec.size, stride=output_spec.stride, dtype=output_spec.dtype
)
# Key and value: match query shape for simplicity
# In practice, seq_len for key/value can differ, but we'll keep it simple
key_spec = TensorSpec(
size=output_spec.size, stride=output_spec.stride, dtype=output_spec.dtype
)
value_spec = TensorSpec(
size=output_spec.size, stride=output_spec.stride, dtype=output_spec.dtype
)
return [query_spec, key_spec, value_spec]
def codegen(
self, output_name: str, input_names: list[str], output_spec: Spec
) -> str:
"""Generate code for scaled_dot_product_attention operation."""
if len(input_names) != 3:
raise ValueError("SDPA requires exactly 3 inputs: query, key, value")
# Ensure dtype compatibility by converting all inputs to the expected output dtype
target_dtype = str(output_spec.dtype)
query_name, key_name, value_name = input_names
return f"{output_name} = torch.nn.functional.scaled_dot_product_attention({query_name}.to({target_dtype}), {key_name}.to({target_dtype}), {value_name}.to({target_dtype}))"
class MultiHeadAttentionForwardOperator(Operator):
"""Operator for torch.nn.functional.multi_head_attention_forward."""
def __init__(self):
super().__init__("torch.nn.functional.multi_head_attention_forward")
@property
def torch_op_name(self) -> Optional[str]:
"""Return the torch operation name."""
return "torch.nn.functional.multi_head_attention_forward"
def can_produce(self, output_spec: Spec) -> bool:
"""Multi-head attention forward can produce tensor outputs with floating point dtypes."""
if not isinstance(output_spec, TensorSpec):
return False
# MHA needs at least 3 dimensions (seq_len, batch, embed_dim)
if len(output_spec.size) < 3:
return False
# MHA cannot handle 0-sized dimensions (seq_len, batch, or embed_dim must be > 0)
if any(dim == 0 for dim in output_spec.size):
return False
return is_float_dtype(output_spec.dtype)
def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]:
"""Generate input specs for multi_head_attention_forward.
MHA requires:
- query, key, value: (seq_len, batch, embed_dim)
- in_proj_weight: (3*embed_dim, embed_dim) for combined QKV projection
- in_proj_bias: (3*embed_dim,) optional
- out_proj_weight: (embed_dim, embed_dim)
- out_proj_bias: (embed_dim,) optional
For simplicity, we'll use the combined in_proj_weight path.
IMPORTANT: The order of optional parameters matters for codegen!
We must ensure that when we have 6 inputs, they are in the order:
query, key, value, in_proj_weight, in_proj_bias, out_proj_weight
NOT: query, key, value, in_proj_weight, out_proj_weight, out_proj_bias
"""
if not isinstance(output_spec, TensorSpec):
raise ValueError(
"MultiHeadAttentionForwardOperator can only produce TensorSpec outputs"
)
if len(output_spec.size) < 3:
raise ValueError("MHA output must have at least 3 dimensions")
# Output shape: (seq_len, batch, embed_dim)
seq_len, batch, embed_dim = output_spec.size[:3]
# Query, key, value have the same shape as output
query_spec = TensorSpec(
size=output_spec.size, stride=output_spec.stride, dtype=output_spec.dtype
)
key_spec = TensorSpec(
size=output_spec.size, stride=output_spec.stride, dtype=output_spec.dtype
)
value_spec = TensorSpec(
size=output_spec.size, stride=output_spec.stride, dtype=output_spec.dtype
)
# in_proj_weight: (3*embed_dim, embed_dim)
in_proj_weight_spec = TensorSpec(
size=(3 * embed_dim, embed_dim),
stride=(embed_dim, 1),
dtype=output_spec.dtype,
)
# out_proj_weight: (embed_dim, embed_dim)
out_proj_weight_spec = TensorSpec(
size=(embed_dim, embed_dim),
stride=(embed_dim, 1),
dtype=output_spec.dtype,
)
# For simplicity and correctness, always generate all required tensors
# This avoids ambiguity in the codegen about which optional parameters are present
# We'll use a simplified signature: query, key, value, in_proj_weight, out_proj_weight only
specs = [
query_spec,
key_spec,
value_spec,
in_proj_weight_spec,
out_proj_weight_spec,
]
from typing import cast
return cast(list[Spec], specs)
def _calculate_stride(self, size):
"""Calculate stride for a given size."""
if not size:
return ()
stride = []
current_stride = 1
for dim_size in reversed(size):
stride.append(current_stride)
current_stride *= dim_size
return tuple(reversed(stride))
def codegen(
self, output_name: str, input_names: list[str], output_spec: Spec
) -> str:
"""Generate code for multi_head_attention_forward operation."""
if len(input_names) != 5:
raise ValueError(
"MHA requires exactly 5 inputs: query, key, value, in_proj_weight, out_proj_weight"
)
if not isinstance(output_spec, TensorSpec):
raise ValueError(
"MultiHeadAttentionForwardOperator can only produce TensorSpec outputs"
)
target_dtype = str(output_spec.dtype)
embed_dim = output_spec.size[-1]
# Determine number of heads (must divide embed_dim evenly)
# Common choices: 8, 4, 2, 1
possible_heads = [h for h in [8, 4, 2, 1] if embed_dim % h == 0]
num_heads = possible_heads[0] if possible_heads else 1
query_name = input_names[0]
key_name = input_names[1]
value_name = input_names[2]
in_proj_weight_name = input_names[3]
out_proj_weight_name = input_names[4]
# Build the function call without optional biases
code = f"""{output_name}, _ = torch.nn.functional.multi_head_attention_forward(
{query_name}.to({target_dtype}),
{key_name}.to({target_dtype}),
{value_name}.to({target_dtype}),
{embed_dim},
{num_heads},
{in_proj_weight_name}.to({target_dtype}),
None, # in_proj_bias
None, # bias_k
None, # bias_v
False, # add_zero_attn
0.0, # dropout_p (no dropout for testing)
{out_proj_weight_name}.to({target_dtype}),
None, # out_proj_bias
training=False, # Use eval mode for deterministic behavior
need_weights=False, # Don't compute attention weights for performance
)"""
return code

View File

@ -30,8 +30,10 @@ from torchfuzz.operators.nn_functional import (
LayerNormOperator,
LeakyReLUOperator,
LinearOperator,
MultiHeadAttentionForwardOperator,
ReLUOperator,
RMSNormOperator,
ScaledDotProductAttentionOperator,
SigmoidOperator,
SiLUOperator,
SoftmaxOperator,
@ -101,6 +103,8 @@ class OperatorRegistry:
# Neural network functional operators
self.register(EmbeddingOperator())
self.register(LinearOperator())
self.register(ScaledDotProductAttentionOperator())
self.register(MultiHeadAttentionForwardOperator())
# Activation functions
self.register(ReLUOperator())

View File

@ -1,7 +1,6 @@
"""Tensor pointwise operator implementation."""
import random
from typing import Optional
import torch
@ -17,16 +16,10 @@ from torchfuzz.type_promotion import (
class PointwiseOperator(Operator):
"""Base class for element-wise pointwise operations."""
def __init__(self, name: str, torch_op: str, symbol: str):
def __init__(self, name: str, symbol: str):
super().__init__(name)
self._torch_op = torch_op
self.symbol = symbol
@property
def torch_op_name(self) -> Optional[str]:
"""Return the torch operation name."""
return self._torch_op
def can_produce(self, output_spec: Spec) -> bool:
"""Tensor pointwise operations can produce tensors but not scalars."""
if isinstance(output_spec, TensorSpec) and output_spec.dtype == torch.bool:
@ -74,9 +67,7 @@ class PointwiseOperator(Operator):
) -> str:
"""Generate code for pointwise operation."""
if len(input_names) == 2:
return (
f"{output_name} = {self._torch_op}({input_names[0]}, {input_names[1]})"
)
return f"{output_name} = {self.torch_op_name}({input_names[0]}, {input_names[1]})"
else:
# Chain operations using symbols for readability
expr = f" {self.symbol} ".join(input_names)
@ -87,26 +78,42 @@ class AddOperator(PointwiseOperator):
"""Operator for element-wise addition."""
def __init__(self, weight: float = 1.0):
super().__init__("add", "torch.add", "+")
super().__init__("add", "+")
self.weight = float(weight)
@property
def torch_op_name(self) -> str:
return "torch.add"
class MulOperator(PointwiseOperator):
"""Operator for element-wise multiplication."""
def __init__(self):
super().__init__("mul", "torch.mul", "*")
super().__init__("mul", "*")
@property
def torch_op_name(self) -> str:
return "torch.mul"
class SubOperator(PointwiseOperator):
"""Operator for element-wise subtraction."""
def __init__(self):
super().__init__("sub", "torch.sub", "-")
super().__init__("sub", "-")
@property
def torch_op_name(self) -> str:
return "torch.sub"
class DivOperator(PointwiseOperator):
"""Operator for element-wise division."""
def __init__(self):
super().__init__("div", "torch.div", "/")
super().__init__("div", "/")
@property
def torch_op_name(self) -> str:
return "torch.div"

View File

@ -2759,7 +2759,6 @@ class _NodeBase:
) -> None: ...
def _update_args_kwargs(self, args: tuple[Any, ...], kwargs: dict[str, Any]): ...
def _prepend(self, n: FxNode) -> None: ...
def _replace_input_with(self, old_input: FxNode, new_input: FxNode) -> None: ...
def _remove_from_list(self) -> None: ...
def __lt__(self, n: Self) -> _bool: ...
def __gt__(self, n: Self) -> _bool: ...

View File

@ -1274,8 +1274,17 @@ def maybe_inline_graph_saved_tensors_hooks(
else:
# Keep usages of bw_g_input in inserted unpacked hook graph.
# Replace other usages of bw_g_input with unpack_saved_tensor_n.
from torch._C import _fx_map_arg
def maybe_replace_node(n):
return unpack_saved_tensor_n if n == bw_g_input else n
for use_node in original_bw_g_input_users:
use_node._replace_input_with(bw_g_input, unpack_saved_tensor_n)
new_args = _fx_map_arg(use_node.args, maybe_replace_node)
new_kwargs = _fx_map_arg(use_node.kwargs, maybe_replace_node)
assert isinstance(new_args, tuple)
assert isinstance(new_kwargs, dict)
use_node._update_args_kwargs(new_args, new_kwargs)
bw_g.erase_node(bw_unpack_out_n)
# Changing forward graph outputs,

View File

@ -365,43 +365,6 @@ static PyObject* NodeBase__remove_from_list(
Py_RETURN_NONE;
}
static PyObject* NodeBase__replace_input_with(
PyObject* self,
PyObject* const* args,
Py_ssize_t nargs) {
if (nargs != 2) {
PyErr_SetString(
PyExc_TypeError,
"_replace_input_with() requires exactly 2 arguments (old_input, new_input)");
return nullptr;
}
PyObject* old_input = args[0];
PyObject* new_input = args[1];
auto replace_fn = [old_input, new_input](PyObject* maybe_node) {
if (maybe_node == old_input) {
return Py_NewRef(new_input);
}
return Py_NewRef(maybe_node);
};
auto node = reinterpret_cast<NodeBase*>(self);
try {
THPObjectPtr new_args(map_aggregate(node->_args, replace_fn));
if (!new_args) {
return nullptr;
}
THPObjectPtr new_kwargs(map_aggregate(node->_kwargs, replace_fn));
if (!new_kwargs) {
return nullptr;
}
PyObject* update_args[2] = {new_args.get(), new_kwargs.get()};
return NodeBase__update_args_kwargs(self, update_args, 2);
} catch (const PythonError& e) {
return nullptr;
}
}
static PyObject* NodeBase__prepend(PyObject* self_, PyObject* arg) {
if (self_ == arg) {
Py_RETURN_NONE;
@ -551,10 +514,6 @@ static PyMethodDef NodeBase_methods[] = {
(PyCFunction)(void*)(NodeBase__remove_from_list),
METH_NOARGS,
"Internal method: do not call directly."},
{"_replace_input_with",
(PyCFunction)(void*)(NodeBase__replace_input_with),
METH_FASTCALL,
"Internal method: replace occurrences of one input Node with another."},
{"_prepend",
(PyCFunction)(void*)(NodeBase__prepend),
METH_O,

View File

@ -1863,6 +1863,8 @@ def empty( # type: ignore[misc]
device: _device | None = None,
) -> torch.Tensor:
r"""
empty(*size, *, dtype=None, device=None) -> Tensor
Similar to :func:`torch.empty()`. The returned tensor can be used by
:func:`torch._distributed._symmetric_memory.rendezvous()` to establish a
symmetric memory tensor among participating processes.
@ -1952,7 +1954,7 @@ def set_backend(name: Literal["NVSHMEM", "CUDA", "NCCL"]) -> None:
Args:
backend (str): the backend for symmetric memory allocation. Currently,
only `"NVSHMEM"`, `"CUDA"`, `"NCCL"` are supported.
only "NVSHMEM", "CUDA", "NCCL" are supported.
"""
_SymmetricMemory.set_backend(name)
@ -1963,7 +1965,8 @@ def get_backend(device: _device) -> str | None:
found, return None.
Args:
device (`torch.device` or str): the device for which to get the backend.
device (class:`torch.device` or str): the device for which to get the
backend.
"""
return _SymmetricMemory.get_backend(torch.device(device))
@ -1971,10 +1974,9 @@ def get_backend(device: _device) -> str | None:
def get_mempool_allocator(device: _device): # type: ignore[no-untyped-def]
r"""
Get the MemPool allocator for symmetric memory for a given device.
Args:
device (`torch.device` or str): the device for which to get the MemPool
allocator.
device (class:`torch.device` or str): the device for which to get the
MemPool allocator.
"""
return _SymmetricMemory.get_mempool_allocator(torch.device(device))

View File

@ -658,7 +658,7 @@ class Node(_NodeBase):
def replace_all_uses_with(
self,
replace_with: "Node",
delete_user_cb: Optional[Callable[["Node"], bool]] = None,
delete_user_cb: Callable[["Node"], bool] = lambda user: True,
*,
propagate_meta: bool = False,
) -> list["Node"]:
@ -686,18 +686,32 @@ class Node(_NodeBase):
)
for k, v in self.meta.items():
replace_with.meta[k] = v
to_process = [*self.users]
replace_hooks = getattr(self.graph.owning_module, "_replace_hooks", None)
result = []
to_process = list(self.users)
skipped = []
m = self.graph.owning_module
for use_node in to_process:
if delete_user_cb is not None and not delete_user_cb(use_node):
if not delete_user_cb(use_node):
skipped.append(use_node)
continue
result.append(use_node)
if replace_hooks:
for replace_hook in replace_hooks:
def maybe_replace_node(n: Node) -> Node:
if n == self:
return replace_with
else:
return n
if getattr(m, "_replace_hooks", None):
for replace_hook in m._replace_hooks:
replace_hook(old=self, new=replace_with.name, user=use_node)
use_node._replace_input_with(self, replace_with)
return result
new_args = _fx_map_arg(use_node.args, maybe_replace_node)
new_kwargs = _fx_map_arg(use_node.kwargs, maybe_replace_node)
assert isinstance(new_args, tuple)
assert isinstance(new_kwargs, dict)
use_node._update_args_kwargs(new_args, new_kwargs)
assert len(self.users) - len(skipped) == 0
return [n for n in to_process if n not in skipped]
@compatibility(is_backward_compatible=False)
def is_impure(self, impure_random: bool = True) -> bool:
@ -828,12 +842,19 @@ class Node(_NodeBase):
new_input (Node): The new input node to replace ``old_input``.
"""
def maybe_replace_node(n: Node) -> Node:
return new_input if n == old_input else n
m = self.graph.owning_module
if getattr(m, "_replace_hooks", None):
for replace_hook in m._replace_hooks:
replace_hook(old=old_input, new=new_input.name, user=self)
self._replace_input_with(old_input, new_input)
new_args = _fx_map_arg(self.args, maybe_replace_node)
new_kwargs = _fx_map_arg(self.kwargs, maybe_replace_node)
assert isinstance(new_args, tuple)
assert isinstance(new_kwargs, dict)
self._update_args_kwargs(new_args, new_kwargs)
def _rename(self, candidate: str) -> None:
if candidate == self.name: