mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
migrate more simple gso checks (#160253)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160253 Approved by: https://github.com/bobrenjc93
This commit is contained in:
committed by
PyTorch MergeBot
parent
16ce2c15fa
commit
f782c790df
@ -1780,9 +1780,9 @@ def _fused_rms_norm_backward(
|
||||
|
||||
N = prod(inner_dims) # type: ignore[arg-type]
|
||||
M = prod(outer_dims) # type: ignore[arg-type]
|
||||
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
||||
from torch.fx.experimental.symbolic_shapes import guard_or_false
|
||||
|
||||
if guard_size_oblivious(M <= 0) or guard_size_oblivious(N <= 0):
|
||||
if guard_or_false(M == 0) or guard_or_false(N == 0):
|
||||
return (
|
||||
input.new_zeros(input_shape) if output_mask[0] else None,
|
||||
input.new_zeros(input_shape[axis:]) if output_mask[1] else None,
|
||||
@ -3987,9 +3987,9 @@ def _unsafe_masked_index(x, mask, indices, fill):
|
||||
lambda: "tensors used as masks must be bool tensors",
|
||||
)
|
||||
|
||||
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
||||
from torch.fx.experimental.symbolic_shapes import guard_or_false
|
||||
|
||||
if guard_size_oblivious(x.numel() == 0):
|
||||
if guard_or_false(x.numel() == 0):
|
||||
meta_result = torch._meta_registrations.meta_index_Tensor(x, indices)
|
||||
return x.new_full(meta_result.shape, fill)
|
||||
|
||||
|
Reference in New Issue
Block a user