mirror of
https://gitee.com/ascend/MindSpeed-RL.git
synced 2025-10-20 16:23:45 +08:00
!565 bug fix Assertion Error
Merge pull request !565 from MeiFei/master
This commit is contained in:
@ -47,7 +47,7 @@ class IndexPutFirstAxis(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, values, indices, first_axis_dim):
|
||||
ctx.save_for_backward(indices)
|
||||
if not indices.ndim == 1 or values.ndim >= 2:
|
||||
if indices.ndim != 1 or values.ndim < 2:
|
||||
raise AssertionError("indices.ndim != 1 or values.ndim < 2")
|
||||
output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype)
|
||||
output[indices] = values
|
||||
|
Reference in New Issue
Block a user