mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[torchfuzz] consolidate on a base implementation of args_codegen (#164693)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164693 Approved by: https://github.com/pianpwk ghstack dependencies: #164432, #164434, #164514, #164646, #164647, #164649, #164687, #164688
This commit is contained in:
committed by
PyTorch MergeBot
parent
c965d6dbb2
commit
ac901bf79a
@ -104,6 +104,89 @@ class FuzzTemplate:
|
||||
|
||||
return ScalarSpec(dtype=dtype)
|
||||
|
||||
def args_codegen(self, arg_operations):
|
||||
"""Generate argument creation code for default template."""
|
||||
code_lines = []
|
||||
|
||||
# Add sentinel tensor that ensures gradient computation
|
||||
code_lines.extend(
|
||||
[
|
||||
"# Sentinel tensor to ensure gradient computation",
|
||||
"sentinel = torch.tensor(1.0, requires_grad=True)",
|
||||
"",
|
||||
]
|
||||
)
|
||||
|
||||
if arg_operations:
|
||||
for i, (node_id, spec) in enumerate(arg_operations):
|
||||
arg_name = f"arg_{i}"
|
||||
|
||||
if isinstance(spec, ScalarSpec):
|
||||
dtype_str = f"torch.{spec.dtype}".replace("torch.torch.", "torch.")
|
||||
if spec.dtype in [
|
||||
torch.int8,
|
||||
torch.int16,
|
||||
torch.int32,
|
||||
torch.int64,
|
||||
]:
|
||||
# For integer scalars, use randint to avoid always getting 0
|
||||
code_lines.append(
|
||||
f"{arg_name} = int(torch.randint(5, 30, ()).item())"
|
||||
)
|
||||
elif spec.dtype == torch.bool:
|
||||
# For boolean scalars, use randint and cast to bool
|
||||
code_lines.append(
|
||||
f"{arg_name} = bool(torch.randint(0, 2, ()).item())"
|
||||
)
|
||||
else:
|
||||
# For float scalars, use randn
|
||||
code_lines.append(
|
||||
f"{arg_name} = float(torch.randn((), dtype={dtype_str}).item())"
|
||||
)
|
||||
|
||||
elif isinstance(spec, TensorSpec):
|
||||
size_str = str(spec.size)
|
||||
dtype_str = f"torch.{spec.dtype}".replace("torch.torch.", "torch.")
|
||||
|
||||
# Calculate storage size needed for the strided tensor
|
||||
if spec.size:
|
||||
# Calculate the maximum index that will be accessed
|
||||
max_offset = 0
|
||||
for dim_size, stride in zip(spec.size, spec.stride):
|
||||
if dim_size > 1:
|
||||
max_offset += (dim_size - 1) * abs(stride)
|
||||
storage_size = max_offset + 1
|
||||
else:
|
||||
storage_size = 1
|
||||
|
||||
stride_str = str(spec.stride)
|
||||
|
||||
# Special handling for integer tensors which might be used as indices
|
||||
if spec.dtype in [
|
||||
torch.int8,
|
||||
torch.int16,
|
||||
torch.int32,
|
||||
torch.int64,
|
||||
]:
|
||||
# For integer tensors, generate valid indices with headroom for arithmetic
|
||||
# Use smaller range [5, 30] to allow for multiplication and other operations
|
||||
# This prevents indices from becoming too large after arithmetic
|
||||
min_val = (
|
||||
5 # Minimum to avoid negative results after subtraction
|
||||
)
|
||||
max_val = (
|
||||
30 # Maximum to avoid out-of-bounds after multiplication
|
||||
)
|
||||
code_lines.append(
|
||||
f"{arg_name} = torch.as_strided(torch.randint({min_val}, {max_val}, ({storage_size},)).to({dtype_str}), {size_str}, {stride_str})"
|
||||
)
|
||||
else:
|
||||
code_lines.append(
|
||||
f"{arg_name} = torch.as_strided(torch.randn({storage_size}).to({dtype_str}), {size_str}, {stride_str})"
|
||||
)
|
||||
|
||||
return code_lines
|
||||
|
||||
|
||||
class DefaultFuzzTemplate(FuzzTemplate):
|
||||
def __init__(self):
|
||||
@ -167,67 +250,6 @@ class DefaultFuzzTemplate(FuzzTemplate):
|
||||
def flags_codegen(self):
|
||||
return ["torch._dynamo.config.capture_scalar_outputs = True"]
|
||||
|
||||
def args_codegen(self, arg_operations):
|
||||
"""Generate argument creation code for default template."""
|
||||
code_lines = []
|
||||
|
||||
# Add sentinel tensor that ensures gradient computation
|
||||
code_lines.extend(
|
||||
[
|
||||
"# Sentinel tensor to ensure gradient computation",
|
||||
"sentinel = torch.tensor(1.0, requires_grad=True)",
|
||||
"",
|
||||
]
|
||||
)
|
||||
|
||||
if arg_operations:
|
||||
for i, (node_id, spec) in enumerate(arg_operations):
|
||||
arg_name = f"arg_{i}"
|
||||
|
||||
if isinstance(spec, ScalarSpec):
|
||||
dtype_str = f"torch.{spec.dtype}".replace("torch.torch.", "torch.")
|
||||
code_lines.append(
|
||||
f"{arg_name} = torch.tensor(torch.randn(()), dtype={dtype_str}).item()"
|
||||
)
|
||||
|
||||
elif isinstance(spec, TensorSpec):
|
||||
size_str = str(spec.size)
|
||||
dtype_str = f"torch.{spec.dtype}".replace("torch.torch.", "torch.")
|
||||
|
||||
# Calculate storage size needed for the strided tensor
|
||||
if spec.size:
|
||||
# Calculate the maximum index that will be accessed
|
||||
max_offset = 0
|
||||
for dim_size, stride in zip(spec.size, spec.stride):
|
||||
if dim_size > 1:
|
||||
max_offset += (dim_size - 1) * abs(stride)
|
||||
storage_size = max_offset + 1
|
||||
else:
|
||||
storage_size = 1
|
||||
|
||||
stride_str = str(spec.stride)
|
||||
|
||||
# Special handling for integer tensors which might be used as indices
|
||||
if spec.dtype in [torch.int32, torch.int64]:
|
||||
# For integer tensors, generate valid indices with headroom for arithmetic
|
||||
# Use smaller range [5, 30] to allow for multiplication and other operations
|
||||
# This prevents indices from becoming too large after arithmetic
|
||||
min_val = (
|
||||
5 # Minimum to avoid negative results after subtraction
|
||||
)
|
||||
max_val = (
|
||||
30 # Maximum to avoid out-of-bounds after multiplication
|
||||
)
|
||||
code_lines.append(
|
||||
f"{arg_name} = torch.as_strided(torch.randint({min_val}, {max_val}, ({storage_size},)).to({dtype_str}), {size_str}, {stride_str})"
|
||||
)
|
||||
else:
|
||||
code_lines.append(
|
||||
f"{arg_name} = torch.as_strided(torch.randn({storage_size}).to({dtype_str}), {size_str}, {stride_str})"
|
||||
)
|
||||
|
||||
return code_lines
|
||||
|
||||
def epilogue_codegen(self):
|
||||
return []
|
||||
|
||||
@ -453,56 +475,6 @@ class UnbackedFuzzTemplate(FuzzTemplate):
|
||||
"torch._dynamo.config.capture_dynamic_output_shape_ops = True",
|
||||
]
|
||||
|
||||
def args_codegen(self, arg_operations):
|
||||
"""Generate argument creation code for unbacked template."""
|
||||
code_lines = []
|
||||
|
||||
# Add sentinel tensor that ensures gradient computation
|
||||
code_lines.extend(
|
||||
[
|
||||
"# Sentinel tensor to ensure gradient computation",
|
||||
"sentinel = torch.tensor(1.0, requires_grad=True)",
|
||||
"",
|
||||
]
|
||||
)
|
||||
|
||||
if arg_operations:
|
||||
for i, (node_id, spec) in enumerate(arg_operations):
|
||||
arg_name = f"arg_{i}"
|
||||
|
||||
if isinstance(spec, ScalarSpec):
|
||||
dtype_str = f"torch.{spec.dtype}".replace("torch.torch.", "torch.")
|
||||
code_lines.append(
|
||||
f"{arg_name} = torch.tensor(torch.randn(()), dtype={dtype_str}).item()"
|
||||
)
|
||||
|
||||
elif isinstance(spec, TensorSpec):
|
||||
size_str = str(spec.size)
|
||||
dtype_str = f"torch.{spec.dtype}".replace("torch.torch.", "torch.")
|
||||
|
||||
# For unbacked operations, create tensors with specific patterns
|
||||
# that are likely to produce meaningful results
|
||||
if spec.dtype == torch.bool:
|
||||
# For boolean tensors, create a mix of True/False values
|
||||
code_lines.append(
|
||||
f"{arg_name} = torch.randint(0, 2, {size_str}, dtype={dtype_str}) > 0"
|
||||
)
|
||||
elif spec.dtype in [torch.int32, torch.int64]:
|
||||
# For integer tensors, create values that will have some duplicates
|
||||
# and some unique values for operations like unique()
|
||||
code_lines.append(
|
||||
f"{arg_name} = torch.randint(0, 3, {size_str}, dtype={dtype_str})"
|
||||
)
|
||||
else:
|
||||
# For float tensors, create values with some zeros and non-zeros
|
||||
code_lines.append(
|
||||
f"{arg_name} = (torch.randn({size_str}) * 2).to({dtype_str})"
|
||||
)
|
||||
# Zero out some values to make nonzero operations meaningful
|
||||
code_lines.append(f"{arg_name}[{arg_name}.abs() < 0.5] = 0")
|
||||
|
||||
return code_lines
|
||||
|
||||
def epilogue_codegen(self):
|
||||
return []
|
||||
|
||||
|
Reference in New Issue
Block a user