!565 bug fix Assertion Error

Merge pull request !565 from MeiFei/master
This commit is contained in:
MeiFei
2025-08-29 04:36:50 +00:00
committed by i-robot
parent a9faec2b83
commit be31ab76a9

View File

@ -47,7 +47,7 @@ class IndexPutFirstAxis(torch.autograd.Function):
@staticmethod
def forward(ctx, values, indices, first_axis_dim):
ctx.save_for_backward(indices)
if not indices.ndim == 1 or values.ndim >= 2:
if indices.ndim != 1 or values.ndim < 2:
raise AssertionError("indices.ndim != 1 or values.ndim < 2")
output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype)
output[indices] = values