[torchfuzz] add matmuls (#164284)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164284
Approved by: https://github.com/pianpwk
ghstack dependencies: #164034, #164209, #164211, #164210, #164397
This commit is contained in:
bobrenjc93
2025-10-01 15:59:15 -07:00
committed by PyTorch MergeBot
parent 144378615a
commit 0fbe3f19c7
6 changed files with 465 additions and 2 deletions

View File

@ -117,7 +117,13 @@ class DefaultFuzzTemplate(FuzzTemplate):
"torch.div",
"torch.Tensor.view",
"torch.reshape",
"torch.flattentorch.squeezetorch.unsqueeze",
"torch.flatten",
"torch.squeeze",
"torch.unsqueeze",
"torch.mm",
"torch.addmm",
"torch.bmm",
"torch.matmul",
],
check=EagerVsFullGraphDynamicCompileCheck(),
)
@ -198,6 +204,10 @@ class DTensorFuzzTemplate(FuzzTemplate):
"torch.sub",
"torch.mul",
"torch.div",
"torch.mm",
"torch.addmm",
"torch.bmm",
"torch.matmul",
],
check=EagerVsFullGraphDynamicCompileCheck(),
)

View File

@ -156,6 +156,7 @@ def fuzz_and_execute(
traceback.print_exc()
error_message = str(e)
print(f"Error: {error_message}")
sys.exit(1)
if __name__ == "__main__":

View File

@ -11,6 +11,12 @@ from torchfuzz.operators.layout import (
UnsqueezeOperator,
ViewOperator,
)
from torchfuzz.operators.matrix_multiply import (
AddmmOperator,
BmmOperator,
MatmulOperator,
MMOperator,
)
from torchfuzz.operators.registry import get_operator, list_operators, register_operator
from torchfuzz.operators.scalar_pointwise import (
ScalarAddOperator,
@ -48,6 +54,10 @@ __all__ = [
"FlattenOperator",
"SqueezeOperator",
"UnsqueezeOperator",
"MMOperator",
"AddmmOperator",
"BmmOperator",
"MatmulOperator",
"get_operator",
"register_operator",
"list_operators",

View File

@ -0,0 +1,430 @@
"""Matrix multiplication operator implementations."""
import random
from typing import Optional
import torch
from torchfuzz.operators.base import Operator
from torchfuzz.tensor_fuzzer import Spec, TensorSpec
# Type promotion imports removed since we now use explicit casting in codegen
class MatrixMultiplyOperator(Operator):
"""Base class for matrix multiplication operations."""
def __init__(self, name: str, torch_op: str):
super().__init__(name)
self._torch_op = torch_op
@property
def torch_op_name(self) -> Optional[str]:
"""Return the torch operation name."""
return self._torch_op
def can_produce(self, output_spec: Spec) -> bool:
"""Matrix multiply operations can produce float/complex tensors of dimension >= 2."""
if not isinstance(output_spec, TensorSpec):
return False
# Must have at least 2 dimensions for matrix multiplication
if len(output_spec.size) < 2:
return False
# Matrix multiply doesn't work with bool or integer types for gradients
if output_spec.dtype in [
torch.bool,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
]:
return False
return True
def _get_compatible_dtype(self, output_dtype):
"""Get a compatible dtype for matrix multiplication."""
# For matrix multiplication, we need to be flexible with input dtypes
# since earlier operations may have performed type promotion.
# We'll let the fuzzer generate whatever dtypes result from earlier operations
# and rely on the operation graph to ensure compatibility.
# Return the output dtype as a starting point, but this may be overridden
# by the actual tensor specs generated by the fuzzer.
return [output_dtype, output_dtype]
class MMOperator(MatrixMultiplyOperator):
"""Operator for matrix multiplication (torch.mm)."""
def __init__(self):
super().__init__("mm", "torch.mm")
def can_produce(self, output_spec: Spec) -> bool:
"""MM requires exactly 2D tensors."""
if not isinstance(output_spec, TensorSpec):
return False
# Must have exactly 2 dimensions for torch.mm
if len(output_spec.size) != 2:
return False
# Matrix multiply doesn't work with bool or integer types for gradients
if output_spec.dtype in [
torch.bool,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
]:
return False
return True
def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]:
"""Generate input specs for matrix multiplication."""
if not isinstance(output_spec, TensorSpec):
raise ValueError("MMOperator can only produce TensorSpec outputs")
if len(output_spec.size) != 2:
raise ValueError("torch.mm requires 2D tensors")
m, n = output_spec.size
# Choose a random inner dimension k
k = random.randint(1, 16)
# Get compatible dtypes
dtypes = self._get_compatible_dtype(output_spec.dtype)
# First tensor: [m, k]
input1_spec = TensorSpec(
size=(m, k),
stride=(k, 1), # Contiguous stride
dtype=dtypes[0],
)
# Second tensor: [k, n]
input2_spec = TensorSpec(
size=(k, n),
stride=(n, 1), # Contiguous stride
dtype=dtypes[1] if len(dtypes) > 1 else dtypes[0],
)
return [input1_spec, input2_spec]
def codegen(
self, output_name: str, input_names: list[str], output_spec: Spec
) -> str:
"""Generate code for matrix multiplication."""
if len(input_names) != 2:
raise ValueError("torch.mm requires exactly 2 inputs")
# Get target dtype
if isinstance(output_spec, TensorSpec):
target_dtype_str = f"torch.{output_spec.dtype}".replace(
"torch.torch.", "torch."
)
# Cast inputs to ensure compatible types
return (
f"{output_name} = torch.mm("
f"{input_names[0]}.to({target_dtype_str}), "
f"{input_names[1]}.to({target_dtype_str}))"
)
else:
return f"{output_name} = torch.mm({input_names[0]}, {input_names[1]})"
class AddmmOperator(MatrixMultiplyOperator):
"""Operator for additive matrix multiplication (torch.addmm)."""
def __init__(self):
super().__init__("addmm", "torch.addmm")
def can_produce(self, output_spec: Spec) -> bool:
"""Addmm requires exactly 2D tensors."""
if not isinstance(output_spec, TensorSpec):
return False
# Must have exactly 2 dimensions for torch.addmm
if len(output_spec.size) != 2:
return False
# Matrix multiply doesn't work with bool or integer types for gradients
if output_spec.dtype in [
torch.bool,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
]:
return False
return True
def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]:
"""Generate input specs for additive matrix multiplication."""
if not isinstance(output_spec, TensorSpec):
raise ValueError("AddmmOperator can only produce TensorSpec outputs")
if len(output_spec.size) != 2:
raise ValueError("torch.addmm requires 2D output tensor")
m, n = output_spec.size
# Choose a random inner dimension k
k = random.randint(1, 16)
# Get compatible dtypes
dtypes = self._get_compatible_dtype(output_spec.dtype)
# Bias tensor: [m, n] (same shape as output)
bias_spec = TensorSpec(
size=(m, n),
stride=(n, 1), # Contiguous stride
dtype=dtypes[0],
)
# First matrix: [m, k]
input1_spec = TensorSpec(
size=(m, k),
stride=(k, 1), # Contiguous stride
dtype=dtypes[1] if len(dtypes) > 1 else dtypes[0],
)
# Second matrix: [k, n]
input2_spec = TensorSpec(
size=(k, n),
stride=(n, 1), # Contiguous stride
dtype=dtypes[1] if len(dtypes) > 1 else dtypes[0],
)
return [bias_spec, input1_spec, input2_spec]
def codegen(
self, output_name: str, input_names: list[str], output_spec: Spec
) -> str:
"""Generate code for additive matrix multiplication."""
if len(input_names) != 3:
raise ValueError("torch.addmm requires exactly 3 inputs")
# Get target dtype
if isinstance(output_spec, TensorSpec):
target_dtype_str = f"torch.{output_spec.dtype}".replace(
"torch.torch.", "torch."
)
# Cast inputs to ensure compatible types
return (
f"{output_name} = torch.addmm("
f"{input_names[0]}.to({target_dtype_str}), "
f"{input_names[1]}.to({target_dtype_str}), "
f"{input_names[2]}.to({target_dtype_str}))"
)
else:
return f"{output_name} = torch.addmm({input_names[0]}, {input_names[1]}, {input_names[2]})"
class BmmOperator(MatrixMultiplyOperator):
"""Operator for batch matrix multiplication (torch.bmm)."""
def __init__(self):
super().__init__("bmm", "torch.bmm")
def can_produce(self, output_spec: Spec) -> bool:
"""Batch matrix multiply requires 3D tensors."""
if not isinstance(output_spec, TensorSpec):
return False
# Must have exactly 3 dimensions for batch matrix multiplication
if len(output_spec.size) != 3:
return False
# Matrix multiply doesn't work with bool or integer types for gradients
if output_spec.dtype in [
torch.bool,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
]:
return False
return True
def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]:
"""Generate input specs for batch matrix multiplication."""
if not isinstance(output_spec, TensorSpec):
raise ValueError("BmmOperator can only produce TensorSpec outputs")
if len(output_spec.size) != 3:
raise ValueError("torch.bmm requires 3D tensors")
b, m, n = output_spec.size
# Choose a random inner dimension k
k = random.randint(1, 16)
# Get compatible dtypes
dtypes = self._get_compatible_dtype(output_spec.dtype)
# First tensor: [b, m, k]
input1_spec = TensorSpec(
size=(b, m, k),
stride=(m * k, k, 1), # Contiguous stride
dtype=dtypes[0],
)
# Second tensor: [b, k, n]
input2_spec = TensorSpec(
size=(b, k, n),
stride=(k * n, n, 1), # Contiguous stride
dtype=dtypes[1] if len(dtypes) > 1 else dtypes[0],
)
return [input1_spec, input2_spec]
def codegen(
self, output_name: str, input_names: list[str], output_spec: Spec
) -> str:
"""Generate code for batch matrix multiplication."""
if len(input_names) != 2:
raise ValueError("torch.bmm requires exactly 2 inputs")
# Get target dtype
if isinstance(output_spec, TensorSpec):
target_dtype_str = f"torch.{output_spec.dtype}".replace(
"torch.torch.", "torch."
)
# Cast inputs to ensure compatible types
return (
f"{output_name} = torch.bmm("
f"{input_names[0]}.to({target_dtype_str}), "
f"{input_names[1]}.to({target_dtype_str}))"
)
else:
return f"{output_name} = torch.bmm({input_names[0]}, {input_names[1]})"
class MatmulOperator(MatrixMultiplyOperator):
"""Operator for general matrix multiplication (torch.matmul)."""
def __init__(self):
super().__init__("matmul", "torch.matmul")
def can_produce(self, output_spec: Spec) -> bool:
"""Matmul can handle various tensor dimensions >= 1."""
if not isinstance(output_spec, TensorSpec):
return False
# Must have at least 1 dimension
if len(output_spec.size) < 1:
return False
# Matrix multiply doesn't work with bool or integer types for gradients
if output_spec.dtype in [
torch.bool,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
]:
return False
return True
def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]:
"""Generate input specs for general matrix multiplication."""
if not isinstance(output_spec, TensorSpec):
raise ValueError("MatmulOperator can only produce TensorSpec outputs")
output_size = output_spec.size
output_dims = len(output_size)
# Get compatible dtypes
dtypes = self._get_compatible_dtype(output_spec.dtype)
if output_dims == 1:
# Matrix-vector multiplication: (n,) = (k,) @ (k, n) or (n,) = (n, k) @ (k,)
n = output_size[0]
k = random.randint(1, 16)
# Randomly choose between two valid patterns
if random.choice([True, False]):
# Pattern 1: (n,) = (k,) @ (k, n)
input1_spec = TensorSpec(size=(k,), stride=(1,), dtype=dtypes[0])
input2_spec = TensorSpec(
size=(k, n),
stride=(n, 1),
dtype=dtypes[1] if len(dtypes) > 1 else dtypes[0],
)
else:
# Pattern 2: (n,) = (n, k) @ (k,)
input1_spec = TensorSpec(size=(n, k), stride=(k, 1), dtype=dtypes[0])
input2_spec = TensorSpec(
size=(k,),
stride=(1,),
dtype=dtypes[1] if len(dtypes) > 1 else dtypes[0],
)
elif output_dims == 2:
# Matrix multiplication: (m, n) = (m, k) @ (k, n)
m, n = output_size
k = random.randint(1, 16)
input1_spec = TensorSpec(size=(m, k), stride=(k, 1), dtype=dtypes[0])
input2_spec = TensorSpec(
size=(k, n),
stride=(n, 1),
dtype=dtypes[1] if len(dtypes) > 1 else dtypes[0],
)
else:
# Batched matrix multiplication: (..., m, n) = (..., m, k) @ (..., k, n)
*batch_dims, m, n = output_size
k = random.randint(1, 16)
# Calculate strides for contiguous tensors
input1_size = tuple(batch_dims + [m, k])
input2_size = tuple(batch_dims + [k, n])
# Contiguous strides
input1_stride = [1]
for i in reversed(range(len(input1_size) - 1)):
input1_stride.append(input1_stride[-1] * input1_size[i + 1])
input1_stride = tuple(reversed(input1_stride))
input2_stride = [1]
for i in reversed(range(len(input2_size) - 1)):
input2_stride.append(input2_stride[-1] * input2_size[i + 1])
input2_stride = tuple(reversed(input2_stride))
input1_spec = TensorSpec(
size=input1_size, stride=input1_stride, dtype=dtypes[0]
)
input2_spec = TensorSpec(
size=input2_size,
stride=input2_stride,
dtype=dtypes[1] if len(dtypes) > 1 else dtypes[0],
)
return [input1_spec, input2_spec]
def codegen(
self, output_name: str, input_names: list[str], output_spec: Spec
) -> str:
"""Generate code for general matrix multiplication."""
if len(input_names) != 2:
raise ValueError("torch.matmul requires exactly 2 inputs")
# Get target dtype
if isinstance(output_spec, TensorSpec):
target_dtype_str = f"torch.{output_spec.dtype}".replace(
"torch.torch.", "torch."
)
# Cast inputs to ensure compatible types
return (
f"{output_name} = torch.matmul("
f"{input_names[0]}.to({target_dtype_str}), "
f"{input_names[1]}.to({target_dtype_str}))"
)
else:
return f"{output_name} = torch.matmul({input_names[0]}, {input_names[1]})"

