Just make the full tensor instead of adding to a zeros tensor

This commit is contained in:
Matt
2025-10-16 17:14:14 +01:00
parent 0b2981b423
commit dc55cad067

View File

@ -99,7 +99,7 @@ def _pad(items, key, padding_value, padding_side):
# we can consistently pad since the size should be matching
return torch.cat([item[key] for item in items], dim=0)
else:
tensor = torch.zeros([batch_size, max_length] + list(shape[2:]), dtype=dtype) + padding_value
tensor = torch.full([batch_size, max_length] + list(shape[2:]), fill_value=padding_value, dtype=dtype)
for i, item in enumerate(items):
if padding_side == "left":