Files
pytorch/tools/experimental/torchfuzz/tensor_descriptor.py

31 lines
1.1 KiB
Python

# mypy: ignore-errors
"""Utility functions for generating tensor descriptors in code comments."""
from torchfuzz.tensor_fuzzer import ScalarSpec, Spec, TensorSpec
def format_tensor_descriptor(spec: Spec) -> str:
"""
Format a tensor or scalar spec as a descriptor comment.
Args:
spec: TensorSpec or ScalarSpec to format
Returns:
Formatted descriptor string like "size=(64, 176, 96), stride=(16896, 96, 1), dtype=bfloat16, device=cuda"
"""
if isinstance(spec, ScalarSpec):
# For scalars, just show the dtype
dtype_str = str(spec.dtype).replace("torch.", "")
return f"dtype={dtype_str}"
elif isinstance(spec, TensorSpec):
# For tensors, show size, stride, dtype, and device (assuming cuda for now)
size_str = str(tuple(spec.size))
stride_str = str(tuple(spec.stride))
dtype_str = str(spec.dtype).replace("torch.", "")
device_str = "cuda" # Most fuzzing is done on GPU
return f"size={size_str}, stride={stride_str}, dtype={dtype_str}, device={device_str}"
else:
return "unknown_spec"