mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Fix] Adding missing f
prefixes to formatted strings [1/N] (#164065)
As stated in the title. * #164068 * #164067 * #164066 * __->__ #164065 Pull Request resolved: https://github.com/pytorch/pytorch/pull/164065 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
d131f213ac
commit
eb4361a801
@ -924,7 +924,7 @@ def im2col(
|
||||
def check_positive(param, param_name, strict=True):
|
||||
cond = all(p > 0 for p in param) if strict else all(p >= 0 for p in param)
|
||||
torch._check(
|
||||
cond, lambda: "{param_name} should be greater {'than' zero, but got {param}"
|
||||
cond, lambda: f"{param_name} should be greater than zero, but got {param}"
|
||||
)
|
||||
|
||||
check_positive(kernel_size, "kernel_size")
|
||||
@ -1009,7 +1009,7 @@ def col2im(
|
||||
def check_positive(param, param_name, strict=True):
|
||||
cond = all(p > 0 for p in param) if strict else all(p >= 0 for p in param)
|
||||
torch._check(
|
||||
cond, lambda: "{param_name} should be greater than zero, but got {param}"
|
||||
cond, lambda: f"{param_name} should be greater than zero, but got {param}"
|
||||
)
|
||||
|
||||
check_positive(kernel_size, "kernel_size")
|
||||
|
@ -166,7 +166,7 @@ class NaNChecker:
|
||||
if grad is not None:
|
||||
assert not torch.isnan(grad).any(), (
|
||||
f"Compiled autograd running under anomaly mode with inputs[{idx}] already "
|
||||
"having NaN gradient. This is not supported. {TURN_OFF_MSG}"
|
||||
f"having NaN gradient. This is not supported. {TURN_OFF_MSG}"
|
||||
)
|
||||
|
||||
self.params_to_check[f"inputs[{idx}]"] = inputs[idx]
|
||||
|
@ -263,7 +263,7 @@ class FrameStateSizeEntry:
|
||||
return f"tensor size={render_tuple(self.size)} stride={render_tuple(self.stride)}"
|
||||
|
||||
# Fallback
|
||||
return "unusual {repr(self)}"
|
||||
return f"unusual {repr(self)}"
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
assert not isinstance(self.scalar, torch.SymInt), self.scalar
|
||||
|
@ -370,7 +370,7 @@ class MemoryPlanningState:
|
||||
|
||||
class WrapperLine:
|
||||
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
|
||||
raise NotImplementedError("FX codegen not yet supported for type {type(self)}")
|
||||
raise NotImplementedError(f"FX codegen not yet supported for type {type(self)}")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
@ -2885,7 +2885,7 @@ class ExpandView(BaseView):
|
||||
# guarded because the meta formula was expected to have taught
|
||||
# us this equality.
|
||||
assert sizevars.size_hint(new_size[i] - old_size[i], fallback=0) == 0, (
|
||||
"Broadcast failed in ExpandView({x.get_size()}, {new_size}) on dimension {i}"
|
||||
f"Broadcast failed in ExpandView({x.get_size()}, {new_size}) on dimension {i}"
|
||||
)
|
||||
return new_size
|
||||
|
||||
|
@ -664,7 +664,7 @@ def meta__cslt_sparse_mm(
|
||||
torch.int32,
|
||||
torch.float8_e4m3fn,
|
||||
}, (
|
||||
"out_dtype is not supported for {compressed_A.dtype} x {dense_B.dtype} -> {out_dtype} matmul!"
|
||||
f"out_dtype is not supported for {compressed_A.dtype} x {dense_B.dtype} -> {out_dtype} matmul!"
|
||||
)
|
||||
output_shape = (n, m) if transpose_result else (m, n)
|
||||
return dense_B.new_empty(output_shape, dtype=out_dtype)
|
||||
@ -3884,11 +3884,11 @@ def meta_cdist_forward(x1, x2, p, compute_mode):
|
||||
)
|
||||
torch._check(
|
||||
utils.is_float_dtype(x1.dtype),
|
||||
lambda: "cdist only supports floating-point dtypes, X1 got: {x1.dtype}",
|
||||
lambda: f"cdist only supports floating-point dtypes, X1 got: {x1.dtype}",
|
||||
)
|
||||
torch._check(
|
||||
utils.is_float_dtype(x2.dtype),
|
||||
lambda: "cdist only supports floating-point dtypes, X2 got: {x2.dtype}",
|
||||
lambda: f"cdist only supports floating-point dtypes, X2 got: {x2.dtype}",
|
||||
)
|
||||
torch._check(p >= 0, lambda: "cdist only supports non-negative p values")
|
||||
torch._check(
|
||||
@ -4466,15 +4466,15 @@ def pool2d_shape_check(
|
||||
|
||||
torch._check(
|
||||
kW > 0 and kH > 0,
|
||||
lambda: "kernel size should be greater than zero, but got kH: {kH}, kW: {kW}",
|
||||
lambda: f"kernel size should be greater than zero, but got kH: {kH}, kW: {kW}",
|
||||
)
|
||||
torch._check(
|
||||
dW > 0 and dH > 0,
|
||||
lambda: "stride should be greater than zero, but got dH: {dH}, dW: {dW}",
|
||||
lambda: f"stride should be greater than zero, but got dH: {dH}, dW: {dW}",
|
||||
)
|
||||
torch._check(
|
||||
dilationH > 0 and dilationW > 0,
|
||||
lambda: "dilation should be greater than zero, but got dilationH: {dilationH}, dilationW: {dilationW}",
|
||||
lambda: f"dilation should be greater than zero, but got dilationH: {dilationH}, dilationW: {dilationW}",
|
||||
)
|
||||
|
||||
valid_dims = input.size(1) != 0 and input.size(2) != 0
|
||||
@ -4483,7 +4483,7 @@ def pool2d_shape_check(
|
||||
torch._check(
|
||||
ndim == 4 and valid_dims and input.size(3) != 0,
|
||||
lambda: "Expected 4D (batch mode) tensor expected for input with channels_last layout"
|
||||
" with optional 0 dim batch size for input, but got: {input.size()}",
|
||||
f" with optional 0 dim batch size for input, but got: {input.size()}",
|
||||
)
|
||||
else:
|
||||
torch._check(
|
||||
@ -7015,7 +7015,7 @@ def meta_histc(input, bins=100, min=0, max=0):
|
||||
isinstance(max, Number),
|
||||
lambda: f"{fn_name}: argument 'max' must be Number, not {type(max)}",
|
||||
)
|
||||
torch._check(max >= min, lambda: "{fn_name}: max must be larger than min")
|
||||
torch._check(max >= min, lambda: f"{fn_name}: max must be larger than min")
|
||||
return torch.empty(bins, device=input.device, dtype=input.dtype)
|
||||
|
||||
|
||||
|
@ -5928,7 +5928,8 @@ def norm(
|
||||
@out_wrapper()
|
||||
def trace(self: TensorLikeType) -> TensorLikeType:
|
||||
torch._check(
|
||||
self.ndim == 2, lambda: "expected a matrix, but got tensor with dim {self.ndim}"
|
||||
self.ndim == 2,
|
||||
lambda: f"expected a matrix, but got tensor with dim {self.ndim}",
|
||||
)
|
||||
return torch.sum(torch.diag(self, 0))
|
||||
|
||||
|
@ -56,7 +56,7 @@ def _check_norm_dtype(dtype: Optional[torch.dtype], x_dtype: torch.dtype, fn_nam
|
||||
torch._check(
|
||||
utils.get_higher_dtype(dtype, x_dtype) == dtype,
|
||||
lambda: f"{fn_name}: the dtype of the input ({x_dtype}) should be convertible "
|
||||
"without narrowing to the specified dtype ({dtype})",
|
||||
f"without narrowing to the specified dtype ({dtype})",
|
||||
)
|
||||
|
||||
|
||||
@ -110,7 +110,7 @@ def _check_vector_norm_args(
|
||||
x.numel() != 0,
|
||||
not isinstance(dim, IntLike) and dim is not None and len(dim) != 0,
|
||||
),
|
||||
"linalg.vector_norm cannot compute the {ord} norm on an empty tensor "
|
||||
f"linalg.vector_norm cannot compute the {ord} norm on an empty tensor "
|
||||
"because the operation does not have an identity",
|
||||
)
|
||||
|
||||
@ -119,7 +119,7 @@ def _check_vector_norm_args(
|
||||
for d in dim:
|
||||
torch._check(
|
||||
sym_or(x.numel() != 0, d < len(shape) and d >= 0 and shape[d] != 0),
|
||||
"linalg.vector_norm cannot compute the {ord} norm on the "
|
||||
f"linalg.vector_norm cannot compute the {ord} norm on the "
|
||||
f"dimension {d} because this dimension is empty and the "
|
||||
"operation does not have an identity",
|
||||
)
|
||||
@ -220,11 +220,11 @@ def matrix_norm(
|
||||
if isinstance(dim, Dim):
|
||||
dim = (dim,) # type: ignore[assignment]
|
||||
torch._check(
|
||||
len(dim) == 2, lambda: "linalg.matrix_norm: dim must be a 2-tuple. Got {dim}"
|
||||
len(dim) == 2, lambda: f"linalg.matrix_norm: dim must be a 2-tuple. Got {dim}"
|
||||
)
|
||||
torch._check(
|
||||
dim[0] != dim[1],
|
||||
lambda: "linalg.matrix_norm: dims must be different. Got ({dim[0]}, {dim[1]})",
|
||||
lambda: f"linalg.matrix_norm: dims must be different. Got ({dim[0]}, {dim[1]})",
|
||||
)
|
||||
# dtype arg
|
||||
_check_norm_dtype(dtype, A.dtype, "linalg.matrix_norm")
|
||||
@ -233,7 +233,7 @@ def matrix_norm(
|
||||
# ord
|
||||
torch._check(
|
||||
ord in ("fro", "nuc"),
|
||||
lambda: "linalg.matrix_norm: Order {ord} not supported.",
|
||||
lambda: f"linalg.matrix_norm: Order {ord} not supported.",
|
||||
)
|
||||
# dtype
|
||||
check_fp_or_complex(
|
||||
@ -256,7 +256,7 @@ def matrix_norm(
|
||||
abs_ord = abs(ord)
|
||||
torch._check(
|
||||
abs_ord in (2, 1, float("inf")),
|
||||
lambda: "linalg.matrix_norm: Order {ord} not supported.",
|
||||
lambda: f"linalg.matrix_norm: Order {ord} not supported.",
|
||||
)
|
||||
# dtype
|
||||
check_fp_or_complex(
|
||||
@ -300,12 +300,12 @@ def norm(
|
||||
dim = (dim,) # type: ignore[assignment]
|
||||
torch._check(
|
||||
len(dim) in (1, 2),
|
||||
lambda: "linalg.norm: If dim is specified, it must be of length 1 or 2. Got {dim}",
|
||||
lambda: f"linalg.norm: If dim is specified, it must be of length 1 or 2. Got {dim}",
|
||||
)
|
||||
elif ord is not None:
|
||||
torch._check(
|
||||
A.ndim in (1, 2),
|
||||
lambda: "linalg.norm: If dim is not specified but ord is, the input must be 1D or 2D. Got {A.ndim}D",
|
||||
lambda: f"linalg.norm: If dim is not specified but ord is, the input must be 1D or 2D. Got {A.ndim}D",
|
||||
)
|
||||
|
||||
if ord is not None and (
|
||||
|
@ -334,7 +334,7 @@ def _check_repo_is_trusted(
|
||||
if not is_trusted:
|
||||
warnings.warn(
|
||||
"You are about to download and run code from an untrusted repository. In a future release, this won't "
|
||||
"be allowed. To add the repository to your trusted list, change the command to {calling_fn}(..., "
|
||||
f"be allowed. To add the repository to your trusted list, change the command to {calling_fn}(..., "
|
||||
"trust_repo=False) and a command prompt will appear asking for an explicit confirmation of trust, "
|
||||
f"or {calling_fn}(..., trust_repo=True), which will assume that the prompt is to be answered with "
|
||||
f"'yes'. You can also use {calling_fn}(..., trust_repo='check') which will only prompt for "
|
||||
|
Reference in New Issue
Block a user