[torchfuzz] various edge case fixes (#164715)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164715
Approved by: https://github.com/pianpwk
ghstack dependencies: #164432, #164434, #164514, #164646, #164647, #164649, #164687, #164688, #164693, #164694
This commit is contained in:
bobrenjc93
2025-10-06 10:03:51 -07:00
committed by PyTorch MergeBot
parent 53f6cc7529
commit 2a6cdba6e5
3 changed files with 26 additions and 1 deletions

View File

@ -180,6 +180,12 @@ class FuzzTemplate:
code_lines.append(
f"{arg_name} = torch.as_strided(torch.randint({min_val}, {max_val}, ({storage_size},)).to({dtype_str}), {size_str}, {stride_str})"
)
elif spec.dtype == torch.bool:
# For boolean tensors, use randint to generate True/False values
# Using randn().to(bool) would yield almost all True due to non-zero floats
code_lines.append(
f"{arg_name} = torch.as_strided(torch.randint(0, 2, ({storage_size},), dtype=torch.int8).bool(), {size_str}, {stride_str})"
)
else:
code_lines.append(
f"{arg_name} = torch.as_strided(torch.randn({storage_size}).to({dtype_str}), {size_str}, {stride_str})"

View File

@ -62,7 +62,7 @@ IGNORE_PATTERNS: list[re.Pattern] = [
re.compile(
r"TypeError\(\"unsupported operand type\(s\) for \*: 'SymBool' and 'FakeTensor'\"\)"
), # https://github.com/pytorch/pytorch/issues/164684
re.compile(r"KeyError: u0"), # https://github.com/pytorch/pytorch/issues/164685
re.compile(r"KeyError: u\d+"), # https://github.com/pytorch/pytorch/issues/164685
re.compile(
r"torch\._inductor\.exc\.InductorError: CppCompileError: C\+\+ compile error"
), # https://github.com/pytorch/pytorch/issues/164686

View File

@ -78,3 +78,22 @@ class ScalarDivOperator(ScalarPointwiseOperator):
def __init__(self):
super().__init__("scalar_div", "/")
def codegen(
self, output_name: str, input_names: list[str], output_spec: Spec
) -> str:
"""Generate code for scalar division with zero-denominator guard."""
if len(input_names) != 2:
raise ValueError(f"{self.__class__.__name__} requires exactly two inputs")
# Prevent ZeroDivisionError at runtime by clamping the denominator.
# Clamp denominator to at least 1 (for ints) or 1e-6 (for floats).
if isinstance(output_spec, ScalarSpec) and output_spec.dtype in [
torch.int8,
torch.int16,
torch.int32,
torch.int64,
]:
return f"{output_name} = {input_names[0]} / max({input_names[1]}, 1)"
else:
return f"{output_name} = {input_names[0]} / max({input_names[1]}, 1e-6)"