mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Reland PySymInt (#79617)"
This reverts commit 8ef6356f267c75276ea23b51163274cd5fffc0ce. Reverted https://github.com/pytorch/pytorch/pull/79617 on behalf of https://github.com/zengk95 due to this is breaking periodic jobs (and maybe pull) on trunk
This commit is contained in:
@ -79,13 +79,6 @@ bool NestedTensorImpl::is_contiguous_custom(MemoryFormat) const {
|
||||
IntArrayRef NestedTensorImpl::sizes_custom() const {
|
||||
TORCH_CHECK(false, "Internal error: NestedTensorImpl doesn't support sizes. Please file an issue on https://github.com/pytorch/nestedtensor");
|
||||
}
|
||||
c10::SymIntArrayRef NestedTensorImpl::sym_sizes_custom() const {
|
||||
TORCH_CHECK(false, "Internal error: NestedTensorImpl doesn't support sizes. Please file an issue on https://github.com/pytorch/nestedtensor");
|
||||
}
|
||||
|
||||
c10::SymIntArrayRef NestedTensorImpl::sym_sizes() const {
|
||||
return sym_sizes_custom();
|
||||
}
|
||||
|
||||
IntArrayRef NestedTensorImpl::strides_custom() const {
|
||||
TORCH_CHECK(false, "Internal error: NestedTensorImpl doesn't support strides. Please file an issue on https://github.com/pytorch/nestedtensor");
|
||||
|
@ -42,8 +42,6 @@ struct TORCH_API NestedTensorImpl : public c10::TensorImpl {
|
||||
int64_t numel_custom() const override;
|
||||
bool is_contiguous_custom(MemoryFormat) const override;
|
||||
IntArrayRef sizes_custom() const override;
|
||||
c10::SymIntArrayRef sym_sizes_custom() const override;
|
||||
c10::SymIntArrayRef sym_sizes() const override;
|
||||
IntArrayRef strides_custom() const override;
|
||||
|
||||
// this one is real
|
||||
|
@ -179,7 +179,6 @@ TORCH_LIBRARY_IMPL(aten, Named, m) {
|
||||
m.impl("exp.out", CppFunction::makeFallthrough());
|
||||
m.impl("exp_", CppFunction::makeFallthrough());
|
||||
m.impl("expand", CppFunction::makeFallthrough());
|
||||
m.impl("expand.SymInt", CppFunction::makeFallthrough());
|
||||
m.impl("expm1", CppFunction::makeFallthrough());
|
||||
m.impl("expm1.out", CppFunction::makeFallthrough());
|
||||
m.impl("expm1_", CppFunction::makeFallthrough());
|
||||
|
@ -156,14 +156,6 @@ class TORCH_API TensorBase {
|
||||
return at::isSignedType(this->scalar_type());
|
||||
}
|
||||
|
||||
c10::SymInt sym_size(int64_t dim) const {
|
||||
const auto sizes = this->sym_sizes();
|
||||
const auto ndim = static_cast<int64_t>(sizes.size());
|
||||
// false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping)
|
||||
return sizes[c10::maybe_wrap_dim(dim, ndim, /*wrap_scalar=*/false)];
|
||||
|
||||
}
|
||||
|
||||
int64_t size(int64_t dim) const {
|
||||
const auto sizes = this->sizes();
|
||||
const auto ndim = static_cast<int64_t>(sizes.size());
|
||||
|
@ -132,7 +132,6 @@ class TORCH_API Tensor: public TensorBase {
|
||||
|
||||
// Aliased by Dimname overloads, so need explicit using
|
||||
using TensorBase::size;
|
||||
using TensorBase::sym_size;
|
||||
using TensorBase::stride;
|
||||
|
||||
/// Should be used if *this can reasonably be expected to be contiguous and
|
||||
|
@ -25,5 +25,4 @@ SymIntTable& getSymIntTable() {
|
||||
static SymIntTable sit;
|
||||
return sit;
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
@ -13,53 +13,7 @@ class C10_API SymbolicIntNode
|
||||
public:
|
||||
c10::SymInt toSymInt();
|
||||
virtual ~SymbolicIntNode(){};
|
||||
// these could be pure virtual when we implement LTC versions
|
||||
virtual std::shared_ptr<SymbolicIntNode> add(
|
||||
const std::shared_ptr<SymbolicIntNode>& other) {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
virtual std::shared_ptr<SymbolicIntNode> sub(
|
||||
const std::shared_ptr<SymbolicIntNode>& other) {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
virtual std::shared_ptr<SymbolicIntNode> mul(
|
||||
const std::shared_ptr<SymbolicIntNode>& other) {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
virtual std::shared_ptr<SymbolicIntNode> div(
|
||||
const std::shared_ptr<SymbolicIntNode>& other) {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
virtual std::shared_ptr<SymbolicIntNode> mod(
|
||||
const std::shared_ptr<SymbolicIntNode>& other) {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
virtual std::shared_ptr<SymbolicIntNode> eq(
|
||||
const std::shared_ptr<SymbolicIntNode>& other) {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
virtual std::shared_ptr<SymbolicIntNode> gt(
|
||||
const std::shared_ptr<SymbolicIntNode>& other) {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
virtual std::shared_ptr<SymbolicIntNode> lt(
|
||||
const std::shared_ptr<SymbolicIntNode>& other) {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
virtual std::shared_ptr<SymbolicIntNode> wrap(int64_t num) {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
virtual bool bool_() {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
virtual int64_t int_() {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
}
|
||||
virtual std::string str() {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
std::ostream& operator<<(std::ostream& os) {
|
||||
os << str();
|
||||
virtual std::ostream& operator<<(std::ostream& os) {
|
||||
return os;
|
||||
};
|
||||
};
|
||||
|
@ -806,15 +806,6 @@ void TensorImpl::ShareExternalPointer(
|
||||
}
|
||||
}
|
||||
|
||||
void TensorImpl::set_sym_sizes_and_strides(
|
||||
c10::SymIntArrayRef sizes,
|
||||
c10::SymIntArrayRef strides) {
|
||||
has_symbolic_sizes_strides_ = true;
|
||||
sizes_strides_policy_ = static_cast<uint8_t>(SizesStridesPolicy::CustomSizes);
|
||||
sizes_and_strides_.set_sizes(sizes);
|
||||
sizes_and_strides_.set_strides(strides);
|
||||
}
|
||||
|
||||
namespace impl {
|
||||
|
||||
namespace {
|
||||
|
@ -552,7 +552,12 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
return sizes_default();
|
||||
}
|
||||
|
||||
virtual c10::SymIntArrayRef sym_sizes() const {
|
||||
c10::SymIntArrayRef sym_sizes() const {
|
||||
if (C10_UNLIKELY(
|
||||
sizes_strides_policy_ >=
|
||||
static_cast<uint8_t>(SizesStridesPolicy::CustomSizes))) {
|
||||
return sym_sizes_custom();
|
||||
}
|
||||
return sym_sizes_default();
|
||||
}
|
||||
|
||||
@ -1307,12 +1312,6 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
return numel() == 0;
|
||||
}
|
||||
|
||||
// if we are going to use sym sizes, we should be setting sym strides at the
|
||||
// same time, otherwise it's very easy to misuse this API
|
||||
void set_sym_sizes_and_strides(
|
||||
c10::SymIntArrayRef sizes,
|
||||
c10::SymIntArrayRef strides);
|
||||
|
||||
/**
|
||||
* Change the size at some dimension. This DOES NOT update strides;
|
||||
* thus, most changes to size will not preserve contiguity. You probably
|
||||
@ -2327,7 +2326,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
// Customizable sizes behavior, e.g., nested tensor
|
||||
//
|
||||
// Can override: strides(), is_contiguous(), sizes(), dim(), numel()
|
||||
CustomSizes = 2
|
||||
CustomSizes = 2,
|
||||
};
|
||||
|
||||
void set_sizes_strides_policy(SizesStridesPolicy policy) {
|
||||
@ -2338,7 +2337,6 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
custom_device_ = custom_device;
|
||||
}
|
||||
|
||||
protected:
|
||||
Storage storage_;
|
||||
|
||||
private:
|
||||
|
@ -170,11 +170,6 @@ class C10_API SizesAndStrides {
|
||||
std::copy(newSizes.begin(), newSizes.end(), sizes_begin());
|
||||
}
|
||||
|
||||
void set_strides(SymIntArrayRef strides) {
|
||||
TORCH_INTERNAL_ASSERT(strides.size() == size());
|
||||
std::copy(strides.begin(), strides.end(), strides_begin());
|
||||
}
|
||||
|
||||
void set_sizes(IntArrayRef newSizes) {
|
||||
set_sizes(SymIntArrayRef::fromIntArrayRef(newSizes));
|
||||
}
|
||||
|
@ -320,7 +320,6 @@ coverage_ignore_classes = [
|
||||
"Quantize",
|
||||
# torch.utils.backcompat
|
||||
"Warning",
|
||||
"SymbolicIntNode"
|
||||
]
|
||||
|
||||
# The suffix(es) of source filenames.
|
||||
|
@ -10,4 +10,3 @@ setuptools
|
||||
six
|
||||
types-dataclasses
|
||||
typing_extensions
|
||||
sympy
|
||||
|
@ -104,20 +104,15 @@ class TestLazyReuseIr(TestCase):
|
||||
def testBatchNorm(self):
|
||||
device = get_test_device()
|
||||
x = torch.randn(16, 3, 224, 224, device=device)
|
||||
weight = torch.randn(3, device=device)
|
||||
bias = torch.randn(3, device=device)
|
||||
|
||||
bn = torch.nn.BatchNorm2d(3).to(device=device)
|
||||
for i in range(10):
|
||||
# BatchNorm2d does extra checks on dimensions which SymInts don't support yet
|
||||
# so we call `torch.ops.aten.native_batch_norm` to bypass the checks.
|
||||
z, _, _ = torch.ops.aten.native_batch_norm(x, weight, bias, None, None, True, 0.1, 1e-5)
|
||||
z = bn(x)
|
||||
|
||||
device = "lazy"
|
||||
x_lazy = x.detach().clone().to(device=device)
|
||||
weight_lazy = weight.detach().clone().to(device=device)
|
||||
bias_lazy = bias.detach().clone().to(device=device)
|
||||
bn = bn.to(device=device)
|
||||
for i in range(10):
|
||||
z_lazy, _, _ = torch.ops.aten.native_batch_norm(x_lazy, weight_lazy, bias_lazy, None, None, True, 0.1, 1e-5)
|
||||
z_lazy = bn(x_lazy)
|
||||
torch._lazy.mark_step()
|
||||
|
||||
torch.testing.assert_close(z.cpu(), z_lazy.cpu())
|
||||
|
@ -17,7 +17,6 @@ import itertools
|
||||
import yaml
|
||||
import os
|
||||
import pathlib
|
||||
from unittest import skip
|
||||
|
||||
torch._lazy.ts_backend.init()
|
||||
|
||||
@ -67,9 +66,6 @@ def clone_move(t):
|
||||
return copy_t
|
||||
|
||||
class TestLazyTensor(JitTestCase):
|
||||
|
||||
|
||||
@skip("Disable until autograd supports symints")
|
||||
def testConvolutionBackward(self):
|
||||
test_device = get_test_device()
|
||||
inp = torch.rand(1, 3, 128, 128, device=test_device, requires_grad=True)
|
||||
@ -224,9 +220,8 @@ class TestLazyDynamicOps(TestCase):
|
||||
x1 = torch.tensor([[0, 1.0, 2.0], [3.0, 0, 0]], device=test_device, requires_grad=True)
|
||||
x1_lazy = clone_move(x1)
|
||||
x2_lazy = torch.nonzero(x1_lazy)
|
||||
|
||||
# FIXME: Add bindings to get upper bounds
|
||||
# self.assertEqual(tuple(x2_lazy.size()), (6, 2))
|
||||
print(x2_lazy.size())
|
||||
self.assertEqual(tuple(x2_lazy.size()), (6, 2))
|
||||
|
||||
# We should still be able to instantiate it and get the actual result
|
||||
x2_eager = x2_lazy.cpu()
|
||||
|
@ -1,309 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
from torch._C import _disabled_torch_function_impl
|
||||
import torch.fx
|
||||
import torch.nn.functional as F
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
import unittest
|
||||
import torch
|
||||
from torch.utils._pytree import tree_map
|
||||
aten = torch.ops.aten
|
||||
|
||||
try:
|
||||
import sympy
|
||||
HAS_SYMPY = True
|
||||
except ImportError:
|
||||
HAS_SYMPY = False
|
||||
skipIfNoSympy = unittest.skipIf(not HAS_SYMPY, "no sympy")
|
||||
|
||||
|
||||
meta_funcs = {}
|
||||
|
||||
|
||||
def register_meta(op):
|
||||
def decorator(f):
|
||||
def add_func(op):
|
||||
meta_funcs[op] = f
|
||||
tree_map(add_func, op)
|
||||
return f
|
||||
return decorator
|
||||
|
||||
|
||||
@register_meta([aten.add.Tensor, aten.sub.Tensor])
|
||||
def binary_meta(a, b):
|
||||
return a.new_empty(a.sym_size())
|
||||
|
||||
|
||||
@register_meta(aten.cat.default)
|
||||
def cat_meta(tensors, dim=0):
|
||||
concat_length = 0
|
||||
shape = tensors[0].shape
|
||||
for tensor in tensors:
|
||||
for idx, (common_length, length) in enumerate(zip(shape, tensor.shape)):
|
||||
if idx == dim:
|
||||
concat_length = concat_length + length
|
||||
else:
|
||||
assert length == common_length
|
||||
new_shape = list(shape)
|
||||
new_shape[dim] = concat_length
|
||||
return tensors[0].new_empty(new_shape)
|
||||
|
||||
|
||||
@register_meta([aten.narrow_copy.SymInt])
|
||||
def narrow_copy_symint_meta(a, dim, start, length, **kwargs):
|
||||
shape = []
|
||||
for i, x in enumerate(a.sym_size()):
|
||||
if i == dim:
|
||||
shape.append(length)
|
||||
else:
|
||||
shape.append(x)
|
||||
return a.new_empty(tuple(shape))
|
||||
|
||||
|
||||
@register_meta([aten.expand.SymInt])
|
||||
def expand_symint_meta(a, size, implicit=False):
|
||||
return a.new_empty(size)
|
||||
|
||||
|
||||
class PySymInt(object):
|
||||
def __init__(self, expr, shape_env):
|
||||
self.expr = expr
|
||||
self.shape_env = shape_env
|
||||
|
||||
def wrap(self, num):
|
||||
return PySymInt(sympy.Integer(num), self.shape_env)
|
||||
|
||||
def __str__(self):
|
||||
return f"PySymInt({self.expr})"
|
||||
|
||||
def __int__(self):
|
||||
return self.shape_env.evaluate_expr(self.expr)
|
||||
|
||||
def __bool__(self):
|
||||
return bool(self.shape_env.evaluate_expr(self.expr))
|
||||
|
||||
|
||||
magic_methods = {
|
||||
'add': lambda a, b: a + b,
|
||||
'radd': lambda a, b: a + b,
|
||||
'sub': lambda a, b: a - b,
|
||||
'mul': lambda a, b: a * b,
|
||||
'div': lambda a, b: a / b,
|
||||
'mod': lambda a, b: a % b,
|
||||
'eq': lambda a, b: sympy.Eq(a, b),
|
||||
'gt': lambda a, b: sympy.Gt(a, b),
|
||||
'lt': lambda a, b: sympy.Lt(a, b),
|
||||
}
|
||||
|
||||
for method, func in magic_methods.items():
|
||||
method_name = f'{method}'
|
||||
|
||||
def create_magic_impl(func):
|
||||
def magic_impl(self, other):
|
||||
if isinstance(other, PySymInt):
|
||||
other = other.expr
|
||||
return PySymInt(func(self.expr, other), self.shape_env)
|
||||
return magic_impl
|
||||
|
||||
# this should be wrapped transparently into torch._C.SymbolicIntNode
|
||||
setattr(PySymInt, method_name, create_magic_impl(func))
|
||||
|
||||
|
||||
class ShapeEnv(object):
|
||||
def __init__(self):
|
||||
self.guards = []
|
||||
self.shape_env = {}
|
||||
|
||||
def create_symint(self, name, val):
|
||||
sympy_expr = sympy.Symbol(name)
|
||||
py_sym_int = PySymInt(sympy_expr, self)
|
||||
cpp_sym_int = torch._C.SymbolicIntNode.new_symint(py_sym_int)
|
||||
self.shape_env[sympy_expr] = val
|
||||
return cpp_sym_int
|
||||
|
||||
def evaluate_expr(self, expr):
|
||||
concrete_val = expr.subs(self.shape_env)
|
||||
self.guards.append((expr, concrete_val))
|
||||
return concrete_val
|
||||
|
||||
|
||||
def create_contiguous(shape):
|
||||
strides = [1]
|
||||
for dim in reversed(shape[:-1]):
|
||||
strides.append(dim * strides[-1])
|
||||
return list(reversed(strides))
|
||||
|
||||
|
||||
class FakeSymbolicTensor(torch.Tensor):
|
||||
@staticmethod
|
||||
def __new__(cls, sym_shape, sym_strides, dtype, layout, requires_grad, device):
|
||||
# sym_strides doesn't work yet
|
||||
# TODO: this is wrong in general
|
||||
offset = 0
|
||||
r = torch.Tensor._make_wrapper_subclass(
|
||||
cls, sym_shape,
|
||||
create_contiguous(sym_shape), offset,
|
||||
dtype=dtype, layout=layout, requires_grad=requires_grad,
|
||||
device=device,
|
||||
)
|
||||
return r
|
||||
|
||||
__torch_function__ = _disabled_torch_function_impl
|
||||
|
||||
def new_empty(self, shape):
|
||||
return FakeSymbolicTensor(shape, None, self.dtype, self.layout, self.requires_grad, self.device)
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None):
|
||||
if func_overload in meta_funcs:
|
||||
return meta_funcs[func_overload](*args, **kwargs)
|
||||
|
||||
if func_overload == torch.ops.aten.new_empty.default:
|
||||
self = args[0]
|
||||
shape = args[1]
|
||||
return FakeSymbolicTensor(shape, self.stride(), self.dtype, self.layout, self.requires_grad, self.device)
|
||||
|
||||
raise RuntimeError(f"operator {func_overload} not supported")
|
||||
|
||||
|
||||
def create_symbolic_tensor(name, arg, shape_env):
|
||||
sym_shapes = tuple([shape_env.create_symint(f"{name}_{idx}", val) for idx, val in enumerate(arg.sym_size())])
|
||||
sym_strides = tuple([shape_env.create_symint(f"{name}_{idx}_stride", val) for idx, val in enumerate(arg.stride())])
|
||||
return FakeSymbolicTensor(sym_shapes, sym_strides, arg.dtype, arg.layout, arg.requires_grad, arg.device)
|
||||
|
||||
|
||||
CPP_SYMINT_CLASS = type(torch._C.SymbolicIntNode.new_symint(1))
|
||||
|
||||
|
||||
class TestPySymInt(TestCase):
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_roundtrip(self):
|
||||
shape_env = ShapeEnv()
|
||||
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
|
||||
self.assertTrue(not isinstance(x.sym_size(0), PySymInt))
|
||||
self.assertTrue(isinstance(x.sym_size(0), CPP_SYMINT_CLASS))
|
||||
|
||||
self.assertEqual(int(x.sym_size(0)), 5)
|
||||
self.assertEqual(int(x.sym_size(1)), 4)
|
||||
self.assertEqual(int(x.sym_size(2)), 3)
|
||||
|
||||
self.assertEqual(int(x.sym_size()[0]), 5)
|
||||
self.assertEqual(int(x.sym_size()[1]), 4)
|
||||
self.assertTrue(isinstance(x.sym_size()[1], CPP_SYMINT_CLASS))
|
||||
self.assertEqual(int(x.sym_size()[2]), 3)
|
||||
|
||||
self.assertEqual(int(x.sym_size(0)), 5)
|
||||
self.assertEqual(int(x.sym_size(1)), 4)
|
||||
self.assertEqual(int(x.sym_size(2)), 3)
|
||||
self.assertTrue(isinstance(x.sym_size(2), CPP_SYMINT_CLASS))
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_binary(self):
|
||||
shape_env = ShapeEnv()
|
||||
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
|
||||
y = create_symbolic_tensor("y", torch.randn(5, 4, 3), shape_env)
|
||||
|
||||
z = x + y
|
||||
self.assertEqual(int(z.sym_size(0)), 5)
|
||||
self.assertEqual(int(z.sym_size(1)), 4)
|
||||
self.assertEqual(int(z.sym_size(2)), 3)
|
||||
|
||||
# broadcasting
|
||||
y = create_symbolic_tensor("y", torch.randn(1, 4, 1), shape_env)
|
||||
z = x + y
|
||||
self.assertEqual(int(z.sym_size(0)), 5)
|
||||
self.assertEqual(int(z.sym_size(1)), 4)
|
||||
self.assertEqual(int(z.sym_size(2)), 3)
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_symint_args(self):
|
||||
shape_env = ShapeEnv()
|
||||
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
|
||||
y = create_symbolic_tensor("y", torch.randn(5, 4, 1), shape_env)
|
||||
LAST_DIM = 2
|
||||
z = x.narrow_copy(LAST_DIM, 0, y.sym_size(LAST_DIM))
|
||||
self.assertEqual(int(z.sym_size(2)), int(y.sym_size(2)))
|
||||
|
||||
# arithmetic expr with two symints
|
||||
z = x.narrow_copy(LAST_DIM, 0, x.sym_size(LAST_DIM) - y.sym_size(LAST_DIM))
|
||||
self.assertEqual(int(z.sym_size(2)), 2)
|
||||
|
||||
# arithmetic expr with a symint and python int
|
||||
z = x.narrow_copy(LAST_DIM, 0, x.sym_size(LAST_DIM) - 1)
|
||||
self.assertEqual(int(z.sym_size(2)), 2)
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_symint_vargs(self):
|
||||
shape_env = ShapeEnv()
|
||||
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
|
||||
y = create_symbolic_tensor("y", torch.randn(1, 4, 1), shape_env)
|
||||
|
||||
# varargs
|
||||
z = y.expand(x.sym_size(0), y.sym_size(1), x.sym_size(2))
|
||||
self.assertEqual(int(z.sym_size(0)), 5)
|
||||
self.assertEqual(int(z.sym_size(1)), 4)
|
||||
self.assertEqual(int(z.sym_size(2)), 3)
|
||||
|
||||
# shape list
|
||||
z = y.expand((x.sym_size(0), y.sym_size(1), x.sym_size(2)))
|
||||
self.assertEqual(int(z.sym_size(0)), 5)
|
||||
self.assertEqual(int(z.sym_size(1)), 4)
|
||||
self.assertEqual(int(z.sym_size(2)), 3)
|
||||
|
||||
# mixed python symints and ints
|
||||
z = y.expand(x.sym_size(0), y.sym_size(1), 3)
|
||||
self.assertEqual(int(z.sym_size(0)), 5)
|
||||
self.assertEqual(int(z.sym_size(1)), 4)
|
||||
self.assertEqual(int(z.sym_size(2)), 3)
|
||||
|
||||
# mixed python symints and ints in a list
|
||||
z = y.expand((x.sym_size(0), y.sym_size(1), 3))
|
||||
self.assertEqual(int(z.sym_size(0)), 5)
|
||||
self.assertEqual(int(z.sym_size(1)), 4)
|
||||
self.assertEqual(int(z.sym_size(2)), 3)
|
||||
|
||||
# mixed python symints and ints
|
||||
z = y.expand(5, y.sym_size(1), x.sym_size(2))
|
||||
self.assertEqual(int(z.sym_size(0)), 5)
|
||||
self.assertEqual(int(z.sym_size(1)), 4)
|
||||
self.assertEqual(int(z.sym_size(2)), 3)
|
||||
|
||||
# mixed python ints and symints in a list
|
||||
z = y.expand((5, y.sym_size(1), x.sym_size(2)))
|
||||
self.assertEqual(int(z.sym_size(0)), 5)
|
||||
self.assertEqual(int(z.sym_size(1)), 4)
|
||||
self.assertEqual(int(z.sym_size(2)), 3)
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_size_expressions(self):
|
||||
shape_env = ShapeEnv()
|
||||
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
|
||||
expand_x = x.expand(x.sym_size(0), x.sym_size(0))
|
||||
if expand_x.sym_size(0) > 3:
|
||||
result = expand_x + expand_x
|
||||
else:
|
||||
result = expand_x + expand_x
|
||||
|
||||
gt_op = shape_env.guards[0][0]
|
||||
self.assertTrue(isinstance(gt_op, sympy.core.relational.StrictGreaterThan))
|
||||
self.assertTrue(str(x.sym_size(0)), str(gt_op.args[0]))
|
||||
self.assertTrue(str(expand_x.sym_size(1)), str(x.sym_size(0)))
|
||||
self.assertTrue(str(expand_x.sym_size(1)), str(result.sym_size(0)))
|
||||
|
||||
def test_fx_trace_intlist(self):
|
||||
class CustomModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
bs, c, h, w = x.shape
|
||||
return F.pad(x, (0, w % 2, 0, h % 2, 0, 0))
|
||||
|
||||
m = CustomModule()
|
||||
x = torch.rand(1, 3, 4, 4)
|
||||
# should not TypeError: pad(): argument 'pad' (position 2) must be
|
||||
# tuple of ints, not tuple
|
||||
torch.fx.symbolic_trace(m)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
@ -146,7 +146,7 @@ class TestNestedTensor(TestCase):
|
||||
a1 = constructor([])
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Tensors of type NestedTensorImpl do not have sym sizes"
|
||||
"Tensors of type NestedTensorImpl do not have sizes"
|
||||
if IS_FBCODE
|
||||
else "NestedTensorImpl doesn't support sizes",
|
||||
lambda: a1.size(),
|
||||
|
@ -188,7 +188,6 @@ class TestPublicBindings(TestCase):
|
||||
"StreamObjType",
|
||||
"StringType",
|
||||
"SUM",
|
||||
"SymbolicIntNode",
|
||||
"TensorType",
|
||||
"ThroughputBenchmark",
|
||||
"TracingState",
|
||||
|
@ -1113,8 +1113,6 @@ def group_overloads(
|
||||
def sort_overloads(
|
||||
grouped_overloads: Sequence[PythonSignatureGroup],
|
||||
) -> Sequence[PythonSignatureGroup]:
|
||||
# NB: Smaller here means lower priority
|
||||
|
||||
def is_arg_smaller(t1: Type, t2: Type) -> bool:
|
||||
return (
|
||||
str(t1) == "Scalar"
|
||||
@ -1133,10 +1131,6 @@ def sort_overloads(
|
||||
# last in signature ordering. See discussion: https://github.com/pytorch/pytorch/issues/58087
|
||||
str(t1) == "Tensor[]"
|
||||
and str(t2).find("[]") != -1
|
||||
or
|
||||
# Prioritize SymIntArrayRef overload over IntArrayRef
|
||||
str(t1) == "int[]"
|
||||
and str(t2) == "SymInt[]"
|
||||
)
|
||||
|
||||
def is_smaller(s1: PythonSignature, s2: PythonSignature) -> bool:
|
||||
|
@ -95,43 +95,6 @@ static PyObject * THPVariable_apply_(PyObject* self, PyObject* arg)
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
// TODO: FIXME This should be super temprorary until we fix the XLA issue.
|
||||
static PyObject * THPVariable_sym_size(PyObject* self, PyObject* args, PyObject* kwargs)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
static PythonArgParser parser({
|
||||
"sym_size(int64_t dim)",
|
||||
"sym_size()",
|
||||
"sym_size(Dimname dim)",
|
||||
});
|
||||
auto& self_ = THPVariable_Unpack(self);
|
||||
ParsedArgs<3> parsed_args;
|
||||
auto r = parser.parse(self, args, kwargs, parsed_args);
|
||||
|
||||
if(r.has_torch_function()){
|
||||
return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor");
|
||||
}
|
||||
if (r.idx == 0) {
|
||||
if (jit::tracer::isTracing()) {
|
||||
// will error out if a tensor has symints
|
||||
return wrap(jit::tracer::getSizeOf(self_, r.toInt64(0)));
|
||||
} else {
|
||||
return torch::toPyObject(self_.sym_size(r.toInt64(0)));
|
||||
}
|
||||
} else if (r.idx == 1) {
|
||||
return THPSize_NewFromSymSizes(self_);
|
||||
}
|
||||
else if (r.idx == 2) {
|
||||
if (jit::tracer::isTracing()) {
|
||||
TORCH_INTERNAL_ASSERT(false, "NYI: Named tensors w/ JIT");
|
||||
}
|
||||
return wrap(self_.size(r.dimname(0)));
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
|
||||
static PyObject * THPVariable_size(PyObject* self, PyObject* args, PyObject* kwargs)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
@ -147,19 +110,17 @@ static PyObject * THPVariable_size(PyObject* self, PyObject* args, PyObject* kwa
|
||||
if(r.has_torch_function()){
|
||||
return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor");
|
||||
}
|
||||
|
||||
if (r.idx == 0) {
|
||||
if (jit::tracer::isTracing()) {
|
||||
// will error out if a tensor has symints
|
||||
return wrap(jit::tracer::getSizeOf(self_, r.toInt64(0)));
|
||||
} else {
|
||||
return wrap(self_.size(r.toInt64(0)));
|
||||
//return torch::toPyObject(self_.sym_size(r.toInt64(0)));
|
||||
}
|
||||
} else if (r.idx == 1) {
|
||||
// we can't do the normal wrapping here because IntArrayRef maps to both
|
||||
// torch.Size and tuple in python.
|
||||
return THPSize_New(self_);
|
||||
//return THPSize_NewFromSymSizes(self_);
|
||||
}
|
||||
else if (r.idx == 2) {
|
||||
if (jit::tracer::isTracing()) {
|
||||
@ -1322,7 +1283,6 @@ PyMethodDef variable_methods[] = {
|
||||
{"set_", castPyCFunctionWithKeywords(THPVariable_set_), METH_VARARGS | METH_KEYWORDS, NULL},
|
||||
{"short", castPyCFunctionWithKeywords(THPVariable_short), METH_VARARGS | METH_KEYWORDS, NULL},
|
||||
{"size", castPyCFunctionWithKeywords(THPVariable_size), METH_VARARGS | METH_KEYWORDS, NULL},
|
||||
{"sym_size", castPyCFunctionWithKeywords(THPVariable_sym_size), METH_VARARGS | METH_KEYWORDS, NULL},
|
||||
{"_storage", THPVariable_storage, METH_NOARGS, NULL},
|
||||
{"storage_offset", THPVariable_storage_offset, METH_NOARGS, NULL},
|
||||
{"stride", castPyCFunctionWithKeywords(THPVariable_stride), METH_VARARGS | METH_KEYWORDS, NULL},
|
||||
|
@ -1,9 +1,7 @@
|
||||
#include <c10/util/irange.h>
|
||||
#include <pybind11/pytypes.h>
|
||||
#include <torch/csrc/Size.h>
|
||||
|
||||
#include <torch/csrc/utils/object_ptr.h>
|
||||
#include <torch/csrc/utils/python_arg_parser.h>
|
||||
#include <torch/csrc/utils/python_numbers.h>
|
||||
#include <torch/csrc/utils/python_strings.h>
|
||||
#include <torch/csrc/utils/python_tuples.h>
|
||||
@ -44,36 +42,6 @@ PyObject* THPSize_NewFromSizes(int dim, const int64_t* sizes) {
|
||||
return self.release();
|
||||
}
|
||||
|
||||
PyObject* THPSize_NewFromSymSizes(const at::Tensor& self_) {
|
||||
auto sym_sizes = self_.sym_sizes();
|
||||
|
||||
auto ret = THPObjectPtr(THPSizeType.tp_alloc(&THPSizeType, sym_sizes.size()));
|
||||
if (!ret)
|
||||
throw python_error();
|
||||
|
||||
for (auto i : c10::irange(sym_sizes.size())) {
|
||||
auto si = sym_sizes[i];
|
||||
if (si.is_symbolic()) {
|
||||
TORCH_CHECK(
|
||||
!torch::jit::tracer::isTracing(),
|
||||
"JIT Tracing of SymInts isn't supported");
|
||||
auto py_symint = py::cast(si.toSymbolicIntNode()).release().ptr();
|
||||
PyTuple_SET_ITEM(ret.get(), i, py_symint);
|
||||
} else {
|
||||
if (torch::jit::tracer::isTracing()) {
|
||||
PyObject* py_size_tensor =
|
||||
THPVariable_Wrap(torch::jit::tracer::getSizeOf(self_, i));
|
||||
if (!py_size_tensor)
|
||||
throw python_error();
|
||||
PyTuple_SET_ITEM(ret.get(), i, py_size_tensor);
|
||||
} else {
|
||||
PyTuple_SET_ITEM(ret.get(), i, THPUtils_packInt64(si.data()));
|
||||
}
|
||||
}
|
||||
}
|
||||
return ret.release();
|
||||
}
|
||||
|
||||
static bool isTracedZeroDimVar(PyObject* item) {
|
||||
if (!THPVariable_Check(item))
|
||||
return false;
|
||||
@ -93,9 +61,6 @@ static PyObject* THPSize_pynew(
|
||||
if (THPUtils_checkLong(item)) {
|
||||
continue;
|
||||
}
|
||||
if (torch::is_symint_node(item)) {
|
||||
continue;
|
||||
}
|
||||
if (torch::jit::tracer::isTracing() && isTracedZeroDimVar(item)) {
|
||||
continue;
|
||||
}
|
||||
@ -127,12 +92,7 @@ static PyObject* THPSize_repr(THPSize* self) {
|
||||
if (i != 0) {
|
||||
repr += ", ";
|
||||
}
|
||||
auto item = PyTuple_GET_ITEM(self, i);
|
||||
auto ih = py::handle(item);
|
||||
|
||||
repr += torch::is_symint_node(ih)
|
||||
? std::string(py::str(ih))
|
||||
: std::to_string(THPUtils_unpackLong(PyTuple_GET_ITEM(self, i)));
|
||||
repr += std::to_string(THPUtils_unpackLong(PyTuple_GET_ITEM(self, i)));
|
||||
}
|
||||
repr += "])";
|
||||
return THPUtils_packString(repr);
|
||||
|
@ -10,6 +10,5 @@ extern PyTypeObject THPSizeType;
|
||||
|
||||
PyObject* THPSize_New(const torch::autograd::Variable& t);
|
||||
PyObject* THPSize_NewFromSizes(int dim, const int64_t* sizes);
|
||||
PyObject* THPSize_NewFromSymSizes(const at::Tensor& t);
|
||||
|
||||
void THPSize_init(PyObject* module);
|
||||
|
@ -656,10 +656,6 @@ static PyObject* THPVariable_make_wrapper_subclass(
|
||||
"int64_t? storage_offset=None, MemoryFormat? memory_format=None, ScalarType dtype=None, "
|
||||
"Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False, "
|
||||
"c10::string_view? dispatch_sizes_strides_policy=None, bool dispatch_device=False)",
|
||||
"_make_wrapper_subclass(PyObject* cls, SymIntArrayRef size, SymIntArrayRef strides, "
|
||||
"int64_t? storage_offset=None, MemoryFormat? memory_format=None, ScalarType dtype=None, "
|
||||
"Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False, "
|
||||
"c10::string_view? dispatch_sizes_strides_policy=None, bool dispatch_device=False)",
|
||||
});
|
||||
ParsedArgs<12> parsed_args{};
|
||||
auto r = parser.parse(args, kwargs, parsed_args);
|
||||
@ -699,64 +695,29 @@ static PyObject* THPVariable_make_wrapper_subclass(
|
||||
// data
|
||||
// TODO: for_blob produces non-resizable tensors, we might want this to be
|
||||
// resizable (have to define a custom allocator in that case)
|
||||
Tensor tensor;
|
||||
if (r.idx == 0) {
|
||||
tensor = at::for_blob(nullptr, r.intlist(1))
|
||||
.strides(r.intlistOptional(2))
|
||||
.storage_offset(r.toInt64Optional(3))
|
||||
.context(nullptr, [](void* ctx) {})
|
||||
.target_device(
|
||||
options.device()) // TODO: this shouldn't be necessary if
|
||||
// it came from options
|
||||
.options(options)
|
||||
.make_tensor();
|
||||
auto data =
|
||||
at::for_blob(nullptr, r.intlist(1))
|
||||
.strides(r.intlistOptional(2))
|
||||
.storage_offset(r.toInt64Optional(3))
|
||||
.context(nullptr, [](void* ctx) {})
|
||||
.target_device(options.device()) // TODO: this shouldn't be necessary
|
||||
// if it came from options
|
||||
.options(options)
|
||||
.make_tensor();
|
||||
data.set_requires_grad(r.toBool(9));
|
||||
|
||||
const auto sizes_strides_policy = r.stringViewOptional(10);
|
||||
if (sizes_strides_policy.has_value()) {
|
||||
tensor.unsafeGetTensorImpl()->set_sizes_strides_policy(
|
||||
parseSizesStridesPolicyArgument(*sizes_strides_policy));
|
||||
}
|
||||
} else {
|
||||
AutoDispatchBelowADInplaceOrView guard{}; // TODO: Remove.
|
||||
tracer::impl::NoTracerDispatchMode tracer_guard{};
|
||||
|
||||
// We shouldn't need storage
|
||||
Storage storage{Storage::use_byte_size_t{}, 0, at::DataPtr{}};
|
||||
|
||||
tensor = at::detail::make_tensor<TensorImpl>(
|
||||
std::move(storage), options.computeDispatchKey(), options.dtype());
|
||||
|
||||
auto sym_sizes = r.symintlist(1);
|
||||
auto sym_strides = r.symintlist(2);
|
||||
|
||||
TensorImpl* tensor_impl = tensor.unsafeGetTensorImpl();
|
||||
|
||||
// TODO: this should probably be sym_sizes, sym_strides AND offset
|
||||
tensor_impl->set_sym_sizes_and_strides(sym_sizes, sym_strides);
|
||||
|
||||
// TODO: this may need to be symbolic as well
|
||||
auto storage_offset = r.toInt64Optional(3);
|
||||
if (storage_offset) {
|
||||
tensor_impl->set_storage_offset(*storage_offset);
|
||||
}
|
||||
|
||||
const auto sizes_strides_policy = r.stringViewOptional(10);
|
||||
if (sizes_strides_policy.has_value()) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Setting sizes_strides_policy isn't suppored for this overload")
|
||||
}
|
||||
const auto sizes_strides_policy = r.stringViewOptional(10);
|
||||
if (sizes_strides_policy.has_value()) {
|
||||
data.unsafeGetTensorImpl()->set_sizes_strides_policy(
|
||||
parseSizesStridesPolicyArgument(*sizes_strides_policy));
|
||||
}
|
||||
|
||||
tensor.set_requires_grad(r.toBool(9));
|
||||
|
||||
if (r.toBool(11)) {
|
||||
tensor.unsafeGetTensorImpl()->set_custom_device(true);
|
||||
data.unsafeGetTensorImpl()->set_custom_device(true);
|
||||
}
|
||||
|
||||
return THPVariable_NewWithVar(
|
||||
(PyTypeObject*)cls,
|
||||
std::move(tensor),
|
||||
std::move(data),
|
||||
c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
@ -1203,7 +1164,6 @@ PyObject* THPVariable_get_shape(THPVariable* self, void* unused) {
|
||||
if (check_has_torch_function((PyObject*)self)) {
|
||||
return handle_torch_function_getter(self, "shape");
|
||||
}
|
||||
// return THPSize_NewFromSymSizes(THPVariable_Unpack(self));
|
||||
return THPSize_New(THPVariable_Unpack(self));
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
@ -606,7 +606,13 @@ void addInputs(Node* n, const char* name, int64_t value) {
|
||||
}
|
||||
|
||||
void addInputs(Node* n, const char* name, c10::SymInt value) {
|
||||
addInputs(n, name, value.expect_int());
|
||||
using ArgumentStash = jit::tracer::ArgumentStash;
|
||||
if (ArgumentStash::hasValue(name)) {
|
||||
Value* v = ArgumentStash::popValue(name);
|
||||
n->addInput(v);
|
||||
} else {
|
||||
detail::genericAddInput(n, value);
|
||||
}
|
||||
}
|
||||
|
||||
void addInputs(Node* n, const char* name, c10::optional<int64_t> value) {
|
||||
@ -801,7 +807,7 @@ void addInputs(Node* n, const char* name, at::IntArrayRef value) {
|
||||
}
|
||||
|
||||
void addInputs(Node* n, const char* name, c10::SymIntArrayRef value) {
|
||||
addInputs(n, name, asIntArrayRefSlow(value));
|
||||
TORCH_CHECK(false, "Tracing operations taking symbolic ints isn't supported");
|
||||
}
|
||||
|
||||
void addInputs(
|
||||
|
@ -1,4 +1,3 @@
|
||||
#include <pybind11/pytypes.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
#include <torch/csrc/utils/python_arg_parser.h>
|
||||
|
||||
@ -12,7 +11,6 @@
|
||||
#if (!defined(FBCODE_CAFFE2) && defined(BUILD_ONEDNN_GRAPH))
|
||||
#include <torch/csrc/jit/codegen/onednn/interface.h>
|
||||
#endif
|
||||
#include <c10/core/SymbolicIntNode.h>
|
||||
#include <torch/csrc/jit/frontend/ir_emitter.h>
|
||||
#include <torch/csrc/jit/frontend/tracer.h>
|
||||
#include <torch/csrc/jit/ir/irparser.h>
|
||||
@ -105,7 +103,6 @@
|
||||
|
||||
#include <ATen/core/function_schema.h>
|
||||
|
||||
#include <pybind11/cast.h>
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/iostream.h>
|
||||
#include <pybind11/operators.h>
|
||||
@ -127,98 +124,6 @@ using ::c10::FunctionSchema;
|
||||
using caffe2::serialize::PyTorchStreamReader;
|
||||
using caffe2::serialize::PyTorchStreamWriter;
|
||||
|
||||
static std::shared_ptr<c10::SymbolicIntNode> toSymIntNode(
|
||||
std::shared_ptr<c10::SymbolicIntNode> a,
|
||||
py::object b) {
|
||||
return torch::is_symint_node(b)
|
||||
? b.cast<std::shared_ptr<c10::SymbolicIntNode>>()
|
||||
: a->wrap(b.cast<int64_t>());
|
||||
}
|
||||
|
||||
class PythonSymbolicIntNode : public c10::SymbolicIntNode {
|
||||
public:
|
||||
PythonSymbolicIntNode(py::object pyobj) : c10::SymbolicIntNode() {
|
||||
pyobj_ = std::make_shared<c10::SafePyObject>(
|
||||
pyobj.release().ptr(), getPyInterpreter());
|
||||
};
|
||||
|
||||
virtual std::shared_ptr<SymbolicIntNode> wrap(int64_t num) override {
|
||||
py::gil_scoped_acquire acquire;
|
||||
auto r = getPyObj().attr("wrap")(num);
|
||||
return std::make_shared<PythonSymbolicIntNode>(r);
|
||||
}
|
||||
|
||||
virtual bool bool_() override {
|
||||
py::gil_scoped_acquire acquire;
|
||||
return getPyObj().attr("__bool__")().is(py::handle(Py_True));
|
||||
}
|
||||
|
||||
virtual int64_t int_() override {
|
||||
py::gil_scoped_acquire acquire;
|
||||
return getPyObj().attr("__int__")().cast<int64_t>();
|
||||
}
|
||||
|
||||
virtual std::string str() override {
|
||||
py::gil_scoped_acquire acquire;
|
||||
return getPyObj().attr("__str__")().cast<std::string>();
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<SymbolicIntNode> dispatch_common_(
|
||||
const char* fname,
|
||||
const std::shared_ptr<SymbolicIntNode>& other) {
|
||||
auto pother = std::dynamic_pointer_cast<PythonSymbolicIntNode>(other);
|
||||
TORCH_CHECK(pother);
|
||||
py::gil_scoped_acquire acquire;
|
||||
auto r = getPyObj().attr(fname)(pother->getPyObj());
|
||||
return std::make_shared<PythonSymbolicIntNode>(r);
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<SymbolicIntNode> add(
|
||||
const std::shared_ptr<SymbolicIntNode>& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<SymbolicIntNode> sub(
|
||||
const std::shared_ptr<SymbolicIntNode>& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<SymbolicIntNode> mul(
|
||||
const std::shared_ptr<SymbolicIntNode>& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<SymbolicIntNode> div(
|
||||
const std::shared_ptr<SymbolicIntNode>& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<SymbolicIntNode> mod(
|
||||
const std::shared_ptr<SymbolicIntNode>& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<SymbolicIntNode> eq(
|
||||
const std::shared_ptr<SymbolicIntNode>& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<SymbolicIntNode> gt(
|
||||
const std::shared_ptr<SymbolicIntNode>& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<SymbolicIntNode> lt(
|
||||
const std::shared_ptr<SymbolicIntNode>& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
py::handle getPyObj() {
|
||||
return py::handle(pyobj_.get()->ptr(getPyInterpreter()));
|
||||
}
|
||||
std::shared_ptr<c10::SafePyObject> pyobj_ = nullptr;
|
||||
};
|
||||
|
||||
namespace {
|
||||
|
||||
using autograd::variable_list;
|
||||
@ -1171,101 +1076,6 @@ void initJITBindings(PyObject* module) {
|
||||
}
|
||||
});
|
||||
|
||||
py::class_<c10::SymbolicIntNode, std::shared_ptr<c10::SymbolicIntNode>>(
|
||||
m, "SymbolicIntNode")
|
||||
.def_static(
|
||||
"new_symint",
|
||||
[](py::object obj) -> std::shared_ptr<c10::SymbolicIntNode> {
|
||||
return std::make_shared<PythonSymbolicIntNode>(obj);
|
||||
})
|
||||
.def(
|
||||
"get_pyobj",
|
||||
[](std::shared_ptr<c10::SymbolicIntNode> a) -> py::object {
|
||||
if (auto psn =
|
||||
std::dynamic_pointer_cast<PythonSymbolicIntNode>(a)) {
|
||||
return py::reinterpret_borrow<py::object>(psn->getPyObj());
|
||||
}
|
||||
return py::none();
|
||||
})
|
||||
.def(
|
||||
"__add__",
|
||||
[](std::shared_ptr<c10::SymbolicIntNode> a,
|
||||
py::object b) -> std::shared_ptr<c10::SymbolicIntNode> {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->add(snb);
|
||||
})
|
||||
.def(
|
||||
"__radd__",
|
||||
[](std::shared_ptr<c10::SymbolicIntNode> a,
|
||||
py::object b) -> std::shared_ptr<c10::SymbolicIntNode> {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->add(snb);
|
||||
})
|
||||
.def(
|
||||
"__sub__",
|
||||
[](std::shared_ptr<c10::SymbolicIntNode> a,
|
||||
py::object b) -> std::shared_ptr<c10::SymbolicIntNode> {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->sub(snb);
|
||||
})
|
||||
.def(
|
||||
"__mul__",
|
||||
[](std::shared_ptr<c10::SymbolicIntNode> a,
|
||||
py::object b) -> std::shared_ptr<c10::SymbolicIntNode> {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->mul(snb);
|
||||
})
|
||||
.def(
|
||||
"__rmul__",
|
||||
[](std::shared_ptr<c10::SymbolicIntNode> a,
|
||||
py::object b) -> std::shared_ptr<c10::SymbolicIntNode> {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->mul(snb);
|
||||
})
|
||||
.def(
|
||||
"__div__",
|
||||
[](std::shared_ptr<c10::SymbolicIntNode> a,
|
||||
py::object b) -> std::shared_ptr<c10::SymbolicIntNode> {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->div(snb);
|
||||
})
|
||||
.def(
|
||||
"__mod__",
|
||||
[](std::shared_ptr<c10::SymbolicIntNode> a,
|
||||
py::object b) -> std::shared_ptr<c10::SymbolicIntNode> {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->mod(snb);
|
||||
})
|
||||
.def(
|
||||
"__eq__",
|
||||
[](std::shared_ptr<c10::SymbolicIntNode> a,
|
||||
py::object b) -> std::shared_ptr<c10::SymbolicIntNode> {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->eq(snb);
|
||||
})
|
||||
.def(
|
||||
"__gt__",
|
||||
[](std::shared_ptr<c10::SymbolicIntNode> a, py::object b) {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->gt(snb);
|
||||
})
|
||||
.def(
|
||||
"__lt__",
|
||||
[](std::shared_ptr<c10::SymbolicIntNode> a,
|
||||
py::object b) -> std::shared_ptr<c10::SymbolicIntNode> {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->lt(snb);
|
||||
})
|
||||
.def(
|
||||
"__bool__",
|
||||
[](std::shared_ptr<c10::SymbolicIntNode> a) { return a->bool_(); })
|
||||
.def(
|
||||
"__int__",
|
||||
[](std::shared_ptr<c10::SymbolicIntNode> a) { return a->int_(); })
|
||||
.def("__str__", [](std::shared_ptr<c10::SymbolicIntNode> a) {
|
||||
return a->str();
|
||||
});
|
||||
|
||||
// NOLINTNEXTLINE(bugprone-unused-raii)
|
||||
py::class_<CompleteArgumentSpec>(m, "CompleteArgumentSpec")
|
||||
.def("__repr__", [](CompleteArgumentSpec& self) {
|
||||
|
@ -837,10 +837,6 @@ inline py::object toPyObject(IValue ivalue) {
|
||||
#else
|
||||
TORCH_CHECK(false, "RRef is only supported with the distributed package");
|
||||
#endif
|
||||
} else if (ivalue.isSymInt()) {
|
||||
auto si = ivalue.toSymInt();
|
||||
return si.is_symbolic() ? py::cast(si.toSymbolicIntNode())
|
||||
: py::cast(si.expect_int());
|
||||
} else {
|
||||
AT_ERROR(
|
||||
"Missing cases in 'toPyObject'! Can't convert ",
|
||||
|
@ -14,10 +14,6 @@ namespace lazy {
|
||||
class TORCH_API SymbolicIntNode : public c10::SymbolicIntNode {
|
||||
public:
|
||||
SymbolicIntNode(NodePtr ptr) : node_(std::move(ptr)){};
|
||||
std::shared_ptr<c10::SymbolicIntNode> add(
|
||||
const std::shared_ptr<c10::SymbolicIntNode>& other) override {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
}
|
||||
NodePtr node_;
|
||||
};
|
||||
|
||||
|
@ -146,10 +146,6 @@ c10::SymIntArrayRef LTCTensorImpl::sym_sizes_custom() const {
|
||||
return c10::SymIntArrayRef(sym_sizes_.data(), sym_sizes_.size());
|
||||
}
|
||||
|
||||
c10::SymIntArrayRef LTCTensorImpl::sym_sizes() const {
|
||||
return sym_sizes_custom();
|
||||
}
|
||||
|
||||
void LTCTensorImpl::setup_size_properties() {
|
||||
size_t generation = tensor_->generation();
|
||||
if (generation != generation_) {
|
||||
|
@ -44,7 +44,6 @@ class TORCH_API LTCTensorImpl final : public c10::TensorImpl {
|
||||
bool is_contiguous_custom(at::MemoryFormat memory_format) const override;
|
||||
|
||||
virtual c10::SymIntArrayRef sym_sizes_custom() const override;
|
||||
virtual c10::SymIntArrayRef sym_sizes() const override;
|
||||
|
||||
#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
|
||||
const at::Storage& storage() const override {
|
||||
|
@ -647,17 +647,15 @@ bool is_float_or_complex_list(PyObject* obj) {
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool is_int_list(PyObject* obj, int broadcast_size) {
|
||||
static bool is_int_list_(PyObject* obj, int broadcast_size) {
|
||||
if (PyTuple_Check(obj) || PyList_Check(obj)) {
|
||||
if (PySequence_Size(obj) == 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
auto item = py::reinterpret_steal<py::object>(PySequence_GetItem(obj, 0));
|
||||
if (THPUtils_checkIndex(item.ptr())) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// NOTE: JIT tracer allows arbitrary scalar tensors to act as ints
|
||||
// in an intlist argument. Even float or complex scalar tensors.
|
||||
return (
|
||||
@ -669,33 +667,22 @@ static bool is_int_list(PyObject* obj, int broadcast_size) {
|
||||
return broadcast_size > 0 && THPUtils_checkLong(obj);
|
||||
}
|
||||
|
||||
static bool is_int_list(PyObject* obj, int broadcast_size) {
|
||||
return is_int_list_(obj, broadcast_size);
|
||||
}
|
||||
|
||||
static bool is_int_or_symint(PyObject* obj) {
|
||||
// THPUtils_checkIndex may call __index__ or __int__
|
||||
// which may have side effects if obj is a symint node
|
||||
// so we do `is_symint_node` check first
|
||||
// TODO: maybe we should be using checkLong here?
|
||||
return torch::is_symint_node(py::handle(obj)) || THPUtils_checkIndex(obj);
|
||||
if (THPUtils_checkLong(obj)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// TODO: test if it's the Python binding for SymbolicIntNode
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool is_int_or_symint_list(PyObject* obj, int broadcast_size) {
|
||||
if (PyTuple_Check(obj) || PyList_Check(obj)) {
|
||||
if (PySequence_Size(obj) == 0) {
|
||||
return true;
|
||||
}
|
||||
auto item = py::reinterpret_steal<py::object>(PySequence_GetItem(obj, 0));
|
||||
|
||||
if (is_int_or_symint(item.ptr())) {
|
||||
return true;
|
||||
}
|
||||
// NOTE: JIT tracer allows arbitrary scalar tensors to act as ints
|
||||
// in an intlist argument. Even float or complex scalar tensors.
|
||||
return (
|
||||
jit::tracer::isTracing() && THPVariable_Check(item.ptr()) &&
|
||||
THPVariable_Unpack(item.ptr()).sizes() == c10::IntArrayRef{});
|
||||
}
|
||||
// if a size is specified (e.g. IntArrayRef[2]) we also allow passing a single
|
||||
// int
|
||||
return broadcast_size > 0 && THPUtils_checkLong(obj);
|
||||
// TODO: add a check for SymbolicIntNode
|
||||
return is_int_list_(obj, broadcast_size);
|
||||
}
|
||||
|
||||
// argnum is needed for raising the TypeError, it's used in the error message.
|
||||
@ -1227,9 +1214,7 @@ bool FunctionSignature::parse(
|
||||
// if there is a single positional IntArrayRef argument, i.e. expand(..),
|
||||
// view(...), allow a var-args style IntArrayRef, so expand(5,3) behaves as
|
||||
// expand((5,3))
|
||||
if (max_pos_args == 1 &&
|
||||
(params[0].type_ == ParameterType::INT_LIST ||
|
||||
params[0].type_ == ParameterType::SYM_INT_LIST)) {
|
||||
if (max_pos_args == 1 && params[0].type_ == ParameterType::INT_LIST) {
|
||||
allow_varargs_intlist = true;
|
||||
}
|
||||
|
||||
@ -1287,7 +1272,7 @@ bool FunctionSignature::parse(
|
||||
// should avoid having complex signatures that make use of it...
|
||||
} else if (
|
||||
allow_varargs_intlist && arg_pos == 0 && !is_kwd &&
|
||||
is_int_or_symint(obj)) {
|
||||
THPUtils_checkIndex(obj)) {
|
||||
// take all positional arguments as this parameter
|
||||
// e.g. permute(1, 2, 3) -> permute((1, 2, 3))
|
||||
dst[i++] = args;
|
||||
|
@ -39,7 +39,6 @@
|
||||
// Scalar and Tensor, UNLESS they require grad (in which case
|
||||
// they only bind to Tensor).
|
||||
|
||||
#include <pybind11/pytypes.h>
|
||||
#include <torch/csrc/python_headers.h>
|
||||
|
||||
#include <torch/csrc/Device.h>
|
||||
@ -68,7 +67,6 @@
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
#include <c10/core/SymbolicIntNode.h>
|
||||
#include <array>
|
||||
#include <cstddef>
|
||||
#include <memory>
|
||||
@ -472,80 +470,9 @@ inline std::vector<int64_t> PythonArgs::intlist(int i) {
|
||||
return intlistWithDefault(i, signature.params[i].default_intlist);
|
||||
}
|
||||
|
||||
inline bool is_symint_node(py::handle obj) {
|
||||
auto static tp_symn = py::type::of<c10::SymbolicIntNode>();
|
||||
// TODO: switch this to `isinstance`
|
||||
if (obj.get_type().equal(tp_symn)) {
|
||||
TORCH_CHECK(
|
||||
!jit::tracer::isTracing(), "JIT tracing of SymInts isn't supported!");
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
inline PyObject* toPyObject(c10::SymInt symint) {
|
||||
if (symint.is_symbolic()) {
|
||||
return py::cast(symint.toSymbolicIntNode()).release().ptr();
|
||||
} else {
|
||||
return THPUtils_packInt64(symint.data());
|
||||
}
|
||||
}
|
||||
|
||||
inline std::vector<c10::SymInt> PythonArgs::symintlist(int i) {
|
||||
if (!args[i]) {
|
||||
return c10::fmap(signature.params[i].default_intlist, [](int64_t di) {
|
||||
return c10::SymInt(di);
|
||||
});
|
||||
}
|
||||
|
||||
const auto size1 = signature.params[i].size;
|
||||
if (size1 > 0 && THPUtils_checkLong(args[i])) {
|
||||
return std::vector<c10::SymInt>(
|
||||
size1, c10::SymInt(THPUtils_unpackIndex(args[i])));
|
||||
}
|
||||
|
||||
if (size1 > 0 && torch::is_symint_node(py::handle(args[i]))) {
|
||||
auto si = py::handle(args[i]).cast<c10::SymbolicIntNode*>()->toSymInt();
|
||||
return std::vector<c10::SymInt>(size1, si);
|
||||
}
|
||||
|
||||
PyObject* arg = args[i];
|
||||
auto tuple = PyTuple_Check(arg);
|
||||
// NOLINTNEXTLINE(bugprone-branch-clone)
|
||||
const auto size2 = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
|
||||
std::vector<c10::SymInt> res;
|
||||
res.reserve(size2);
|
||||
for (const auto idx : c10::irange(size2)) {
|
||||
PyObject* obj =
|
||||
tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx);
|
||||
try {
|
||||
if (is_symint_node(py::handle(obj))) {
|
||||
res.push_back(
|
||||
py::handle(obj).cast<c10::SymbolicIntNode*>()->toSymInt());
|
||||
} else {
|
||||
// Elements of torch.Size are tensors during tracing, and we need to
|
||||
// record extra information before they are turned into an IntArrayRef
|
||||
if (traceable && jit::tracer::isTracing() && THPVariable_Check(obj)) {
|
||||
auto& var = THPVariable_Unpack(obj);
|
||||
jit::tracer::ArgumentStash::stashIntArrayRefElem(
|
||||
signature.params[i].name, size2, idx, var);
|
||||
res.push_back(var.item<int64_t>());
|
||||
continue;
|
||||
}
|
||||
res.push_back(c10::SymInt(THPUtils_unpackIndex(obj)));
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
auto te = TypeError(
|
||||
"%s(): argument '%s' must be %s, but found element of type %s at pos %ld",
|
||||
signature.name.c_str(),
|
||||
signature.params[i].name.c_str(),
|
||||
signature.params[i].type_name().c_str(),
|
||||
Py_TYPE(obj)->tp_name,
|
||||
idx + 1);
|
||||
}
|
||||
}
|
||||
|
||||
return res;
|
||||
auto intlist = intlistWithDefault(i, signature.params[i].default_intlist);
|
||||
return c10::fmap(intlist, [](int64_t n) { return c10::SymInt(n); });
|
||||
}
|
||||
|
||||
inline std::vector<int64_t> PythonArgs::intlistWithDefault(
|
||||
@ -847,9 +774,6 @@ inline c10::SymInt PythonArgs::toSymInt(int i) {
|
||||
jit::tracer::ArgumentStash::stashValue(
|
||||
signature.params[i].name, idx, var, c10::IntType::get());
|
||||
}
|
||||
if (torch::is_symint_node(py::handle(args[i]))) {
|
||||
return py::handle(args[i]).cast<c10::SymbolicIntNode*>()->toSymInt();
|
||||
}
|
||||
return c10::SymInt(THPUtils_unpackLong(args[i]));
|
||||
}
|
||||
|
||||
|
@ -280,7 +280,6 @@ def get_ignored_functions() -> Set[Callable]:
|
||||
Tensor._addmm_activation,
|
||||
Tensor._nested_tensor_layer_norm,
|
||||
Tensor.to_padded_tensor,
|
||||
Tensor.sym_size
|
||||
}
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user