migrate more simple gso checks (#160253)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160253
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Laith Sakka
2025-08-15 14:02:23 -07:00
committed by PyTorch MergeBot
parent 16ce2c15fa
commit f782c790df
3 changed files with 11 additions and 8 deletions

View File

@ -1780,9 +1780,9 @@ def _fused_rms_norm_backward(
N = prod(inner_dims) # type: ignore[arg-type]
M = prod(outer_dims) # type: ignore[arg-type]
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
from torch.fx.experimental.symbolic_shapes import guard_or_false
if guard_size_oblivious(M <= 0) or guard_size_oblivious(N <= 0):
if guard_or_false(M == 0) or guard_or_false(N == 0):
return (
input.new_zeros(input_shape) if output_mask[0] else None,
input.new_zeros(input_shape[axis:]) if output_mask[1] else None,
@ -3987,9 +3987,9 @@ def _unsafe_masked_index(x, mask, indices, fill):
lambda: "tensors used as masks must be bool tensors",
)
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
from torch.fx.experimental.symbolic_shapes import guard_or_false
if guard_size_oblivious(x.numel() == 0):
if guard_or_false(x.numel() == 0):
meta_result = torch._meta_registrations.meta_index_Tensor(x, indices)
return x.new_full(meta_result.shape, fill)