Rename singleton int to nested int (#119661)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119661
Approved by: https://github.com/ezyang
This commit is contained in:
soulitzer
2024-02-16 11:16:12 -05:00
committed by PyTorch MergeBot
parent b97fa6ac30
commit 312ce35c1f
21 changed files with 99 additions and 99 deletions

View File

@ -15,7 +15,7 @@ def get_tensor_symint(tensor, *, coeff=1):
global _tensor_id_counter
tensor_symint = _tensor_symint_registry.get(tensor)
if tensor_symint is None:
tensor_symint = torch._C._get_singleton_int(_tensor_id_counter, coeff)
tensor_symint = torch._C._get_nested_int(_tensor_id_counter, coeff)
_tensor_id_counter += 1
_tensor_symint_registry[tensor] = tensor_symint
return tensor_symint
@ -30,18 +30,18 @@ class NestedTensor(torch.Tensor):
_values: torch.Tensor # type: ignore[assignment]
_offsets: torch.Tensor
_lengths: Optional[torch.Tensor]
# NOTE [ Singleton ints for ragged sizes and strides ]
# NOTE [ Nested ints for ragged sizes and strides ]
#
# Jagged layout tensors are tensors that represent a n-dim tensor with a
# ragged dimension, but are backed by an (n-1)-dim tensor underneath, e.g.,
# a jagged tensor with outer shape [B, x, D] is represented internally by a
# tensor with shape [sum(x), D] where we introduce what we call a singleton
# (or skolem) denoted as "x" here (but sometimes denoted with "*" to
# tensor with shape [sum(x), D] where we introduce what we call a nested int
# denoted as "x" here (but sometimes denoted with "*" to
# represent the ragged dimension, and sum(x) represents the dim of the inner
# tensor or equivalently the sum of all the sizes of the constituent
# tensors' varying lengths.
#
# We also use singleton ints to represent the strides of this tensor.
# We also use nested ints to represent the strides of this tensor.
# For example, a jagged tensor with shape [B, x, D] can be strided in two
# ways: [xD, D, 1] and [x, 1, sum(x)], where xD represents x multiplied by D
_size: Tuple[int, ...]