View File

@ -14,6 +14,12 @@ from torchfuzz.operators.layout import (
ViewOperator,
)
from torchfuzz.operators.masked_select import MaskedSelectOperator
from torchfuzz.operators.matrix_multiply import (
AddmmOperator,
BmmOperator,
MatmulOperator,
MMOperator,
)
from torchfuzz.operators.nonzero import NonzeroOperator
from torchfuzz.operators.scalar_pointwise import (
ScalarAddOperator,
@ -69,6 +75,12 @@ class OperatorRegistry:
self.register(SqueezeOperator())
self.register(UnsqueezeOperator())
# Matrix multiplication operators
self.register(MMOperator())
self.register(AddmmOperator())
self.register(BmmOperator())
self.register(MatmulOperator())
def register(self, operator: Operator):
"""Register an operator in the registry."""
self._operators[operator.name] = operator

View File

@ -68,7 +68,7 @@ def fuzz_torch_tensor_type(template: str = "default") -> torch.dtype:
return random.choice(tensor_dtypes)
def fuzz_tensor_size(max_dims: int = 6, max_size_per_dim: int = 30) -> tuple[int, ...]:
def fuzz_tensor_size(max_dims: int = 3, max_size_per_dim: int = 30) -> tuple[int, ...]:
"""
Fuzzes PyTorch tensor sizes by generating random tensor shapes.