Fix reductions for NJTs with ragged_idx != 1 (#142173)

**Background:** conversion from outer dim -> inner dim makes the (previously valid) assumption that the ragged dim is immediately next to the batch dim. This is no longer the case after #137125.

This PR:
* Updates the outer dim -> inner dim conversion logic to match the actual ragged_idx. Since ragged_idx tells us where the packed ragged / batch dim is, both ragged and batch outer dims should map to this inner dim. The conversion logic must now take in `ragged_idx` to make this possible, so the PR updates all call-sites to pass this.
* Fixes outputs across keepdim settings when reducing over ragged / batch dims.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142173
Approved by: https://github.com/drisspg
This commit is contained in:
Joel Schlosser
2024-12-05 16:48:34 -05:00
committed by PyTorch MergeBot
parent 6b0df2f720
commit e803a3d83a
2 changed files with 35 additions and 28 deletions

View File

@ -17,20 +17,22 @@ __all__: List[Any] = []
JAGGED_OPS_TABLE: Dict[Any, Any] = {}
# Simplifying assumption: we assume that the batch dim is always the left-most
# dim, and the ragged dim is always the second dim.
def _outer_to_inner_dim(ndim, dim, canonicalize=False):
def _outer_to_inner_dim(ndim, dim, ragged_dim, canonicalize=False):
from torch._prims_common import canonicalize_dims
if isinstance(dim, (tuple, list)):
output = type(dim)(_outer_to_inner_dim(ndim, d) for d in dim)
output = type(dim)(_outer_to_inner_dim(ndim, d, ragged_dim) for d in dim)
# ensure no duplicates, which can result from both batch and ragged mapping to 0
return type(output)(dict.fromkeys(output))
if canonicalize:
dim = canonicalize_dims(ndim, dim)
assert dim >= 0 and dim < ndim
return 0 if dim < 2 else dim - 1
# Map dim=0 (AKA batch dim) -> packed dim i.e. outer ragged dim - 1.
# For other dims, subtract 1 to convert to inner space.
return ragged_dim - 1 if dim == 0 else dim - 1
def _wrap_jagged_dim(
@ -49,7 +51,11 @@ def _wrap_jagged_dim(
raise RuntimeError(f"{op_name}(): not supported for NestedTensor on ragged dim")
elif wrapped == 0 and not allow_batch_dim:
raise RuntimeError(f"{op_name}(): not supported for NestedTensor on dim=0")
ret = _outer_to_inner_dim(ndim, wrapped) if convert_to_inner_dim else wrapped
ret = (
_outer_to_inner_dim(ndim, wrapped, ragged_dim)
if convert_to_inner_dim
else wrapped
)
if allow_batch_dim:
# Need to disambiguate whether we're operating on the batch dim or not.
# Operating on dim=1 -> dim=0 after the inner dim conversion.
@ -80,7 +86,7 @@ def _wrap_jagged_dims(ndim, dims, op_name, ragged_idx=1):
# ensure no duplicates, which can result from both batch and ragged mapping to 0
outer_to_inner_dim = tuple(
dict.fromkeys(_outer_to_inner_dim(ndim, d) for d in wrapped_dims)
dict.fromkeys(_outer_to_inner_dim(ndim, d, ragged_idx) for d in wrapped_dims)
)
return outer_to_inner_dim, operate_on_batch, operate_on_ragged, operate_on_non_batch
@ -1313,12 +1319,18 @@ def _apply_reduction(func, func_name, identity_element, *args, **kwargs):
"for non-contiguous nested tensors with holes"
)
from torch.utils._pytree import tree_map
# raggedness reduced away --> return dense tensor
if reduce_on_ragged:
# reduction cases: (batch, ragged), (batch, ragged, non-batch), etc.
if reduce_on_batch:
# no need to read offsets --> apply sum directly on values
out = func(inp._values, **new_kwargs)
if new_kwargs.get("keepdim", False):
# some ops return multiple things; unsqueeze all of them
out = tree_map(lambda o: o.unsqueeze(0), out)
return out
else:
# invalid reduction cases: (ragged, non-batch), etc.
if reduce_on_non_batch:
@ -1329,17 +1341,11 @@ def _apply_reduction(func, func_name, identity_element, *args, **kwargs):
# reduction cases: (ragged)
# convert to padded dense and reduce
new_kwargs.pop("dim")
dim_to_pass = [inp._ragged_idx] if is_dimlist else inp._ragged_idx
out = func(inp.to_padded_tensor(identity_element), dim=dim_to_pass)
if new_kwargs.get("keepdim", False):
if isinstance(out, (tuple, list)):
# some ops return multiple things; unsqueeze all of them
out = type(out)(o.unsqueeze(inp._ragged_idx) for o in out)
else:
out = out.unsqueeze(inp._ragged_idx)
return out
return func(
inp.to_padded_tensor(identity_element), dim=dim_to_pass, **new_kwargs
)
# raggedness preserved --> return nested tensor
else:
# invalid reduction cases: (batch), (batch, non-batch), etc.
@ -1365,10 +1371,8 @@ def _apply_reduction(func, func_name, identity_element, *args, **kwargs):
if d < inp._ragged_idx - 1:
out_kwargs["_ragged_idx"] -= 1
if isinstance(out, (tuple, list)):
# some ops return multiple things; wrap each of them as an NJT
return type(out)(NestedTensor(o, **out_kwargs) for o in out)
return NestedTensor(out, **out_kwargs)
# some ops return multiple things; wrap each of them as an NJT
return tree_map(lambda o: NestedTensor(o, **out_kwargs), out)
@register_jagged_func(torch.ops.aten.sum.default, "self: jt_all, dtype: any?")
@ -1419,8 +1423,8 @@ def transpose_int(func, *args, **kwargs):
inp_kwargs["_ragged_idx"] = to_dim
return NestedTensor(
inp.values().transpose(
_outer_to_inner_dim(len(inp._size), dim0),
_outer_to_inner_dim(len(inp._size), dim1),
_outer_to_inner_dim(len(inp._size), dim0, inp._ragged_idx),
_outer_to_inner_dim(len(inp._size), dim1, inp._ragged_idx),
),
**inp_kwargs,
)
@ -1468,7 +1472,10 @@ def permute_default(func, *args, **kwargs):
"Permute is not supported on the batch dimension for jagged NT"
)
inp_kwargs["_ragged_idx"] = canonicalized_dims.index(inp._ragged_idx)
inner_dims = [_outer_to_inner_dim(inp_dim, dim) for dim in canonicalized_dims[1:]]
inner_dims = [
_outer_to_inner_dim(inp_dim, dim, inp._ragged_idx)
for dim in canonicalized_dims[1:]
]
new_kwargs["dims"] = inner_dims
return NestedTensor(func(inp._values, **new_kwargs), **inp_kwargs)