mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[torchfuzz] add matmuls (#164284)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164284 Approved by: https://github.com/pianpwk ghstack dependencies: #164034, #164209, #164211, #164210, #164397
This commit is contained in:
committed by
PyTorch MergeBot
parent
144378615a
commit
0fbe3f19c7
@ -117,7 +117,13 @@ class DefaultFuzzTemplate(FuzzTemplate):
|
||||
"torch.div",
|
||||
"torch.Tensor.view",
|
||||
"torch.reshape",
|
||||
"torch.flattentorch.squeezetorch.unsqueeze",
|
||||
"torch.flatten",
|
||||
"torch.squeeze",
|
||||
"torch.unsqueeze",
|
||||
"torch.mm",
|
||||
"torch.addmm",
|
||||
"torch.bmm",
|
||||
"torch.matmul",
|
||||
],
|
||||
check=EagerVsFullGraphDynamicCompileCheck(),
|
||||
)
|
||||
@ -198,6 +204,10 @@ class DTensorFuzzTemplate(FuzzTemplate):
|
||||
"torch.sub",
|
||||
"torch.mul",
|
||||
"torch.div",
|
||||
"torch.mm",
|
||||
"torch.addmm",
|
||||
"torch.bmm",
|
||||
"torch.matmul",
|
||||
],
|
||||
check=EagerVsFullGraphDynamicCompileCheck(),
|
||||
)
|
||||
|
@ -156,6 +156,7 @@ def fuzz_and_execute(
|
||||
traceback.print_exc()
|
||||
error_message = str(e)
|
||||
print(f"Error: {error_message}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -11,6 +11,12 @@ from torchfuzz.operators.layout import (
|
||||
UnsqueezeOperator,
|
||||
ViewOperator,
|
||||
)
|
||||
from torchfuzz.operators.matrix_multiply import (
|
||||
AddmmOperator,
|
||||
BmmOperator,
|
||||
MatmulOperator,
|
||||
MMOperator,
|
||||
)
|
||||
from torchfuzz.operators.registry import get_operator, list_operators, register_operator
|
||||
from torchfuzz.operators.scalar_pointwise import (
|
||||
ScalarAddOperator,
|
||||
@ -48,6 +54,10 @@ __all__ = [
|
||||
"FlattenOperator",
|
||||
"SqueezeOperator",
|
||||
"UnsqueezeOperator",
|
||||
"MMOperator",
|
||||
"AddmmOperator",
|
||||
"BmmOperator",
|
||||
"MatmulOperator",
|
||||
"get_operator",
|
||||
"register_operator",
|
||||
"list_operators",
|
||||
|
@ -0,0 +1,430 @@
|
||||
"""Matrix multiplication operator implementations."""
|
||||
|
||||
import random
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from torchfuzz.operators.base import Operator
|
||||
from torchfuzz.tensor_fuzzer import Spec, TensorSpec
|
||||
|
||||
|
||||
# Type promotion imports removed since we now use explicit casting in codegen
|
||||
|
||||
|
||||
class MatrixMultiplyOperator(Operator):
|
||||
"""Base class for matrix multiplication operations."""
|
||||
|
||||
def __init__(self, name: str, torch_op: str):
|
||||
super().__init__(name)
|
||||
self._torch_op = torch_op
|
||||
|
||||
@property
|
||||
def torch_op_name(self) -> Optional[str]:
|
||||
"""Return the torch operation name."""
|
||||
return self._torch_op
|
||||
|
||||
def can_produce(self, output_spec: Spec) -> bool:
|
||||
"""Matrix multiply operations can produce float/complex tensors of dimension >= 2."""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
return False
|
||||
|
||||
# Must have at least 2 dimensions for matrix multiplication
|
||||
if len(output_spec.size) < 2:
|
||||
return False
|
||||
|
||||
# Matrix multiply doesn't work with bool or integer types for gradients
|
||||
if output_spec.dtype in [
|
||||
torch.bool,
|
||||
torch.int8,
|
||||
torch.int16,
|
||||
torch.int32,
|
||||
torch.int64,
|
||||
]:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _get_compatible_dtype(self, output_dtype):
|
||||
"""Get a compatible dtype for matrix multiplication."""
|
||||
# For matrix multiplication, we need to be flexible with input dtypes
|
||||
# since earlier operations may have performed type promotion.
|
||||
# We'll let the fuzzer generate whatever dtypes result from earlier operations
|
||||
# and rely on the operation graph to ensure compatibility.
|
||||
# Return the output dtype as a starting point, but this may be overridden
|
||||
# by the actual tensor specs generated by the fuzzer.
|
||||
return [output_dtype, output_dtype]
|
||||
|
||||
|
||||
class MMOperator(MatrixMultiplyOperator):
|
||||
"""Operator for matrix multiplication (torch.mm)."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("mm", "torch.mm")
|
||||
|
||||
def can_produce(self, output_spec: Spec) -> bool:
|
||||
"""MM requires exactly 2D tensors."""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
return False
|
||||
|
||||
# Must have exactly 2 dimensions for torch.mm
|
||||
if len(output_spec.size) != 2:
|
||||
return False
|
||||
|
||||
# Matrix multiply doesn't work with bool or integer types for gradients
|
||||
if output_spec.dtype in [
|
||||
torch.bool,
|
||||
torch.int8,
|
||||
torch.int16,
|
||||
torch.int32,
|
||||
torch.int64,
|
||||
]:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]:
|
||||
"""Generate input specs for matrix multiplication."""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
raise ValueError("MMOperator can only produce TensorSpec outputs")
|
||||
|
||||
if len(output_spec.size) != 2:
|
||||
raise ValueError("torch.mm requires 2D tensors")
|
||||
|
||||
m, n = output_spec.size
|
||||
# Choose a random inner dimension k
|
||||
k = random.randint(1, 16)
|
||||
|
||||
# Get compatible dtypes
|
||||
dtypes = self._get_compatible_dtype(output_spec.dtype)
|
||||
|
||||
# First tensor: [m, k]
|
||||
input1_spec = TensorSpec(
|
||||
size=(m, k),
|
||||
stride=(k, 1), # Contiguous stride
|
||||
dtype=dtypes[0],
|
||||
)
|
||||
|
||||
# Second tensor: [k, n]
|
||||
input2_spec = TensorSpec(
|
||||
size=(k, n),
|
||||
stride=(n, 1), # Contiguous stride
|
||||
dtype=dtypes[1] if len(dtypes) > 1 else dtypes[0],
|
||||
)
|
||||
|
||||
return [input1_spec, input2_spec]
|
||||
|
||||
def codegen(
|
||||
self, output_name: str, input_names: list[str], output_spec: Spec
|
||||
) -> str:
|
||||
"""Generate code for matrix multiplication."""
|
||||
if len(input_names) != 2:
|
||||
raise ValueError("torch.mm requires exactly 2 inputs")
|
||||
|
||||
# Get target dtype
|
||||
if isinstance(output_spec, TensorSpec):
|
||||
target_dtype_str = f"torch.{output_spec.dtype}".replace(
|
||||
"torch.torch.", "torch."
|
||||
)
|
||||
# Cast inputs to ensure compatible types
|
||||
return (
|
||||
f"{output_name} = torch.mm("
|
||||
f"{input_names[0]}.to({target_dtype_str}), "
|
||||
f"{input_names[1]}.to({target_dtype_str}))"
|
||||
)
|
||||
else:
|
||||
return f"{output_name} = torch.mm({input_names[0]}, {input_names[1]})"
|
||||
|
||||
|
||||
class AddmmOperator(MatrixMultiplyOperator):
|
||||
"""Operator for additive matrix multiplication (torch.addmm)."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("addmm", "torch.addmm")
|
||||
|
||||
def can_produce(self, output_spec: Spec) -> bool:
|
||||
"""Addmm requires exactly 2D tensors."""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
return False
|
||||
|
||||
# Must have exactly 2 dimensions for torch.addmm
|
||||
if len(output_spec.size) != 2:
|
||||
return False
|
||||
|
||||
# Matrix multiply doesn't work with bool or integer types for gradients
|
||||
if output_spec.dtype in [
|
||||
torch.bool,
|
||||
torch.int8,
|
||||
torch.int16,
|
||||
torch.int32,
|
||||
torch.int64,
|
||||
]:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]:
|
||||
"""Generate input specs for additive matrix multiplication."""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
raise ValueError("AddmmOperator can only produce TensorSpec outputs")
|
||||
|
||||
if len(output_spec.size) != 2:
|
||||
raise ValueError("torch.addmm requires 2D output tensor")
|
||||
|
||||
m, n = output_spec.size
|
||||
# Choose a random inner dimension k
|
||||
k = random.randint(1, 16)
|
||||
|
||||
# Get compatible dtypes
|
||||
dtypes = self._get_compatible_dtype(output_spec.dtype)
|
||||
|
||||
# Bias tensor: [m, n] (same shape as output)
|
||||
bias_spec = TensorSpec(
|
||||
size=(m, n),
|
||||
stride=(n, 1), # Contiguous stride
|
||||
dtype=dtypes[0],
|
||||
)
|
||||
|
||||
# First matrix: [m, k]
|
||||
input1_spec = TensorSpec(
|
||||
size=(m, k),
|
||||
stride=(k, 1), # Contiguous stride
|
||||
dtype=dtypes[1] if len(dtypes) > 1 else dtypes[0],
|
||||
)
|
||||
|
||||
# Second matrix: [k, n]
|
||||
input2_spec = TensorSpec(
|
||||
size=(k, n),
|
||||
stride=(n, 1), # Contiguous stride
|
||||
dtype=dtypes[1] if len(dtypes) > 1 else dtypes[0],
|
||||
)
|
||||
|
||||
return [bias_spec, input1_spec, input2_spec]
|
||||
|
||||
def codegen(
|
||||
self, output_name: str, input_names: list[str], output_spec: Spec
|
||||
) -> str:
|
||||
"""Generate code for additive matrix multiplication."""
|
||||
if len(input_names) != 3:
|
||||
raise ValueError("torch.addmm requires exactly 3 inputs")
|
||||
|
||||
# Get target dtype
|
||||
if isinstance(output_spec, TensorSpec):
|
||||
target_dtype_str = f"torch.{output_spec.dtype}".replace(
|
||||
"torch.torch.", "torch."
|
||||
)
|
||||
# Cast inputs to ensure compatible types
|
||||
return (
|
||||
f"{output_name} = torch.addmm("
|
||||
f"{input_names[0]}.to({target_dtype_str}), "
|
||||
f"{input_names[1]}.to({target_dtype_str}), "
|
||||
f"{input_names[2]}.to({target_dtype_str}))"
|
||||
)
|
||||
else:
|
||||
return f"{output_name} = torch.addmm({input_names[0]}, {input_names[1]}, {input_names[2]})"
|
||||
|
||||
|
||||
class BmmOperator(MatrixMultiplyOperator):
|
||||
"""Operator for batch matrix multiplication (torch.bmm)."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("bmm", "torch.bmm")
|
||||
|
||||
def can_produce(self, output_spec: Spec) -> bool:
|
||||
"""Batch matrix multiply requires 3D tensors."""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
return False
|
||||
|
||||
# Must have exactly 3 dimensions for batch matrix multiplication
|
||||
if len(output_spec.size) != 3:
|
||||
return False
|
||||
|
||||
# Matrix multiply doesn't work with bool or integer types for gradients
|
||||
if output_spec.dtype in [
|
||||
torch.bool,
|
||||
torch.int8,
|
||||
torch.int16,
|
||||
torch.int32,
|
||||
torch.int64,
|
||||
]:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]:
|
||||
"""Generate input specs for batch matrix multiplication."""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
raise ValueError("BmmOperator can only produce TensorSpec outputs")
|
||||
|
||||
if len(output_spec.size) != 3:
|
||||
raise ValueError("torch.bmm requires 3D tensors")
|
||||
|
||||
b, m, n = output_spec.size
|
||||
# Choose a random inner dimension k
|
||||
k = random.randint(1, 16)
|
||||
|
||||
# Get compatible dtypes
|
||||
dtypes = self._get_compatible_dtype(output_spec.dtype)
|
||||
|
||||
# First tensor: [b, m, k]
|
||||
input1_spec = TensorSpec(
|
||||
size=(b, m, k),
|
||||
stride=(m * k, k, 1), # Contiguous stride
|
||||
dtype=dtypes[0],
|
||||
)
|
||||
|
||||
# Second tensor: [b, k, n]
|
||||
input2_spec = TensorSpec(
|
||||
size=(b, k, n),
|
||||
stride=(k * n, n, 1), # Contiguous stride
|
||||
dtype=dtypes[1] if len(dtypes) > 1 else dtypes[0],
|
||||
)
|
||||
|
||||
return [input1_spec, input2_spec]
|
||||
|
||||
def codegen(
|
||||
self, output_name: str, input_names: list[str], output_spec: Spec
|
||||
) -> str:
|
||||
"""Generate code for batch matrix multiplication."""
|
||||
if len(input_names) != 2:
|
||||
raise ValueError("torch.bmm requires exactly 2 inputs")
|
||||
|
||||
# Get target dtype
|
||||
if isinstance(output_spec, TensorSpec):
|
||||
target_dtype_str = f"torch.{output_spec.dtype}".replace(
|
||||
"torch.torch.", "torch."
|
||||
)
|
||||
# Cast inputs to ensure compatible types
|
||||
return (
|
||||
f"{output_name} = torch.bmm("
|
||||
f"{input_names[0]}.to({target_dtype_str}), "
|
||||
f"{input_names[1]}.to({target_dtype_str}))"
|
||||
)
|
||||
else:
|
||||
return f"{output_name} = torch.bmm({input_names[0]}, {input_names[1]})"
|
||||
|
||||
|
||||
class MatmulOperator(MatrixMultiplyOperator):
|
||||
"""Operator for general matrix multiplication (torch.matmul)."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("matmul", "torch.matmul")
|
||||
|
||||
def can_produce(self, output_spec: Spec) -> bool:
|
||||
"""Matmul can handle various tensor dimensions >= 1."""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
return False
|
||||
|
||||
# Must have at least 1 dimension
|
||||
if len(output_spec.size) < 1:
|
||||
return False
|
||||
|
||||
# Matrix multiply doesn't work with bool or integer types for gradients
|
||||
if output_spec.dtype in [
|
||||
torch.bool,
|
||||
torch.int8,
|
||||
torch.int16,
|
||||
torch.int32,
|
||||
torch.int64,
|
||||
]:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]:
|
||||
"""Generate input specs for general matrix multiplication."""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
raise ValueError("MatmulOperator can only produce TensorSpec outputs")
|
||||
|
||||
output_size = output_spec.size
|
||||
output_dims = len(output_size)
|
||||
|
||||
# Get compatible dtypes
|
||||
dtypes = self._get_compatible_dtype(output_spec.dtype)
|
||||
|
||||
if output_dims == 1:
|
||||
# Matrix-vector multiplication: (n,) = (k,) @ (k, n) or (n,) = (n, k) @ (k,)
|
||||
n = output_size[0]
|
||||
k = random.randint(1, 16)
|
||||
|
||||
# Randomly choose between two valid patterns
|
||||
if random.choice([True, False]):
|
||||
# Pattern 1: (n,) = (k,) @ (k, n)
|
||||
input1_spec = TensorSpec(size=(k,), stride=(1,), dtype=dtypes[0])
|
||||
input2_spec = TensorSpec(
|
||||
size=(k, n),
|
||||
stride=(n, 1),
|
||||
dtype=dtypes[1] if len(dtypes) > 1 else dtypes[0],
|
||||
)
|
||||
else:
|
||||
# Pattern 2: (n,) = (n, k) @ (k,)
|
||||
input1_spec = TensorSpec(size=(n, k), stride=(k, 1), dtype=dtypes[0])
|
||||
input2_spec = TensorSpec(
|
||||
size=(k,),
|
||||
stride=(1,),
|
||||
dtype=dtypes[1] if len(dtypes) > 1 else dtypes[0],
|
||||
)
|
||||
|
||||
elif output_dims == 2:
|
||||
# Matrix multiplication: (m, n) = (m, k) @ (k, n)
|
||||
m, n = output_size
|
||||
k = random.randint(1, 16)
|
||||
|
||||
input1_spec = TensorSpec(size=(m, k), stride=(k, 1), dtype=dtypes[0])
|
||||
input2_spec = TensorSpec(
|
||||
size=(k, n),
|
||||
stride=(n, 1),
|
||||
dtype=dtypes[1] if len(dtypes) > 1 else dtypes[0],
|
||||
)
|
||||
|
||||
else:
|
||||
# Batched matrix multiplication: (..., m, n) = (..., m, k) @ (..., k, n)
|
||||
*batch_dims, m, n = output_size
|
||||
k = random.randint(1, 16)
|
||||
|
||||
# Calculate strides for contiguous tensors
|
||||
input1_size = tuple(batch_dims + [m, k])
|
||||
input2_size = tuple(batch_dims + [k, n])
|
||||
|
||||
# Contiguous strides
|
||||
input1_stride = [1]
|
||||
for i in reversed(range(len(input1_size) - 1)):
|
||||
input1_stride.append(input1_stride[-1] * input1_size[i + 1])
|
||||
input1_stride = tuple(reversed(input1_stride))
|
||||
|
||||
input2_stride = [1]
|
||||
for i in reversed(range(len(input2_size) - 1)):
|
||||
input2_stride.append(input2_stride[-1] * input2_size[i + 1])
|
||||
input2_stride = tuple(reversed(input2_stride))
|
||||
|
||||
input1_spec = TensorSpec(
|
||||
size=input1_size, stride=input1_stride, dtype=dtypes[0]
|
||||
)
|
||||
input2_spec = TensorSpec(
|
||||
size=input2_size,
|
||||
stride=input2_stride,
|
||||
dtype=dtypes[1] if len(dtypes) > 1 else dtypes[0],
|
||||
)
|
||||
|
||||
return [input1_spec, input2_spec]
|
||||
|
||||
def codegen(
|
||||
self, output_name: str, input_names: list[str], output_spec: Spec
|
||||
) -> str:
|
||||
"""Generate code for general matrix multiplication."""
|
||||
if len(input_names) != 2:
|
||||
raise ValueError("torch.matmul requires exactly 2 inputs")
|
||||
|
||||
# Get target dtype
|
||||
if isinstance(output_spec, TensorSpec):
|
||||
target_dtype_str = f"torch.{output_spec.dtype}".replace(
|
||||
"torch.torch.", "torch."
|
||||
)
|
||||
# Cast inputs to ensure compatible types
|
||||
return (
|
||||
f"{output_name} = torch.matmul("
|
||||
f"{input_names[0]}.to({target_dtype_str}), "
|
||||
f"{input_names[1]}.to({target_dtype_str}))"
|
||||
)
|
||||
else:
|
||||
return f"{output_name} = torch.matmul({input_names[0]}, {input_names[1]})"
|
@ -14,6 +14,12 @@ from torchfuzz.operators.layout import (
|
||||
ViewOperator,
|
||||
)
|
||||
from torchfuzz.operators.masked_select import MaskedSelectOperator
|
||||
from torchfuzz.operators.matrix_multiply import (
|
||||
AddmmOperator,
|
||||
BmmOperator,
|
||||
MatmulOperator,
|
||||
MMOperator,
|
||||
)
|
||||
from torchfuzz.operators.nonzero import NonzeroOperator
|
||||
from torchfuzz.operators.scalar_pointwise import (
|
||||
ScalarAddOperator,
|
||||
@ -69,6 +75,12 @@ class OperatorRegistry:
|
||||
self.register(SqueezeOperator())
|
||||
self.register(UnsqueezeOperator())
|
||||
|
||||
# Matrix multiplication operators
|
||||
self.register(MMOperator())
|
||||
self.register(AddmmOperator())
|
||||
self.register(BmmOperator())
|
||||
self.register(MatmulOperator())
|
||||
|
||||
def register(self, operator: Operator):
|
||||
"""Register an operator in the registry."""
|
||||
self._operators[operator.name] = operator
|
||||
|
@ -68,7 +68,7 @@ def fuzz_torch_tensor_type(template: str = "default") -> torch.dtype:
|
||||
return random.choice(tensor_dtypes)
|
||||
|
||||
|
||||
def fuzz_tensor_size(max_dims: int = 6, max_size_per_dim: int = 30) -> tuple[int, ...]:
|
||||
def fuzz_tensor_size(max_dims: int = 3, max_size_per_dim: int = 30) -> tuple[int, ...]:
|
||||
"""
|
||||
Fuzzes PyTorch tensor sizes by generating random tensor shapes.
|
||||
|
||||
|
Reference in New Issue
Block a user