mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[torchfuzz] add norm operators (#164514)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164514 Approved by: https://github.com/pianpwk ghstack dependencies: #164432, #164434
This commit is contained in:
committed by
PyTorch MergeBot
parent
5bb8f04d3e
commit
3db2164341
@ -111,25 +111,41 @@ class DefaultFuzzTemplate(FuzzTemplate):
|
||||
|
||||
super().__init__(
|
||||
supported_ops=[
|
||||
# Basic arithmetic operations
|
||||
"torch.add",
|
||||
"torch.sub",
|
||||
"torch.mul",
|
||||
"torch.div",
|
||||
# Tensor shape operations
|
||||
"torch.Tensor.view",
|
||||
"torch.reshape",
|
||||
"torch.flatten",
|
||||
"torch.squeeze",
|
||||
"torch.unsqueeze",
|
||||
# Matrix operations
|
||||
"torch.mm",
|
||||
"torch.addmm",
|
||||
"torch.bmm",
|
||||
"torch.matmul",
|
||||
# Neural network operations
|
||||
"torch.nn.functional.embedding",
|
||||
"torch.nn.functional.linear",
|
||||
# Activation functions
|
||||
"torch.nn.functional.relu",
|
||||
"torch.nn.functional.leaky_relu",
|
||||
"torch.nn.functional.elu",
|
||||
"torch.nn.functional.gelu",
|
||||
"torch.nn.functional.silu",
|
||||
"torch.sigmoid",
|
||||
"torch.tanh",
|
||||
"torch.nn.functional.softmax",
|
||||
"torch.nn.functional.dropout",
|
||||
# Normalization layers
|
||||
"torch.nn.functional.layer_norm",
|
||||
"torch.nn.functional.rms_norm",
|
||||
"torch.nn.functional.batch_norm",
|
||||
"torch.nn.functional.group_norm",
|
||||
# Regularization
|
||||
"torch.nn.functional.dropout",
|
||||
],
|
||||
check=EagerVsFullGraphDynamicCompileWithNumericsCheck(),
|
||||
)
|
||||
|
@ -9,6 +9,16 @@ from torchfuzz.operators.base import Operator
|
||||
from torchfuzz.tensor_fuzzer import Spec, TensorSpec
|
||||
|
||||
|
||||
def is_float_dtype(dtype: torch.dtype) -> bool:
|
||||
"""Check if dtype is a floating point type."""
|
||||
return dtype in [
|
||||
torch.float32,
|
||||
torch.float64,
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]
|
||||
|
||||
|
||||
class EmbeddingOperator(Operator):
|
||||
"""Operator for torch.nn.functional.embedding."""
|
||||
|
||||
@ -28,12 +38,7 @@ class EmbeddingOperator(Operator):
|
||||
if len(output_spec.size) == 0:
|
||||
return False
|
||||
# Embedding outputs are typically float tensors
|
||||
return output_spec.dtype in [
|
||||
torch.float32,
|
||||
torch.float64,
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]
|
||||
return is_float_dtype(output_spec.dtype)
|
||||
|
||||
def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]:
|
||||
"""Generate input specs for embedding operation.
|
||||
@ -114,12 +119,7 @@ class LinearOperator(Operator):
|
||||
# Linear needs at least 1 dimension (output features)
|
||||
if len(output_spec.size) == 0:
|
||||
return False
|
||||
return output_spec.dtype in [
|
||||
torch.float32,
|
||||
torch.float64,
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]
|
||||
return is_float_dtype(output_spec.dtype)
|
||||
|
||||
def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]:
|
||||
"""Generate input specs for linear operation.
|
||||
@ -214,12 +214,7 @@ class ReLUOperator(Operator):
|
||||
"""ReLU can produce tensor outputs with floating point dtypes."""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
return False
|
||||
return output_spec.dtype in [
|
||||
torch.float32,
|
||||
torch.float64,
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]
|
||||
return is_float_dtype(output_spec.dtype)
|
||||
|
||||
def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]:
|
||||
"""Generate input specs for ReLU operation.
|
||||
@ -265,12 +260,7 @@ class SoftmaxOperator(Operator):
|
||||
# Softmax needs at least 1 dimension to apply softmax along a dimension
|
||||
if len(output_spec.size) == 0:
|
||||
return False
|
||||
return output_spec.dtype in [
|
||||
torch.float32,
|
||||
torch.float64,
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]
|
||||
return is_float_dtype(output_spec.dtype)
|
||||
|
||||
def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]:
|
||||
"""Generate input specs for softmax operation.
|
||||
@ -314,12 +304,7 @@ class DropoutOperator(Operator):
|
||||
"""Dropout can produce tensor outputs with floating point dtypes."""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
return False
|
||||
return output_spec.dtype in [
|
||||
torch.float32,
|
||||
torch.float64,
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]
|
||||
return is_float_dtype(output_spec.dtype)
|
||||
|
||||
def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]:
|
||||
"""Generate input specs for dropout operation.
|
||||
@ -366,12 +351,7 @@ class LayerNormOperator(Operator):
|
||||
# LayerNorm needs at least 1 dimension to normalize over
|
||||
if len(output_spec.size) == 0:
|
||||
return False
|
||||
return output_spec.dtype in [
|
||||
torch.float32,
|
||||
torch.float64,
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]
|
||||
return is_float_dtype(output_spec.dtype)
|
||||
|
||||
def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]:
|
||||
"""Generate input specs for layer_norm operation.
|
||||
@ -444,3 +424,529 @@ class LayerNormOperator(Operator):
|
||||
else: # len(input_names) == 3
|
||||
weight_name, bias_name = input_names[1], input_names[2]
|
||||
return f"{output_name} = torch.nn.functional.layer_norm({input_name}.to({target_dtype}), {normalized_shape}, weight={weight_name}.to({target_dtype}), bias={bias_name}.to({target_dtype}))"
|
||||
|
||||
|
||||
class RMSNormOperator(Operator):
|
||||
"""Operator for torch.nn.functional.rms_norm (Root Mean Square Normalization).
|
||||
|
||||
RMSNorm is commonly used in modern LLMs like LLaMA. It normalizes by the RMS of the input.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("torch.nn.functional.rms_norm")
|
||||
|
||||
@property
|
||||
def torch_op_name(self) -> Optional[str]:
|
||||
"""Return the torch operation name."""
|
||||
return "torch.nn.functional.rms_norm"
|
||||
|
||||
def can_produce(self, output_spec: Spec) -> bool:
|
||||
"""RMSNorm can produce tensor outputs with floating point dtypes."""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
return False
|
||||
# RMSNorm needs at least 1 dimension to normalize over
|
||||
if len(output_spec.size) == 0:
|
||||
return False
|
||||
return is_float_dtype(output_spec.dtype)
|
||||
|
||||
def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]:
|
||||
"""Generate input specs for RMSNorm operation.
|
||||
|
||||
RMSNorm requires:
|
||||
- input: input tensor
|
||||
- weight: (normalized_shape,) [optional]
|
||||
"""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
raise ValueError("RMSNormOperator can only produce TensorSpec outputs")
|
||||
|
||||
if len(output_spec.size) == 0:
|
||||
raise ValueError("RMSNorm output must have at least 1 dimension")
|
||||
|
||||
# Input tensor has same shape and dtype as output
|
||||
input_spec = TensorSpec(
|
||||
size=output_spec.size, stride=output_spec.stride, dtype=output_spec.dtype
|
||||
)
|
||||
|
||||
# Weight tensor (optional with 70% probability)
|
||||
normalized_shape = output_spec.size[-1:]
|
||||
specs = [input_spec]
|
||||
if random.random() < 0.7:
|
||||
weight_spec = TensorSpec(
|
||||
size=normalized_shape, stride=(1,), dtype=output_spec.dtype
|
||||
)
|
||||
specs.append(weight_spec)
|
||||
|
||||
from typing import cast
|
||||
|
||||
return cast(list[Spec], specs)
|
||||
|
||||
def codegen(
|
||||
self, output_name: str, input_names: list[str], output_spec: Spec
|
||||
) -> str:
|
||||
"""Generate code for RMSNorm operation."""
|
||||
if len(input_names) < 1 or len(input_names) > 2:
|
||||
raise ValueError("RMSNorm requires 1-2 inputs: input, optional weight")
|
||||
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
raise ValueError("RMSNormOperator can only produce TensorSpec outputs")
|
||||
|
||||
target_dtype = str(output_spec.dtype)
|
||||
input_name = input_names[0]
|
||||
|
||||
# Normalize over the last dimension
|
||||
normalized_shape = f"({output_spec.size[-1]},)"
|
||||
|
||||
if len(input_names) == 1:
|
||||
return f"{output_name} = torch.nn.functional.rms_norm({input_name}.to({target_dtype}), {normalized_shape})"
|
||||
else: # len(input_names) == 2
|
||||
weight_name = input_names[1]
|
||||
return f"{output_name} = torch.nn.functional.rms_norm({input_name}.to({target_dtype}), {normalized_shape}, weight={weight_name}.to({target_dtype}))"
|
||||
|
||||
|
||||
class GELUOperator(Operator):
|
||||
"""Operator for torch.nn.functional.gelu (Gaussian Error Linear Unit)."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("torch.nn.functional.gelu")
|
||||
|
||||
@property
|
||||
def torch_op_name(self) -> Optional[str]:
|
||||
"""Return the torch operation name."""
|
||||
return "torch.nn.functional.gelu"
|
||||
|
||||
def can_produce(self, output_spec: Spec) -> bool:
|
||||
"""GELU can produce tensor outputs with floating point dtypes."""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
return False
|
||||
return is_float_dtype(output_spec.dtype)
|
||||
|
||||
def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]:
|
||||
"""Generate input specs for GELU operation.
|
||||
|
||||
GELU is element-wise, so input shape matches output shape.
|
||||
"""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
raise ValueError("GELUOperator can only produce TensorSpec outputs")
|
||||
|
||||
input_spec = TensorSpec(
|
||||
size=output_spec.size, stride=output_spec.stride, dtype=output_spec.dtype
|
||||
)
|
||||
|
||||
return [input_spec]
|
||||
|
||||
def codegen(
|
||||
self, output_name: str, input_names: list[str], output_spec: Spec
|
||||
) -> str:
|
||||
"""Generate code for GELU operation."""
|
||||
if len(input_names) != 1:
|
||||
raise ValueError("GELU requires exactly 1 input")
|
||||
|
||||
input_name = input_names[0]
|
||||
return f"{output_name} = torch.nn.functional.gelu({input_name})"
|
||||
|
||||
|
||||
class SigmoidOperator(Operator):
|
||||
"""Operator for torch.sigmoid."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("torch.sigmoid")
|
||||
|
||||
@property
|
||||
def torch_op_name(self) -> Optional[str]:
|
||||
"""Return the torch operation name."""
|
||||
return "torch.sigmoid"
|
||||
|
||||
def can_produce(self, output_spec: Spec) -> bool:
|
||||
"""Sigmoid can produce tensor outputs with floating point dtypes."""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
return False
|
||||
return is_float_dtype(output_spec.dtype)
|
||||
|
||||
def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]:
|
||||
"""Generate input specs for sigmoid operation.
|
||||
|
||||
Sigmoid is element-wise, so input shape matches output shape.
|
||||
"""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
raise ValueError("SigmoidOperator can only produce TensorSpec outputs")
|
||||
|
||||
input_spec = TensorSpec(
|
||||
size=output_spec.size, stride=output_spec.stride, dtype=output_spec.dtype
|
||||
)
|
||||
|
||||
return [input_spec]
|
||||
|
||||
def codegen(
|
||||
self, output_name: str, input_names: list[str], output_spec: Spec
|
||||
) -> str:
|
||||
"""Generate code for sigmoid operation."""
|
||||
if len(input_names) != 1:
|
||||
raise ValueError("Sigmoid requires exactly 1 input")
|
||||
|
||||
input_name = input_names[0]
|
||||
return f"{output_name} = torch.sigmoid({input_name})"
|
||||
|
||||
|
||||
class TanhOperator(Operator):
|
||||
"""Operator for torch.tanh."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("torch.tanh")
|
||||
|
||||
@property
|
||||
def torch_op_name(self) -> Optional[str]:
|
||||
"""Return the torch operation name."""
|
||||
return "torch.tanh"
|
||||
|
||||
def can_produce(self, output_spec: Spec) -> bool:
|
||||
"""Tanh can produce tensor outputs with floating point dtypes."""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
return False
|
||||
return is_float_dtype(output_spec.dtype)
|
||||
|
||||
def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]:
|
||||
"""Generate input specs for tanh operation.
|
||||
|
||||
Tanh is element-wise, so input shape matches output shape.
|
||||
"""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
raise ValueError("TanhOperator can only produce TensorSpec outputs")
|
||||
|
||||
input_spec = TensorSpec(
|
||||
size=output_spec.size, stride=output_spec.stride, dtype=output_spec.dtype
|
||||
)
|
||||
|
||||
return [input_spec]
|
||||
|
||||
def codegen(
|
||||
self, output_name: str, input_names: list[str], output_spec: Spec
|
||||
) -> str:
|
||||
"""Generate code for tanh operation."""
|
||||
if len(input_names) != 1:
|
||||
raise ValueError("Tanh requires exactly 1 input")
|
||||
|
||||
input_name = input_names[0]
|
||||
return f"{output_name} = torch.tanh({input_name})"
|
||||
|
||||
|
||||
class BatchNormOperator(Operator):
|
||||
"""Operator for torch.nn.functional.batch_norm."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("torch.nn.functional.batch_norm")
|
||||
|
||||
@property
|
||||
def torch_op_name(self) -> Optional[str]:
|
||||
"""Return the torch operation name."""
|
||||
return "torch.nn.functional.batch_norm"
|
||||
|
||||
def can_produce(self, output_spec: Spec) -> bool:
|
||||
"""BatchNorm can produce tensor outputs with floating point dtypes."""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
return False
|
||||
# BatchNorm needs at least 2 dimensions (batch, features)
|
||||
if len(output_spec.size) < 2:
|
||||
return False
|
||||
# Channel dimension (second dimension) must be greater than 0
|
||||
if output_spec.size[1] == 0:
|
||||
return False
|
||||
return is_float_dtype(output_spec.dtype)
|
||||
|
||||
def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]:
|
||||
"""Generate input specs for batch_norm operation.
|
||||
|
||||
BatchNorm requires:
|
||||
- input: (N, C, ...) where N is batch and C is channels
|
||||
- running_mean: (C,) [optional]
|
||||
- running_var: (C,) [optional]
|
||||
- weight: (C,) [optional]
|
||||
- bias: (C,) [optional]
|
||||
"""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
raise ValueError("BatchNormOperator can only produce TensorSpec outputs")
|
||||
|
||||
if len(output_spec.size) < 2:
|
||||
raise ValueError("BatchNorm output must have at least 2 dimensions")
|
||||
|
||||
# Input tensor has same shape and dtype as output
|
||||
input_spec = TensorSpec(
|
||||
size=output_spec.size, stride=output_spec.stride, dtype=output_spec.dtype
|
||||
)
|
||||
|
||||
# Channel dimension is the second dimension
|
||||
num_features = output_spec.size[1]
|
||||
|
||||
specs = [input_spec]
|
||||
|
||||
# Add running_mean and running_var (required for inference mode)
|
||||
running_mean_spec = TensorSpec(
|
||||
size=(num_features,), stride=(1,), dtype=output_spec.dtype
|
||||
)
|
||||
running_var_spec = TensorSpec(
|
||||
size=(num_features,), stride=(1,), dtype=output_spec.dtype
|
||||
)
|
||||
specs.extend([running_mean_spec, running_var_spec])
|
||||
|
||||
# Add weight and bias (optional with 70% probability)
|
||||
if random.random() < 0.7:
|
||||
weight_spec = TensorSpec(
|
||||
size=(num_features,), stride=(1,), dtype=output_spec.dtype
|
||||
)
|
||||
specs.append(weight_spec)
|
||||
|
||||
if random.random() < 0.7:
|
||||
bias_spec = TensorSpec(
|
||||
size=(num_features,), stride=(1,), dtype=output_spec.dtype
|
||||
)
|
||||
specs.append(bias_spec)
|
||||
|
||||
from typing import cast
|
||||
|
||||
return cast(list[Spec], specs)
|
||||
|
||||
def codegen(
|
||||
self, output_name: str, input_names: list[str], output_spec: Spec
|
||||
) -> str:
|
||||
"""Generate code for batch_norm operation."""
|
||||
if len(input_names) < 3 or len(input_names) > 5:
|
||||
raise ValueError(
|
||||
"BatchNorm requires 3-5 inputs: input, running_mean, running_var, optional weight, optional bias"
|
||||
)
|
||||
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
raise ValueError("BatchNormOperator can only produce TensorSpec outputs")
|
||||
|
||||
target_dtype = str(output_spec.dtype)
|
||||
input_name = input_names[0]
|
||||
running_mean_name = input_names[1]
|
||||
running_var_name = input_names[2]
|
||||
|
||||
# Use training=False for deterministic behavior
|
||||
if len(input_names) == 3:
|
||||
return f"{output_name} = torch.nn.functional.batch_norm({input_name}.to({target_dtype}), {running_mean_name}.to({target_dtype}), {running_var_name}.to({target_dtype}), training=False)"
|
||||
elif len(input_names) == 4:
|
||||
weight_name = input_names[3]
|
||||
return f"{output_name} = torch.nn.functional.batch_norm({input_name}.to({target_dtype}), {running_mean_name}.to({target_dtype}), {running_var_name}.to({target_dtype}), weight={weight_name}.to({target_dtype}), training=False)"
|
||||
else: # len(input_names) == 5
|
||||
weight_name = input_names[3]
|
||||
bias_name = input_names[4]
|
||||
return f"{output_name} = torch.nn.functional.batch_norm({input_name}.to({target_dtype}), {running_mean_name}.to({target_dtype}), {running_var_name}.to({target_dtype}), weight={weight_name}.to({target_dtype}), bias={bias_name}.to({target_dtype}), training=False)"
|
||||
|
||||
|
||||
class GroupNormOperator(Operator):
|
||||
"""Operator for torch.nn.functional.group_norm."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("torch.nn.functional.group_norm")
|
||||
|
||||
@property
|
||||
def torch_op_name(self) -> Optional[str]:
|
||||
"""Return the torch operation name."""
|
||||
return "torch.nn.functional.group_norm"
|
||||
|
||||
def can_produce(self, output_spec: Spec) -> bool:
|
||||
"""GroupNorm can produce tensor outputs with floating point dtypes."""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
return False
|
||||
# GroupNorm needs at least 2 dimensions (batch, channels)
|
||||
if len(output_spec.size) < 2:
|
||||
return False
|
||||
return is_float_dtype(output_spec.dtype)
|
||||
|
||||
def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]:
|
||||
"""Generate input specs for group_norm operation.
|
||||
|
||||
GroupNorm requires:
|
||||
- input: (N, C, ...) where N is batch and C is channels
|
||||
- weight: (C,) [optional]
|
||||
- bias: (C,) [optional]
|
||||
"""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
raise ValueError("GroupNormOperator can only produce TensorSpec outputs")
|
||||
|
||||
if len(output_spec.size) < 2:
|
||||
raise ValueError("GroupNorm output must have at least 2 dimensions")
|
||||
|
||||
# Input tensor has same shape and dtype as output
|
||||
input_spec = TensorSpec(
|
||||
size=output_spec.size, stride=output_spec.stride, dtype=output_spec.dtype
|
||||
)
|
||||
|
||||
# Channel dimension is the second dimension
|
||||
num_channels = output_spec.size[1]
|
||||
|
||||
specs = [input_spec]
|
||||
|
||||
# Add weight and bias (optional with 70% probability)
|
||||
if random.random() < 0.7:
|
||||
weight_spec = TensorSpec(
|
||||
size=(num_channels,), stride=(1,), dtype=output_spec.dtype
|
||||
)
|
||||
specs.append(weight_spec)
|
||||
|
||||
if random.random() < 0.7:
|
||||
bias_spec = TensorSpec(
|
||||
size=(num_channels,), stride=(1,), dtype=output_spec.dtype
|
||||
)
|
||||
specs.append(bias_spec)
|
||||
|
||||
from typing import cast
|
||||
|
||||
return cast(list[Spec], specs)
|
||||
|
||||
def codegen(
|
||||
self, output_name: str, input_names: list[str], output_spec: Spec
|
||||
) -> str:
|
||||
"""Generate code for group_norm operation."""
|
||||
if len(input_names) < 1 or len(input_names) > 3:
|
||||
raise ValueError(
|
||||
"GroupNorm requires 1-3 inputs: input, optional weight, optional bias"
|
||||
)
|
||||
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
raise ValueError("GroupNormOperator can only produce TensorSpec outputs")
|
||||
|
||||
target_dtype = str(output_spec.dtype)
|
||||
input_name = input_names[0]
|
||||
|
||||
# Determine number of groups (must divide num_channels evenly)
|
||||
num_channels = output_spec.size[1]
|
||||
# Common choices: 32, 16, 8, or equal to channels (instance norm)
|
||||
possible_groups = [g for g in [32, 16, 8, 4, 2, 1] if num_channels % g == 0]
|
||||
num_groups = possible_groups[0] if possible_groups else 1
|
||||
|
||||
if len(input_names) == 1:
|
||||
return f"{output_name} = torch.nn.functional.group_norm({input_name}.to({target_dtype}), {num_groups})"
|
||||
elif len(input_names) == 2:
|
||||
weight_name = input_names[1]
|
||||
return f"{output_name} = torch.nn.functional.group_norm({input_name}.to({target_dtype}), {num_groups}, weight={weight_name}.to({target_dtype}))"
|
||||
else: # len(input_names) == 3
|
||||
weight_name = input_names[1]
|
||||
bias_name = input_names[2]
|
||||
return f"{output_name} = torch.nn.functional.group_norm({input_name}.to({target_dtype}), {num_groups}, weight={weight_name}.to({target_dtype}), bias={bias_name}.to({target_dtype}))"
|
||||
|
||||
|
||||
class LeakyReLUOperator(Operator):
|
||||
"""Operator for torch.nn.functional.leaky_relu."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("torch.nn.functional.leaky_relu")
|
||||
|
||||
@property
|
||||
def torch_op_name(self) -> Optional[str]:
|
||||
"""Return the torch operation name."""
|
||||
return "torch.nn.functional.leaky_relu"
|
||||
|
||||
def can_produce(self, output_spec: Spec) -> bool:
|
||||
"""LeakyReLU can produce tensor outputs with floating point dtypes."""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
return False
|
||||
return is_float_dtype(output_spec.dtype)
|
||||
|
||||
def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]:
|
||||
"""Generate input specs for LeakyReLU operation.
|
||||
|
||||
LeakyReLU is element-wise, so input shape matches output shape.
|
||||
"""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
raise ValueError("LeakyReLUOperator can only produce TensorSpec outputs")
|
||||
|
||||
input_spec = TensorSpec(
|
||||
size=output_spec.size, stride=output_spec.stride, dtype=output_spec.dtype
|
||||
)
|
||||
|
||||
return [input_spec]
|
||||
|
||||
def codegen(
|
||||
self, output_name: str, input_names: list[str], output_spec: Spec
|
||||
) -> str:
|
||||
"""Generate code for LeakyReLU operation."""
|
||||
if len(input_names) != 1:
|
||||
raise ValueError("LeakyReLU requires exactly 1 input")
|
||||
|
||||
input_name = input_names[0]
|
||||
return f"{output_name} = torch.nn.functional.leaky_relu({input_name}, negative_slope=0.01)"
|
||||
|
||||
|
||||
class ELUOperator(Operator):
|
||||
"""Operator for torch.nn.functional.elu (Exponential Linear Unit)."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("torch.nn.functional.elu")
|
||||
|
||||
@property
|
||||
def torch_op_name(self) -> Optional[str]:
|
||||
"""Return the torch operation name."""
|
||||
return "torch.nn.functional.elu"
|
||||
|
||||
def can_produce(self, output_spec: Spec) -> bool:
|
||||
"""ELU can produce tensor outputs with floating point dtypes."""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
return False
|
||||
return is_float_dtype(output_spec.dtype)
|
||||
|
||||
def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]:
|
||||
"""Generate input specs for ELU operation.
|
||||
|
||||
ELU is element-wise, so input shape matches output shape.
|
||||
"""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
raise ValueError("ELUOperator can only produce TensorSpec outputs")
|
||||
|
||||
input_spec = TensorSpec(
|
||||
size=output_spec.size, stride=output_spec.stride, dtype=output_spec.dtype
|
||||
)
|
||||
|
||||
return [input_spec]
|
||||
|
||||
def codegen(
|
||||
self, output_name: str, input_names: list[str], output_spec: Spec
|
||||
) -> str:
|
||||
"""Generate code for ELU operation."""
|
||||
if len(input_names) != 1:
|
||||
raise ValueError("ELU requires exactly 1 input")
|
||||
|
||||
input_name = input_names[0]
|
||||
return f"{output_name} = torch.nn.functional.elu({input_name})"
|
||||
|
||||
|
||||
class SiLUOperator(Operator):
|
||||
"""Operator for torch.nn.functional.silu (Sigmoid Linear Unit, also known as Swish)."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("torch.nn.functional.silu")
|
||||
|
||||
@property
|
||||
def torch_op_name(self) -> Optional[str]:
|
||||
"""Return the torch operation name."""
|
||||
return "torch.nn.functional.silu"
|
||||
|
||||
def can_produce(self, output_spec: Spec) -> bool:
|
||||
"""SiLU can produce tensor outputs with floating point dtypes."""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
return False
|
||||
return is_float_dtype(output_spec.dtype)
|
||||
|
||||
def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]:
|
||||
"""Generate input specs for SiLU operation.
|
||||
|
||||
SiLU is element-wise, so input shape matches output shape.
|
||||
"""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
raise ValueError("SiLUOperator can only produce TensorSpec outputs")
|
||||
|
||||
input_spec = TensorSpec(
|
||||
size=output_spec.size, stride=output_spec.stride, dtype=output_spec.dtype
|
||||
)
|
||||
|
||||
return [input_spec]
|
||||
|
||||
def codegen(
|
||||
self, output_name: str, input_names: list[str], output_spec: Spec
|
||||
) -> str:
|
||||
"""Generate code for SiLU operation."""
|
||||
if len(input_names) != 1:
|
||||
raise ValueError("SiLU requires exactly 1 input")
|
||||
|
||||
input_name = input_names[0]
|
||||
return f"{output_name} = torch.nn.functional.silu({input_name})"
|
||||
|
@ -21,12 +21,21 @@ from torchfuzz.operators.matrix_multiply import (
|
||||
MMOperator,
|
||||
)
|
||||
from torchfuzz.operators.nn_functional import (
|
||||
BatchNormOperator,
|
||||
DropoutOperator,
|
||||
ELUOperator,
|
||||
EmbeddingOperator,
|
||||
GELUOperator,
|
||||
GroupNormOperator,
|
||||
LayerNormOperator,
|
||||
LeakyReLUOperator,
|
||||
LinearOperator,
|
||||
ReLUOperator,
|
||||
RMSNormOperator,
|
||||
SigmoidOperator,
|
||||
SiLUOperator,
|
||||
SoftmaxOperator,
|
||||
TanhOperator,
|
||||
)
|
||||
from torchfuzz.operators.nonzero import NonzeroOperator
|
||||
from torchfuzz.operators.scalar_pointwise import (
|
||||
@ -92,10 +101,25 @@ class OperatorRegistry:
|
||||
# Neural network functional operators
|
||||
self.register(EmbeddingOperator())
|
||||
self.register(LinearOperator())
|
||||
|
||||
# Activation functions
|
||||
self.register(ReLUOperator())
|
||||
self.register(LeakyReLUOperator())
|
||||
self.register(ELUOperator())
|
||||
self.register(GELUOperator())
|
||||
self.register(SiLUOperator())
|
||||
self.register(SigmoidOperator())
|
||||
self.register(TanhOperator())
|
||||
self.register(SoftmaxOperator())
|
||||
self.register(DropoutOperator())
|
||||
|
||||
# Normalization layers
|
||||
self.register(LayerNormOperator())
|
||||
self.register(RMSNormOperator())
|
||||
self.register(BatchNormOperator())
|
||||
self.register(GroupNormOperator())
|
||||
|
||||
# Regularization
|
||||
self.register(DropoutOperator())
|
||||
|
||||
def register(self, operator: Operator):
|
||||
"""Register an operator in the registry."""
|
||||
|
Reference in New Issue
Block a user