Files
pytorch/tools/experimental/torchfuzz/operators/constant.py
Maggie Moss f02e3947f6 Expand type checking to mypy strict files (#165697)
Expands Pyrefly type checking to check the files outlined in the mypy-strict.ini configuration file:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165697
Approved by: https://github.com/ezyang
2025-10-18 04:34:45 +00:00

132 lines
4.9 KiB
Python

"""Constant operator implementation."""
from typing import Optional
from torchfuzz.operators.base import Operator
from torchfuzz.tensor_fuzzer import (
fuzz_scalar,
fuzz_tensor_simple,
ScalarSpec,
Spec,
TensorSpec,
)
class ConstantOperator(Operator):
"""Operator for generating constants."""
def __init__(self):
super().__init__("constant")
self.template = "default" # Track template for DTensor compatibility
@property
def torch_op_name(self) -> Optional[str]:
"""Constant is not a torch operation, it generates constant values."""
return None
def set_template(self, template: str):
"""Set the template for context-aware code generation."""
self.template = template
def can_produce(self, output_spec: Spec) -> bool:
"""Constant can produce any type of output."""
return True
def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]:
"""Constant requires no inputs for fuzzing."""
return []
def codegen(
self, output_name: str, input_names: list[str], output_spec: Spec
) -> str:
"""Generate code for constant creation."""
# Create constant by calling fuzzing functions during codegen with deterministic seed
# Use a deterministic hash based on the variable name to ensure reproducibility across processes
import hashlib
var_seed = int(hashlib.md5(output_name.encode()).hexdigest()[:8], 16) % (2**31) # noqa: S324
if isinstance(output_spec, ScalarSpec):
# Call fuzz_scalar during codegen and embed the result
actual_value = fuzz_scalar(output_spec, seed=var_seed)
# Format the value for embedding in code
if isinstance(actual_value, bool):
value_str = str(actual_value)
elif isinstance(actual_value, (int, float)):
value_str = repr(actual_value)
elif isinstance(actual_value, complex):
value_str = f"complex({actual_value.real}, {actual_value.imag})"
else:
value_str = repr(actual_value)
return f"{output_name} = {value_str}"
elif isinstance(output_spec, TensorSpec):
# Call fuzz_tensor_simple during codegen and embed the result
actual_tensor = fuzz_tensor_simple(
output_spec.size, output_spec.stride, output_spec.dtype, seed=var_seed
)
# Convert tensor to code representation
size_str = str(output_spec.size)
dtype_str = f"torch.{output_spec.dtype}".replace("torch.torch.", "torch.")
# Handle empty tensors (with 0 elements)
if actual_tensor.numel() == 0:
# For empty tensors, use a default fill value based on dtype
import torch
default_values = {
torch.float16: 1.0,
torch.float32: 1.0,
torch.float64: 1.0,
torch.bfloat16: 1.0,
torch.int8: 1,
torch.int16: 1,
torch.int32: 1,
torch.int64: 1,
torch.bool: True,
torch.complex64: 1.0,
torch.complex128: 1.0,
}
fill_value = default_values.get(output_spec.dtype, 1)
tensor_creation = (
f"torch.full({size_str}, {fill_value}, dtype={dtype_str})"
)
else:
# For non-empty tensors, use the first element as fill value
fill_value = actual_tensor.flatten()[0].item()
# For integer types, clamp the value to a smaller range to avoid
# issues when used in arithmetic with embedding indices
import torch
if output_spec.dtype in [
torch.int8,
torch.int16,
torch.int32,
torch.int64,
]:
# Clamp integer values to [0, 3] to avoid index overflow in multiplication
# Even with multiplication, indices should stay in reasonable range
# pyrefly: ignore # bad-argument-type
fill_value = max(0, min(3, abs(fill_value)))
tensor_creation = (
f"torch.full({size_str}, {fill_value}, dtype={dtype_str})"
)
# For DTensor template, convert to DTensor
if self.template == "dtensor":
return (
f"{output_name}_local = {tensor_creation}.to('cuda')\n"
f" {output_name} = DTensor.from_local({output_name}_local, mesh, placements)"
)
else:
return f"{output_name} = {tensor_creation}"
else:
return f"# Unknown output spec type for constant: {type(output_spec)}"