mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Expands Pyrefly type checking to check the files outlined in the mypy-strict.ini configuration file: Pull Request resolved: https://github.com/pytorch/pytorch/pull/165697 Approved by: https://github.com/ezyang
132 lines
4.9 KiB
Python
132 lines
4.9 KiB
Python
"""Constant operator implementation."""
|
|
|
|
from typing import Optional
|
|
|
|
from torchfuzz.operators.base import Operator
|
|
from torchfuzz.tensor_fuzzer import (
|
|
fuzz_scalar,
|
|
fuzz_tensor_simple,
|
|
ScalarSpec,
|
|
Spec,
|
|
TensorSpec,
|
|
)
|
|
|
|
|
|
class ConstantOperator(Operator):
|
|
"""Operator for generating constants."""
|
|
|
|
def __init__(self):
|
|
super().__init__("constant")
|
|
self.template = "default" # Track template for DTensor compatibility
|
|
|
|
@property
|
|
def torch_op_name(self) -> Optional[str]:
|
|
"""Constant is not a torch operation, it generates constant values."""
|
|
return None
|
|
|
|
def set_template(self, template: str):
|
|
"""Set the template for context-aware code generation."""
|
|
self.template = template
|
|
|
|
def can_produce(self, output_spec: Spec) -> bool:
|
|
"""Constant can produce any type of output."""
|
|
return True
|
|
|
|
def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]:
|
|
"""Constant requires no inputs for fuzzing."""
|
|
return []
|
|
|
|
def codegen(
|
|
self, output_name: str, input_names: list[str], output_spec: Spec
|
|
) -> str:
|
|
"""Generate code for constant creation."""
|
|
# Create constant by calling fuzzing functions during codegen with deterministic seed
|
|
# Use a deterministic hash based on the variable name to ensure reproducibility across processes
|
|
import hashlib
|
|
|
|
var_seed = int(hashlib.md5(output_name.encode()).hexdigest()[:8], 16) % (2**31) # noqa: S324
|
|
|
|
if isinstance(output_spec, ScalarSpec):
|
|
# Call fuzz_scalar during codegen and embed the result
|
|
actual_value = fuzz_scalar(output_spec, seed=var_seed)
|
|
|
|
# Format the value for embedding in code
|
|
if isinstance(actual_value, bool):
|
|
value_str = str(actual_value)
|
|
elif isinstance(actual_value, (int, float)):
|
|
value_str = repr(actual_value)
|
|
elif isinstance(actual_value, complex):
|
|
value_str = f"complex({actual_value.real}, {actual_value.imag})"
|
|
else:
|
|
value_str = repr(actual_value)
|
|
|
|
return f"{output_name} = {value_str}"
|
|
|
|
elif isinstance(output_spec, TensorSpec):
|
|
# Call fuzz_tensor_simple during codegen and embed the result
|
|
actual_tensor = fuzz_tensor_simple(
|
|
output_spec.size, output_spec.stride, output_spec.dtype, seed=var_seed
|
|
)
|
|
|
|
# Convert tensor to code representation
|
|
size_str = str(output_spec.size)
|
|
dtype_str = f"torch.{output_spec.dtype}".replace("torch.torch.", "torch.")
|
|
|
|
# Handle empty tensors (with 0 elements)
|
|
if actual_tensor.numel() == 0:
|
|
# For empty tensors, use a default fill value based on dtype
|
|
import torch
|
|
|
|
default_values = {
|
|
torch.float16: 1.0,
|
|
torch.float32: 1.0,
|
|
torch.float64: 1.0,
|
|
torch.bfloat16: 1.0,
|
|
torch.int8: 1,
|
|
torch.int16: 1,
|
|
torch.int32: 1,
|
|
torch.int64: 1,
|
|
torch.bool: True,
|
|
torch.complex64: 1.0,
|
|
torch.complex128: 1.0,
|
|
}
|
|
|
|
fill_value = default_values.get(output_spec.dtype, 1)
|
|
tensor_creation = (
|
|
f"torch.full({size_str}, {fill_value}, dtype={dtype_str})"
|
|
)
|
|
else:
|
|
# For non-empty tensors, use the first element as fill value
|
|
fill_value = actual_tensor.flatten()[0].item()
|
|
|
|
# For integer types, clamp the value to a smaller range to avoid
|
|
# issues when used in arithmetic with embedding indices
|
|
import torch
|
|
|
|
if output_spec.dtype in [
|
|
torch.int8,
|
|
torch.int16,
|
|
torch.int32,
|
|
torch.int64,
|
|
]:
|
|
# Clamp integer values to [0, 3] to avoid index overflow in multiplication
|
|
# Even with multiplication, indices should stay in reasonable range
|
|
# pyrefly: ignore # bad-argument-type
|
|
fill_value = max(0, min(3, abs(fill_value)))
|
|
|
|
tensor_creation = (
|
|
f"torch.full({size_str}, {fill_value}, dtype={dtype_str})"
|
|
)
|
|
|
|
# For DTensor template, convert to DTensor
|
|
if self.template == "dtensor":
|
|
return (
|
|
f"{output_name}_local = {tensor_creation}.to('cuda')\n"
|
|
f" {output_name} = DTensor.from_local({output_name}_local, mesh, placements)"
|
|
)
|
|
else:
|
|
return f"{output_name} = {tensor_creation}"
|
|
|
|
else:
|
|
return f"# Unknown output spec type for constant: {type(output_spec)}"
|