mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[torchfuzz] various edge case fixes (#164715)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164715 Approved by: https://github.com/pianpwk ghstack dependencies: #164432, #164434, #164514, #164646, #164647, #164649, #164687, #164688, #164693, #164694
This commit is contained in:
committed by
PyTorch MergeBot
parent
53f6cc7529
commit
2a6cdba6e5
@ -180,6 +180,12 @@ class FuzzTemplate:
|
||||
code_lines.append(
|
||||
f"{arg_name} = torch.as_strided(torch.randint({min_val}, {max_val}, ({storage_size},)).to({dtype_str}), {size_str}, {stride_str})"
|
||||
)
|
||||
elif spec.dtype == torch.bool:
|
||||
# For boolean tensors, use randint to generate True/False values
|
||||
# Using randn().to(bool) would yield almost all True due to non-zero floats
|
||||
code_lines.append(
|
||||
f"{arg_name} = torch.as_strided(torch.randint(0, 2, ({storage_size},), dtype=torch.int8).bool(), {size_str}, {stride_str})"
|
||||
)
|
||||
else:
|
||||
code_lines.append(
|
||||
f"{arg_name} = torch.as_strided(torch.randn({storage_size}).to({dtype_str}), {size_str}, {stride_str})"
|
||||
|
@ -62,7 +62,7 @@ IGNORE_PATTERNS: list[re.Pattern] = [
|
||||
re.compile(
|
||||
r"TypeError\(\"unsupported operand type\(s\) for \*: 'SymBool' and 'FakeTensor'\"\)"
|
||||
), # https://github.com/pytorch/pytorch/issues/164684
|
||||
re.compile(r"KeyError: u0"), # https://github.com/pytorch/pytorch/issues/164685
|
||||
re.compile(r"KeyError: u\d+"), # https://github.com/pytorch/pytorch/issues/164685
|
||||
re.compile(
|
||||
r"torch\._inductor\.exc\.InductorError: CppCompileError: C\+\+ compile error"
|
||||
), # https://github.com/pytorch/pytorch/issues/164686
|
||||
|
@ -78,3 +78,22 @@ class ScalarDivOperator(ScalarPointwiseOperator):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("scalar_div", "/")
|
||||
|
||||
def codegen(
|
||||
self, output_name: str, input_names: list[str], output_spec: Spec
|
||||
) -> str:
|
||||
"""Generate code for scalar division with zero-denominator guard."""
|
||||
if len(input_names) != 2:
|
||||
raise ValueError(f"{self.__class__.__name__} requires exactly two inputs")
|
||||
|
||||
# Prevent ZeroDivisionError at runtime by clamping the denominator.
|
||||
# Clamp denominator to at least 1 (for ints) or 1e-6 (for floats).
|
||||
if isinstance(output_spec, ScalarSpec) and output_spec.dtype in [
|
||||
torch.int8,
|
||||
torch.int16,
|
||||
torch.int32,
|
||||
torch.int64,
|
||||
]:
|
||||
return f"{output_name} = {input_names[0]} / max({input_names[1]}, 1)"
|
||||
else:
|
||||
return f"{output_name} = {input_names[0]} / max({input_names[1]}, 1e-6)"
|
||||
|
Reference in New Issue
Block a user