mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[torchfuzz] Make scalar and tensor distribution configurable (#164034)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164034 Approved by: https://github.com/pianpwk
This commit is contained in:
committed by
PyTorch MergeBot
parent
a293206bd5
commit
c39357bab6
@ -29,6 +29,81 @@ class FuzzTemplate:
|
||||
torch.bool,
|
||||
]
|
||||
|
||||
def spec_distribution(self):
|
||||
"""
|
||||
Define the distribution for generating random Specs.
|
||||
|
||||
Returns:
|
||||
Dict with keys:
|
||||
- 'tensor_prob': Probability of generating TensorSpec (0.0 to 1.0)
|
||||
- 'scalar_prob': Probability of generating ScalarSpec (0.0 to 1.0)
|
||||
- 'allow_tensors': Whether TensorSpec generation is allowed (boolean)
|
||||
- 'allow_scalars': Whether ScalarSpec generation is allowed (boolean)
|
||||
"""
|
||||
return {
|
||||
"tensor_prob": 0.8,
|
||||
"scalar_prob": 0.2,
|
||||
"allow_tensors": True,
|
||||
"allow_scalars": True,
|
||||
}
|
||||
|
||||
def fuzz_spec_custom(self):
|
||||
"""
|
||||
Generate a random Spec based on this template's distribution preferences.
|
||||
|
||||
Returns:
|
||||
Spec: Either a TensorSpec or ScalarSpec according to template's distribution
|
||||
"""
|
||||
import random
|
||||
|
||||
from torchfuzz.tensor_fuzzer import fuzz_torch_tensor_type
|
||||
|
||||
# Get template's distribution configuration
|
||||
distribution = self.spec_distribution()
|
||||
|
||||
# Get random dtype based on template
|
||||
dtype = fuzz_torch_tensor_type("default")
|
||||
|
||||
# Validate distribution configuration
|
||||
allow_tensors = distribution.get("allow_tensors", True)
|
||||
allow_scalars = distribution.get("allow_scalars", True)
|
||||
|
||||
if not allow_tensors and not allow_scalars:
|
||||
raise ValueError("Template must allow at least one of tensors or scalars")
|
||||
|
||||
# Determine which type to generate
|
||||
if not allow_scalars:
|
||||
# Only tensors allowed
|
||||
return self._generate_tensor_spec(dtype)
|
||||
elif not allow_tensors:
|
||||
# Only scalars allowed
|
||||
return self._generate_scalar_spec(dtype)
|
||||
else:
|
||||
# Both allowed, use probability distribution
|
||||
tensor_prob = distribution.get("tensor_prob", 0.8)
|
||||
if random.random() < tensor_prob:
|
||||
return self._generate_tensor_spec(dtype)
|
||||
else:
|
||||
return self._generate_scalar_spec(dtype)
|
||||
|
||||
def _generate_tensor_spec(self, dtype):
|
||||
"""Generate a TensorSpec with the given dtype."""
|
||||
from torchfuzz.tensor_fuzzer import (
|
||||
fuzz_tensor_size,
|
||||
fuzz_valid_stride,
|
||||
TensorSpec,
|
||||
)
|
||||
|
||||
size = fuzz_tensor_size()
|
||||
stride = fuzz_valid_stride(size)
|
||||
return TensorSpec(size=size, stride=stride, dtype=dtype)
|
||||
|
||||
def _generate_scalar_spec(self, dtype):
|
||||
"""Generate a ScalarSpec with the given dtype."""
|
||||
from torchfuzz.tensor_fuzzer import ScalarSpec
|
||||
|
||||
return ScalarSpec(dtype=dtype)
|
||||
|
||||
|
||||
class DefaultFuzzTemplate(FuzzTemplate):
|
||||
def __init__(self):
|
||||
@ -44,6 +119,15 @@ class DefaultFuzzTemplate(FuzzTemplate):
|
||||
check=EagerVsFullGraphDynamicCompileCheck(),
|
||||
)
|
||||
|
||||
def spec_distribution(self):
|
||||
"""Default template: tensor-only (no scalars)."""
|
||||
return {
|
||||
"tensor_prob": 1.0,
|
||||
"scalar_prob": 0.0,
|
||||
"allow_tensors": True,
|
||||
"allow_scalars": False,
|
||||
}
|
||||
|
||||
def imports_codegen(self):
|
||||
return [
|
||||
"import torch",
|
||||
@ -129,6 +213,15 @@ class DTensorFuzzTemplate(FuzzTemplate):
|
||||
torch.bool,
|
||||
]
|
||||
|
||||
def spec_distribution(self):
|
||||
"""DTensor template: tensor-only (no scalars)."""
|
||||
return {
|
||||
"tensor_prob": 1.0,
|
||||
"scalar_prob": 0.0,
|
||||
"allow_tensors": True,
|
||||
"allow_scalars": False,
|
||||
}
|
||||
|
||||
def imports_codegen(self):
|
||||
return [
|
||||
"import torch",
|
||||
@ -259,6 +352,15 @@ class UnbackedFuzzTemplate(FuzzTemplate):
|
||||
torch.int64,
|
||||
]
|
||||
|
||||
def spec_distribution(self):
|
||||
"""Unbacked template: 50% tensors, 50% scalars."""
|
||||
return {
|
||||
"tensor_prob": 0.5,
|
||||
"scalar_prob": 0.5,
|
||||
"allow_tensors": True,
|
||||
"allow_scalars": True,
|
||||
}
|
||||
|
||||
def imports_codegen(self):
|
||||
return [
|
||||
"import torch",
|
||||
|
@ -216,31 +216,47 @@ class OperationGraph:
|
||||
|
||||
def fuzz_spec(template: str = "default") -> Spec:
|
||||
"""
|
||||
Generate a random Spec (either TensorSpec or ScalarSpec) using tensor fuzzing functions.
|
||||
|
||||
Utilizes:
|
||||
- fuzz_torch_tensor_type() for random dtype
|
||||
- fuzz_tensor_size() for random tensor size
|
||||
- fuzz_valid_stride() for random valid strides
|
||||
Generate a random Spec (either TensorSpec or ScalarSpec) using template's distribution preferences.
|
||||
|
||||
Args:
|
||||
template: Template name to determine supported dtypes
|
||||
template: Template name to determine configuration and distribution
|
||||
|
||||
Returns:
|
||||
Spec: Either a TensorSpec (80% probability) or ScalarSpec (20% probability) with random properties
|
||||
Spec: Either a TensorSpec or ScalarSpec according to template's distribution
|
||||
"""
|
||||
# Get random dtype based on template
|
||||
dtype = fuzz_torch_tensor_type(template)
|
||||
# Try to use template's custom distribution if available
|
||||
try:
|
||||
# Instantiate template
|
||||
if template == "dtensor":
|
||||
from torchfuzz.codegen import DTensorFuzzTemplate
|
||||
|
||||
# 20% probability of returning ScalarSpec
|
||||
if random.random() < 0.2:
|
||||
return ScalarSpec(dtype=dtype)
|
||||
fuzz_template = DTensorFuzzTemplate()
|
||||
elif template == "unbacked":
|
||||
from torchfuzz.codegen import UnbackedFuzzTemplate
|
||||
|
||||
# 80% probability of returning TensorSpec
|
||||
# Get random size and corresponding stride
|
||||
size = fuzz_tensor_size()
|
||||
stride = fuzz_valid_stride(size)
|
||||
return TensorSpec(size=size, stride=stride, dtype=dtype)
|
||||
fuzz_template = UnbackedFuzzTemplate()
|
||||
else:
|
||||
from torchfuzz.codegen import DefaultFuzzTemplate
|
||||
|
||||
fuzz_template = DefaultFuzzTemplate()
|
||||
|
||||
# Use template's custom spec generation
|
||||
return fuzz_template.fuzz_spec_custom()
|
||||
|
||||
except Exception:
|
||||
# Fallback to original hardcoded behavior if template fails
|
||||
# Get random dtype based on template
|
||||
dtype = fuzz_torch_tensor_type(template)
|
||||
|
||||
# 20% probability of returning ScalarSpec
|
||||
if random.random() < 0.2:
|
||||
return ScalarSpec(dtype=dtype)
|
||||
|
||||
# 80% probability of returning TensorSpec
|
||||
# Get random size and corresponding stride
|
||||
size = fuzz_tensor_size()
|
||||
stride = fuzz_valid_stride(size)
|
||||
return TensorSpec(size=size, stride=stride, dtype=dtype)
|
||||
|
||||
|
||||
def fuzz_op(
|
||||
|
Reference in New Issue
Block a user