mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[torchfuzz] ones over zero (#164002)
reduces likelihood of divide by zero errors. long term we'll probably want to just fuzz these values entirely Pull Request resolved: https://github.com/pytorch/pytorch/pull/164002 Approved by: https://github.com/pianpwk ghstack dependencies: #163743, #163812, #163890
This commit is contained in:
committed by
PyTorch MergeBot
parent
b48a3d0a38
commit
3a115da3e6
@ -76,19 +76,20 @@ class ConstantOperator(Operator):
|
||||
import torch
|
||||
|
||||
default_values = {
|
||||
torch.float16: 0.0,
|
||||
torch.float32: 0.0,
|
||||
torch.float64: 0.0,
|
||||
torch.bfloat16: 0.0,
|
||||
torch.int8: 0,
|
||||
torch.int16: 0,
|
||||
torch.int32: 0,
|
||||
torch.int64: 0,
|
||||
torch.bool: False,
|
||||
torch.complex64: 0.0,
|
||||
torch.complex128: 0.0,
|
||||
torch.float16: 1.0,
|
||||
torch.float32: 1.0,
|
||||
torch.float64: 1.0,
|
||||
torch.bfloat16: 1.0,
|
||||
torch.int8: 1,
|
||||
torch.int16: 1,
|
||||
torch.int32: 1,
|
||||
torch.int64: 1,
|
||||
torch.bool: True,
|
||||
torch.complex64: 1.0,
|
||||
torch.complex128: 1.0,
|
||||
}
|
||||
fill_value = default_values.get(output_spec.dtype, 0)
|
||||
|
||||
fill_value = default_values.get(output_spec.dtype, 1)
|
||||
tensor_creation = (
|
||||
f"torch.full({size_str}, {fill_value}, dtype={dtype_str})"
|
||||
)
|
||||
|
@ -373,7 +373,7 @@ def fuzz_tensor(
|
||||
|
||||
# Handle empty tensor case
|
||||
if len(size) == 0:
|
||||
return torch.zeros((), dtype=dtype), seed
|
||||
return torch.ones((), dtype=dtype), seed
|
||||
|
||||
# Calculate required storage size for the custom stride
|
||||
required_storage = _compute_storage_size_needed(size, stride)
|
||||
@ -400,7 +400,7 @@ def fuzz_tensor(
|
||||
base_tensor = torch.randint(-100, 100, (required_storage,), dtype=dtype)
|
||||
else:
|
||||
# Use zeros (default behavior)
|
||||
base_tensor = torch.zeros(required_storage, dtype=dtype)
|
||||
base_tensor = torch.ones(required_storage, dtype=dtype)
|
||||
|
||||
# Create strided tensor view
|
||||
strided_tensor = torch.as_strided(base_tensor, size, stride)
|
||||
|
Reference in New Issue
Block a user