Wconstab/reland pysymint (#79795)

rebased https://github.com/pytorch/pytorch/pull/79617/ to see if issues are reproducible.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79795
Approved by: https://github.com/malfet
This commit is contained in:
Edward Z. Yang
2022-06-20 22:55:06 +00:00
committed by PyTorch MergeBot
parent a6b783e714
commit f7ee061638
31 changed files with 878 additions and 58 deletions

View File

@ -76,6 +76,13 @@ bool NestedTensorImpl::is_contiguous_custom(MemoryFormat) const {
IntArrayRef NestedTensorImpl::sizes_custom() const {
TORCH_CHECK(false, "Internal error: NestedTensorImpl doesn't support sizes. Please file an issue on https://github.com/pytorch/nestedtensor");
}
c10::SymIntArrayRef NestedTensorImpl::sym_sizes_custom() const {
TORCH_CHECK(false, "Internal error: NestedTensorImpl doesn't support sizes. Please file an issue on https://github.com/pytorch/nestedtensor");
}
c10::SymIntArrayRef NestedTensorImpl::sym_sizes() const {
return sym_sizes_custom();
}
IntArrayRef NestedTensorImpl::strides_custom() const {
TORCH_CHECK(false, "Internal error: NestedTensorImpl doesn't support strides. Please file an issue on https://github.com/pytorch/nestedtensor");

View File

@ -42,6 +42,8 @@ struct TORCH_API NestedTensorImpl : public c10::TensorImpl {
int64_t numel_custom() const override;
bool is_contiguous_custom(MemoryFormat) const override;
IntArrayRef sizes_custom() const override;
c10::SymIntArrayRef sym_sizes_custom() const override;
c10::SymIntArrayRef sym_sizes() const override;
IntArrayRef strides_custom() const override;
// this one is real

View File

@ -179,6 +179,7 @@ TORCH_LIBRARY_IMPL(aten, Named, m) {
m.impl("exp.out", CppFunction::makeFallthrough());
m.impl("exp_", CppFunction::makeFallthrough());
m.impl("expand", CppFunction::makeFallthrough());
m.impl("expand.SymInt", CppFunction::makeFallthrough());
m.impl("expm1", CppFunction::makeFallthrough());
m.impl("expm1.out", CppFunction::makeFallthrough());
m.impl("expm1_", CppFunction::makeFallthrough());

View File

@ -156,6 +156,14 @@ class TORCH_API TensorBase {
return at::isSignedType(this->scalar_type());
}
c10::SymInt sym_size(int64_t dim) const {
const auto sizes = this->sym_sizes();
const auto ndim = static_cast<int64_t>(sizes.size());
// false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping)
return sizes[c10::maybe_wrap_dim(dim, ndim, /*wrap_scalar=*/false)];
}
int64_t size(int64_t dim) const {
const auto sizes = this->sizes();
const auto ndim = static_cast<int64_t>(sizes.size());

View File

@ -132,6 +132,7 @@ class TORCH_API Tensor: public TensorBase {
// Aliased by Dimname overloads, so need explicit using
using TensorBase::size;
using TensorBase::sym_size;
using TensorBase::stride;
/// Should be used if *this can reasonably be expected to be contiguous and

View File

@ -25,4 +25,5 @@ SymIntTable& getSymIntTable() {
static SymIntTable sit;
return sit;
}
} // namespace c10

View File

@ -13,7 +13,53 @@ class C10_API SymbolicIntNode
public:
c10::SymInt toSymInt();
virtual ~SymbolicIntNode(){};
virtual std::ostream& operator<<(std::ostream& os) {
// these could be pure virtual when we implement LTC versions
virtual std::shared_ptr<SymbolicIntNode> add(
const std::shared_ptr<SymbolicIntNode>& other) {
TORCH_CHECK(false, "NYI");
};
virtual std::shared_ptr<SymbolicIntNode> sub(
const std::shared_ptr<SymbolicIntNode>& other) {
TORCH_CHECK(false, "NYI");
};
virtual std::shared_ptr<SymbolicIntNode> mul(
const std::shared_ptr<SymbolicIntNode>& other) {
TORCH_CHECK(false, "NYI");
};
virtual std::shared_ptr<SymbolicIntNode> div(
const std::shared_ptr<SymbolicIntNode>& other) {
TORCH_CHECK(false, "NYI");
};
virtual std::shared_ptr<SymbolicIntNode> mod(
const std::shared_ptr<SymbolicIntNode>& other) {
TORCH_CHECK(false, "NYI");
};
virtual std::shared_ptr<SymbolicIntNode> eq(
const std::shared_ptr<SymbolicIntNode>& other) {
TORCH_CHECK(false, "NYI");
};
virtual std::shared_ptr<SymbolicIntNode> gt(
const std::shared_ptr<SymbolicIntNode>& other) {
TORCH_CHECK(false, "NYI");
};
virtual std::shared_ptr<SymbolicIntNode> lt(
const std::shared_ptr<SymbolicIntNode>& other) {
TORCH_CHECK(false, "NYI");
};
virtual std::shared_ptr<SymbolicIntNode> wrap(int64_t num) {
TORCH_CHECK(false, "NYI");
};
virtual bool bool_() {
TORCH_CHECK(false, "NYI");
};
virtual int64_t int_() {
TORCH_CHECK(false, "NYI");
}
virtual std::string str() {
TORCH_CHECK(false, "NYI");
};
std::ostream& operator<<(std::ostream& os) {
os << str();
return os;
};
};

View File

@ -811,6 +811,15 @@ void TensorImpl::ShareExternalPointer(
}
}
void TensorImpl::set_sym_sizes_and_strides(
c10::SymIntArrayRef sizes,
c10::SymIntArrayRef strides) {
has_symbolic_sizes_strides_ = true;
sizes_strides_policy_ = static_cast<uint8_t>(SizesStridesPolicy::CustomSizes);
sizes_and_strides_.set_sizes(sizes);
sizes_and_strides_.set_strides(strides);
}
namespace impl {
namespace {

View File

@ -552,12 +552,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
return sizes_default();
}
c10::SymIntArrayRef sym_sizes() const {
if (C10_UNLIKELY(
sizes_strides_policy_ >=
static_cast<uint8_t>(SizesStridesPolicy::CustomSizes))) {
return sym_sizes_custom();
}
virtual c10::SymIntArrayRef sym_sizes() const {
return sym_sizes_default();
}
@ -1312,6 +1307,12 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
return numel() == 0;
}
// if we are going to use sym sizes, we should be setting sym strides at the
// same time, otherwise it's very easy to misuse this API
void set_sym_sizes_and_strides(
c10::SymIntArrayRef sizes,
c10::SymIntArrayRef strides);
/**
* Change the size at some dimension. This DOES NOT update strides;
* thus, most changes to size will not preserve contiguity. You probably
@ -2326,7 +2327,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
// Customizable sizes behavior, e.g., nested tensor
//
// Can override: strides(), is_contiguous(), sizes(), dim(), numel()
CustomSizes = 2,
CustomSizes = 2
};
void set_sizes_strides_policy(SizesStridesPolicy policy) {
@ -2337,6 +2338,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
custom_device_ = custom_device;
}
protected:
Storage storage_;
private:

View File

@ -170,6 +170,11 @@ class C10_API SizesAndStrides {
std::copy(newSizes.begin(), newSizes.end(), sizes_begin());
}
void set_strides(SymIntArrayRef strides) {
TORCH_INTERNAL_ASSERT(strides.size() == size());
std::copy(strides.begin(), strides.end(), strides_begin());
}
void set_sizes(IntArrayRef newSizes) {
set_sizes(SymIntArrayRef::fromIntArrayRef(newSizes));
}

View File

@ -320,6 +320,7 @@ coverage_ignore_classes = [
"Quantize",
# torch.utils.backcompat
"Warning",
"SymbolicIntNode"
]
# The suffix(es) of source filenames.

View File

@ -10,3 +10,4 @@ setuptools
six
types-dataclasses
typing_extensions
sympy

View File

@ -104,15 +104,20 @@ class TestLazyReuseIr(TestCase):
def testBatchNorm(self):
device = get_test_device()
x = torch.randn(16, 3, 224, 224, device=device)
bn = torch.nn.BatchNorm2d(3).to(device=device)
weight = torch.randn(3, device=device)
bias = torch.randn(3, device=device)
for i in range(10):
z = bn(x)
# BatchNorm2d does extra checks on dimensions which SymInts don't support yet
# so we call `torch.ops.aten.native_batch_norm` to bypass the checks.
z, _, _ = torch.ops.aten.native_batch_norm(x, weight, bias, None, None, True, 0.1, 1e-5)
device = "lazy"
x_lazy = x.detach().clone().to(device=device)
bn = bn.to(device=device)
weight_lazy = weight.detach().clone().to(device=device)
bias_lazy = bias.detach().clone().to(device=device)
for i in range(10):
z_lazy = bn(x_lazy)
z_lazy, _, _ = torch.ops.aten.native_batch_norm(x_lazy, weight_lazy, bias_lazy, None, None, True, 0.1, 1e-5)
torch._lazy.mark_step()
torch.testing.assert_close(z.cpu(), z_lazy.cpu())

View File

@ -17,6 +17,7 @@ import itertools
import yaml
import os
import pathlib
from unittest import skip
torch._lazy.ts_backend.init()
@ -66,6 +67,9 @@ def clone_move(t):
return copy_t
class TestLazyTensor(JitTestCase):
@skip("Disable until autograd supports symints")
def testConvolutionBackward(self):
test_device = get_test_device()
inp = torch.rand(1, 3, 128, 128, device=test_device, requires_grad=True)
@ -220,8 +224,9 @@ class TestLazyDynamicOps(TestCase):
x1 = torch.tensor([[0, 1.0, 2.0], [3.0, 0, 0]], device=test_device, requires_grad=True)
x1_lazy = clone_move(x1)
x2_lazy = torch.nonzero(x1_lazy)
print(x2_lazy.size())
self.assertEqual(tuple(x2_lazy.size()), (6, 2))
# FIXME: Add bindings to get upper bounds
# self.assertEqual(tuple(x2_lazy.size()), (6, 2))
# We should still be able to instantiate it and get the actual result
x2_eager = x2_lazy.cpu()

309
test/test_dynamic_shapes.py Normal file
View File

@ -0,0 +1,309 @@
# -*- coding: utf-8 -*-
# Owner(s): ["oncall: jit"]
from torch._C import _disabled_torch_function_impl
import torch.fx
import torch.nn.functional as F
from torch.testing._internal.common_utils import run_tests, TestCase
import unittest
import torch
from torch.utils._pytree import tree_map
aten = torch.ops.aten
try:
import sympy
HAS_SYMPY = True
except ImportError:
HAS_SYMPY = False
skipIfNoSympy = unittest.skipIf(not HAS_SYMPY, "no sympy")
meta_funcs = {}
def register_meta(op):
def decorator(f):
def add_func(op):
meta_funcs[op] = f
tree_map(add_func, op)
return f
return decorator
@register_meta([aten.add.Tensor, aten.sub.Tensor])
def binary_meta(a, b):
return a.new_empty(a.sym_size())
@register_meta(aten.cat.default)
def cat_meta(tensors, dim=0):
concat_length = 0
shape = tensors[0].shape
for tensor in tensors:
for idx, (common_length, length) in enumerate(zip(shape, tensor.shape)):
if idx == dim:
concat_length = concat_length + length
else:
assert length == common_length
new_shape = list(shape)
new_shape[dim] = concat_length
return tensors[0].new_empty(new_shape)
@register_meta([aten.narrow_copy.SymInt])
def narrow_copy_symint_meta(a, dim, start, length, **kwargs):
shape = []
for i, x in enumerate(a.sym_size()):
if i == dim:
shape.append(length)
else:
shape.append(x)
return a.new_empty(tuple(shape))
@register_meta([aten.expand.SymInt])
def expand_symint_meta(a, size, implicit=False):
return a.new_empty(size)
class PySymInt(object):
def __init__(self, expr, shape_env):
self.expr = expr
self.shape_env = shape_env
def wrap(self, num):
return PySymInt(sympy.Integer(num), self.shape_env)
def __str__(self):
return f"PySymInt({self.expr})"
def __int__(self):
return self.shape_env.evaluate_expr(self.expr)
def __bool__(self):
return bool(self.shape_env.evaluate_expr(self.expr))
magic_methods = {
'add': lambda a, b: a + b,
'radd': lambda a, b: a + b,
'sub': lambda a, b: a - b,
'mul': lambda a, b: a * b,
'div': lambda a, b: a / b,
'mod': lambda a, b: a % b,
'eq': lambda a, b: sympy.Eq(a, b),
'gt': lambda a, b: sympy.Gt(a, b),
'lt': lambda a, b: sympy.Lt(a, b),
}
for method, func in magic_methods.items():
method_name = f'{method}'
def create_magic_impl(func):
def magic_impl(self, other):
if isinstance(other, PySymInt):
other = other.expr
return PySymInt(func(self.expr, other), self.shape_env)
return magic_impl
# this should be wrapped transparently into torch._C.SymbolicIntNode
setattr(PySymInt, method_name, create_magic_impl(func))
class ShapeEnv(object):
def __init__(self):
self.guards = []
self.shape_env = {}
def create_symint(self, name, val):
sympy_expr = sympy.Symbol(name)
py_sym_int = PySymInt(sympy_expr, self)
cpp_sym_int = torch._C.SymbolicIntNode.new_symint(py_sym_int)
self.shape_env[sympy_expr] = val
return cpp_sym_int
def evaluate_expr(self, expr):
concrete_val = expr.subs(self.shape_env)
self.guards.append((expr, concrete_val))
return concrete_val
def create_contiguous(shape):
strides = [1]
for dim in reversed(shape[:-1]):
strides.append(dim * strides[-1])
return list(reversed(strides))
class FakeSymbolicTensor(torch.Tensor):
@staticmethod
def __new__(cls, sym_shape, sym_strides, dtype, layout, requires_grad, device):
# sym_strides doesn't work yet
# TODO: this is wrong in general
offset = 0
r = torch.Tensor._make_wrapper_subclass(
cls, sym_shape,
create_contiguous(sym_shape), offset,
dtype=dtype, layout=layout, requires_grad=requires_grad,
device=device,
)
return r
__torch_function__ = _disabled_torch_function_impl
def new_empty(self, shape):
return FakeSymbolicTensor(shape, None, self.dtype, self.layout, self.requires_grad, self.device)
@classmethod
def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None):
if func_overload in meta_funcs:
return meta_funcs[func_overload](*args, **kwargs)
if func_overload == torch.ops.aten.new_empty.default:
self = args[0]
shape = args[1]
return FakeSymbolicTensor(shape, self.stride(), self.dtype, self.layout, self.requires_grad, self.device)
raise RuntimeError(f"operator {func_overload} not supported")
def create_symbolic_tensor(name, arg, shape_env):
sym_shapes = tuple([shape_env.create_symint(f"{name}_{idx}", val) for idx, val in enumerate(arg.sym_size())])
sym_strides = tuple([shape_env.create_symint(f"{name}_{idx}_stride", val) for idx, val in enumerate(arg.stride())])
return FakeSymbolicTensor(sym_shapes, sym_strides, arg.dtype, arg.layout, arg.requires_grad, arg.device)
CPP_SYMINT_CLASS = type(torch._C.SymbolicIntNode.new_symint(1))
class TestPySymInt(TestCase):
@skipIfNoSympy
def test_roundtrip(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
self.assertTrue(not isinstance(x.sym_size(0), PySymInt))
self.assertTrue(isinstance(x.sym_size(0), CPP_SYMINT_CLASS))
self.assertEqual(int(x.sym_size(0)), 5)
self.assertEqual(int(x.sym_size(1)), 4)
self.assertEqual(int(x.sym_size(2)), 3)
self.assertEqual(int(x.sym_size()[0]), 5)
self.assertEqual(int(x.sym_size()[1]), 4)
self.assertTrue(isinstance(x.sym_size()[1], CPP_SYMINT_CLASS))
self.assertEqual(int(x.sym_size()[2]), 3)
self.assertEqual(int(x.sym_size(0)), 5)
self.assertEqual(int(x.sym_size(1)), 4)
self.assertEqual(int(x.sym_size(2)), 3)
self.assertTrue(isinstance(x.sym_size(2), CPP_SYMINT_CLASS))
@skipIfNoSympy
def test_binary(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
y = create_symbolic_tensor("y", torch.randn(5, 4, 3), shape_env)
z = x + y
self.assertEqual(int(z.sym_size(0)), 5)
self.assertEqual(int(z.sym_size(1)), 4)
self.assertEqual(int(z.sym_size(2)), 3)
# broadcasting
y = create_symbolic_tensor("y", torch.randn(1, 4, 1), shape_env)
z = x + y
self.assertEqual(int(z.sym_size(0)), 5)
self.assertEqual(int(z.sym_size(1)), 4)
self.assertEqual(int(z.sym_size(2)), 3)
@skipIfNoSympy
def test_symint_args(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
y = create_symbolic_tensor("y", torch.randn(5, 4, 1), shape_env)
LAST_DIM = 2
z = x.narrow_copy(LAST_DIM, 0, y.sym_size(LAST_DIM))
self.assertEqual(int(z.sym_size(2)), int(y.sym_size(2)))
# arithmetic expr with two symints
z = x.narrow_copy(LAST_DIM, 0, x.sym_size(LAST_DIM) - y.sym_size(LAST_DIM))
self.assertEqual(int(z.sym_size(2)), 2)
# arithmetic expr with a symint and python int
z = x.narrow_copy(LAST_DIM, 0, x.sym_size(LAST_DIM) - 1)
self.assertEqual(int(z.sym_size(2)), 2)
@skipIfNoSympy
def test_symint_vargs(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
y = create_symbolic_tensor("y", torch.randn(1, 4, 1), shape_env)
# varargs
z = y.expand(x.sym_size(0), y.sym_size(1), x.sym_size(2))
self.assertEqual(int(z.sym_size(0)), 5)
self.assertEqual(int(z.sym_size(1)), 4)
self.assertEqual(int(z.sym_size(2)), 3)
# shape list
z = y.expand((x.sym_size(0), y.sym_size(1), x.sym_size(2)))
self.assertEqual(int(z.sym_size(0)), 5)
self.assertEqual(int(z.sym_size(1)), 4)
self.assertEqual(int(z.sym_size(2)), 3)
# mixed python symints and ints
z = y.expand(x.sym_size(0), y.sym_size(1), 3)
self.assertEqual(int(z.sym_size(0)), 5)
self.assertEqual(int(z.sym_size(1)), 4)
self.assertEqual(int(z.sym_size(2)), 3)
# mixed python symints and ints in a list
z = y.expand((x.sym_size(0), y.sym_size(1), 3))
self.assertEqual(int(z.sym_size(0)), 5)
self.assertEqual(int(z.sym_size(1)), 4)
self.assertEqual(int(z.sym_size(2)), 3)
# mixed python symints and ints
z = y.expand(5, y.sym_size(1), x.sym_size(2))
self.assertEqual(int(z.sym_size(0)), 5)
self.assertEqual(int(z.sym_size(1)), 4)
self.assertEqual(int(z.sym_size(2)), 3)
# mixed python ints and symints in a list
z = y.expand((5, y.sym_size(1), x.sym_size(2)))
self.assertEqual(int(z.sym_size(0)), 5)
self.assertEqual(int(z.sym_size(1)), 4)
self.assertEqual(int(z.sym_size(2)), 3)
@skipIfNoSympy
def test_size_expressions(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
expand_x = x.expand(x.sym_size(0), x.sym_size(0))
if expand_x.sym_size(0) > 3:
result = expand_x + expand_x
else:
result = expand_x + expand_x
gt_op = shape_env.guards[0][0]
self.assertTrue(isinstance(gt_op, sympy.core.relational.StrictGreaterThan))
self.assertTrue(str(x.sym_size(0)), str(gt_op.args[0]))
self.assertTrue(str(expand_x.sym_size(1)), str(x.sym_size(0)))
self.assertTrue(str(expand_x.sym_size(1)), str(result.sym_size(0)))
def test_fx_trace_intlist(self):
class CustomModule(torch.nn.Module):
def forward(self, x):
bs, c, h, w = x.shape
return F.pad(x, (0, w % 2, 0, h % 2, 0, 0))
m = CustomModule()
x = torch.rand(1, 3, 4, 4)
# should not TypeError: pad(): argument 'pad' (position 2) must be
# tuple of ints, not tuple
torch.fx.symbolic_trace(m)
if __name__ == '__main__':
run_tests()

View File

@ -146,7 +146,7 @@ class TestNestedTensor(TestCase):
a1 = constructor([])
self.assertRaisesRegex(
RuntimeError,
"Tensors of type NestedTensorImpl do not have sizes"
"Tensors of type NestedTensorImpl do not have sym sizes"
if IS_FBCODE
else "NestedTensorImpl doesn't support sizes",
lambda: a1.size(),

View File

@ -188,6 +188,7 @@ class TestPublicBindings(TestCase):
"StreamObjType",
"StringType",
"SUM",
"SymbolicIntNode",
"TensorType",
"ThroughputBenchmark",
"TracingState",

View File

@ -1113,6 +1113,8 @@ def group_overloads(
def sort_overloads(
grouped_overloads: Sequence[PythonSignatureGroup],
) -> Sequence[PythonSignatureGroup]:
# NB: Smaller here means lower priority
def is_arg_smaller(t1: Type, t2: Type) -> bool:
return (
str(t1) == "Scalar"
@ -1131,6 +1133,10 @@ def sort_overloads(
# last in signature ordering. See discussion: https://github.com/pytorch/pytorch/issues/58087
str(t1) == "Tensor[]"
and str(t2).find("[]") != -1
or
# Prioritize SymIntArrayRef overload over IntArrayRef
str(t1) == "int[]"
and str(t2) == "SymInt[]"
)
def is_smaller(s1: PythonSignature, s2: PythonSignature) -> bool:

View File

@ -95,6 +95,43 @@ static PyObject * THPVariable_apply_(PyObject* self, PyObject* arg)
END_HANDLE_TH_ERRORS
}
// TODO: FIXME This should be super temprorary until we fix the XLA issue.
static PyObject * THPVariable_sym_size(PyObject* self, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
static PythonArgParser parser({
"sym_size(int64_t dim)",
"sym_size()",
"sym_size(Dimname dim)",
});
auto& self_ = THPVariable_Unpack(self);
ParsedArgs<3> parsed_args;
auto r = parser.parse(self, args, kwargs, parsed_args);
if(r.has_torch_function()){
return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor");
}
if (r.idx == 0) {
if (jit::tracer::isTracing()) {
// will error out if a tensor has symints
return wrap(jit::tracer::getSizeOf(self_, r.toInt64(0)));
} else {
return torch::toPyObject(self_.sym_size(r.toInt64(0)));
}
} else if (r.idx == 1) {
return THPSize_NewFromSymSizes(self_);
}
else if (r.idx == 2) {
if (jit::tracer::isTracing()) {
TORCH_INTERNAL_ASSERT(false, "NYI: Named tensors w/ JIT");
}
return wrap(self_.size(r.dimname(0)));
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static PyObject * THPVariable_size(PyObject* self, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
@ -110,17 +147,19 @@ static PyObject * THPVariable_size(PyObject* self, PyObject* args, PyObject* kwa
if(r.has_torch_function()){
return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor");
}
if (r.idx == 0) {
if (jit::tracer::isTracing()) {
// will error out if a tensor has symints
return wrap(jit::tracer::getSizeOf(self_, r.toInt64(0)));
} else {
return wrap(self_.size(r.toInt64(0)));
//return torch::toPyObject(self_.sym_size(r.toInt64(0)));
}
} else if (r.idx == 1) {
// we can't do the normal wrapping here because IntArrayRef maps to both
// torch.Size and tuple in python.
return THPSize_New(self_);
//return THPSize_NewFromSymSizes(self_);
}
else if (r.idx == 2) {
if (jit::tracer::isTracing()) {
@ -1283,6 +1322,7 @@ PyMethodDef variable_methods[] = {
{"set_", castPyCFunctionWithKeywords(THPVariable_set_), METH_VARARGS | METH_KEYWORDS, NULL},
{"short", castPyCFunctionWithKeywords(THPVariable_short), METH_VARARGS | METH_KEYWORDS, NULL},
{"size", castPyCFunctionWithKeywords(THPVariable_size), METH_VARARGS | METH_KEYWORDS, NULL},
{"sym_size", castPyCFunctionWithKeywords(THPVariable_sym_size), METH_VARARGS | METH_KEYWORDS, NULL},
{"_storage", THPVariable_storage, METH_NOARGS, NULL},
{"storage_offset", THPVariable_storage_offset, METH_NOARGS, NULL},
{"stride", castPyCFunctionWithKeywords(THPVariable_stride), METH_VARARGS | METH_KEYWORDS, NULL},

View File

@ -1,7 +1,9 @@
#include <c10/util/irange.h>
#include <pybind11/pytypes.h>
#include <torch/csrc/Size.h>
#include <torch/csrc/utils/object_ptr.h>
#include <torch/csrc/utils/python_arg_parser.h>
#include <torch/csrc/utils/python_numbers.h>
#include <torch/csrc/utils/python_strings.h>
#include <torch/csrc/utils/python_tuples.h>
@ -42,6 +44,36 @@ PyObject* THPSize_NewFromSizes(int dim, const int64_t* sizes) {
return self.release();
}
PyObject* THPSize_NewFromSymSizes(const at::Tensor& self_) {
auto sym_sizes = self_.sym_sizes();
auto ret = THPObjectPtr(THPSizeType.tp_alloc(&THPSizeType, sym_sizes.size()));
if (!ret)
throw python_error();
for (auto i : c10::irange(sym_sizes.size())) {
auto si = sym_sizes[i];
if (si.is_symbolic()) {
TORCH_CHECK(
!torch::jit::tracer::isTracing(),
"JIT Tracing of SymInts isn't supported");
auto py_symint = py::cast(si.toSymbolicIntNode()).release().ptr();
PyTuple_SET_ITEM(ret.get(), i, py_symint);
} else {
if (torch::jit::tracer::isTracing()) {
PyObject* py_size_tensor =
THPVariable_Wrap(torch::jit::tracer::getSizeOf(self_, i));
if (!py_size_tensor)
throw python_error();
PyTuple_SET_ITEM(ret.get(), i, py_size_tensor);
} else {
PyTuple_SET_ITEM(ret.get(), i, THPUtils_packInt64(si.data()));
}
}
}
return ret.release();
}
static bool isTracedZeroDimVar(PyObject* item) {
if (!THPVariable_Check(item))
return false;
@ -61,6 +93,9 @@ static PyObject* THPSize_pynew(
if (THPUtils_checkLong(item)) {
continue;
}
if (torch::is_symint_node(item)) {
continue;
}
if (torch::jit::tracer::isTracing() && isTracedZeroDimVar(item)) {
continue;
}
@ -92,7 +127,12 @@ static PyObject* THPSize_repr(THPSize* self) {
if (i != 0) {
repr += ", ";
}
repr += std::to_string(THPUtils_unpackLong(PyTuple_GET_ITEM(self, i)));
auto item = PyTuple_GET_ITEM(self, i);
auto ih = py::handle(item);
repr += torch::is_symint_node(ih)
? std::string(py::str(ih))
: std::to_string(THPUtils_unpackLong(PyTuple_GET_ITEM(self, i)));
}
repr += "])";
return THPUtils_packString(repr);

View File

@ -10,5 +10,6 @@ extern PyTypeObject THPSizeType;
PyObject* THPSize_New(const torch::autograd::Variable& t);
PyObject* THPSize_NewFromSizes(int dim, const int64_t* sizes);
PyObject* THPSize_NewFromSymSizes(const at::Tensor& t);
void THPSize_init(PyObject* module);

View File

@ -658,6 +658,10 @@ static PyObject* THPVariable_make_wrapper_subclass(
"int64_t? storage_offset=None, MemoryFormat? memory_format=None, ScalarType dtype=None, "
"Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False, "
"c10::string_view? dispatch_sizes_strides_policy=None, bool dispatch_device=False)",
"_make_wrapper_subclass(PyObject* cls, SymIntArrayRef size, SymIntArrayRef strides, "
"int64_t? storage_offset=None, MemoryFormat? memory_format=None, ScalarType dtype=None, "
"Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False, "
"c10::string_view? dispatch_sizes_strides_policy=None, bool dispatch_device=False)",
});
ParsedArgs<12> parsed_args{};
auto r = parser.parse(args, kwargs, parsed_args);
@ -697,29 +701,64 @@ static PyObject* THPVariable_make_wrapper_subclass(
// data
// TODO: for_blob produces non-resizable tensors, we might want this to be
// resizable (have to define a custom allocator in that case)
auto data =
at::for_blob(nullptr, r.intlist(1))
.strides(r.intlistOptional(2))
.storage_offset(r.toInt64Optional(3))
.context(nullptr, [](void* ctx) {})
.target_device(options.device()) // TODO: this shouldn't be necessary
// if it came from options
.options(options)
.make_tensor();
data.set_requires_grad(r.toBool(9));
Tensor tensor;
if (r.idx == 0) {
tensor = at::for_blob(nullptr, r.intlist(1))
.strides(r.intlistOptional(2))
.storage_offset(r.toInt64Optional(3))
.context(nullptr, [](void* ctx) {})
.target_device(
options.device()) // TODO: this shouldn't be necessary if
// it came from options
.options(options)
.make_tensor();
const auto sizes_strides_policy = r.stringViewOptional(10);
if (sizes_strides_policy.has_value()) {
data.unsafeGetTensorImpl()->set_sizes_strides_policy(
parseSizesStridesPolicyArgument(*sizes_strides_policy));
const auto sizes_strides_policy = r.stringViewOptional(10);
if (sizes_strides_policy.has_value()) {
tensor.unsafeGetTensorImpl()->set_sizes_strides_policy(
parseSizesStridesPolicyArgument(*sizes_strides_policy));
}
} else {
AutoDispatchBelowADInplaceOrView guard{}; // TODO: Remove.
tracer::impl::NoTracerDispatchMode tracer_guard{};
// We shouldn't need storage
Storage storage{Storage::use_byte_size_t{}, 0, at::DataPtr{}};
tensor = at::detail::make_tensor<TensorImpl>(
std::move(storage), options.computeDispatchKey(), options.dtype());
auto sym_sizes = r.symintlist(1);
auto sym_strides = r.symintlist(2);
TensorImpl* tensor_impl = tensor.unsafeGetTensorImpl();
// TODO: this should probably be sym_sizes, sym_strides AND offset
tensor_impl->set_sym_sizes_and_strides(sym_sizes, sym_strides);
// TODO: this may need to be symbolic as well
auto storage_offset = r.toInt64Optional(3);
if (storage_offset) {
tensor_impl->set_storage_offset(*storage_offset);
}
const auto sizes_strides_policy = r.stringViewOptional(10);
if (sizes_strides_policy.has_value()) {
TORCH_CHECK(
false,
"Setting sizes_strides_policy isn't suppored for this overload")
}
}
tensor.set_requires_grad(r.toBool(9));
if (r.toBool(11)) {
data.unsafeGetTensorImpl()->set_custom_device(true);
tensor.unsafeGetTensorImpl()->set_custom_device(true);
}
return THPVariable_NewWithVar(
(PyTypeObject*)cls,
std::move(data),
std::move(tensor),
c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
END_HANDLE_TH_ERRORS
}
@ -1166,6 +1205,7 @@ PyObject* THPVariable_get_shape(THPVariable* self, void* unused) {
if (check_has_torch_function((PyObject*)self)) {
return handle_torch_function_getter(self, "shape");
}
// return THPSize_NewFromSymSizes(THPVariable_Unpack(self));
return THPSize_New(THPVariable_Unpack(self));
END_HANDLE_TH_ERRORS
}

View File

@ -606,13 +606,7 @@ void addInputs(Node* n, const char* name, int64_t value) {
}
void addInputs(Node* n, const char* name, c10::SymInt value) {
using ArgumentStash = jit::tracer::ArgumentStash;
if (ArgumentStash::hasValue(name)) {
Value* v = ArgumentStash::popValue(name);
n->addInput(v);
} else {
detail::genericAddInput(n, value);
}
addInputs(n, name, value.expect_int());
}
void addInputs(Node* n, const char* name, c10::optional<int64_t> value) {
@ -807,7 +801,7 @@ void addInputs(Node* n, const char* name, at::IntArrayRef value) {
}
void addInputs(Node* n, const char* name, c10::SymIntArrayRef value) {
TORCH_CHECK(false, "Tracing operations taking symbolic ints isn't supported");
addInputs(n, name, asIntArrayRefSlow(value));
}
void addInputs(

View File

@ -1,3 +1,4 @@
#include <pybind11/pytypes.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/python_arg_parser.h>
@ -11,6 +12,7 @@
#if (!defined(FBCODE_CAFFE2) && defined(BUILD_ONEDNN_GRAPH))
#include <torch/csrc/jit/codegen/onednn/interface.h>
#endif
#include <c10/core/SymbolicIntNode.h>
#include <torch/csrc/jit/frontend/ir_emitter.h>
#include <torch/csrc/jit/frontend/tracer.h>
#include <torch/csrc/jit/ir/irparser.h>
@ -103,6 +105,7 @@
#include <ATen/core/function_schema.h>
#include <pybind11/cast.h>
#include <pybind11/functional.h>
#include <pybind11/iostream.h>
#include <pybind11/operators.h>
@ -124,6 +127,98 @@ using ::c10::FunctionSchema;
using caffe2::serialize::PyTorchStreamReader;
using caffe2::serialize::PyTorchStreamWriter;
static std::shared_ptr<c10::SymbolicIntNode> toSymIntNode(
std::shared_ptr<c10::SymbolicIntNode> a,
py::object b) {
return torch::is_symint_node(b)
? b.cast<std::shared_ptr<c10::SymbolicIntNode>>()
: a->wrap(b.cast<int64_t>());
}
class PythonSymbolicIntNode : public c10::SymbolicIntNode {
public:
PythonSymbolicIntNode(py::object pyobj) : c10::SymbolicIntNode() {
pyobj_ = std::make_shared<c10::SafePyObject>(
pyobj.release().ptr(), getPyInterpreter());
};
virtual std::shared_ptr<SymbolicIntNode> wrap(int64_t num) override {
py::gil_scoped_acquire acquire;
auto r = getPyObj().attr("wrap")(num);
return std::make_shared<PythonSymbolicIntNode>(r);
}
virtual bool bool_() override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("__bool__")().is(py::handle(Py_True));
}
virtual int64_t int_() override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("__int__")().cast<int64_t>();
}
virtual std::string str() override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("__str__")().cast<std::string>();
}
virtual std::shared_ptr<SymbolicIntNode> dispatch_common_(
const char* fname,
const std::shared_ptr<SymbolicIntNode>& other) {
auto pother = std::dynamic_pointer_cast<PythonSymbolicIntNode>(other);
TORCH_CHECK(pother);
py::gil_scoped_acquire acquire;
auto r = getPyObj().attr(fname)(pother->getPyObj());
return std::make_shared<PythonSymbolicIntNode>(r);
}
virtual std::shared_ptr<SymbolicIntNode> add(
const std::shared_ptr<SymbolicIntNode>& other) override {
return dispatch_common_(__FUNCTION__, other);
}
virtual std::shared_ptr<SymbolicIntNode> sub(
const std::shared_ptr<SymbolicIntNode>& other) override {
return dispatch_common_(__FUNCTION__, other);
}
virtual std::shared_ptr<SymbolicIntNode> mul(
const std::shared_ptr<SymbolicIntNode>& other) override {
return dispatch_common_(__FUNCTION__, other);
}
virtual std::shared_ptr<SymbolicIntNode> div(
const std::shared_ptr<SymbolicIntNode>& other) override {
return dispatch_common_(__FUNCTION__, other);
}
virtual std::shared_ptr<SymbolicIntNode> mod(
const std::shared_ptr<SymbolicIntNode>& other) override {
return dispatch_common_(__FUNCTION__, other);
}
virtual std::shared_ptr<SymbolicIntNode> eq(
const std::shared_ptr<SymbolicIntNode>& other) override {
return dispatch_common_(__FUNCTION__, other);
}
virtual std::shared_ptr<SymbolicIntNode> gt(
const std::shared_ptr<SymbolicIntNode>& other) override {
return dispatch_common_(__FUNCTION__, other);
}
virtual std::shared_ptr<SymbolicIntNode> lt(
const std::shared_ptr<SymbolicIntNode>& other) override {
return dispatch_common_(__FUNCTION__, other);
}
py::handle getPyObj() {
return py::handle(pyobj_.get()->ptr(getPyInterpreter()));
}
std::shared_ptr<c10::SafePyObject> pyobj_ = nullptr;
};
namespace {
using autograd::variable_list;
@ -1077,6 +1172,101 @@ void initJITBindings(PyObject* module) {
}
});
py::class_<c10::SymbolicIntNode, std::shared_ptr<c10::SymbolicIntNode>>(
m, "SymbolicIntNode")
.def_static(
"new_symint",
[](py::object obj) -> std::shared_ptr<c10::SymbolicIntNode> {
return std::make_shared<PythonSymbolicIntNode>(obj);
})
.def(
"get_pyobj",
[](std::shared_ptr<c10::SymbolicIntNode> a) -> py::object {
if (auto psn =
std::dynamic_pointer_cast<PythonSymbolicIntNode>(a)) {
return py::reinterpret_borrow<py::object>(psn->getPyObj());
}
return py::none();
})
.def(
"__add__",
[](std::shared_ptr<c10::SymbolicIntNode> a,
py::object b) -> std::shared_ptr<c10::SymbolicIntNode> {
auto snb = toSymIntNode(a, b);
return a->add(snb);
})
.def(
"__radd__",
[](std::shared_ptr<c10::SymbolicIntNode> a,
py::object b) -> std::shared_ptr<c10::SymbolicIntNode> {
auto snb = toSymIntNode(a, b);
return a->add(snb);
})
.def(
"__sub__",
[](std::shared_ptr<c10::SymbolicIntNode> a,
py::object b) -> std::shared_ptr<c10::SymbolicIntNode> {
auto snb = toSymIntNode(a, b);
return a->sub(snb);
})
.def(
"__mul__",
[](std::shared_ptr<c10::SymbolicIntNode> a,
py::object b) -> std::shared_ptr<c10::SymbolicIntNode> {
auto snb = toSymIntNode(a, b);
return a->mul(snb);
})
.def(
"__rmul__",
[](std::shared_ptr<c10::SymbolicIntNode> a,
py::object b) -> std::shared_ptr<c10::SymbolicIntNode> {
auto snb = toSymIntNode(a, b);
return a->mul(snb);
})
.def(
"__div__",
[](std::shared_ptr<c10::SymbolicIntNode> a,
py::object b) -> std::shared_ptr<c10::SymbolicIntNode> {
auto snb = toSymIntNode(a, b);
return a->div(snb);
})
.def(
"__mod__",
[](std::shared_ptr<c10::SymbolicIntNode> a,
py::object b) -> std::shared_ptr<c10::SymbolicIntNode> {
auto snb = toSymIntNode(a, b);
return a->mod(snb);
})
.def(
"__eq__",
[](std::shared_ptr<c10::SymbolicIntNode> a,
py::object b) -> std::shared_ptr<c10::SymbolicIntNode> {
auto snb = toSymIntNode(a, b);
return a->eq(snb);
})
.def(
"__gt__",
[](std::shared_ptr<c10::SymbolicIntNode> a, py::object b) {
auto snb = toSymIntNode(a, b);
return a->gt(snb);
})
.def(
"__lt__",
[](std::shared_ptr<c10::SymbolicIntNode> a,
py::object b) -> std::shared_ptr<c10::SymbolicIntNode> {
auto snb = toSymIntNode(a, b);
return a->lt(snb);
})
.def(
"__bool__",
[](std::shared_ptr<c10::SymbolicIntNode> a) { return a->bool_(); })
.def(
"__int__",
[](std::shared_ptr<c10::SymbolicIntNode> a) { return a->int_(); })
.def("__str__", [](std::shared_ptr<c10::SymbolicIntNode> a) {
return a->str();
});
// NOLINTNEXTLINE(bugprone-unused-raii)
py::class_<CompleteArgumentSpec>(m, "CompleteArgumentSpec")
.def("__repr__", [](CompleteArgumentSpec& self) {

View File

@ -837,6 +837,10 @@ inline py::object toPyObject(IValue ivalue) {
#else
TORCH_CHECK(false, "RRef is only supported with the distributed package");
#endif
} else if (ivalue.isSymInt()) {
auto si = ivalue.toSymInt();
return si.is_symbolic() ? py::cast(si.toSymbolicIntNode())
: py::cast(si.expect_int());
} else {
AT_ERROR(
"Missing cases in 'toPyObject'! Can't convert ",

View File

@ -14,6 +14,10 @@ namespace lazy {
class TORCH_API SymbolicIntNode : public c10::SymbolicIntNode {
public:
SymbolicIntNode(NodePtr ptr) : node_(std::move(ptr)){};
std::shared_ptr<c10::SymbolicIntNode> add(
const std::shared_ptr<c10::SymbolicIntNode>& other) override {
TORCH_CHECK(false, "NYI");
}
NodePtr node_;
};

View File

@ -146,6 +146,10 @@ c10::SymIntArrayRef LTCTensorImpl::sym_sizes_custom() const {
return c10::SymIntArrayRef(sym_sizes_.data(), sym_sizes_.size());
}
c10::SymIntArrayRef LTCTensorImpl::sym_sizes() const {
return sym_sizes_custom();
}
void LTCTensorImpl::setup_size_properties() {
size_t generation = tensor_->generation();
if (generation != generation_) {

View File

@ -44,6 +44,7 @@ class TORCH_API LTCTensorImpl final : public c10::TensorImpl {
bool is_contiguous_custom(at::MemoryFormat memory_format) const override;
virtual c10::SymIntArrayRef sym_sizes_custom() const override;
virtual c10::SymIntArrayRef sym_sizes() const override;
#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
const at::Storage& storage() const override {

View File

@ -647,15 +647,17 @@ bool is_float_or_complex_list(PyObject* obj) {
return true;
}
static bool is_int_list_(PyObject* obj, int broadcast_size) {
static bool is_int_list(PyObject* obj, int broadcast_size) {
if (PyTuple_Check(obj) || PyList_Check(obj)) {
if (PySequence_Size(obj) == 0) {
return true;
}
auto item = py::reinterpret_steal<py::object>(PySequence_GetItem(obj, 0));
if (THPUtils_checkIndex(item.ptr())) {
return true;
}
// NOTE: JIT tracer allows arbitrary scalar tensors to act as ints
// in an intlist argument. Even float or complex scalar tensors.
return (
@ -667,22 +669,33 @@ static bool is_int_list_(PyObject* obj, int broadcast_size) {
return broadcast_size > 0 && THPUtils_checkLong(obj);
}
static bool is_int_list(PyObject* obj, int broadcast_size) {
return is_int_list_(obj, broadcast_size);
}
static bool is_int_or_symint(PyObject* obj) {
if (THPUtils_checkLong(obj)) {
return true;
}
// TODO: test if it's the Python binding for SymbolicIntNode
return false;
// THPUtils_checkIndex may call __index__ or __int__
// which may have side effects if obj is a symint node
// so we do `is_symint_node` check first
// TODO: maybe we should be using checkLong here?
return torch::is_symint_node(py::handle(obj)) || THPUtils_checkIndex(obj);
}
static bool is_int_or_symint_list(PyObject* obj, int broadcast_size) {
// TODO: add a check for SymbolicIntNode
return is_int_list_(obj, broadcast_size);
if (PyTuple_Check(obj) || PyList_Check(obj)) {
if (PySequence_Size(obj) == 0) {
return true;
}
auto item = py::reinterpret_steal<py::object>(PySequence_GetItem(obj, 0));
if (is_int_or_symint(item.ptr())) {
return true;
}
// NOTE: JIT tracer allows arbitrary scalar tensors to act as ints
// in an intlist argument. Even float or complex scalar tensors.
return (
jit::tracer::isTracing() && THPVariable_Check(item.ptr()) &&
THPVariable_Unpack(item.ptr()).sizes() == c10::IntArrayRef{});
}
// if a size is specified (e.g. IntArrayRef[2]) we also allow passing a single
// int
return broadcast_size > 0 && THPUtils_checkLong(obj);
}
// argnum is needed for raising the TypeError, it's used in the error message.
@ -1214,7 +1227,9 @@ bool FunctionSignature::parse(
// if there is a single positional IntArrayRef argument, i.e. expand(..),
// view(...), allow a var-args style IntArrayRef, so expand(5,3) behaves as
// expand((5,3))
if (max_pos_args == 1 && params[0].type_ == ParameterType::INT_LIST) {
if (max_pos_args == 1 &&
(params[0].type_ == ParameterType::INT_LIST ||
params[0].type_ == ParameterType::SYM_INT_LIST)) {
allow_varargs_intlist = true;
}
@ -1272,7 +1287,7 @@ bool FunctionSignature::parse(
// should avoid having complex signatures that make use of it...
} else if (
allow_varargs_intlist && arg_pos == 0 && !is_kwd &&
THPUtils_checkIndex(obj)) {
is_int_or_symint(obj)) {
// take all positional arguments as this parameter
// e.g. permute(1, 2, 3) -> permute((1, 2, 3))
dst[i++] = args;

View File

@ -39,6 +39,7 @@
// Scalar and Tensor, UNLESS they require grad (in which case
// they only bind to Tensor).
#include <pybind11/pytypes.h>
#include <torch/csrc/python_headers.h>
#include <torch/csrc/Device.h>
@ -67,6 +68,7 @@
#include <c10/util/Exception.h>
#include <c10/util/irange.h>
#include <c10/core/SymbolicIntNode.h>
#include <array>
#include <cstddef>
#include <memory>
@ -470,9 +472,80 @@ inline std::vector<int64_t> PythonArgs::intlist(int i) {
return intlistWithDefault(i, signature.params[i].default_intlist);
}
inline bool is_symint_node(py::handle obj) {
auto static tp_symn = py::type::of<c10::SymbolicIntNode>();
// TODO: switch this to `isinstance`
if (obj.get_type().equal(tp_symn)) {
TORCH_CHECK(
!jit::tracer::isTracing(), "JIT tracing of SymInts isn't supported!");
return true;
}
return false;
}
inline PyObject* toPyObject(c10::SymInt symint) {
if (symint.is_symbolic()) {
return py::cast(symint.toSymbolicIntNode()).release().ptr();
} else {
return THPUtils_packInt64(symint.data());
}
}
inline std::vector<c10::SymInt> PythonArgs::symintlist(int i) {
auto intlist = intlistWithDefault(i, signature.params[i].default_intlist);
return c10::fmap(intlist, [](int64_t n) { return c10::SymInt(n); });
if (!args[i]) {
return c10::fmap(signature.params[i].default_intlist, [](int64_t di) {
return c10::SymInt(di);
});
}
const auto size1 = signature.params[i].size;
if (size1 > 0 && THPUtils_checkLong(args[i])) {
return std::vector<c10::SymInt>(
size1, c10::SymInt(THPUtils_unpackIndex(args[i])));
}
if (size1 > 0 && torch::is_symint_node(py::handle(args[i]))) {
auto si = py::handle(args[i]).cast<c10::SymbolicIntNode*>()->toSymInt();
return std::vector<c10::SymInt>(size1, si);
}
PyObject* arg = args[i];
auto tuple = PyTuple_Check(arg);
// NOLINTNEXTLINE(bugprone-branch-clone)
const auto size2 = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
std::vector<c10::SymInt> res;
res.reserve(size2);
for (const auto idx : c10::irange(size2)) {
PyObject* obj =
tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx);
try {
if (is_symint_node(py::handle(obj))) {
res.push_back(
py::handle(obj).cast<c10::SymbolicIntNode*>()->toSymInt());
} else {
// Elements of torch.Size are tensors during tracing, and we need to
// record extra information before they are turned into an IntArrayRef
if (traceable && jit::tracer::isTracing() && THPVariable_Check(obj)) {
auto& var = THPVariable_Unpack(obj);
jit::tracer::ArgumentStash::stashIntArrayRefElem(
signature.params[i].name, size2, idx, var);
res.push_back(var.item<int64_t>());
continue;
}
res.push_back(c10::SymInt(THPUtils_unpackIndex(obj)));
}
} catch (const std::exception& e) {
auto te = TypeError(
"%s(): argument '%s' must be %s, but found element of type %s at pos %ld",
signature.name.c_str(),
signature.params[i].name.c_str(),
signature.params[i].type_name().c_str(),
Py_TYPE(obj)->tp_name,
idx + 1);
}
}
return res;
}
inline std::vector<int64_t> PythonArgs::intlistWithDefault(
@ -774,6 +847,9 @@ inline c10::SymInt PythonArgs::toSymInt(int i) {
jit::tracer::ArgumentStash::stashValue(
signature.params[i].name, idx, var, c10::IntType::get());
}
if (torch::is_symint_node(py::handle(args[i]))) {
return py::handle(args[i]).cast<c10::SymbolicIntNode*>()->toSymInt();
}
return c10::SymInt(THPUtils_unpackLong(args[i]));
}

View File

@ -280,6 +280,7 @@ def get_ignored_functions() -> Set[Callable]:
Tensor._addmm_activation,
Tensor._nested_tensor_layer_norm,
Tensor.to_padded_tensor,
Tensor.sym_size
}