[Fix] Adding missing f prefixes to formatted strings [1/N] (#164065)

As stated in the title.

* #164068
* #164067
* #164066
* __->__ #164065

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164065
Approved by: https://github.com/Skylion007
This commit is contained in:
can-gaa-hou
2025-09-29 04:53:00 +00:00
committed by PyTorch MergeBot
parent d131f213ac
commit eb4361a801
9 changed files with 26 additions and 25 deletions

View File

@ -924,7 +924,7 @@ def im2col(
def check_positive(param, param_name, strict=True):
cond = all(p > 0 for p in param) if strict else all(p >= 0 for p in param)
torch._check(
cond, lambda: "{param_name} should be greater {'than' zero, but got {param}"
cond, lambda: f"{param_name} should be greater than zero, but got {param}"
)
check_positive(kernel_size, "kernel_size")
@ -1009,7 +1009,7 @@ def col2im(
def check_positive(param, param_name, strict=True):
cond = all(p > 0 for p in param) if strict else all(p >= 0 for p in param)
torch._check(
cond, lambda: "{param_name} should be greater than zero, but got {param}"
cond, lambda: f"{param_name} should be greater than zero, but got {param}"
)
check_positive(kernel_size, "kernel_size")

View File

@ -166,7 +166,7 @@ class NaNChecker:
if grad is not None:
assert not torch.isnan(grad).any(), (
f"Compiled autograd running under anomaly mode with inputs[{idx}] already "
"having NaN gradient. This is not supported. {TURN_OFF_MSG}"
f"having NaN gradient. This is not supported. {TURN_OFF_MSG}"
)
self.params_to_check[f"inputs[{idx}]"] = inputs[idx]

View File

@ -263,7 +263,7 @@ class FrameStateSizeEntry:
return f"tensor size={render_tuple(self.size)} stride={render_tuple(self.stride)}"
# Fallback
return "unusual {repr(self)}"
return f"unusual {repr(self)}"
def __post_init__(self) -> None:
assert not isinstance(self.scalar, torch.SymInt), self.scalar

View File

@ -370,7 +370,7 @@ class MemoryPlanningState:
class WrapperLine:
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
raise NotImplementedError("FX codegen not yet supported for type {type(self)}")
raise NotImplementedError(f"FX codegen not yet supported for type {type(self)}")
@dataclasses.dataclass

View File

@ -2885,7 +2885,7 @@ class ExpandView(BaseView):
# guarded because the meta formula was expected to have taught
# us this equality.
assert sizevars.size_hint(new_size[i] - old_size[i], fallback=0) == 0, (
"Broadcast failed in ExpandView({x.get_size()}, {new_size}) on dimension {i}"
f"Broadcast failed in ExpandView({x.get_size()}, {new_size}) on dimension {i}"
)
return new_size

View File

@ -664,7 +664,7 @@ def meta__cslt_sparse_mm(
torch.int32,
torch.float8_e4m3fn,
}, (
"out_dtype is not supported for {compressed_A.dtype} x {dense_B.dtype} -> {out_dtype} matmul!"
f"out_dtype is not supported for {compressed_A.dtype} x {dense_B.dtype} -> {out_dtype} matmul!"
)
output_shape = (n, m) if transpose_result else (m, n)
return dense_B.new_empty(output_shape, dtype=out_dtype)
@ -3884,11 +3884,11 @@ def meta_cdist_forward(x1, x2, p, compute_mode):
)
torch._check(
utils.is_float_dtype(x1.dtype),
lambda: "cdist only supports floating-point dtypes, X1 got: {x1.dtype}",
lambda: f"cdist only supports floating-point dtypes, X1 got: {x1.dtype}",
)
torch._check(
utils.is_float_dtype(x2.dtype),
lambda: "cdist only supports floating-point dtypes, X2 got: {x2.dtype}",
lambda: f"cdist only supports floating-point dtypes, X2 got: {x2.dtype}",
)
torch._check(p >= 0, lambda: "cdist only supports non-negative p values")
torch._check(
@ -4466,15 +4466,15 @@ def pool2d_shape_check(
torch._check(
kW > 0 and kH > 0,
lambda: "kernel size should be greater than zero, but got kH: {kH}, kW: {kW}",
lambda: f"kernel size should be greater than zero, but got kH: {kH}, kW: {kW}",
)
torch._check(
dW > 0 and dH > 0,
lambda: "stride should be greater than zero, but got dH: {dH}, dW: {dW}",
lambda: f"stride should be greater than zero, but got dH: {dH}, dW: {dW}",
)
torch._check(
dilationH > 0 and dilationW > 0,
lambda: "dilation should be greater than zero, but got dilationH: {dilationH}, dilationW: {dilationW}",
lambda: f"dilation should be greater than zero, but got dilationH: {dilationH}, dilationW: {dilationW}",
)
valid_dims = input.size(1) != 0 and input.size(2) != 0
@ -4483,7 +4483,7 @@ def pool2d_shape_check(
torch._check(
ndim == 4 and valid_dims and input.size(3) != 0,
lambda: "Expected 4D (batch mode) tensor expected for input with channels_last layout"
" with optional 0 dim batch size for input, but got: {input.size()}",
f" with optional 0 dim batch size for input, but got: {input.size()}",
)
else:
torch._check(
@ -7015,7 +7015,7 @@ def meta_histc(input, bins=100, min=0, max=0):
isinstance(max, Number),
lambda: f"{fn_name}: argument 'max' must be Number, not {type(max)}",
)
torch._check(max >= min, lambda: "{fn_name}: max must be larger than min")
torch._check(max >= min, lambda: f"{fn_name}: max must be larger than min")
return torch.empty(bins, device=input.device, dtype=input.dtype)

View File

@ -5928,7 +5928,8 @@ def norm(
@out_wrapper()
def trace(self: TensorLikeType) -> TensorLikeType:
torch._check(
self.ndim == 2, lambda: "expected a matrix, but got tensor with dim {self.ndim}"
self.ndim == 2,
lambda: f"expected a matrix, but got tensor with dim {self.ndim}",
)
return torch.sum(torch.diag(self, 0))

View File

@ -56,7 +56,7 @@ def _check_norm_dtype(dtype: Optional[torch.dtype], x_dtype: torch.dtype, fn_nam
torch._check(
utils.get_higher_dtype(dtype, x_dtype) == dtype,
lambda: f"{fn_name}: the dtype of the input ({x_dtype}) should be convertible "
"without narrowing to the specified dtype ({dtype})",
f"without narrowing to the specified dtype ({dtype})",
)
@ -110,7 +110,7 @@ def _check_vector_norm_args(
x.numel() != 0,
not isinstance(dim, IntLike) and dim is not None and len(dim) != 0,
),
"linalg.vector_norm cannot compute the {ord} norm on an empty tensor "
f"linalg.vector_norm cannot compute the {ord} norm on an empty tensor "
"because the operation does not have an identity",
)
@ -119,7 +119,7 @@ def _check_vector_norm_args(
for d in dim:
torch._check(
sym_or(x.numel() != 0, d < len(shape) and d >= 0 and shape[d] != 0),
"linalg.vector_norm cannot compute the {ord} norm on the "
f"linalg.vector_norm cannot compute the {ord} norm on the "
f"dimension {d} because this dimension is empty and the "
"operation does not have an identity",
)
@ -220,11 +220,11 @@ def matrix_norm(
if isinstance(dim, Dim):
dim = (dim,) # type: ignore[assignment]
torch._check(
len(dim) == 2, lambda: "linalg.matrix_norm: dim must be a 2-tuple. Got {dim}"
len(dim) == 2, lambda: f"linalg.matrix_norm: dim must be a 2-tuple. Got {dim}"
)
torch._check(
dim[0] != dim[1],
lambda: "linalg.matrix_norm: dims must be different. Got ({dim[0]}, {dim[1]})",
lambda: f"linalg.matrix_norm: dims must be different. Got ({dim[0]}, {dim[1]})",
)
# dtype arg
_check_norm_dtype(dtype, A.dtype, "linalg.matrix_norm")
@ -233,7 +233,7 @@ def matrix_norm(
# ord
torch._check(
ord in ("fro", "nuc"),
lambda: "linalg.matrix_norm: Order {ord} not supported.",
lambda: f"linalg.matrix_norm: Order {ord} not supported.",
)
# dtype
check_fp_or_complex(
@ -256,7 +256,7 @@ def matrix_norm(
abs_ord = abs(ord)
torch._check(
abs_ord in (2, 1, float("inf")),
lambda: "linalg.matrix_norm: Order {ord} not supported.",
lambda: f"linalg.matrix_norm: Order {ord} not supported.",
)
# dtype
check_fp_or_complex(
@ -300,12 +300,12 @@ def norm(
dim = (dim,) # type: ignore[assignment]
torch._check(
len(dim) in (1, 2),
lambda: "linalg.norm: If dim is specified, it must be of length 1 or 2. Got {dim}",
lambda: f"linalg.norm: If dim is specified, it must be of length 1 or 2. Got {dim}",
)
elif ord is not None:
torch._check(
A.ndim in (1, 2),
lambda: "linalg.norm: If dim is not specified but ord is, the input must be 1D or 2D. Got {A.ndim}D",
lambda: f"linalg.norm: If dim is not specified but ord is, the input must be 1D or 2D. Got {A.ndim}D",
)
if ord is not None and (

View File

@ -334,7 +334,7 @@ def _check_repo_is_trusted(
if not is_trusted:
warnings.warn(
"You are about to download and run code from an untrusted repository. In a future release, this won't "
"be allowed. To add the repository to your trusted list, change the command to {calling_fn}(..., "
f"be allowed. To add the repository to your trusted list, change the command to {calling_fn}(..., "
"trust_repo=False) and a command prompt will appear asking for an explicit confirmation of trust, "
f"or {calling_fn}(..., trust_repo=True), which will assume that the prompt is to be answered with "
f"'yes'. You can also use {calling_fn}(..., trust_repo='check') which will only prompt for "