mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Rename singleton int to nested int (#119661)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119661 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
b97fa6ac30
commit
312ce35c1f
@ -1,4 +1,4 @@
|
||||
#include <ATen/core/SingletonSymNodeImpl.h>
|
||||
#include <ATen/core/NestedIntSymNodeImpl.h>
|
||||
#include <c10/core/SymNodeImpl.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
@ -6,73 +6,73 @@ namespace c10 {
|
||||
|
||||
namespace {
|
||||
bool _eq(const char* op, c10::SymNodeImpl* lhs, c10::SymNodeImpl* rhs) {
|
||||
TORCH_INTERNAL_ASSERT(lhs->singleton_int().has_value());
|
||||
c10::optional<int64_t> c = rhs->singleton_int();
|
||||
TORCH_INTERNAL_ASSERT(lhs->nested_int().has_value());
|
||||
c10::optional<int64_t> c = rhs->nested_int();
|
||||
return (
|
||||
c.has_value() && lhs->singleton_int() == *c &&
|
||||
lhs->singleton_coeff() == rhs->singleton_coeff());
|
||||
c.has_value() && lhs->nested_int() == *c &&
|
||||
lhs->nested_int_coeff() == rhs->nested_int_coeff());
|
||||
}
|
||||
bool _ge(const char* op, c10::SymNodeImpl* lhs, c10::SymNodeImpl* rhs) {
|
||||
if (auto mb_si = lhs->singleton_int()) {
|
||||
if (auto mb_si2 = rhs->singleton_int()) {
|
||||
if (auto mb_si = lhs->nested_int()) {
|
||||
if (auto mb_si2 = rhs->nested_int()) {
|
||||
if (*mb_si == *mb_si2) {
|
||||
return lhs->singleton_coeff() >= rhs->singleton_coeff();
|
||||
return lhs->nested_int_coeff() >= rhs->nested_int_coeff();
|
||||
}
|
||||
TORCH_CHECK(false, "Singleton int ", op, ": Relation is indeterminate");
|
||||
TORCH_CHECK(false, "nested int ", op, ": Relation is indeterminate");
|
||||
}
|
||||
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||
if (rhs->constant_int() && *rhs->constant_int() <= 2) {
|
||||
return true;
|
||||
}
|
||||
TORCH_CHECK(false, "Singleton int ", op, ": Relation is indeterminate");
|
||||
} else if (rhs->singleton_int()) {
|
||||
TORCH_CHECK(false, "nested int ", op, ": Relation is indeterminate");
|
||||
} else if (rhs->nested_int()) {
|
||||
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||
if (lhs->constant_int() && *lhs->constant_int() < 2) {
|
||||
return false;
|
||||
}
|
||||
TORCH_CHECK(false, "Singleton int ", op, ": Relation is indeterminate");
|
||||
TORCH_CHECK(false, "nested int ", op, ": Relation is indeterminate");
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(false, "expect at least one singleton");
|
||||
TORCH_INTERNAL_ASSERT(false, "expect at least one nested int");
|
||||
}
|
||||
} // namespace
|
||||
|
||||
c10::SymNode SingletonSymNodeImpl::eq(const c10::SymNode& other) {
|
||||
c10::SymNode NestedIntSymNodeImpl::eq(const c10::SymNode& other) {
|
||||
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
|
||||
_eq("eq", this, other.get())));
|
||||
}
|
||||
|
||||
c10::SymNode SingletonSymNodeImpl::ne(const c10::SymNode& other) {
|
||||
c10::SymNode NestedIntSymNodeImpl::ne(const c10::SymNode& other) {
|
||||
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
|
||||
!_eq("ne", this, other.get())));
|
||||
}
|
||||
|
||||
c10::SymNode SingletonSymNodeImpl::ge(const c10::SymNode& other) {
|
||||
c10::SymNode NestedIntSymNodeImpl::ge(const c10::SymNode& other) {
|
||||
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
|
||||
_ge("ge", this, other.get())));
|
||||
}
|
||||
|
||||
c10::SymNode SingletonSymNodeImpl::gt(const c10::SymNode& other) {
|
||||
c10::SymNode NestedIntSymNodeImpl::gt(const c10::SymNode& other) {
|
||||
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
|
||||
!_ge("gt", other.get(), this)));
|
||||
}
|
||||
|
||||
c10::SymNode SingletonSymNodeImpl::lt(const c10::SymNode& other) {
|
||||
c10::SymNode NestedIntSymNodeImpl::lt(const c10::SymNode& other) {
|
||||
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
|
||||
!_ge("lt", this, other.get())));
|
||||
}
|
||||
|
||||
c10::SymNode SingletonSymNodeImpl::le(const c10::SymNode& other) {
|
||||
c10::SymNode NestedIntSymNodeImpl::le(const c10::SymNode& other) {
|
||||
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
|
||||
_ge("le", other.get(), this)));
|
||||
}
|
||||
|
||||
c10::SymNode SingletonSymNodeImpl::mul(const c10::SymNode& other) {
|
||||
if (auto mb_si = other->singleton_int()) {
|
||||
TORCH_CHECK(false, "Singleton int cannot be multiplied by singleton int");
|
||||
c10::SymNode NestedIntSymNodeImpl::mul(const c10::SymNode& other) {
|
||||
if (auto mb_si = other->nested_int()) {
|
||||
TORCH_CHECK(false, "nested int cannot be multiplied by nested int");
|
||||
}
|
||||
c10::optional<int64_t> c = other->constant_int();
|
||||
TORCH_CHECK(c.has_value());
|
||||
return SymNode(c10::make_intrusive<SingletonSymNodeImpl>(val_, coeff_ * *c));
|
||||
return SymNode(c10::make_intrusive<NestedIntSymNodeImpl>(val_, coeff_ * *c));
|
||||
}
|
||||
|
||||
} // namespace c10
|
@ -16,9 +16,9 @@ namespace c10 {
|
||||
// allows us to simply return [B, j0, D] if someone queries for the size of our
|
||||
// tensor.
|
||||
//
|
||||
// Morally we define comparison between two singleton ints to return true if
|
||||
// Morally we define comparison between two nested ints to return true if
|
||||
// that comparison holds for all corresponding elements of the arrays they
|
||||
// represent. Comparison between a singleton int and a plain int is defined
|
||||
// represent. Comparison between a nested int and a plain int is defined
|
||||
// similarly.
|
||||
//
|
||||
// To simulate this desired behavior but also avoid the O(N) cost of checking,
|
||||
@ -32,13 +32,13 @@ namespace c10 {
|
||||
// differentiate the two cases.
|
||||
//
|
||||
// During tracing the strides of the outputs need to be a function of the size
|
||||
// and strides of the inputs so it is important that SingletonSymNode itself is
|
||||
// and strides of the inputs so it is important that NestedIntSymNode itself is
|
||||
// able to express this.
|
||||
class TORCH_API SingletonSymNodeImpl : public SymNodeImpl {
|
||||
class TORCH_API NestedIntSymNodeImpl : public SymNodeImpl {
|
||||
public:
|
||||
// CAUTION: you should probably not be constructing these directly; please
|
||||
// the higher-level API in python instead (TODO: actually introduce that).
|
||||
explicit SingletonSymNodeImpl(int64_t val, int64_t coeff)
|
||||
explicit NestedIntSymNodeImpl(int64_t val, int64_t coeff)
|
||||
: val_(val), coeff_(coeff) {}
|
||||
|
||||
bool bool_() override {
|
||||
@ -88,9 +88,9 @@ class TORCH_API SingletonSymNodeImpl : public SymNodeImpl {
|
||||
return std::to_string(coeff_) + "*j" + std::to_string(val_);
|
||||
}
|
||||
|
||||
// NOTE [ Inequalities with SingletonInt ]
|
||||
// NOTE [ Inequalities with nested int ]
|
||||
//
|
||||
// The semantics of SingletonInt when it comes to relations is that it is
|
||||
// The semantics of nested int when it comes to relations is that it is
|
||||
// treated as integer known to be within a certain range,
|
||||
//
|
||||
// j0 \in [2, int64_t::max]
|
||||
@ -117,7 +117,7 @@ class TORCH_API SingletonSymNodeImpl : public SymNodeImpl {
|
||||
// [ Coefficient are assumed positive ]
|
||||
//
|
||||
// For the purpose of computing inequalities, we consider the coefficient of
|
||||
// the SingletonInt to be a positive integer.
|
||||
// the nested int to be a positive integer.
|
||||
//
|
||||
// Thus, no modifications are needed to the logic since
|
||||
// j0 >= k implies coeff * j0 >= k
|
||||
@ -130,11 +130,11 @@ class TORCH_API SingletonSymNodeImpl : public SymNodeImpl {
|
||||
c10::SymNode le(const c10::SymNode& other) override;
|
||||
c10::SymNode mul(const c10::SymNode& other) override;
|
||||
|
||||
c10::optional<int64_t> singleton_int() override {
|
||||
c10::optional<int64_t> nested_int() override {
|
||||
return val_;
|
||||
}
|
||||
|
||||
c10::optional<int64_t> singleton_coeff() override {
|
||||
c10::optional<int64_t> nested_int_coeff() override {
|
||||
return coeff_;
|
||||
}
|
||||
|
||||
@ -144,7 +144,7 @@ class TORCH_API SingletonSymNodeImpl : public SymNodeImpl {
|
||||
|
||||
#define DEFINE_BINARY_NOT_SUPPORTED(name) \
|
||||
c10::SymNode name(const c10::SymNode& other) override { \
|
||||
TORCH_CHECK(false, #name " not supported by SingletonSymNode"); \
|
||||
TORCH_CHECK(false, #name " not supported by NestedIntSymNode"); \
|
||||
}
|
||||
|
||||
DEFINE_BINARY_NOT_SUPPORTED(add)
|
||||
@ -162,7 +162,7 @@ class TORCH_API SingletonSymNodeImpl : public SymNodeImpl {
|
||||
|
||||
#define DEFINE_NOT_SUPPORTED(name) \
|
||||
c10::SymNode name() override { \
|
||||
TORCH_CHECK(false, #name " is not supported by SingletonSymNode"); \
|
||||
TORCH_CHECK(false, #name " is not supported by NestedIntSymNode"); \
|
||||
}
|
||||
|
||||
DEFINE_NOT_SUPPORTED(sym_not)
|
@ -1037,7 +1037,7 @@ aten_cpu_source_non_codegen_list = [
|
||||
"aten/src/ATen/core/operator_name.cpp",
|
||||
"aten/src/ATen/core/TorchDispatchUtils.cpp",
|
||||
"aten/src/ATen/core/register_symbols.cpp",
|
||||
"aten/src/ATen/core/SingletonSymNodeImpl.cpp",
|
||||
"aten/src/ATen/core/NestedIntSymNodeImpl.cpp",
|
||||
"aten/src/ATen/core/class_type.cpp",
|
||||
"aten/src/ATen/core/type.cpp",
|
||||
"aten/src/ATen/core/type_factory.cpp",
|
||||
|
@ -4,14 +4,14 @@ namespace c10 {
|
||||
|
||||
// This is used to support the case where the lhs is a constant symnode
|
||||
// and the rhs is a singleton symnode. This situation occurs today when we
|
||||
// perform a binary op between singleton int and plain int and the
|
||||
// perform a binary op between nested int and plain int and the
|
||||
// singleton promotes the int into a constant symnode. If we'd like to
|
||||
// support more combinations in the future, we may need to implement some
|
||||
// kind of multiple dispatch.
|
||||
#define DEFINE_BINARY_OP(OP, ROP) \
|
||||
template <typename T> \
|
||||
c10::SymNode ConstantSymNodeImpl<T>::OP(const c10::SymNode& other) { \
|
||||
TORCH_INTERNAL_ASSERT(other->singleton_int().has_value()); \
|
||||
TORCH_INTERNAL_ASSERT(other->nested_int().has_value()); \
|
||||
return other->ROP( \
|
||||
c10::intrusive_ptr<ConstantSymNodeImpl<T>>::reclaim_copy(this)); \
|
||||
}
|
||||
|
@ -185,10 +185,10 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target {
|
||||
virtual std::string str() {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
virtual c10::optional<int64_t> singleton_int() {
|
||||
virtual c10::optional<int64_t> nested_int() {
|
||||
return c10::nullopt;
|
||||
}
|
||||
virtual c10::optional<int64_t> singleton_coeff() {
|
||||
virtual c10::optional<int64_t> nested_int_coeff() {
|
||||
return c10::nullopt;
|
||||
}
|
||||
virtual c10::optional<int64_t> constant_int() {
|
||||
|
@ -944,7 +944,7 @@ coverage_ignore_functions = [
|
||||
"is_channels_last_strides_3d",
|
||||
"is_contiguous",
|
||||
"is_non_overlapping_and_dense_indicator",
|
||||
"is_singleton",
|
||||
"is_nested_int",
|
||||
"is_symbol_binding_fx_node",
|
||||
"is_symbolic",
|
||||
# torch.fx.experimental.unification.core
|
||||
|
@ -41,7 +41,7 @@ set(TORCH_API_TEST_SOURCES
|
||||
${TORCH_API_TEST_DIR}/inference_mode.cpp
|
||||
${TORCH_API_TEST_DIR}/grad_mode.cpp
|
||||
${TORCH_API_TEST_DIR}/operations.cpp
|
||||
${TORCH_API_TEST_DIR}/singleton_int.cpp
|
||||
${TORCH_API_TEST_DIR}/nested_int.cpp
|
||||
)
|
||||
if(USE_CUDA OR USE_ROCM)
|
||||
list(APPEND TORCH_API_TEST_SOURCES ${TORCH_API_TEST_DIR}/parallel.cpp)
|
||||
|
@ -1,19 +1,19 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <ATen/core/SingletonSymNodeImpl.h>
|
||||
#include <ATen/core/NestedIntSymNodeImpl.h>
|
||||
#include <c10/core/SymInt.h>
|
||||
#include <c10/core/SymNodeImpl.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
#include <test/cpp/api/support.h>
|
||||
|
||||
TEST(SingletonIntTest, Comparisons) {
|
||||
TEST(NestedIntTest, Comparisons) {
|
||||
auto a = c10::SymInt(
|
||||
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(1, 1)));
|
||||
c10::SymNode(c10::make_intrusive<c10::NestedIntSymNodeImpl>(1, 1)));
|
||||
auto b = c10::SymInt(
|
||||
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(1, 1)));
|
||||
c10::SymNode(c10::make_intrusive<c10::NestedIntSymNodeImpl>(1, 1)));
|
||||
auto c = c10::SymInt(
|
||||
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(2, 1)));
|
||||
c10::SymNode(c10::make_intrusive<c10::NestedIntSymNodeImpl>(2, 1)));
|
||||
auto d = c10::SymInt(3);
|
||||
|
||||
ASSERT_TRUE(a == a);
|
||||
@ -85,11 +85,11 @@ TEST(SingletonIntTest, Comparisons) {
|
||||
ASSERT_TRUE(a > 1);
|
||||
}
|
||||
|
||||
TEST(SingletonIntTest, WiithFactor) {
|
||||
TEST(NestedIntTest, WithFactor) {
|
||||
auto a = c10::SymInt(
|
||||
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(1, 5)));
|
||||
c10::SymNode(c10::make_intrusive<c10::NestedIntSymNodeImpl>(1, 5)));
|
||||
auto b = c10::SymInt(
|
||||
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(1, 10)));
|
||||
c10::SymNode(c10::make_intrusive<c10::NestedIntSymNodeImpl>(1, 10)));
|
||||
// eq
|
||||
ASSERT_FALSE(a == b);
|
||||
ASSERT_FALSE(a >= b);
|
@ -849,10 +849,10 @@ class TestSymNumberMagicMethods(TestCase):
|
||||
with self.assertRaisesRegex(TypeError, "unhashable"):
|
||||
hash(x)
|
||||
|
||||
# Singleton SymInt, constant SymBool, SymNode are hashable
|
||||
j1 = torch._C._get_singleton_int(1, 1)
|
||||
j1_copy = torch._C._get_singleton_int(1, 1)
|
||||
j2 = torch._C._get_singleton_int(2, 1)
|
||||
# NestedInt (SymInt), constant SymBool, SymNode are hashable
|
||||
j1 = torch._C._get_nested_int(1, 1)
|
||||
j1_copy = torch._C._get_nested_int(1, 1)
|
||||
j2 = torch._C._get_nested_int(2, 1)
|
||||
t = self.get_constant_bool(True)
|
||||
t_copy = self.get_constant_bool(True)
|
||||
f = self.get_constant_bool(False)
|
||||
@ -872,14 +872,14 @@ class TestSymNumberMagicMethods(TestCase):
|
||||
hash(m)
|
||||
|
||||
def test_non_symbolic_symnode(self):
|
||||
j1 = torch._C._get_singleton_int(1, 1)
|
||||
j2 = torch._C._get_singleton_int(1, 1)
|
||||
j3 = torch._C._get_singleton_int(3, 1)
|
||||
j1 = torch._C._get_nested_int(1, 1)
|
||||
j2 = torch._C._get_nested_int(1, 1)
|
||||
j3 = torch._C._get_nested_int(3, 1)
|
||||
|
||||
self.assertIsInstance(j1, torch.SymInt)
|
||||
self.assertNotIsInstance(j1, int)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "add not supported by SingletonSymNode"):
|
||||
with self.assertRaisesRegex(RuntimeError, "add not supported by NestedIntSymNode"):
|
||||
j1 + 3
|
||||
|
||||
self.assertFalse(j1 == 3)
|
||||
|
@ -3028,9 +3028,9 @@ class TestNestedTensorSubclass(TestCase):
|
||||
"directly calling torch.ops.aten.size"):
|
||||
torch.ops.aten.size.default(nt)
|
||||
|
||||
singleton_int = torch.nested._internal.nested_tensor.get_tensor_symint(_offsets, coeff=1)
|
||||
self.assertEqual(nt.size(), (3, singleton_int, 3))
|
||||
self.assertEqual(nt.shape, (3, singleton_int, 3))
|
||||
nested_int = torch.nested._internal.nested_tensor.get_tensor_symint(_offsets, coeff=1)
|
||||
self.assertEqual(nt.size(), (3, nested_int, 3))
|
||||
self.assertEqual(nt.shape, (3, nested_int, 3))
|
||||
self.assertEqual(nt.dim(), 3)
|
||||
self.assertEqual(nt.numel(), 27)
|
||||
|
||||
|
@ -1533,7 +1533,7 @@ def _are_functorch_transforms_active() -> _bool: ...
|
||||
# Define in torch/csrc/autograd/init.cpp
|
||||
def _set_python_dispatcher(dispatcher: object) -> None: ...
|
||||
|
||||
def _get_singleton_int(id: _int, coeff: _int) -> SymInt: ...
|
||||
def _get_nested_int(id: _int, coeff: _int) -> SymInt: ...
|
||||
|
||||
def _get_constant_bool_symnode(val: _bool) -> Any: ...
|
||||
|
||||
|
@ -301,12 +301,12 @@ class SymInt:
|
||||
return str(self.node)
|
||||
|
||||
def __hash__(self) -> builtins.int:
|
||||
ret = self.node.singleton_int()
|
||||
ret = self.node.nested_int()
|
||||
if ret is not None:
|
||||
return hash(ret)
|
||||
else:
|
||||
# We could support constant SymInts as well, but not doing it for now
|
||||
raise TypeError("unhashable type: non-singleton SymInt")
|
||||
raise TypeError("unhashable type: non-nested SymInt")
|
||||
|
||||
class SymFloat:
|
||||
"""
|
||||
|
@ -491,7 +491,7 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
|
||||
"torch._C._get_privateuse1_backend_name",
|
||||
"torch._C._get_qengine",
|
||||
"torch._C._get_schema",
|
||||
"torch._C._get_singleton_int",
|
||||
"torch._C._get_nested_int",
|
||||
"torch._C._get_tensor_metadata",
|
||||
"torch._C._get_tracing_state",
|
||||
"torch._C._get_upgrader_ranges",
|
||||
|
@ -1598,9 +1598,9 @@ def _automatic_dynamic(
|
||||
|
||||
# We preserve the dynamism of inputs. For example, when users call
|
||||
# make_fx(torch.cond, tracing_mode="symbolic")(*args), inputs have SymInt sizes.
|
||||
from torch.fx.experimental.symbolic_shapes import is_singleton
|
||||
from torch.fx.experimental.symbolic_shapes import is_nested_int
|
||||
|
||||
if any(isinstance(s, SymInt) and not is_singleton(s) for s in e.size()):
|
||||
if any(isinstance(s, SymInt) and not is_nested_int(s) for s in e.size()):
|
||||
return StatefulSymbolicContext(
|
||||
dynamic_sizes=[
|
||||
DimDynamic.DYNAMIC if isinstance(s, SymInt) else DimDynamic.STATIC
|
||||
@ -1729,7 +1729,7 @@ def _automatic_dynamic(
|
||||
constraint_dim is not None
|
||||
or marked_dynamic
|
||||
or marked_weak_dynamic
|
||||
or is_singleton(e.shape[i])
|
||||
or is_nested_int(e.shape[i])
|
||||
):
|
||||
# NB: We could assert static_shapes is False here, but it
|
||||
# seems better to allow the user to override symbolic_context in this
|
||||
|
@ -846,7 +846,7 @@ def infer_size_shapes(a: ShapeType, b: ShapeType) -> Tuple[int, ...]:
|
||||
(sizeA == sizeB) or (sizeA == 1) or (sizeB == 1),
|
||||
lambda: (
|
||||
f"The size of tensor a ({sizeA}) must match the size of "
|
||||
f"tensor b ({sizeB}) at non-singleton dimension {i}"
|
||||
f"tensor b ({sizeB}) at non-jagged dimension {i}"
|
||||
),
|
||||
)
|
||||
|
||||
@ -1566,13 +1566,13 @@ def make_contiguous_strides_for(
|
||||
if not shape:
|
||||
return ()
|
||||
|
||||
from torch.fx.experimental.symbolic_shapes import is_singleton
|
||||
from torch.fx.experimental.symbolic_shapes import is_nested_int
|
||||
|
||||
multiplier = 1
|
||||
strides = []
|
||||
for l in reversed(shape):
|
||||
strides.append(multiplier)
|
||||
multiplier *= l if is_singleton(l) else sym_max(l, 1)
|
||||
multiplier *= l if is_nested_int(l) else sym_max(l, 1)
|
||||
|
||||
result = tuple(reversed(strides))
|
||||
|
||||
|
@ -1290,14 +1290,14 @@ void initJITBindings(PyObject* module) {
|
||||
return node->is_symbolic();
|
||||
})
|
||||
.def(
|
||||
"singleton_int",
|
||||
"nested_int",
|
||||
[](const c10::SymNode& node) {
|
||||
return node->singleton_int();
|
||||
return node->nested_int();
|
||||
})
|
||||
.def(
|
||||
"singleton_coeff",
|
||||
"nested_int_coeff",
|
||||
[](const c10::SymNode& node) {
|
||||
return node->singleton_coeff();
|
||||
return node->nested_int_coeff();
|
||||
});
|
||||
|
||||
// clang-format on
|
||||
|
@ -5,8 +5,8 @@
|
||||
#include <ATen/FuncTorchTLS.h>
|
||||
#include <ATen/FunctionalTensorWrapper.h>
|
||||
#include <ATen/TensorSubclassLikeUtils.h>
|
||||
#include <ATen/core/NestedIntSymNodeImpl.h>
|
||||
#include <ATen/core/PythonOpRegistrationTrampoline.h>
|
||||
#include <ATen/core/SingletonSymNodeImpl.h>
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
|
||||
#include <ATen/functorch/BatchedTensorImpl.h>
|
||||
@ -823,9 +823,9 @@ void initDispatchBindings(PyObject* module) {
|
||||
include_set.has(c10::DispatchKey::FuncTorchDynamicLayerBackMode));
|
||||
});
|
||||
|
||||
m.def("_get_singleton_int", [](int64_t data, int64_t coeff) {
|
||||
m.def("_get_nested_int", [](int64_t data, int64_t coeff) {
|
||||
return c10::SymInt(c10::SymNode(
|
||||
c10::make_intrusive<c10::SingletonSymNodeImpl>(data, coeff)));
|
||||
c10::make_intrusive<c10::NestedIntSymNodeImpl>(data, coeff)));
|
||||
});
|
||||
|
||||
m.def("_get_constant_bool_symnode", [](int64_t data) {
|
||||
|
@ -426,7 +426,7 @@ class SymNode:
|
||||
def is_symbolic(self):
|
||||
return True
|
||||
|
||||
def singleton_int(self):
|
||||
def nested_int(self):
|
||||
return None
|
||||
|
||||
def is_constant(self):
|
||||
|
@ -90,7 +90,7 @@ __all__ = [
|
||||
"has_symbolic_sizes_strides", "create_contiguous", "ShapeEnv", "is_concrete_int",
|
||||
"guard_int", "guard_float", "guard_scalar", "canonicalize_bool_expr",
|
||||
"hint_int", "SYMPY_INTERP", "free_symbols", "is_symbol_binding_fx_node",
|
||||
"is_concrete_bool", "is_singleton", "SHAPEENV_EVENT_KEY", "CURRENT_NODE_KEY",
|
||||
"is_concrete_bool", "is_nested_int", "SHAPEENV_EVENT_KEY", "CURRENT_NODE_KEY",
|
||||
"has_free_symbols", "sym_eq", "SymbolicContext", "StatelessSymbolicContext",
|
||||
"StatefulSymbolicContext", "SubclassSymbolicContext", "statically_known_true",
|
||||
"guard_size_oblivious",
|
||||
@ -262,25 +262,25 @@ def is_concrete_bool(a: Union[bool, SymBool]) -> bool:
|
||||
|
||||
return False
|
||||
|
||||
def is_singleton(s: Int) -> bool:
|
||||
# check for SingletonSymNode
|
||||
def is_nested_int(s):
|
||||
# check for NestedIntSymNode
|
||||
if not isinstance(s, torch.SymInt):
|
||||
return False
|
||||
if s.node.singleton_int() is not None:
|
||||
if s.node.nested_int() is not None:
|
||||
return True
|
||||
|
||||
# check for symbolic variable wrapping a SingletonSymNode (fake-ifying causes this)
|
||||
# check for symbolic variable wrapping a NestedIntSymNode (fake-ifying causes this)
|
||||
return (
|
||||
s.node.is_symbolic()
|
||||
and s.node.hint is not None
|
||||
and isinstance(s.node.hint, torch.SymInt)
|
||||
and s.node.hint.node.singleton_int() is not None
|
||||
and s.node.hint.node.nested_int() is not None
|
||||
)
|
||||
|
||||
def _iterate_exprs(val: Union[SymInt, torch.Tensor]) -> Iterable[sympy.Basic]:
|
||||
if isinstance(val, SymTypes):
|
||||
# This allow applies to the jagged layout NestedTensor case as
|
||||
# singleton ints are not symbolic
|
||||
# nested ints are not symbolic
|
||||
if is_symbolic(val):
|
||||
yield val.node.expr
|
||||
elif isinstance(val, sympy.Basic):
|
||||
@ -2306,9 +2306,9 @@ class ShapeEnv:
|
||||
val_list = sorted(
|
||||
[(ex_stride[i], i) for i in range(len(stride)) if stride[i] is None],
|
||||
key=lambda tup: (
|
||||
# Order singletons by their coefficients.
|
||||
# 1 here to order singletons after non-singletons.
|
||||
(1, tup[0].node.singleton_coeff(), tup[1]) if is_singleton(tup[0])
|
||||
# Order nested int by their coefficients.
|
||||
# 1 here to order nested int after non-nested int.
|
||||
(1, tup[0].node.nested_int_coeff(), tup[1]) if is_nested_int(tup[0])
|
||||
else (0, *tup)
|
||||
)
|
||||
)
|
||||
@ -2572,7 +2572,7 @@ class ShapeEnv:
|
||||
self.var_to_val[sympy_expr] = sympy.Integer(val)
|
||||
else:
|
||||
# Only used for jagged layout nested tensors
|
||||
self.var_to_val[sympy_expr] = SingletonInt(val.node.singleton_int(), coeff=val.node.singleton_coeff())
|
||||
self.var_to_val[sympy_expr] = SingletonInt(val.node.nested_int(), coeff=val.node.nested_int_coeff())
|
||||
|
||||
# Do the appending later, because we always want to populate this
|
||||
self.var_to_sources[sympy_expr] = []
|
||||
|
@ -15,7 +15,7 @@ def get_tensor_symint(tensor, *, coeff=1):
|
||||
global _tensor_id_counter
|
||||
tensor_symint = _tensor_symint_registry.get(tensor)
|
||||
if tensor_symint is None:
|
||||
tensor_symint = torch._C._get_singleton_int(_tensor_id_counter, coeff)
|
||||
tensor_symint = torch._C._get_nested_int(_tensor_id_counter, coeff)
|
||||
_tensor_id_counter += 1
|
||||
_tensor_symint_registry[tensor] = tensor_symint
|
||||
return tensor_symint
|
||||
@ -30,18 +30,18 @@ class NestedTensor(torch.Tensor):
|
||||
_values: torch.Tensor # type: ignore[assignment]
|
||||
_offsets: torch.Tensor
|
||||
_lengths: Optional[torch.Tensor]
|
||||
# NOTE [ Singleton ints for ragged sizes and strides ]
|
||||
# NOTE [ Nested ints for ragged sizes and strides ]
|
||||
#
|
||||
# Jagged layout tensors are tensors that represent a n-dim tensor with a
|
||||
# ragged dimension, but are backed by an (n-1)-dim tensor underneath, e.g.,
|
||||
# a jagged tensor with outer shape [B, x, D] is represented internally by a
|
||||
# tensor with shape [sum(x), D] where we introduce what we call a singleton
|
||||
# (or skolem) denoted as "x" here (but sometimes denoted with "*" to
|
||||
# tensor with shape [sum(x), D] where we introduce what we call a nested int
|
||||
# denoted as "x" here (but sometimes denoted with "*" to
|
||||
# represent the ragged dimension, and sum(x) represents the dim of the inner
|
||||
# tensor or equivalently the sum of all the sizes of the constituent
|
||||
# tensors' varying lengths.
|
||||
#
|
||||
# We also use singleton ints to represent the strides of this tensor.
|
||||
# We also use nested ints to represent the strides of this tensor.
|
||||
# For example, a jagged tensor with shape [B, x, D] can be strided in two
|
||||
# ways: [xD, D, 1] and [x, 1, sum(x)], where xD represents x multiplied by D
|
||||
_size: Tuple[int, ...]
|
||||
|
@ -13,14 +13,14 @@ class SingletonInt(sympy.AtomicExpr):
|
||||
instance = super().__new__(cls, *args, **kwargs)
|
||||
return instance
|
||||
|
||||
# The semantics of this class should match that of SingletonSymNodeImpl in
|
||||
# c10/core/SingletonSymNodeImpl.h
|
||||
# The semantics of this class should match that of NestedIntSymNodeImpl in
|
||||
# c10/core/NestedIntSymNodeImpl.h
|
||||
def __init__(self, val, *, coeff=1):
|
||||
self._val = val
|
||||
self._coeff = coeff
|
||||
super().__init__()
|
||||
|
||||
# See NOTE [ Inequalities with SingletonInt ]
|
||||
# See NOTE [ Inequalities with nested int ]
|
||||
def _eval_Eq(self, other):
|
||||
if (
|
||||
isinstance(other, SingletonInt)
|
||||
@ -69,7 +69,7 @@ class SingletonInt(sympy.AtomicExpr):
|
||||
raise NotImplementedError("NYI")
|
||||
|
||||
|
||||
# See NOTE [ Inequalities with SingletonInt ]
|
||||
# See NOTE [ Inequalities with nested int ]
|
||||
@dispatch(sympy.Integer, SingletonInt)
|
||||
def _eval_is_ge(a, b):
|
||||
if a < 2:
|
||||
|
Reference in New Issue
Block a user