Rename singleton int to nested int (#119661)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119661
Approved by: https://github.com/ezyang
This commit is contained in:
soulitzer
2024-02-16 11:16:12 -05:00
committed by PyTorch MergeBot
parent b97fa6ac30
commit 312ce35c1f
21 changed files with 99 additions and 99 deletions

View File

@ -1,4 +1,4 @@
#include <ATen/core/SingletonSymNodeImpl.h>
#include <ATen/core/NestedIntSymNodeImpl.h>
#include <c10/core/SymNodeImpl.h>
#include <c10/util/Exception.h>
@ -6,73 +6,73 @@ namespace c10 {
namespace {
bool _eq(const char* op, c10::SymNodeImpl* lhs, c10::SymNodeImpl* rhs) {
TORCH_INTERNAL_ASSERT(lhs->singleton_int().has_value());
c10::optional<int64_t> c = rhs->singleton_int();
TORCH_INTERNAL_ASSERT(lhs->nested_int().has_value());
c10::optional<int64_t> c = rhs->nested_int();
return (
c.has_value() && lhs->singleton_int() == *c &&
lhs->singleton_coeff() == rhs->singleton_coeff());
c.has_value() && lhs->nested_int() == *c &&
lhs->nested_int_coeff() == rhs->nested_int_coeff());
}
bool _ge(const char* op, c10::SymNodeImpl* lhs, c10::SymNodeImpl* rhs) {
if (auto mb_si = lhs->singleton_int()) {
if (auto mb_si2 = rhs->singleton_int()) {
if (auto mb_si = lhs->nested_int()) {
if (auto mb_si2 = rhs->nested_int()) {
if (*mb_si == *mb_si2) {
return lhs->singleton_coeff() >= rhs->singleton_coeff();
return lhs->nested_int_coeff() >= rhs->nested_int_coeff();
}
TORCH_CHECK(false, "Singleton int ", op, ": Relation is indeterminate");
TORCH_CHECK(false, "nested int ", op, ": Relation is indeterminate");
}
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
if (rhs->constant_int() && *rhs->constant_int() <= 2) {
return true;
}
TORCH_CHECK(false, "Singleton int ", op, ": Relation is indeterminate");
} else if (rhs->singleton_int()) {
TORCH_CHECK(false, "nested int ", op, ": Relation is indeterminate");
} else if (rhs->nested_int()) {
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
if (lhs->constant_int() && *lhs->constant_int() < 2) {
return false;
}
TORCH_CHECK(false, "Singleton int ", op, ": Relation is indeterminate");
TORCH_CHECK(false, "nested int ", op, ": Relation is indeterminate");
}
TORCH_INTERNAL_ASSERT(false, "expect at least one singleton");
TORCH_INTERNAL_ASSERT(false, "expect at least one nested int");
}
} // namespace
c10::SymNode SingletonSymNodeImpl::eq(const c10::SymNode& other) {
c10::SymNode NestedIntSymNodeImpl::eq(const c10::SymNode& other) {
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
_eq("eq", this, other.get())));
}
c10::SymNode SingletonSymNodeImpl::ne(const c10::SymNode& other) {
c10::SymNode NestedIntSymNodeImpl::ne(const c10::SymNode& other) {
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
!_eq("ne", this, other.get())));
}
c10::SymNode SingletonSymNodeImpl::ge(const c10::SymNode& other) {
c10::SymNode NestedIntSymNodeImpl::ge(const c10::SymNode& other) {
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
_ge("ge", this, other.get())));
}
c10::SymNode SingletonSymNodeImpl::gt(const c10::SymNode& other) {
c10::SymNode NestedIntSymNodeImpl::gt(const c10::SymNode& other) {
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
!_ge("gt", other.get(), this)));
}
c10::SymNode SingletonSymNodeImpl::lt(const c10::SymNode& other) {
c10::SymNode NestedIntSymNodeImpl::lt(const c10::SymNode& other) {
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
!_ge("lt", this, other.get())));
}
c10::SymNode SingletonSymNodeImpl::le(const c10::SymNode& other) {
c10::SymNode NestedIntSymNodeImpl::le(const c10::SymNode& other) {
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
_ge("le", other.get(), this)));
}
c10::SymNode SingletonSymNodeImpl::mul(const c10::SymNode& other) {
if (auto mb_si = other->singleton_int()) {
TORCH_CHECK(false, "Singleton int cannot be multiplied by singleton int");
c10::SymNode NestedIntSymNodeImpl::mul(const c10::SymNode& other) {
if (auto mb_si = other->nested_int()) {
TORCH_CHECK(false, "nested int cannot be multiplied by nested int");
}
c10::optional<int64_t> c = other->constant_int();
TORCH_CHECK(c.has_value());
return SymNode(c10::make_intrusive<SingletonSymNodeImpl>(val_, coeff_ * *c));
return SymNode(c10::make_intrusive<NestedIntSymNodeImpl>(val_, coeff_ * *c));
}
} // namespace c10

View File

@ -16,9 +16,9 @@ namespace c10 {
// allows us to simply return [B, j0, D] if someone queries for the size of our
// tensor.
//
// Morally we define comparison between two singleton ints to return true if
// Morally we define comparison between two nested ints to return true if
// that comparison holds for all corresponding elements of the arrays they
// represent. Comparison between a singleton int and a plain int is defined
// represent. Comparison between a nested int and a plain int is defined
// similarly.
//
// To simulate this desired behavior but also avoid the O(N) cost of checking,
@ -32,13 +32,13 @@ namespace c10 {
// differentiate the two cases.
//
// During tracing the strides of the outputs need to be a function of the size
// and strides of the inputs so it is important that SingletonSymNode itself is
// and strides of the inputs so it is important that NestedIntSymNode itself is
// able to express this.
class TORCH_API SingletonSymNodeImpl : public SymNodeImpl {
class TORCH_API NestedIntSymNodeImpl : public SymNodeImpl {
public:
// CAUTION: you should probably not be constructing these directly; please
// the higher-level API in python instead (TODO: actually introduce that).
explicit SingletonSymNodeImpl(int64_t val, int64_t coeff)
explicit NestedIntSymNodeImpl(int64_t val, int64_t coeff)
: val_(val), coeff_(coeff) {}
bool bool_() override {
@ -88,9 +88,9 @@ class TORCH_API SingletonSymNodeImpl : public SymNodeImpl {
return std::to_string(coeff_) + "*j" + std::to_string(val_);
}
// NOTE [ Inequalities with SingletonInt ]
// NOTE [ Inequalities with nested int ]
//
// The semantics of SingletonInt when it comes to relations is that it is
// The semantics of nested int when it comes to relations is that it is
// treated as integer known to be within a certain range,
//
// j0 \in [2, int64_t::max]
@ -117,7 +117,7 @@ class TORCH_API SingletonSymNodeImpl : public SymNodeImpl {
// [ Coefficient are assumed positive ]
//
// For the purpose of computing inequalities, we consider the coefficient of
// the SingletonInt to be a positive integer.
// the nested int to be a positive integer.
//
// Thus, no modifications are needed to the logic since
// j0 >= k implies coeff * j0 >= k
@ -130,11 +130,11 @@ class TORCH_API SingletonSymNodeImpl : public SymNodeImpl {
c10::SymNode le(const c10::SymNode& other) override;
c10::SymNode mul(const c10::SymNode& other) override;
c10::optional<int64_t> singleton_int() override {
c10::optional<int64_t> nested_int() override {
return val_;
}
c10::optional<int64_t> singleton_coeff() override {
c10::optional<int64_t> nested_int_coeff() override {
return coeff_;
}
@ -144,7 +144,7 @@ class TORCH_API SingletonSymNodeImpl : public SymNodeImpl {
#define DEFINE_BINARY_NOT_SUPPORTED(name) \
c10::SymNode name(const c10::SymNode& other) override { \
TORCH_CHECK(false, #name " not supported by SingletonSymNode"); \
TORCH_CHECK(false, #name " not supported by NestedIntSymNode"); \
}
DEFINE_BINARY_NOT_SUPPORTED(add)
@ -162,7 +162,7 @@ class TORCH_API SingletonSymNodeImpl : public SymNodeImpl {
#define DEFINE_NOT_SUPPORTED(name) \
c10::SymNode name() override { \
TORCH_CHECK(false, #name " is not supported by SingletonSymNode"); \
TORCH_CHECK(false, #name " is not supported by NestedIntSymNode"); \
}
DEFINE_NOT_SUPPORTED(sym_not)

View File

@ -1037,7 +1037,7 @@ aten_cpu_source_non_codegen_list = [
"aten/src/ATen/core/operator_name.cpp",
"aten/src/ATen/core/TorchDispatchUtils.cpp",
"aten/src/ATen/core/register_symbols.cpp",
"aten/src/ATen/core/SingletonSymNodeImpl.cpp",
"aten/src/ATen/core/NestedIntSymNodeImpl.cpp",
"aten/src/ATen/core/class_type.cpp",
"aten/src/ATen/core/type.cpp",
"aten/src/ATen/core/type_factory.cpp",

View File

@ -4,14 +4,14 @@ namespace c10 {
// This is used to support the case where the lhs is a constant symnode
// and the rhs is a singleton symnode. This situation occurs today when we
// perform a binary op between singleton int and plain int and the
// perform a binary op between nested int and plain int and the
// singleton promotes the int into a constant symnode. If we'd like to
// support more combinations in the future, we may need to implement some
// kind of multiple dispatch.
#define DEFINE_BINARY_OP(OP, ROP) \
template <typename T> \
c10::SymNode ConstantSymNodeImpl<T>::OP(const c10::SymNode& other) { \
TORCH_INTERNAL_ASSERT(other->singleton_int().has_value()); \
TORCH_INTERNAL_ASSERT(other->nested_int().has_value()); \
return other->ROP( \
c10::intrusive_ptr<ConstantSymNodeImpl<T>>::reclaim_copy(this)); \
}

View File

@ -185,10 +185,10 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target {
virtual std::string str() {
TORCH_CHECK(false, "NYI");
};
virtual c10::optional<int64_t> singleton_int() {
virtual c10::optional<int64_t> nested_int() {
return c10::nullopt;
}
virtual c10::optional<int64_t> singleton_coeff() {
virtual c10::optional<int64_t> nested_int_coeff() {
return c10::nullopt;
}
virtual c10::optional<int64_t> constant_int() {

View File

@ -944,7 +944,7 @@ coverage_ignore_functions = [
"is_channels_last_strides_3d",
"is_contiguous",
"is_non_overlapping_and_dense_indicator",
"is_singleton",
"is_nested_int",
"is_symbol_binding_fx_node",
"is_symbolic",
# torch.fx.experimental.unification.core

View File

@ -41,7 +41,7 @@ set(TORCH_API_TEST_SOURCES
${TORCH_API_TEST_DIR}/inference_mode.cpp
${TORCH_API_TEST_DIR}/grad_mode.cpp
${TORCH_API_TEST_DIR}/operations.cpp
${TORCH_API_TEST_DIR}/singleton_int.cpp
${TORCH_API_TEST_DIR}/nested_int.cpp
)
if(USE_CUDA OR USE_ROCM)
list(APPEND TORCH_API_TEST_SOURCES ${TORCH_API_TEST_DIR}/parallel.cpp)

View File

@ -1,19 +1,19 @@
#include <gtest/gtest.h>
#include <ATen/core/SingletonSymNodeImpl.h>
#include <ATen/core/NestedIntSymNodeImpl.h>
#include <c10/core/SymInt.h>
#include <c10/core/SymNodeImpl.h>
#include <torch/torch.h>
#include <test/cpp/api/support.h>
TEST(SingletonIntTest, Comparisons) {
TEST(NestedIntTest, Comparisons) {
auto a = c10::SymInt(
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(1, 1)));
c10::SymNode(c10::make_intrusive<c10::NestedIntSymNodeImpl>(1, 1)));
auto b = c10::SymInt(
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(1, 1)));
c10::SymNode(c10::make_intrusive<c10::NestedIntSymNodeImpl>(1, 1)));
auto c = c10::SymInt(
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(2, 1)));
c10::SymNode(c10::make_intrusive<c10::NestedIntSymNodeImpl>(2, 1)));
auto d = c10::SymInt(3);
ASSERT_TRUE(a == a);
@ -85,11 +85,11 @@ TEST(SingletonIntTest, Comparisons) {
ASSERT_TRUE(a > 1);
}
TEST(SingletonIntTest, WiithFactor) {
TEST(NestedIntTest, WithFactor) {
auto a = c10::SymInt(
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(1, 5)));
c10::SymNode(c10::make_intrusive<c10::NestedIntSymNodeImpl>(1, 5)));
auto b = c10::SymInt(
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(1, 10)));
c10::SymNode(c10::make_intrusive<c10::NestedIntSymNodeImpl>(1, 10)));
// eq
ASSERT_FALSE(a == b);
ASSERT_FALSE(a >= b);

View File

@ -849,10 +849,10 @@ class TestSymNumberMagicMethods(TestCase):
with self.assertRaisesRegex(TypeError, "unhashable"):
hash(x)
# Singleton SymInt, constant SymBool, SymNode are hashable
j1 = torch._C._get_singleton_int(1, 1)
j1_copy = torch._C._get_singleton_int(1, 1)
j2 = torch._C._get_singleton_int(2, 1)
# NestedInt (SymInt), constant SymBool, SymNode are hashable
j1 = torch._C._get_nested_int(1, 1)
j1_copy = torch._C._get_nested_int(1, 1)
j2 = torch._C._get_nested_int(2, 1)
t = self.get_constant_bool(True)
t_copy = self.get_constant_bool(True)
f = self.get_constant_bool(False)
@ -872,14 +872,14 @@ class TestSymNumberMagicMethods(TestCase):
hash(m)
def test_non_symbolic_symnode(self):
j1 = torch._C._get_singleton_int(1, 1)
j2 = torch._C._get_singleton_int(1, 1)
j3 = torch._C._get_singleton_int(3, 1)
j1 = torch._C._get_nested_int(1, 1)
j2 = torch._C._get_nested_int(1, 1)
j3 = torch._C._get_nested_int(3, 1)
self.assertIsInstance(j1, torch.SymInt)
self.assertNotIsInstance(j1, int)
with self.assertRaisesRegex(RuntimeError, "add not supported by SingletonSymNode"):
with self.assertRaisesRegex(RuntimeError, "add not supported by NestedIntSymNode"):
j1 + 3
self.assertFalse(j1 == 3)

View File

@ -3028,9 +3028,9 @@ class TestNestedTensorSubclass(TestCase):
"directly calling torch.ops.aten.size"):
torch.ops.aten.size.default(nt)
singleton_int = torch.nested._internal.nested_tensor.get_tensor_symint(_offsets, coeff=1)
self.assertEqual(nt.size(), (3, singleton_int, 3))
self.assertEqual(nt.shape, (3, singleton_int, 3))
nested_int = torch.nested._internal.nested_tensor.get_tensor_symint(_offsets, coeff=1)
self.assertEqual(nt.size(), (3, nested_int, 3))
self.assertEqual(nt.shape, (3, nested_int, 3))
self.assertEqual(nt.dim(), 3)
self.assertEqual(nt.numel(), 27)

View File

@ -1533,7 +1533,7 @@ def _are_functorch_transforms_active() -> _bool: ...
# Define in torch/csrc/autograd/init.cpp
def _set_python_dispatcher(dispatcher: object) -> None: ...
def _get_singleton_int(id: _int, coeff: _int) -> SymInt: ...
def _get_nested_int(id: _int, coeff: _int) -> SymInt: ...
def _get_constant_bool_symnode(val: _bool) -> Any: ...

View File

@ -301,12 +301,12 @@ class SymInt:
return str(self.node)
def __hash__(self) -> builtins.int:
ret = self.node.singleton_int()
ret = self.node.nested_int()
if ret is not None:
return hash(ret)
else:
# We could support constant SymInts as well, but not doing it for now
raise TypeError("unhashable type: non-singleton SymInt")
raise TypeError("unhashable type: non-nested SymInt")
class SymFloat:
"""

View File

@ -491,7 +491,7 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
"torch._C._get_privateuse1_backend_name",
"torch._C._get_qengine",
"torch._C._get_schema",
"torch._C._get_singleton_int",
"torch._C._get_nested_int",
"torch._C._get_tensor_metadata",
"torch._C._get_tracing_state",
"torch._C._get_upgrader_ranges",

View File

@ -1598,9 +1598,9 @@ def _automatic_dynamic(
# We preserve the dynamism of inputs. For example, when users call
# make_fx(torch.cond, tracing_mode="symbolic")(*args), inputs have SymInt sizes.
from torch.fx.experimental.symbolic_shapes import is_singleton
from torch.fx.experimental.symbolic_shapes import is_nested_int
if any(isinstance(s, SymInt) and not is_singleton(s) for s in e.size()):
if any(isinstance(s, SymInt) and not is_nested_int(s) for s in e.size()):
return StatefulSymbolicContext(
dynamic_sizes=[
DimDynamic.DYNAMIC if isinstance(s, SymInt) else DimDynamic.STATIC
@ -1729,7 +1729,7 @@ def _automatic_dynamic(
constraint_dim is not None
or marked_dynamic
or marked_weak_dynamic
or is_singleton(e.shape[i])
or is_nested_int(e.shape[i])
):
# NB: We could assert static_shapes is False here, but it
# seems better to allow the user to override symbolic_context in this

View File

@ -846,7 +846,7 @@ def infer_size_shapes(a: ShapeType, b: ShapeType) -> Tuple[int, ...]:
(sizeA == sizeB) or (sizeA == 1) or (sizeB == 1),
lambda: (
f"The size of tensor a ({sizeA}) must match the size of "
f"tensor b ({sizeB}) at non-singleton dimension {i}"
f"tensor b ({sizeB}) at non-jagged dimension {i}"
),
)
@ -1566,13 +1566,13 @@ def make_contiguous_strides_for(
if not shape:
return ()
from torch.fx.experimental.symbolic_shapes import is_singleton
from torch.fx.experimental.symbolic_shapes import is_nested_int
multiplier = 1
strides = []
for l in reversed(shape):
strides.append(multiplier)
multiplier *= l if is_singleton(l) else sym_max(l, 1)
multiplier *= l if is_nested_int(l) else sym_max(l, 1)
result = tuple(reversed(strides))

View File

@ -1290,14 +1290,14 @@ void initJITBindings(PyObject* module) {
return node->is_symbolic();
})
.def(
"singleton_int",
"nested_int",
[](const c10::SymNode& node) {
return node->singleton_int();
return node->nested_int();
})
.def(
"singleton_coeff",
"nested_int_coeff",
[](const c10::SymNode& node) {
return node->singleton_coeff();
return node->nested_int_coeff();
});
// clang-format on

View File

@ -5,8 +5,8 @@
#include <ATen/FuncTorchTLS.h>
#include <ATen/FunctionalTensorWrapper.h>
#include <ATen/TensorSubclassLikeUtils.h>
#include <ATen/core/NestedIntSymNodeImpl.h>
#include <ATen/core/PythonOpRegistrationTrampoline.h>
#include <ATen/core/SingletonSymNodeImpl.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <ATen/functorch/BatchedTensorImpl.h>
@ -823,9 +823,9 @@ void initDispatchBindings(PyObject* module) {
include_set.has(c10::DispatchKey::FuncTorchDynamicLayerBackMode));
});
m.def("_get_singleton_int", [](int64_t data, int64_t coeff) {
m.def("_get_nested_int", [](int64_t data, int64_t coeff) {
return c10::SymInt(c10::SymNode(
c10::make_intrusive<c10::SingletonSymNodeImpl>(data, coeff)));
c10::make_intrusive<c10::NestedIntSymNodeImpl>(data, coeff)));
});
m.def("_get_constant_bool_symnode", [](int64_t data) {

View File

@ -426,7 +426,7 @@ class SymNode:
def is_symbolic(self):
return True
def singleton_int(self):
def nested_int(self):
return None
def is_constant(self):

View File

@ -90,7 +90,7 @@ __all__ = [
"has_symbolic_sizes_strides", "create_contiguous", "ShapeEnv", "is_concrete_int",
"guard_int", "guard_float", "guard_scalar", "canonicalize_bool_expr",
"hint_int", "SYMPY_INTERP", "free_symbols", "is_symbol_binding_fx_node",
"is_concrete_bool", "is_singleton", "SHAPEENV_EVENT_KEY", "CURRENT_NODE_KEY",
"is_concrete_bool", "is_nested_int", "SHAPEENV_EVENT_KEY", "CURRENT_NODE_KEY",
"has_free_symbols", "sym_eq", "SymbolicContext", "StatelessSymbolicContext",
"StatefulSymbolicContext", "SubclassSymbolicContext", "statically_known_true",
"guard_size_oblivious",
@ -262,25 +262,25 @@ def is_concrete_bool(a: Union[bool, SymBool]) -> bool:
return False
def is_singleton(s: Int) -> bool:
# check for SingletonSymNode
def is_nested_int(s):
# check for NestedIntSymNode
if not isinstance(s, torch.SymInt):
return False
if s.node.singleton_int() is not None:
if s.node.nested_int() is not None:
return True
# check for symbolic variable wrapping a SingletonSymNode (fake-ifying causes this)
# check for symbolic variable wrapping a NestedIntSymNode (fake-ifying causes this)
return (
s.node.is_symbolic()
and s.node.hint is not None
and isinstance(s.node.hint, torch.SymInt)
and s.node.hint.node.singleton_int() is not None
and s.node.hint.node.nested_int() is not None
)
def _iterate_exprs(val: Union[SymInt, torch.Tensor]) -> Iterable[sympy.Basic]:
if isinstance(val, SymTypes):
# This allow applies to the jagged layout NestedTensor case as
# singleton ints are not symbolic
# nested ints are not symbolic
if is_symbolic(val):
yield val.node.expr
elif isinstance(val, sympy.Basic):
@ -2306,9 +2306,9 @@ class ShapeEnv:
val_list = sorted(
[(ex_stride[i], i) for i in range(len(stride)) if stride[i] is None],
key=lambda tup: (
# Order singletons by their coefficients.
# 1 here to order singletons after non-singletons.
(1, tup[0].node.singleton_coeff(), tup[1]) if is_singleton(tup[0])
# Order nested int by their coefficients.
# 1 here to order nested int after non-nested int.
(1, tup[0].node.nested_int_coeff(), tup[1]) if is_nested_int(tup[0])
else (0, *tup)
)
)
@ -2572,7 +2572,7 @@ class ShapeEnv:
self.var_to_val[sympy_expr] = sympy.Integer(val)
else:
# Only used for jagged layout nested tensors
self.var_to_val[sympy_expr] = SingletonInt(val.node.singleton_int(), coeff=val.node.singleton_coeff())
self.var_to_val[sympy_expr] = SingletonInt(val.node.nested_int(), coeff=val.node.nested_int_coeff())
# Do the appending later, because we always want to populate this
self.var_to_sources[sympy_expr] = []

View File

@ -15,7 +15,7 @@ def get_tensor_symint(tensor, *, coeff=1):
global _tensor_id_counter
tensor_symint = _tensor_symint_registry.get(tensor)
if tensor_symint is None:
tensor_symint = torch._C._get_singleton_int(_tensor_id_counter, coeff)
tensor_symint = torch._C._get_nested_int(_tensor_id_counter, coeff)
_tensor_id_counter += 1
_tensor_symint_registry[tensor] = tensor_symint
return tensor_symint
@ -30,18 +30,18 @@ class NestedTensor(torch.Tensor):
_values: torch.Tensor # type: ignore[assignment]
_offsets: torch.Tensor
_lengths: Optional[torch.Tensor]
# NOTE [ Singleton ints for ragged sizes and strides ]
# NOTE [ Nested ints for ragged sizes and strides ]
#
# Jagged layout tensors are tensors that represent a n-dim tensor with a
# ragged dimension, but are backed by an (n-1)-dim tensor underneath, e.g.,
# a jagged tensor with outer shape [B, x, D] is represented internally by a
# tensor with shape [sum(x), D] where we introduce what we call a singleton
# (or skolem) denoted as "x" here (but sometimes denoted with "*" to
# tensor with shape [sum(x), D] where we introduce what we call a nested int
# denoted as "x" here (but sometimes denoted with "*" to
# represent the ragged dimension, and sum(x) represents the dim of the inner
# tensor or equivalently the sum of all the sizes of the constituent
# tensors' varying lengths.
#
# We also use singleton ints to represent the strides of this tensor.
# We also use nested ints to represent the strides of this tensor.
# For example, a jagged tensor with shape [B, x, D] can be strided in two
# ways: [xD, D, 1] and [x, 1, sum(x)], where xD represents x multiplied by D
_size: Tuple[int, ...]

View File

@ -13,14 +13,14 @@ class SingletonInt(sympy.AtomicExpr):
instance = super().__new__(cls, *args, **kwargs)
return instance
# The semantics of this class should match that of SingletonSymNodeImpl in
# c10/core/SingletonSymNodeImpl.h
# The semantics of this class should match that of NestedIntSymNodeImpl in
# c10/core/NestedIntSymNodeImpl.h
def __init__(self, val, *, coeff=1):
self._val = val
self._coeff = coeff
super().__init__()
# See NOTE [ Inequalities with SingletonInt ]
# See NOTE [ Inequalities with nested int ]
def _eval_Eq(self, other):
if (
isinstance(other, SingletonInt)
@ -69,7 +69,7 @@ class SingletonInt(sympy.AtomicExpr):
raise NotImplementedError("NYI")
# See NOTE [ Inequalities with SingletonInt ]
# See NOTE [ Inequalities with nested int ]
@dispatch(sympy.Integer, SingletonInt)
def _eval_is_ge(a, b):
if a < 2: