[Fix] missing lambda in torch._check (#165043)

Fixes more missing lambda in torch._check in the source code. Inspired by #164225.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165043
Approved by: https://github.com/FFFrog, https://github.com/Skylion007
This commit is contained in:
can-gaa-hou
2025-10-10 17:11:50 +00:00
committed by PyTorch MergeBot
parent 3ed90f5a09
commit 39161e73fc
3 changed files with 18 additions and 16 deletions

View File

@ -2719,20 +2719,22 @@ if torch._C._has_mkldnn:
@register_meta(torch.ops.quantized.int4mm_packed_weight_cpu)
def meta_int4mm_packed_weight_cpu(x, w, q_group_size, q_scale_and_zeros):
torch._check(x.dim() == 2, f"x must be a 2D tensor, got {x.dim()}D")
torch._check(w.dim() == 2, f"w must be a 2D tensor, got {w.dim()}D")
torch._check(x.dim() == 2, lambda: f"x must be a 2D tensor, got {x.dim()}D")
torch._check(w.dim() == 2, lambda: f"w must be a 2D tensor, got {w.dim()}D")
torch._check(
x.dtype in [torch.float32, torch.float16, torch.bfloat16],
f"expected x to be f32/f16/bf16, got {x.dtype}",
lambda: f"expected x to be f32/f16/bf16, got {x.dtype}",
)
torch._check(
w.dtype == torch.uint8, lambda: f"expected w to be uint8, got {w.dtype}"
)
torch._check(w.dtype == torch.uint8, f"expected w to be uint8, got {w.dtype}")
torch._check(
q_group_size.dtype == torch.int64,
f"q_group_size must be int64, got {q_group_size.dtype}",
lambda: f"q_group_size must be int64, got {q_group_size.dtype}",
)
torch._check(
q_scale_and_zeros.dtype == x.dtype,
f"q_scale_and_zeros must have the same dtype as x, got {q_scale_and_zeros.dtype}",
lambda: f"q_scale_and_zeros must have the same dtype as x, got {q_scale_and_zeros.dtype}",
)
return x.new_empty(x.size(0), w.size(0), dtype=x.dtype)
@ -4895,7 +4897,7 @@ def meta_fractional_max_pool2d(self, kernel_size, output_size, random_samples):
for d in range(ndim - 3, ndim):
torch._check(
self.size(d) > 0,
f"fractional_max_pool2d: Expected input to have non-zero "
lambda: f"fractional_max_pool2d: Expected input to have non-zero "
f" size for non-batch dimensions, but got {self.size()} with dimension {d} empty",
)
@ -4933,7 +4935,7 @@ def meta_fractional_max_pool2d(self, kernel_size, output_size, random_samples):
d = random_samples.size(2)
torch._check(
n >= input_batch,
"Expect _random_samples.size(0) no less then input batch size.",
lambda: "Expect _random_samples.size(0) no less then input batch size.",
)
torch._check(
c == input_channels,
@ -7192,7 +7194,7 @@ def meta_searchsorted(
# Per the docs, if side == "left" and right is True, we error.
torch._check(
side != "left" or not right,
"torch.searchsorted(): side and right can't be set to opposites, got side of "
lambda: "torch.searchsorted(): side and right can't be set to opposites, got side of "
"left while right was True",
)
@ -7303,7 +7305,7 @@ def meta_embedding_bag_per_sample_weights_backward(
embedding_features = grad.size(1)
torch._check(
mode == MODE_SUM,
"embedding_bag_backward: per_sample_weights only supported for mode='sum'",
lambda: "embedding_bag_backward: per_sample_weights only supported for mode='sum'",
)
torch._check(grad.dim() == 2)
torch._check(indices.dim() == 1)
@ -7450,7 +7452,7 @@ def _meta_grouped_mm_common(
if not mat_a_is_2d or not mat_b_is_2d:
torch._check(
mat_a.size(-1) == mat_b.size(-2),
"contraction dimension of mat_a and mat_b must match",
lambda: "contraction dimension of mat_a and mat_b must match",
)
if scaled:

View File

@ -3494,7 +3494,7 @@ def stft(
)
torch._check(
not center or align_to_window is None,
"stft only supports align_to_window for center = False.",
lambda: "stft only supports align_to_window for center = False.",
)
hop_length_ = hop_length if hop_length is not None else n_fft // 4
@ -3506,7 +3506,7 @@ def stft(
)
torch._check(
return_complex_,
(
lambda: (
"stft requires the return_complex parameter be given for real inputs, "
+ "and will further require that return_complex=True in a future PyTorch release."
),
@ -3951,7 +3951,7 @@ def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorL
shape_numel = reduce(operator.mul, shape, 1)
torch._check(
a.numel() == shape_numel,
f"Could not reshape a tensor with shape {a.shape} as a tensor with shape {shape}!",
lambda: f"Could not reshape a tensor with shape {a.shape} as a tensor with shape {shape}!",
)
# Handles general case: a 1+D tensor reshaped into a distinct 1+D shape

View File

@ -110,7 +110,7 @@ def _check_vector_norm_args(
x.numel() != 0,
not isinstance(dim, IntLike) and dim is not None and len(dim) != 0,
),
f"linalg.vector_norm cannot compute the {ord} norm on an empty tensor "
lambda: f"linalg.vector_norm cannot compute the {ord} norm on an empty tensor "
"because the operation does not have an identity",
)
@ -119,7 +119,7 @@ def _check_vector_norm_args(
for d in dim:
torch._check(
sym_or(x.numel() != 0, d < len(shape) and d >= 0 and shape[d] != 0),
f"linalg.vector_norm cannot compute the {ord} norm on the "
lambda: f"linalg.vector_norm cannot compute the {ord} norm on the "
f"dimension {d} because this dimension is empty and the "
"operation does not have an identity",
)