mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	[torchfuzz] add norm operators (#164514)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164514 Approved by: https://github.com/pianpwk ghstack dependencies: #164432, #164434
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							5bb8f04d3e
						
					
				
				
					commit
					3db2164341
				
			| @ -111,25 +111,41 @@ class DefaultFuzzTemplate(FuzzTemplate): | ||||
|  | ||||
|         super().__init__( | ||||
|             supported_ops=[ | ||||
|                 # Basic arithmetic operations | ||||
|                 "torch.add", | ||||
|                 "torch.sub", | ||||
|                 "torch.mul", | ||||
|                 "torch.div", | ||||
|                 # Tensor shape operations | ||||
|                 "torch.Tensor.view", | ||||
|                 "torch.reshape", | ||||
|                 "torch.flatten", | ||||
|                 "torch.squeeze", | ||||
|                 "torch.unsqueeze", | ||||
|                 # Matrix operations | ||||
|                 "torch.mm", | ||||
|                 "torch.addmm", | ||||
|                 "torch.bmm", | ||||
|                 "torch.matmul", | ||||
|                 # Neural network operations | ||||
|                 "torch.nn.functional.embedding", | ||||
|                 "torch.nn.functional.linear", | ||||
|                 # Activation functions | ||||
|                 "torch.nn.functional.relu", | ||||
|                 "torch.nn.functional.leaky_relu", | ||||
|                 "torch.nn.functional.elu", | ||||
|                 "torch.nn.functional.gelu", | ||||
|                 "torch.nn.functional.silu", | ||||
|                 "torch.sigmoid", | ||||
|                 "torch.tanh", | ||||
|                 "torch.nn.functional.softmax", | ||||
|                 "torch.nn.functional.dropout", | ||||
|                 # Normalization layers | ||||
|                 "torch.nn.functional.layer_norm", | ||||
|                 "torch.nn.functional.rms_norm", | ||||
|                 "torch.nn.functional.batch_norm", | ||||
|                 "torch.nn.functional.group_norm", | ||||
|                 # Regularization | ||||
|                 "torch.nn.functional.dropout", | ||||
|             ], | ||||
|             check=EagerVsFullGraphDynamicCompileWithNumericsCheck(), | ||||
|         ) | ||||
|  | ||||
| @ -9,6 +9,16 @@ from torchfuzz.operators.base import Operator | ||||
| from torchfuzz.tensor_fuzzer import Spec, TensorSpec | ||||
|  | ||||
|  | ||||
| def is_float_dtype(dtype: torch.dtype) -> bool: | ||||
|     """Check if dtype is a floating point type.""" | ||||
|     return dtype in [ | ||||
|         torch.float32, | ||||
|         torch.float64, | ||||
|         torch.float16, | ||||
|         torch.bfloat16, | ||||
|     ] | ||||
|  | ||||
|  | ||||
| class EmbeddingOperator(Operator): | ||||
|     """Operator for torch.nn.functional.embedding.""" | ||||
|  | ||||
| @ -28,12 +38,7 @@ class EmbeddingOperator(Operator): | ||||
|         if len(output_spec.size) == 0: | ||||
|             return False | ||||
|         # Embedding outputs are typically float tensors | ||||
|         return output_spec.dtype in [ | ||||
|             torch.float32, | ||||
|             torch.float64, | ||||
|             torch.float16, | ||||
|             torch.bfloat16, | ||||
|         ] | ||||
|         return is_float_dtype(output_spec.dtype) | ||||
|  | ||||
|     def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]: | ||||
|         """Generate input specs for embedding operation. | ||||
| @ -114,12 +119,7 @@ class LinearOperator(Operator): | ||||
|         # Linear needs at least 1 dimension (output features) | ||||
|         if len(output_spec.size) == 0: | ||||
|             return False | ||||
|         return output_spec.dtype in [ | ||||
|             torch.float32, | ||||
|             torch.float64, | ||||
|             torch.float16, | ||||
|             torch.bfloat16, | ||||
|         ] | ||||
|         return is_float_dtype(output_spec.dtype) | ||||
|  | ||||
|     def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]: | ||||
|         """Generate input specs for linear operation. | ||||
| @ -214,12 +214,7 @@ class ReLUOperator(Operator): | ||||
|         """ReLU can produce tensor outputs with floating point dtypes.""" | ||||
|         if not isinstance(output_spec, TensorSpec): | ||||
|             return False | ||||
|         return output_spec.dtype in [ | ||||
|             torch.float32, | ||||
|             torch.float64, | ||||
|             torch.float16, | ||||
|             torch.bfloat16, | ||||
|         ] | ||||
|         return is_float_dtype(output_spec.dtype) | ||||
|  | ||||
|     def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]: | ||||
|         """Generate input specs for ReLU operation. | ||||
| @ -265,12 +260,7 @@ class SoftmaxOperator(Operator): | ||||
|         # Softmax needs at least 1 dimension to apply softmax along a dimension | ||||
|         if len(output_spec.size) == 0: | ||||
|             return False | ||||
|         return output_spec.dtype in [ | ||||
|             torch.float32, | ||||
|             torch.float64, | ||||
|             torch.float16, | ||||
|             torch.bfloat16, | ||||
|         ] | ||||
|         return is_float_dtype(output_spec.dtype) | ||||
|  | ||||
|     def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]: | ||||
|         """Generate input specs for softmax operation. | ||||
| @ -314,12 +304,7 @@ class DropoutOperator(Operator): | ||||
|         """Dropout can produce tensor outputs with floating point dtypes.""" | ||||
|         if not isinstance(output_spec, TensorSpec): | ||||
|             return False | ||||
|         return output_spec.dtype in [ | ||||
|             torch.float32, | ||||
|             torch.float64, | ||||
|             torch.float16, | ||||
|             torch.bfloat16, | ||||
|         ] | ||||
|         return is_float_dtype(output_spec.dtype) | ||||
|  | ||||
|     def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]: | ||||
|         """Generate input specs for dropout operation. | ||||
| @ -366,12 +351,7 @@ class LayerNormOperator(Operator): | ||||
|         # LayerNorm needs at least 1 dimension to normalize over | ||||
|         if len(output_spec.size) == 0: | ||||
|             return False | ||||
|         return output_spec.dtype in [ | ||||
|             torch.float32, | ||||
|             torch.float64, | ||||
|             torch.float16, | ||||
|             torch.bfloat16, | ||||
|         ] | ||||
|         return is_float_dtype(output_spec.dtype) | ||||
|  | ||||
|     def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]: | ||||
|         """Generate input specs for layer_norm operation. | ||||
| @ -444,3 +424,529 @@ class LayerNormOperator(Operator): | ||||
|         else:  # len(input_names) == 3 | ||||
|             weight_name, bias_name = input_names[1], input_names[2] | ||||
|             return f"{output_name} = torch.nn.functional.layer_norm({input_name}.to({target_dtype}), {normalized_shape}, weight={weight_name}.to({target_dtype}), bias={bias_name}.to({target_dtype}))" | ||||
|  | ||||
|  | ||||
| class RMSNormOperator(Operator): | ||||
|     """Operator for torch.nn.functional.rms_norm (Root Mean Square Normalization). | ||||
|  | ||||
|     RMSNorm is commonly used in modern LLMs like LLaMA. It normalizes by the RMS of the input. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self): | ||||
|         super().__init__("torch.nn.functional.rms_norm") | ||||
|  | ||||
|     @property | ||||
|     def torch_op_name(self) -> Optional[str]: | ||||
|         """Return the torch operation name.""" | ||||
|         return "torch.nn.functional.rms_norm" | ||||
|  | ||||
|     def can_produce(self, output_spec: Spec) -> bool: | ||||
|         """RMSNorm can produce tensor outputs with floating point dtypes.""" | ||||
|         if not isinstance(output_spec, TensorSpec): | ||||
|             return False | ||||
|         # RMSNorm needs at least 1 dimension to normalize over | ||||
|         if len(output_spec.size) == 0: | ||||
|             return False | ||||
|         return is_float_dtype(output_spec.dtype) | ||||
|  | ||||
|     def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]: | ||||
|         """Generate input specs for RMSNorm operation. | ||||
|  | ||||
|         RMSNorm requires: | ||||
|         - input: input tensor | ||||
|         - weight: (normalized_shape,) [optional] | ||||
|         """ | ||||
|         if not isinstance(output_spec, TensorSpec): | ||||
|             raise ValueError("RMSNormOperator can only produce TensorSpec outputs") | ||||
|  | ||||
|         if len(output_spec.size) == 0: | ||||
|             raise ValueError("RMSNorm output must have at least 1 dimension") | ||||
|  | ||||
|         # Input tensor has same shape and dtype as output | ||||
|         input_spec = TensorSpec( | ||||
|             size=output_spec.size, stride=output_spec.stride, dtype=output_spec.dtype | ||||
|         ) | ||||
|  | ||||
|         # Weight tensor (optional with 70% probability) | ||||
|         normalized_shape = output_spec.size[-1:] | ||||
|         specs = [input_spec] | ||||
|         if random.random() < 0.7: | ||||
|             weight_spec = TensorSpec( | ||||
|                 size=normalized_shape, stride=(1,), dtype=output_spec.dtype | ||||
|             ) | ||||
|             specs.append(weight_spec) | ||||
|  | ||||
|         from typing import cast | ||||
|  | ||||
|         return cast(list[Spec], specs) | ||||
|  | ||||
|     def codegen( | ||||
|         self, output_name: str, input_names: list[str], output_spec: Spec | ||||
|     ) -> str: | ||||
|         """Generate code for RMSNorm operation.""" | ||||
|         if len(input_names) < 1 or len(input_names) > 2: | ||||
|             raise ValueError("RMSNorm requires 1-2 inputs: input, optional weight") | ||||
|  | ||||
|         if not isinstance(output_spec, TensorSpec): | ||||
|             raise ValueError("RMSNormOperator can only produce TensorSpec outputs") | ||||
|  | ||||
|         target_dtype = str(output_spec.dtype) | ||||
|         input_name = input_names[0] | ||||
|  | ||||
|         # Normalize over the last dimension | ||||
|         normalized_shape = f"({output_spec.size[-1]},)" | ||||
|  | ||||
|         if len(input_names) == 1: | ||||
|             return f"{output_name} = torch.nn.functional.rms_norm({input_name}.to({target_dtype}), {normalized_shape})" | ||||
|         else:  # len(input_names) == 2 | ||||
|             weight_name = input_names[1] | ||||
|             return f"{output_name} = torch.nn.functional.rms_norm({input_name}.to({target_dtype}), {normalized_shape}, weight={weight_name}.to({target_dtype}))" | ||||
|  | ||||
|  | ||||
| class GELUOperator(Operator): | ||||
|     """Operator for torch.nn.functional.gelu (Gaussian Error Linear Unit).""" | ||||
|  | ||||
|     def __init__(self): | ||||
|         super().__init__("torch.nn.functional.gelu") | ||||
|  | ||||
|     @property | ||||
|     def torch_op_name(self) -> Optional[str]: | ||||
|         """Return the torch operation name.""" | ||||
|         return "torch.nn.functional.gelu" | ||||
|  | ||||
|     def can_produce(self, output_spec: Spec) -> bool: | ||||
|         """GELU can produce tensor outputs with floating point dtypes.""" | ||||
|         if not isinstance(output_spec, TensorSpec): | ||||
|             return False | ||||
|         return is_float_dtype(output_spec.dtype) | ||||
|  | ||||
|     def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]: | ||||
|         """Generate input specs for GELU operation. | ||||
|  | ||||
|         GELU is element-wise, so input shape matches output shape. | ||||
|         """ | ||||
|         if not isinstance(output_spec, TensorSpec): | ||||
|             raise ValueError("GELUOperator can only produce TensorSpec outputs") | ||||
|  | ||||
|         input_spec = TensorSpec( | ||||
|             size=output_spec.size, stride=output_spec.stride, dtype=output_spec.dtype | ||||
|         ) | ||||
|  | ||||
|         return [input_spec] | ||||
|  | ||||
|     def codegen( | ||||
|         self, output_name: str, input_names: list[str], output_spec: Spec | ||||
|     ) -> str: | ||||
|         """Generate code for GELU operation.""" | ||||
|         if len(input_names) != 1: | ||||
|             raise ValueError("GELU requires exactly 1 input") | ||||
|  | ||||
|         input_name = input_names[0] | ||||
|         return f"{output_name} = torch.nn.functional.gelu({input_name})" | ||||
|  | ||||
|  | ||||
| class SigmoidOperator(Operator): | ||||
|     """Operator for torch.sigmoid.""" | ||||
|  | ||||
|     def __init__(self): | ||||
|         super().__init__("torch.sigmoid") | ||||
|  | ||||
|     @property | ||||
|     def torch_op_name(self) -> Optional[str]: | ||||
|         """Return the torch operation name.""" | ||||
|         return "torch.sigmoid" | ||||
|  | ||||
|     def can_produce(self, output_spec: Spec) -> bool: | ||||
|         """Sigmoid can produce tensor outputs with floating point dtypes.""" | ||||
|         if not isinstance(output_spec, TensorSpec): | ||||
|             return False | ||||
|         return is_float_dtype(output_spec.dtype) | ||||
|  | ||||
|     def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]: | ||||
|         """Generate input specs for sigmoid operation. | ||||
|  | ||||
|         Sigmoid is element-wise, so input shape matches output shape. | ||||
|         """ | ||||
|         if not isinstance(output_spec, TensorSpec): | ||||
|             raise ValueError("SigmoidOperator can only produce TensorSpec outputs") | ||||
|  | ||||
|         input_spec = TensorSpec( | ||||
|             size=output_spec.size, stride=output_spec.stride, dtype=output_spec.dtype | ||||
|         ) | ||||
|  | ||||
|         return [input_spec] | ||||
|  | ||||
|     def codegen( | ||||
|         self, output_name: str, input_names: list[str], output_spec: Spec | ||||
|     ) -> str: | ||||
|         """Generate code for sigmoid operation.""" | ||||
|         if len(input_names) != 1: | ||||
|             raise ValueError("Sigmoid requires exactly 1 input") | ||||
|  | ||||
|         input_name = input_names[0] | ||||
|         return f"{output_name} = torch.sigmoid({input_name})" | ||||
|  | ||||
|  | ||||
| class TanhOperator(Operator): | ||||
|     """Operator for torch.tanh.""" | ||||
|  | ||||
|     def __init__(self): | ||||
|         super().__init__("torch.tanh") | ||||
|  | ||||
|     @property | ||||
|     def torch_op_name(self) -> Optional[str]: | ||||
|         """Return the torch operation name.""" | ||||
|         return "torch.tanh" | ||||
|  | ||||
|     def can_produce(self, output_spec: Spec) -> bool: | ||||
|         """Tanh can produce tensor outputs with floating point dtypes.""" | ||||
|         if not isinstance(output_spec, TensorSpec): | ||||
|             return False | ||||
|         return is_float_dtype(output_spec.dtype) | ||||
|  | ||||
|     def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]: | ||||
|         """Generate input specs for tanh operation. | ||||
|  | ||||
|         Tanh is element-wise, so input shape matches output shape. | ||||
|         """ | ||||
|         if not isinstance(output_spec, TensorSpec): | ||||
|             raise ValueError("TanhOperator can only produce TensorSpec outputs") | ||||
|  | ||||
|         input_spec = TensorSpec( | ||||
|             size=output_spec.size, stride=output_spec.stride, dtype=output_spec.dtype | ||||
|         ) | ||||
|  | ||||
|         return [input_spec] | ||||
|  | ||||
|     def codegen( | ||||
|         self, output_name: str, input_names: list[str], output_spec: Spec | ||||
|     ) -> str: | ||||
|         """Generate code for tanh operation.""" | ||||
|         if len(input_names) != 1: | ||||
|             raise ValueError("Tanh requires exactly 1 input") | ||||
|  | ||||
|         input_name = input_names[0] | ||||
|         return f"{output_name} = torch.tanh({input_name})" | ||||
|  | ||||
|  | ||||
| class BatchNormOperator(Operator): | ||||
|     """Operator for torch.nn.functional.batch_norm.""" | ||||
|  | ||||
|     def __init__(self): | ||||
|         super().__init__("torch.nn.functional.batch_norm") | ||||
|  | ||||
|     @property | ||||
|     def torch_op_name(self) -> Optional[str]: | ||||
|         """Return the torch operation name.""" | ||||
|         return "torch.nn.functional.batch_norm" | ||||
|  | ||||
|     def can_produce(self, output_spec: Spec) -> bool: | ||||
|         """BatchNorm can produce tensor outputs with floating point dtypes.""" | ||||
|         if not isinstance(output_spec, TensorSpec): | ||||
|             return False | ||||
|         # BatchNorm needs at least 2 dimensions (batch, features) | ||||
|         if len(output_spec.size) < 2: | ||||
|             return False | ||||
|         # Channel dimension (second dimension) must be greater than 0 | ||||
|         if output_spec.size[1] == 0: | ||||
|             return False | ||||
|         return is_float_dtype(output_spec.dtype) | ||||
|  | ||||
|     def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]: | ||||
|         """Generate input specs for batch_norm operation. | ||||
|  | ||||
|         BatchNorm requires: | ||||
|         - input: (N, C, ...) where N is batch and C is channels | ||||
|         - running_mean: (C,) [optional] | ||||
|         - running_var: (C,) [optional] | ||||
|         - weight: (C,) [optional] | ||||
|         - bias: (C,) [optional] | ||||
|         """ | ||||
|         if not isinstance(output_spec, TensorSpec): | ||||
|             raise ValueError("BatchNormOperator can only produce TensorSpec outputs") | ||||
|  | ||||
|         if len(output_spec.size) < 2: | ||||
|             raise ValueError("BatchNorm output must have at least 2 dimensions") | ||||
|  | ||||
|         # Input tensor has same shape and dtype as output | ||||
|         input_spec = TensorSpec( | ||||
|             size=output_spec.size, stride=output_spec.stride, dtype=output_spec.dtype | ||||
|         ) | ||||
|  | ||||
|         # Channel dimension is the second dimension | ||||
|         num_features = output_spec.size[1] | ||||
|  | ||||
|         specs = [input_spec] | ||||
|  | ||||
|         # Add running_mean and running_var (required for inference mode) | ||||
|         running_mean_spec = TensorSpec( | ||||
|             size=(num_features,), stride=(1,), dtype=output_spec.dtype | ||||
|         ) | ||||
|         running_var_spec = TensorSpec( | ||||
|             size=(num_features,), stride=(1,), dtype=output_spec.dtype | ||||
|         ) | ||||
|         specs.extend([running_mean_spec, running_var_spec]) | ||||
|  | ||||
|         # Add weight and bias (optional with 70% probability) | ||||
|         if random.random() < 0.7: | ||||
|             weight_spec = TensorSpec( | ||||
|                 size=(num_features,), stride=(1,), dtype=output_spec.dtype | ||||
|             ) | ||||
|             specs.append(weight_spec) | ||||
|  | ||||
|             if random.random() < 0.7: | ||||
|                 bias_spec = TensorSpec( | ||||
|                     size=(num_features,), stride=(1,), dtype=output_spec.dtype | ||||
|                 ) | ||||
|                 specs.append(bias_spec) | ||||
|  | ||||
|         from typing import cast | ||||
|  | ||||
|         return cast(list[Spec], specs) | ||||
|  | ||||
|     def codegen( | ||||
|         self, output_name: str, input_names: list[str], output_spec: Spec | ||||
|     ) -> str: | ||||
|         """Generate code for batch_norm operation.""" | ||||
|         if len(input_names) < 3 or len(input_names) > 5: | ||||
|             raise ValueError( | ||||
|                 "BatchNorm requires 3-5 inputs: input, running_mean, running_var, optional weight, optional bias" | ||||
|             ) | ||||
|  | ||||
|         if not isinstance(output_spec, TensorSpec): | ||||
|             raise ValueError("BatchNormOperator can only produce TensorSpec outputs") | ||||
|  | ||||
|         target_dtype = str(output_spec.dtype) | ||||
|         input_name = input_names[0] | ||||
|         running_mean_name = input_names[1] | ||||
|         running_var_name = input_names[2] | ||||
|  | ||||
|         # Use training=False for deterministic behavior | ||||
|         if len(input_names) == 3: | ||||
|             return f"{output_name} = torch.nn.functional.batch_norm({input_name}.to({target_dtype}), {running_mean_name}.to({target_dtype}), {running_var_name}.to({target_dtype}), training=False)" | ||||
|         elif len(input_names) == 4: | ||||
|             weight_name = input_names[3] | ||||
|             return f"{output_name} = torch.nn.functional.batch_norm({input_name}.to({target_dtype}), {running_mean_name}.to({target_dtype}), {running_var_name}.to({target_dtype}), weight={weight_name}.to({target_dtype}), training=False)" | ||||
|         else:  # len(input_names) == 5 | ||||
|             weight_name = input_names[3] | ||||
|             bias_name = input_names[4] | ||||
|             return f"{output_name} = torch.nn.functional.batch_norm({input_name}.to({target_dtype}), {running_mean_name}.to({target_dtype}), {running_var_name}.to({target_dtype}), weight={weight_name}.to({target_dtype}), bias={bias_name}.to({target_dtype}), training=False)" | ||||
|  | ||||
|  | ||||
| class GroupNormOperator(Operator): | ||||
|     """Operator for torch.nn.functional.group_norm.""" | ||||
|  | ||||
|     def __init__(self): | ||||
|         super().__init__("torch.nn.functional.group_norm") | ||||
|  | ||||
|     @property | ||||
|     def torch_op_name(self) -> Optional[str]: | ||||
|         """Return the torch operation name.""" | ||||
|         return "torch.nn.functional.group_norm" | ||||
|  | ||||
|     def can_produce(self, output_spec: Spec) -> bool: | ||||
|         """GroupNorm can produce tensor outputs with floating point dtypes.""" | ||||
|         if not isinstance(output_spec, TensorSpec): | ||||
|             return False | ||||
|         # GroupNorm needs at least 2 dimensions (batch, channels) | ||||
|         if len(output_spec.size) < 2: | ||||
|             return False | ||||
|         return is_float_dtype(output_spec.dtype) | ||||
|  | ||||
|     def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]: | ||||
|         """Generate input specs for group_norm operation. | ||||
|  | ||||
|         GroupNorm requires: | ||||
|         - input: (N, C, ...) where N is batch and C is channels | ||||
|         - weight: (C,) [optional] | ||||
|         - bias: (C,) [optional] | ||||
|         """ | ||||
|         if not isinstance(output_spec, TensorSpec): | ||||
|             raise ValueError("GroupNormOperator can only produce TensorSpec outputs") | ||||
|  | ||||
|         if len(output_spec.size) < 2: | ||||
|             raise ValueError("GroupNorm output must have at least 2 dimensions") | ||||
|  | ||||
|         # Input tensor has same shape and dtype as output | ||||
|         input_spec = TensorSpec( | ||||
|             size=output_spec.size, stride=output_spec.stride, dtype=output_spec.dtype | ||||
|         ) | ||||
|  | ||||
|         # Channel dimension is the second dimension | ||||
|         num_channels = output_spec.size[1] | ||||
|  | ||||
|         specs = [input_spec] | ||||
|  | ||||
|         # Add weight and bias (optional with 70% probability) | ||||
|         if random.random() < 0.7: | ||||
|             weight_spec = TensorSpec( | ||||
|                 size=(num_channels,), stride=(1,), dtype=output_spec.dtype | ||||
|             ) | ||||
|             specs.append(weight_spec) | ||||
|  | ||||
|             if random.random() < 0.7: | ||||
|                 bias_spec = TensorSpec( | ||||
|                     size=(num_channels,), stride=(1,), dtype=output_spec.dtype | ||||
|                 ) | ||||
|                 specs.append(bias_spec) | ||||
|  | ||||
|         from typing import cast | ||||
|  | ||||
|         return cast(list[Spec], specs) | ||||
|  | ||||
|     def codegen( | ||||
|         self, output_name: str, input_names: list[str], output_spec: Spec | ||||
|     ) -> str: | ||||
|         """Generate code for group_norm operation.""" | ||||
|         if len(input_names) < 1 or len(input_names) > 3: | ||||
|             raise ValueError( | ||||
|                 "GroupNorm requires 1-3 inputs: input, optional weight, optional bias" | ||||
|             ) | ||||
|  | ||||
|         if not isinstance(output_spec, TensorSpec): | ||||
|             raise ValueError("GroupNormOperator can only produce TensorSpec outputs") | ||||
|  | ||||
|         target_dtype = str(output_spec.dtype) | ||||
|         input_name = input_names[0] | ||||
|  | ||||
|         # Determine number of groups (must divide num_channels evenly) | ||||
|         num_channels = output_spec.size[1] | ||||
|         # Common choices: 32, 16, 8, or equal to channels (instance norm) | ||||
|         possible_groups = [g for g in [32, 16, 8, 4, 2, 1] if num_channels % g == 0] | ||||
|         num_groups = possible_groups[0] if possible_groups else 1 | ||||
|  | ||||
|         if len(input_names) == 1: | ||||
|             return f"{output_name} = torch.nn.functional.group_norm({input_name}.to({target_dtype}), {num_groups})" | ||||
|         elif len(input_names) == 2: | ||||
|             weight_name = input_names[1] | ||||
|             return f"{output_name} = torch.nn.functional.group_norm({input_name}.to({target_dtype}), {num_groups}, weight={weight_name}.to({target_dtype}))" | ||||
|         else:  # len(input_names) == 3 | ||||
|             weight_name = input_names[1] | ||||
|             bias_name = input_names[2] | ||||
|             return f"{output_name} = torch.nn.functional.group_norm({input_name}.to({target_dtype}), {num_groups}, weight={weight_name}.to({target_dtype}), bias={bias_name}.to({target_dtype}))" | ||||
|  | ||||
|  | ||||
| class LeakyReLUOperator(Operator): | ||||
|     """Operator for torch.nn.functional.leaky_relu.""" | ||||
|  | ||||
|     def __init__(self): | ||||
|         super().__init__("torch.nn.functional.leaky_relu") | ||||
|  | ||||
|     @property | ||||
|     def torch_op_name(self) -> Optional[str]: | ||||
|         """Return the torch operation name.""" | ||||
|         return "torch.nn.functional.leaky_relu" | ||||
|  | ||||
|     def can_produce(self, output_spec: Spec) -> bool: | ||||
|         """LeakyReLU can produce tensor outputs with floating point dtypes.""" | ||||
|         if not isinstance(output_spec, TensorSpec): | ||||
|             return False | ||||
|         return is_float_dtype(output_spec.dtype) | ||||
|  | ||||
|     def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]: | ||||
|         """Generate input specs for LeakyReLU operation. | ||||
|  | ||||
|         LeakyReLU is element-wise, so input shape matches output shape. | ||||
|         """ | ||||
|         if not isinstance(output_spec, TensorSpec): | ||||
|             raise ValueError("LeakyReLUOperator can only produce TensorSpec outputs") | ||||
|  | ||||
|         input_spec = TensorSpec( | ||||
|             size=output_spec.size, stride=output_spec.stride, dtype=output_spec.dtype | ||||
|         ) | ||||
|  | ||||
|         return [input_spec] | ||||
|  | ||||
|     def codegen( | ||||
|         self, output_name: str, input_names: list[str], output_spec: Spec | ||||
|     ) -> str: | ||||
|         """Generate code for LeakyReLU operation.""" | ||||
|         if len(input_names) != 1: | ||||
|             raise ValueError("LeakyReLU requires exactly 1 input") | ||||
|  | ||||
|         input_name = input_names[0] | ||||
|         return f"{output_name} = torch.nn.functional.leaky_relu({input_name}, negative_slope=0.01)" | ||||
|  | ||||
|  | ||||
| class ELUOperator(Operator): | ||||
|     """Operator for torch.nn.functional.elu (Exponential Linear Unit).""" | ||||
|  | ||||
|     def __init__(self): | ||||
|         super().__init__("torch.nn.functional.elu") | ||||
|  | ||||
|     @property | ||||
|     def torch_op_name(self) -> Optional[str]: | ||||
|         """Return the torch operation name.""" | ||||
|         return "torch.nn.functional.elu" | ||||
|  | ||||
|     def can_produce(self, output_spec: Spec) -> bool: | ||||
|         """ELU can produce tensor outputs with floating point dtypes.""" | ||||
|         if not isinstance(output_spec, TensorSpec): | ||||
|             return False | ||||
|         return is_float_dtype(output_spec.dtype) | ||||
|  | ||||
|     def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]: | ||||
|         """Generate input specs for ELU operation. | ||||
|  | ||||
|         ELU is element-wise, so input shape matches output shape. | ||||
|         """ | ||||
|         if not isinstance(output_spec, TensorSpec): | ||||
|             raise ValueError("ELUOperator can only produce TensorSpec outputs") | ||||
|  | ||||
|         input_spec = TensorSpec( | ||||
|             size=output_spec.size, stride=output_spec.stride, dtype=output_spec.dtype | ||||
|         ) | ||||
|  | ||||
|         return [input_spec] | ||||
|  | ||||
|     def codegen( | ||||
|         self, output_name: str, input_names: list[str], output_spec: Spec | ||||
|     ) -> str: | ||||
|         """Generate code for ELU operation.""" | ||||
|         if len(input_names) != 1: | ||||
|             raise ValueError("ELU requires exactly 1 input") | ||||
|  | ||||
|         input_name = input_names[0] | ||||
|         return f"{output_name} = torch.nn.functional.elu({input_name})" | ||||
|  | ||||
|  | ||||
| class SiLUOperator(Operator): | ||||
|     """Operator for torch.nn.functional.silu (Sigmoid Linear Unit, also known as Swish).""" | ||||
|  | ||||
|     def __init__(self): | ||||
|         super().__init__("torch.nn.functional.silu") | ||||
|  | ||||
|     @property | ||||
|     def torch_op_name(self) -> Optional[str]: | ||||
|         """Return the torch operation name.""" | ||||
|         return "torch.nn.functional.silu" | ||||
|  | ||||
|     def can_produce(self, output_spec: Spec) -> bool: | ||||
|         """SiLU can produce tensor outputs with floating point dtypes.""" | ||||
|         if not isinstance(output_spec, TensorSpec): | ||||
|             return False | ||||
|         return is_float_dtype(output_spec.dtype) | ||||
|  | ||||
|     def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]: | ||||
|         """Generate input specs for SiLU operation. | ||||
|  | ||||
|         SiLU is element-wise, so input shape matches output shape. | ||||
|         """ | ||||
|         if not isinstance(output_spec, TensorSpec): | ||||
|             raise ValueError("SiLUOperator can only produce TensorSpec outputs") | ||||
|  | ||||
|         input_spec = TensorSpec( | ||||
|             size=output_spec.size, stride=output_spec.stride, dtype=output_spec.dtype | ||||
|         ) | ||||
|  | ||||
|         return [input_spec] | ||||
|  | ||||
|     def codegen( | ||||
|         self, output_name: str, input_names: list[str], output_spec: Spec | ||||
|     ) -> str: | ||||
|         """Generate code for SiLU operation.""" | ||||
|         if len(input_names) != 1: | ||||
|             raise ValueError("SiLU requires exactly 1 input") | ||||
|  | ||||
|         input_name = input_names[0] | ||||
|         return f"{output_name} = torch.nn.functional.silu({input_name})" | ||||
|  | ||||
| @ -21,12 +21,21 @@ from torchfuzz.operators.matrix_multiply import ( | ||||
|     MMOperator, | ||||
| ) | ||||
| from torchfuzz.operators.nn_functional import ( | ||||
|     BatchNormOperator, | ||||
|     DropoutOperator, | ||||
|     ELUOperator, | ||||
|     EmbeddingOperator, | ||||
|     GELUOperator, | ||||
|     GroupNormOperator, | ||||
|     LayerNormOperator, | ||||
|     LeakyReLUOperator, | ||||
|     LinearOperator, | ||||
|     ReLUOperator, | ||||
|     RMSNormOperator, | ||||
|     SigmoidOperator, | ||||
|     SiLUOperator, | ||||
|     SoftmaxOperator, | ||||
|     TanhOperator, | ||||
| ) | ||||
| from torchfuzz.operators.nonzero import NonzeroOperator | ||||
| from torchfuzz.operators.scalar_pointwise import ( | ||||
| @ -92,10 +101,25 @@ class OperatorRegistry: | ||||
|         # Neural network functional operators | ||||
|         self.register(EmbeddingOperator()) | ||||
|         self.register(LinearOperator()) | ||||
|  | ||||
|         # Activation functions | ||||
|         self.register(ReLUOperator()) | ||||
|         self.register(LeakyReLUOperator()) | ||||
|         self.register(ELUOperator()) | ||||
|         self.register(GELUOperator()) | ||||
|         self.register(SiLUOperator()) | ||||
|         self.register(SigmoidOperator()) | ||||
|         self.register(TanhOperator()) | ||||
|         self.register(SoftmaxOperator()) | ||||
|         self.register(DropoutOperator()) | ||||
|  | ||||
|         # Normalization layers | ||||
|         self.register(LayerNormOperator()) | ||||
|         self.register(RMSNormOperator()) | ||||
|         self.register(BatchNormOperator()) | ||||
|         self.register(GroupNormOperator()) | ||||
|  | ||||
|         # Regularization | ||||
|         self.register(DropoutOperator()) | ||||
|  | ||||
|     def register(self, operator: Operator): | ||||
|         """Register an operator in the registry.""" | ||||
|  | ||||
		Reference in New Issue
	
	Block a user