From 5bb8f04d3eb7295a8ba96e2d0f5124a77f4babea Mon Sep 17 00:00:00 2001 From: bobrenjc93 Date: Thu, 2 Oct 2025 10:43:30 -0700 Subject: [PATCH] [torchfuzz] add nn functional ops (#164434) Pull Request resolved: https://github.com/pytorch/pytorch/pull/164434 Approved by: https://github.com/pianpwk ghstack dependencies: #164432 --- .lintrunner.toml | 1 + .../dynamic_shapes/torchfuzz/checks.py | 4 +- .../dynamic_shapes/torchfuzz/codegen.py | 28 +- .../torchfuzz/operators/__init__.py | 14 + .../torchfuzz/operators/constant.py | 15 + .../torchfuzz/operators/nn_functional.py | 446 ++++++++++++++++++ .../torchfuzz/operators/registry.py | 16 + 7 files changed, 519 insertions(+), 5 deletions(-) create mode 100644 tools/experimental/dynamic_shapes/torchfuzz/operators/nn_functional.py diff --git a/.lintrunner.toml b/.lintrunner.toml index b0ef669aeb53..5ccff27559a4 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -28,6 +28,7 @@ exclude_patterns = [ 'torch/lib/**', 'venv/**', '**/*.pyi', + "tools/experimental/dynamic_shapes/torchfuzz/**", 'tools/test/test_selective_build.py', ] command = [ diff --git a/tools/experimental/dynamic_shapes/torchfuzz/checks.py b/tools/experimental/dynamic_shapes/torchfuzz/checks.py index 86e5393f3606..5b7b2e9da0e9 100644 --- a/tools/experimental/dynamic_shapes/torchfuzz/checks.py +++ b/tools/experimental/dynamic_shapes/torchfuzz/checks.py @@ -43,8 +43,8 @@ class EagerVsFullGraphDynamicCompileWithNumericsCheck(Check): "diff = (out_eager_sum - out_compiled_sum).abs().item()", "rel_diff = diff / (out_eager_sum.abs().item() + 1e-12) * 100", "print(f'Relative diff (sum): {rel_diff:.6f}%')", - "if rel_diff > 5:", - " print(f'❌ Forward output sums differ significantly (relative)!')", + "if rel_diff > 5 and diff > 1:", + " print(f'❌ Forward output sums differ significantly (relative and absolute)!')", " print('out_eager_sum:', out_eager_sum.item())", " print('out_compiled_sum:', out_compiled_sum.item())", " print('Absolute diff:', diff)", diff --git a/tools/experimental/dynamic_shapes/torchfuzz/codegen.py b/tools/experimental/dynamic_shapes/torchfuzz/codegen.py index 8ac1492c45e1..1a97a1fed85c 100644 --- a/tools/experimental/dynamic_shapes/torchfuzz/codegen.py +++ b/tools/experimental/dynamic_shapes/torchfuzz/codegen.py @@ -124,6 +124,12 @@ class DefaultFuzzTemplate(FuzzTemplate): "torch.addmm", "torch.bmm", "torch.matmul", + "torch.nn.functional.embedding", + "torch.nn.functional.linear", + "torch.nn.functional.relu", + "torch.nn.functional.softmax", + "torch.nn.functional.dropout", + "torch.nn.functional.layer_norm", ], check=EagerVsFullGraphDynamicCompileWithNumericsCheck(), ) @@ -184,9 +190,25 @@ class DefaultFuzzTemplate(FuzzTemplate): storage_size = 1 stride_str = str(spec.stride) - code_lines.append( - f"{arg_name} = torch.as_strided(torch.randn({storage_size}).to({dtype_str}), {size_str}, {stride_str})" - ) + + # Special handling for integer tensors which might be used as indices + if spec.dtype in [torch.int32, torch.int64]: + # For integer tensors, generate valid indices with headroom for arithmetic + # Use smaller range [5, 30] to allow for multiplication and other operations + # This prevents indices from becoming too large after arithmetic + min_val = ( + 5 # Minimum to avoid negative results after subtraction + ) + max_val = ( + 30 # Maximum to avoid out-of-bounds after multiplication + ) + code_lines.append( + f"{arg_name} = torch.as_strided(torch.randint({min_val}, {max_val}, ({storage_size},)).to({dtype_str}), {size_str}, {stride_str})" + ) + else: + code_lines.append( + f"{arg_name} = torch.as_strided(torch.randn({storage_size}).to({dtype_str}), {size_str}, {stride_str})" + ) return code_lines diff --git a/tools/experimental/dynamic_shapes/torchfuzz/operators/__init__.py b/tools/experimental/dynamic_shapes/torchfuzz/operators/__init__.py index f2daefde42f0..441fd764b300 100644 --- a/tools/experimental/dynamic_shapes/torchfuzz/operators/__init__.py +++ b/tools/experimental/dynamic_shapes/torchfuzz/operators/__init__.py @@ -17,6 +17,14 @@ from torchfuzz.operators.matrix_multiply import ( MatmulOperator, MMOperator, ) +from torchfuzz.operators.nn_functional import ( + DropoutOperator, + EmbeddingOperator, + LayerNormOperator, + LinearOperator, + ReLUOperator, + SoftmaxOperator, +) from torchfuzz.operators.registry import get_operator, list_operators, register_operator from torchfuzz.operators.scalar_pointwise import ( ScalarAddOperator, @@ -58,6 +66,12 @@ __all__ = [ "AddmmOperator", "BmmOperator", "MatmulOperator", + "EmbeddingOperator", + "LinearOperator", + "ReLUOperator", + "SoftmaxOperator", + "DropoutOperator", + "LayerNormOperator", "get_operator", "register_operator", "list_operators", diff --git a/tools/experimental/dynamic_shapes/torchfuzz/operators/constant.py b/tools/experimental/dynamic_shapes/torchfuzz/operators/constant.py index 9988d3851200..8fb0b33a4c1a 100644 --- a/tools/experimental/dynamic_shapes/torchfuzz/operators/constant.py +++ b/tools/experimental/dynamic_shapes/torchfuzz/operators/constant.py @@ -98,6 +98,21 @@ class ConstantOperator(Operator): else: # For non-empty tensors, use the first element as fill value fill_value = actual_tensor.flatten()[0].item() + + # For integer types, clamp the value to a smaller range to avoid + # issues when used in arithmetic with embedding indices + import torch + + if output_spec.dtype in [ + torch.int8, + torch.int16, + torch.int32, + torch.int64, + ]: + # Clamp integer values to [0, 3] to avoid index overflow in multiplication + # Even with multiplication, indices should stay in reasonable range + fill_value = max(0, min(3, abs(fill_value))) + tensor_creation = ( f"torch.full({size_str}, {fill_value}, dtype={dtype_str})" ) diff --git a/tools/experimental/dynamic_shapes/torchfuzz/operators/nn_functional.py b/tools/experimental/dynamic_shapes/torchfuzz/operators/nn_functional.py new file mode 100644 index 000000000000..ac974dd30e34 --- /dev/null +++ b/tools/experimental/dynamic_shapes/torchfuzz/operators/nn_functional.py @@ -0,0 +1,446 @@ +"""Neural network functional operator implementations.""" + +import random +from typing import Optional + +import torch + +from torchfuzz.operators.base import Operator +from torchfuzz.tensor_fuzzer import Spec, TensorSpec + + +class EmbeddingOperator(Operator): + """Operator for torch.nn.functional.embedding.""" + + def __init__(self): + super().__init__("torch.nn.functional.embedding") + + @property + def torch_op_name(self) -> Optional[str]: + """Return the torch operation name.""" + return "torch.nn.functional.embedding" + + def can_produce(self, output_spec: Spec) -> bool: + """Embedding can produce tensor outputs with floating point dtypes.""" + if not isinstance(output_spec, TensorSpec): + return False + # Embedding needs at least 1 dimension (embedding_dim) + if len(output_spec.size) == 0: + return False + # Embedding outputs are typically float tensors + return output_spec.dtype in [ + torch.float32, + torch.float64, + torch.float16, + torch.bfloat16, + ] + + def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]: + """Generate input specs for embedding operation. + + Embedding requires: + - weight tensor: (num_embeddings, embedding_dim) + - input tensor: integer indices (any shape, but output shape + [embedding_dim]) + """ + if not isinstance(output_spec, TensorSpec): + raise ValueError("EmbeddingOperator can only produce TensorSpec outputs") + + # Output shape should be input_shape + [embedding_dim] + if len(output_spec.size) == 0: + raise ValueError("Embedding output must have at least 1 dimension") + + embedding_dim = output_spec.size[-1] + input_shape = output_spec.size[:-1] # Remove last dimension for embedding_dim + + # Generate reasonable vocab size that's larger than our index generation range + # This ensures that indices generated in range [0, 100) will always be valid + num_embeddings = random.randint(150, 500) # Always larger than max index (100) + + # Weight tensor: (num_embeddings, embedding_dim) + weight_spec = TensorSpec( + size=(num_embeddings, embedding_dim), + stride=(embedding_dim, 1), + dtype=output_spec.dtype, + ) + + # Input tensor: integer indices with shape that produces the output shape + input_spec = TensorSpec( + size=input_shape, + stride=self._calculate_stride(input_shape), + dtype=torch.int64, # Indices are typically int64 + ) + + return [weight_spec, input_spec] + + def _calculate_stride(self, size): + """Calculate stride for a given size.""" + if not size: + return () + stride = [] + current_stride = 1 + for dim_size in reversed(size): + stride.append(current_stride) + current_stride *= dim_size + return tuple(reversed(stride)) + + def codegen( + self, output_name: str, input_names: list[str], output_spec: Spec + ) -> str: + """Generate code for embedding operation.""" + if len(input_names) != 2: + raise ValueError("Embedding requires exactly 2 inputs: weight and input") + + weight_name, input_name = input_names + # Ensure indices are integer type and clamped to valid range + # This handles any arithmetic operations that might produce out-of-bounds indices + return f"{output_name} = torch.nn.functional.embedding(torch.clamp({input_name}.to(torch.int64), 0, {weight_name}.size(0)-1), {weight_name})" + + +class LinearOperator(Operator): + """Operator for torch.nn.functional.linear.""" + + def __init__(self): + super().__init__("torch.nn.functional.linear") + + @property + def torch_op_name(self) -> Optional[str]: + """Return the torch operation name.""" + return "torch.nn.functional.linear" + + def can_produce(self, output_spec: Spec) -> bool: + """Linear can produce tensor outputs with floating point dtypes.""" + if not isinstance(output_spec, TensorSpec): + return False + # Linear needs at least 1 dimension (output features) + if len(output_spec.size) == 0: + return False + return output_spec.dtype in [ + torch.float32, + torch.float64, + torch.float16, + torch.bfloat16, + ] + + def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]: + """Generate input specs for linear operation. + + Linear transformation: y = xW^T + b + - input: (..., in_features) + - weight: (out_features, in_features) + - bias: (out_features,) [optional] + - output: (..., out_features) + """ + if not isinstance(output_spec, TensorSpec): + raise ValueError("LinearOperator can only produce TensorSpec outputs") + + if len(output_spec.size) == 0: + raise ValueError("Linear output must have at least 1 dimension") + + out_features = output_spec.size[-1] + batch_shape = output_spec.size[:-1] + + # Generate reasonable input features size + in_features = random.randint(8, 256) + + # Input tensor: (..., in_features) + input_shape = batch_shape + (in_features,) + input_spec = TensorSpec( + size=input_shape, + stride=self._calculate_stride(input_shape), + dtype=output_spec.dtype, + ) + + # Weight tensor: (out_features, in_features) + weight_spec = TensorSpec( + size=(out_features, in_features), + stride=(in_features, 1), + dtype=output_spec.dtype, + ) + + # Bias tensor: (out_features,) - make bias optional with 50% probability + if random.random() < 0.5: + bias_spec = TensorSpec( + size=(out_features,), stride=(1,), dtype=output_spec.dtype + ) + return [input_spec, weight_spec, bias_spec] + else: + return [input_spec, weight_spec] + + def _calculate_stride(self, size): + """Calculate stride for a given size.""" + if not size: + return () + stride = [] + current_stride = 1 + for dim_size in reversed(size): + stride.append(current_stride) + current_stride *= dim_size + return tuple(reversed(stride)) + + def codegen( + self, output_name: str, input_names: list[str], output_spec: Spec + ) -> str: + """Generate code for linear operation.""" + if not isinstance(output_spec, TensorSpec): + raise ValueError("LinearOperator can only produce TensorSpec outputs") + + # Ensure dtype compatibility by converting all inputs to the expected output dtype + target_dtype = str(output_spec.dtype) + + if len(input_names) == 2: + input_name, weight_name = input_names + return f"{output_name} = torch.nn.functional.linear({input_name}.to({target_dtype}), {weight_name}.to({target_dtype}))" + elif len(input_names) == 3: + input_name, weight_name, bias_name = input_names + return f"{output_name} = torch.nn.functional.linear({input_name}.to({target_dtype}), {weight_name}.to({target_dtype}), {bias_name}.to({target_dtype}))" + else: + raise ValueError( + "Linear requires 2 or 3 inputs: input, weight, and optional bias" + ) + + +class ReLUOperator(Operator): + """Operator for torch.nn.functional.relu.""" + + def __init__(self): + super().__init__("torch.nn.functional.relu") + + @property + def torch_op_name(self) -> Optional[str]: + """Return the torch operation name.""" + return "torch.nn.functional.relu" + + def can_produce(self, output_spec: Spec) -> bool: + """ReLU can produce tensor outputs with floating point dtypes.""" + if not isinstance(output_spec, TensorSpec): + return False + return output_spec.dtype in [ + torch.float32, + torch.float64, + torch.float16, + torch.bfloat16, + ] + + def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]: + """Generate input specs for ReLU operation. + + ReLU is element-wise, so input shape matches output shape. + """ + if not isinstance(output_spec, TensorSpec): + raise ValueError("ReLUOperator can only produce TensorSpec outputs") + + # Input tensor has same shape and dtype as output + input_spec = TensorSpec( + size=output_spec.size, stride=output_spec.stride, dtype=output_spec.dtype + ) + + return [input_spec] + + def codegen( + self, output_name: str, input_names: list[str], output_spec: Spec + ) -> str: + """Generate code for ReLU operation.""" + if len(input_names) != 1: + raise ValueError("ReLU requires exactly 1 input") + + input_name = input_names[0] + return f"{output_name} = torch.nn.functional.relu({input_name})" + + +class SoftmaxOperator(Operator): + """Operator for torch.nn.functional.softmax.""" + + def __init__(self): + super().__init__("torch.nn.functional.softmax") + + @property + def torch_op_name(self) -> Optional[str]: + """Return the torch operation name.""" + return "torch.nn.functional.softmax" + + def can_produce(self, output_spec: Spec) -> bool: + """Softmax can produce tensor outputs with floating point dtypes.""" + if not isinstance(output_spec, TensorSpec): + return False + # Softmax needs at least 1 dimension to apply softmax along a dimension + if len(output_spec.size) == 0: + return False + return output_spec.dtype in [ + torch.float32, + torch.float64, + torch.float16, + torch.bfloat16, + ] + + def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]: + """Generate input specs for softmax operation. + + Softmax is element-wise along a dimension, input shape matches output shape. + """ + if not isinstance(output_spec, TensorSpec): + raise ValueError("SoftmaxOperator can only produce TensorSpec outputs") + + # Input tensor has same shape and dtype as output + input_spec = TensorSpec( + size=output_spec.size, stride=output_spec.stride, dtype=output_spec.dtype + ) + + return [input_spec] + + def codegen( + self, output_name: str, input_names: list[str], output_spec: Spec + ) -> str: + """Generate code for softmax operation.""" + if len(input_names) != 1: + raise ValueError("Softmax requires exactly 1 input") + + input_name = input_names[0] + # Use dim=-1 as default (last dimension) + return f"{output_name} = torch.nn.functional.softmax({input_name}, dim=-1)" + + +class DropoutOperator(Operator): + """Operator for torch.nn.functional.dropout.""" + + def __init__(self): + super().__init__("torch.nn.functional.dropout") + + @property + def torch_op_name(self) -> Optional[str]: + """Return the torch operation name.""" + return "torch.nn.functional.dropout" + + def can_produce(self, output_spec: Spec) -> bool: + """Dropout can produce tensor outputs with floating point dtypes.""" + if not isinstance(output_spec, TensorSpec): + return False + return output_spec.dtype in [ + torch.float32, + torch.float64, + torch.float16, + torch.bfloat16, + ] + + def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]: + """Generate input specs for dropout operation. + + Dropout is element-wise, input shape matches output shape. + """ + if not isinstance(output_spec, TensorSpec): + raise ValueError("DropoutOperator can only produce TensorSpec outputs") + + # Input tensor has same shape and dtype as output + input_spec = TensorSpec( + size=output_spec.size, stride=output_spec.stride, dtype=output_spec.dtype + ) + + return [input_spec] + + def codegen( + self, output_name: str, input_names: list[str], output_spec: Spec + ) -> str: + """Generate code for dropout operation.""" + if len(input_names) != 1: + raise ValueError("Dropout requires exactly 1 input") + + input_name = input_names[0] + # Use training=False to make it deterministic for testing + return f"{output_name} = torch.nn.functional.dropout({input_name}, p=0.1, training=False)" + + +class LayerNormOperator(Operator): + """Operator for torch.nn.functional.layer_norm.""" + + def __init__(self): + super().__init__("torch.nn.functional.layer_norm") + + @property + def torch_op_name(self) -> Optional[str]: + """Return the torch operation name.""" + return "torch.nn.functional.layer_norm" + + def can_produce(self, output_spec: Spec) -> bool: + """LayerNorm can produce tensor outputs with floating point dtypes.""" + if not isinstance(output_spec, TensorSpec): + return False + # LayerNorm needs at least 1 dimension to normalize over + if len(output_spec.size) == 0: + return False + return output_spec.dtype in [ + torch.float32, + torch.float64, + torch.float16, + torch.bfloat16, + ] + + def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]: + """Generate input specs for layer_norm operation. + + LayerNorm normalizes over the last dimensions specified by normalized_shape. + - input: input tensor + - weight: (normalized_shape,) [optional] + - bias: (normalized_shape,) [optional] + """ + if not isinstance(output_spec, TensorSpec): + raise ValueError("LayerNormOperator can only produce TensorSpec outputs") + + if len(output_spec.size) == 0: + raise ValueError("LayerNorm output must have at least 1 dimension") + + # Input tensor has same shape and dtype as output + input_spec = TensorSpec( + size=output_spec.size, stride=output_spec.stride, dtype=output_spec.dtype + ) + + # For simplicity, normalize over the last dimension + normalized_shape = output_spec.size[-1:] + + # Weight and bias tensors (optional with 70% probability each) + specs = [input_spec] + if random.random() < 0.7: + # LayerNorm weight and bias parameters should match input tensor dtype + # for compatibility (conversion will be handled in codegen) + weight_spec = TensorSpec( + size=normalized_shape, stride=(1,), dtype=output_spec.dtype + ) + specs.append(weight_spec) + + if random.random() < 0.7: + bias_spec = TensorSpec( + size=normalized_shape, stride=(1,), dtype=output_spec.dtype + ) + specs.append(bias_spec) + + # Cast to list[Spec] to fix type checking + from typing import cast + + return cast(list[Spec], specs) + + def codegen( + self, output_name: str, input_names: list[str], output_spec: Spec + ) -> str: + """Generate code for layer_norm operation.""" + if len(input_names) < 1 or len(input_names) > 3: + raise ValueError( + "LayerNorm requires 1-3 inputs: input, optional weight, optional bias" + ) + + if not isinstance(output_spec, TensorSpec): + raise ValueError("LayerNormOperator can only produce TensorSpec outputs") + + # Normalize over the last dimension + normalized_shape = f"({output_spec.size[-1]},)" + + # Ensure dtype compatibility by converting all inputs to the expected output dtype + target_dtype = str(output_spec.dtype) + + input_name = input_names[0] + + if len(input_names) == 1: + return f"{output_name} = torch.nn.functional.layer_norm({input_name}.to({target_dtype}), {normalized_shape})" + elif len(input_names) == 2: + weight_name = input_names[1] + return f"{output_name} = torch.nn.functional.layer_norm({input_name}.to({target_dtype}), {normalized_shape}, weight={weight_name}.to({target_dtype}))" + else: # len(input_names) == 3 + weight_name, bias_name = input_names[1], input_names[2] + return f"{output_name} = torch.nn.functional.layer_norm({input_name}.to({target_dtype}), {normalized_shape}, weight={weight_name}.to({target_dtype}), bias={bias_name}.to({target_dtype}))" diff --git a/tools/experimental/dynamic_shapes/torchfuzz/operators/registry.py b/tools/experimental/dynamic_shapes/torchfuzz/operators/registry.py index 6b262a0c2100..f76ec5d9d07e 100644 --- a/tools/experimental/dynamic_shapes/torchfuzz/operators/registry.py +++ b/tools/experimental/dynamic_shapes/torchfuzz/operators/registry.py @@ -20,6 +20,14 @@ from torchfuzz.operators.matrix_multiply import ( MatmulOperator, MMOperator, ) +from torchfuzz.operators.nn_functional import ( + DropoutOperator, + EmbeddingOperator, + LayerNormOperator, + LinearOperator, + ReLUOperator, + SoftmaxOperator, +) from torchfuzz.operators.nonzero import NonzeroOperator from torchfuzz.operators.scalar_pointwise import ( ScalarAddOperator, @@ -81,6 +89,14 @@ class OperatorRegistry: self.register(BmmOperator()) self.register(MatmulOperator()) + # Neural network functional operators + self.register(EmbeddingOperator()) + self.register(LinearOperator()) + self.register(ReLUOperator()) + self.register(SoftmaxOperator()) + self.register(DropoutOperator()) + self.register(LayerNormOperator()) + def register(self, operator: Operator): """Register an operator in the registry.""" self._operators[operator.name] = operator