mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164688 Approved by: https://github.com/pianpwk ghstack dependencies: #164432, #164434, #164514, #164646, #164647, #164649, #164687
31 lines
1.1 KiB
Python
31 lines
1.1 KiB
Python
# mypy: ignore-errors
|
|
"""Utility functions for generating tensor descriptors in code comments."""
|
|
|
|
from torchfuzz.tensor_fuzzer import ScalarSpec, Spec, TensorSpec
|
|
|
|
|
|
def format_tensor_descriptor(spec: Spec) -> str:
|
|
"""
|
|
Format a tensor or scalar spec as a descriptor comment.
|
|
|
|
Args:
|
|
spec: TensorSpec or ScalarSpec to format
|
|
|
|
Returns:
|
|
Formatted descriptor string like "size=(64, 176, 96), stride=(16896, 96, 1), dtype=bfloat16, device=cuda"
|
|
"""
|
|
if isinstance(spec, ScalarSpec):
|
|
# For scalars, just show the dtype
|
|
dtype_str = str(spec.dtype).replace("torch.", "")
|
|
return f"dtype={dtype_str}"
|
|
elif isinstance(spec, TensorSpec):
|
|
# For tensors, show size, stride, dtype, and device (assuming cuda for now)
|
|
size_str = str(tuple(spec.size))
|
|
stride_str = str(tuple(spec.stride))
|
|
dtype_str = str(spec.dtype).replace("torch.", "")
|
|
device_str = "cuda" # Most fuzzing is done on GPU
|
|
|
|
return f"size={size_str}, stride={stride_str}, dtype={dtype_str}, device={device_str}"
|
|
else:
|
|
return "unknown_spec"
|