mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix reductions for NJTs with ragged_idx != 1 (#142173)
**Background:** conversion from outer dim -> inner dim makes the (previously valid) assumption that the ragged dim is immediately next to the batch dim. This is no longer the case after #137125. This PR: * Updates the outer dim -> inner dim conversion logic to match the actual ragged_idx. Since ragged_idx tells us where the packed ragged / batch dim is, both ragged and batch outer dims should map to this inner dim. The conversion logic must now take in `ragged_idx` to make this possible, so the PR updates all call-sites to pass this. * Fixes outputs across keepdim settings when reducing over ragged / batch dims. Pull Request resolved: https://github.com/pytorch/pytorch/pull/142173 Approved by: https://github.com/drisspg
This commit is contained in:
committed by
PyTorch MergeBot
parent
6b0df2f720
commit
e803a3d83a
@ -17,20 +17,22 @@ __all__: List[Any] = []
|
||||
JAGGED_OPS_TABLE: Dict[Any, Any] = {}
|
||||
|
||||
|
||||
# Simplifying assumption: we assume that the batch dim is always the left-most
|
||||
# dim, and the ragged dim is always the second dim.
|
||||
def _outer_to_inner_dim(ndim, dim, canonicalize=False):
|
||||
def _outer_to_inner_dim(ndim, dim, ragged_dim, canonicalize=False):
|
||||
from torch._prims_common import canonicalize_dims
|
||||
|
||||
if isinstance(dim, (tuple, list)):
|
||||
output = type(dim)(_outer_to_inner_dim(ndim, d) for d in dim)
|
||||
output = type(dim)(_outer_to_inner_dim(ndim, d, ragged_dim) for d in dim)
|
||||
# ensure no duplicates, which can result from both batch and ragged mapping to 0
|
||||
return type(output)(dict.fromkeys(output))
|
||||
|
||||
if canonicalize:
|
||||
dim = canonicalize_dims(ndim, dim)
|
||||
|
||||
assert dim >= 0 and dim < ndim
|
||||
return 0 if dim < 2 else dim - 1
|
||||
|
||||
# Map dim=0 (AKA batch dim) -> packed dim i.e. outer ragged dim - 1.
|
||||
# For other dims, subtract 1 to convert to inner space.
|
||||
return ragged_dim - 1 if dim == 0 else dim - 1
|
||||
|
||||
|
||||
def _wrap_jagged_dim(
|
||||
@ -49,7 +51,11 @@ def _wrap_jagged_dim(
|
||||
raise RuntimeError(f"{op_name}(): not supported for NestedTensor on ragged dim")
|
||||
elif wrapped == 0 and not allow_batch_dim:
|
||||
raise RuntimeError(f"{op_name}(): not supported for NestedTensor on dim=0")
|
||||
ret = _outer_to_inner_dim(ndim, wrapped) if convert_to_inner_dim else wrapped
|
||||
ret = (
|
||||
_outer_to_inner_dim(ndim, wrapped, ragged_dim)
|
||||
if convert_to_inner_dim
|
||||
else wrapped
|
||||
)
|
||||
if allow_batch_dim:
|
||||
# Need to disambiguate whether we're operating on the batch dim or not.
|
||||
# Operating on dim=1 -> dim=0 after the inner dim conversion.
|
||||
@ -80,7 +86,7 @@ def _wrap_jagged_dims(ndim, dims, op_name, ragged_idx=1):
|
||||
|
||||
# ensure no duplicates, which can result from both batch and ragged mapping to 0
|
||||
outer_to_inner_dim = tuple(
|
||||
dict.fromkeys(_outer_to_inner_dim(ndim, d) for d in wrapped_dims)
|
||||
dict.fromkeys(_outer_to_inner_dim(ndim, d, ragged_idx) for d in wrapped_dims)
|
||||
)
|
||||
|
||||
return outer_to_inner_dim, operate_on_batch, operate_on_ragged, operate_on_non_batch
|
||||
@ -1313,12 +1319,18 @@ def _apply_reduction(func, func_name, identity_element, *args, **kwargs):
|
||||
"for non-contiguous nested tensors with holes"
|
||||
)
|
||||
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
# raggedness reduced away --> return dense tensor
|
||||
if reduce_on_ragged:
|
||||
# reduction cases: (batch, ragged), (batch, ragged, non-batch), etc.
|
||||
if reduce_on_batch:
|
||||
# no need to read offsets --> apply sum directly on values
|
||||
out = func(inp._values, **new_kwargs)
|
||||
if new_kwargs.get("keepdim", False):
|
||||
# some ops return multiple things; unsqueeze all of them
|
||||
out = tree_map(lambda o: o.unsqueeze(0), out)
|
||||
return out
|
||||
else:
|
||||
# invalid reduction cases: (ragged, non-batch), etc.
|
||||
if reduce_on_non_batch:
|
||||
@ -1329,17 +1341,11 @@ def _apply_reduction(func, func_name, identity_element, *args, **kwargs):
|
||||
|
||||
# reduction cases: (ragged)
|
||||
# convert to padded dense and reduce
|
||||
new_kwargs.pop("dim")
|
||||
dim_to_pass = [inp._ragged_idx] if is_dimlist else inp._ragged_idx
|
||||
out = func(inp.to_padded_tensor(identity_element), dim=dim_to_pass)
|
||||
|
||||
if new_kwargs.get("keepdim", False):
|
||||
if isinstance(out, (tuple, list)):
|
||||
# some ops return multiple things; unsqueeze all of them
|
||||
out = type(out)(o.unsqueeze(inp._ragged_idx) for o in out)
|
||||
else:
|
||||
out = out.unsqueeze(inp._ragged_idx)
|
||||
|
||||
return out
|
||||
return func(
|
||||
inp.to_padded_tensor(identity_element), dim=dim_to_pass, **new_kwargs
|
||||
)
|
||||
# raggedness preserved --> return nested tensor
|
||||
else:
|
||||
# invalid reduction cases: (batch), (batch, non-batch), etc.
|
||||
@ -1365,10 +1371,8 @@ def _apply_reduction(func, func_name, identity_element, *args, **kwargs):
|
||||
if d < inp._ragged_idx - 1:
|
||||
out_kwargs["_ragged_idx"] -= 1
|
||||
|
||||
if isinstance(out, (tuple, list)):
|
||||
# some ops return multiple things; wrap each of them as an NJT
|
||||
return type(out)(NestedTensor(o, **out_kwargs) for o in out)
|
||||
return NestedTensor(out, **out_kwargs)
|
||||
# some ops return multiple things; wrap each of them as an NJT
|
||||
return tree_map(lambda o: NestedTensor(o, **out_kwargs), out)
|
||||
|
||||
|
||||
@register_jagged_func(torch.ops.aten.sum.default, "self: jt_all, dtype: any?")
|
||||
@ -1419,8 +1423,8 @@ def transpose_int(func, *args, **kwargs):
|
||||
inp_kwargs["_ragged_idx"] = to_dim
|
||||
return NestedTensor(
|
||||
inp.values().transpose(
|
||||
_outer_to_inner_dim(len(inp._size), dim0),
|
||||
_outer_to_inner_dim(len(inp._size), dim1),
|
||||
_outer_to_inner_dim(len(inp._size), dim0, inp._ragged_idx),
|
||||
_outer_to_inner_dim(len(inp._size), dim1, inp._ragged_idx),
|
||||
),
|
||||
**inp_kwargs,
|
||||
)
|
||||
@ -1468,7 +1472,10 @@ def permute_default(func, *args, **kwargs):
|
||||
"Permute is not supported on the batch dimension for jagged NT"
|
||||
)
|
||||
inp_kwargs["_ragged_idx"] = canonicalized_dims.index(inp._ragged_idx)
|
||||
inner_dims = [_outer_to_inner_dim(inp_dim, dim) for dim in canonicalized_dims[1:]]
|
||||
inner_dims = [
|
||||
_outer_to_inner_dim(inp_dim, dim, inp._ragged_idx)
|
||||
for dim in canonicalized_dims[1:]
|
||||
]
|
||||
new_kwargs["dims"] = inner_dims
|
||||
return NestedTensor(func(inp._values, **new_kwargs), **inp_kwargs)
|
||||
|
||||
|
Reference in New Issue
Block a user