Files
pytorch/torch/csrc/jit/runtime/register_ops_utils.cpp
Yanan Cao c22bbb2124 [JIT] Add Type::repr_str to return human-readable str (#39544)
Summary:
Clearly expressing a type is inferred by PyTorch instead of explicitly annotated by user makes many error messages more user-friendly

Currently Type has two string conversion methods. str() for IR printing and python_str() for serialization and error message generation. If we want to include more information in type printing while maintaining serialization/deserialization correctness, we need to split python_str() into annotation_str() and repr_str().

annotation_str is solely responsible for serialization, it strictly matches format of python type annotation. repr_str() is responsible for generating a human-readable error message that includes information like "this type is inferred, not explicitly annotated"

Closes https://github.com/pytorch/pytorch/issues/39449
Pull Request resolved: https://github.com/pytorch/pytorch/pull/39544

Differential Revision: D21978759

Pulled By: gmagogsfm

fbshipit-source-id: 733566f5a62e748b5ca4bb3c5943ebb6d5b664d0
2020-06-10 12:01:24 -07:00

503 lines
13 KiB
C++

#include <torch/csrc/jit/runtime/register_ops_utils.h>
namespace torch {
namespace jit {
template <>
c10::impl::GenericList make_result_list<IValue>(const TypePtr& elemType) {
return c10::impl::GenericList(elemType);
}
template <>
int listIndex<at::Tensor>(Stack& stack) {
at::Tensor elem = pop(stack).to<at::Tensor>();
c10::List<at::Tensor> list = pop(stack).to<c10::List<at::Tensor>>();
auto pos =
std::find_if(list.begin(), list.end(), [elem](const at::Tensor& b) {
const auto cmp_result = elem.eq(b);
return cmp_result.is_nonzero();
});
if (pos != list.end()) {
push(stack, static_cast<int64_t>(std::distance(list.begin(), pos)));
} else {
AT_ERROR("'", elem, "' is not in list");
}
return 0;
}
template <>
int listCount<at::Tensor>(Stack& stack) {
at::Tensor elem = pop(stack).to<at::Tensor>();
c10::List<at::Tensor> list = pop(stack).to<c10::List<at::Tensor>>();
const int64_t count =
std::count_if(list.begin(), list.end(), [&](const at::Tensor& b) {
const auto cmp_result = elem.eq(b);
return cmp_result.is_nonzero();
});
push(stack, count);
return 0;
}
template <>
int listEq<at::Tensor>(Stack& stack) {
c10::List<at::Tensor> b = pop(stack).to<c10::List<at::Tensor>>();
c10::List<at::Tensor> a = pop(stack).to<c10::List<at::Tensor>>();
push(stack, tensor_list_equal(a, b));
return 0;
}
template <>
int listNe<at::Tensor>(Stack& stack) {
c10::List<at::Tensor> b = pop(stack).to<c10::List<at::Tensor>>();
c10::List<at::Tensor> a = pop(stack).to<c10::List<at::Tensor>>();
push(stack, !tensor_list_equal(a, b));
return 0;
}
template <>
int listSort<at::Tensor>(Stack& stack) {
bool reverse = pop(stack).toBool();
c10::List<at::Tensor> list = pop(stack).toTensorList();
std::sort(
list.begin(),
list.end(),
[reverse](const at::Tensor& a, const at::Tensor& b) -> bool {
// "strict weak ordering" issue - see other sort
if (a.getIntrusivePtr() == b.getIntrusivePtr()) {
return false;
}
return (a.lt(b).is_nonzero()) ^ reverse;
});
return 0;
}
template <>
int listCopyAndSort<at::Tensor>(Stack& stack) {
c10::List<at::Tensor> list = pop(stack).toTensorList();
auto list_copied = list.copy();
std::sort(
list_copied.begin(),
list_copied.end(),
[](const at::Tensor& a, const at::Tensor& b) {
return a.lt(b).is_nonzero();
});
push(stack, list_copied);
return 0;
}
template <>
int listRemove<at::Tensor>(Stack& stack) {
at::Tensor elem = pop(stack).to<at::Tensor>();
c10::List<at::Tensor> list = pop(stack).to<c10::List<at::Tensor>>();
auto pos = std::find_if(list.begin(), list.end(), [&](const at::Tensor& b) {
const auto cmp_result = elem.eq(b);
return cmp_result.is_nonzero();
});
if (pos != list.end()) {
list.erase(pos);
} else {
AT_ERROR("list.remove(x): x not in list");
}
return 0;
}
void checkImplicitTensorToNum(const at::Tensor& t, bool toInt) {
if (t.requires_grad()) {
throw std::runtime_error(
"Cannot input a tensor that requires grad as a scalar argument");
}
if (t.sizes().size() != 0) {
throw std::runtime_error(
"Cannot input a tensor of dimension other than 0 as a scalar argument");
}
if (toInt && !isIntegralType(t.scalar_type(), /*includeBool=*/false)) {
std::stringstream ss;
ss << "Cannot input a tensor of type " << t.scalar_type()
<< " as an integral argument";
throw std::runtime_error(ss.str());
}
}
IValue tensorToListRecursive(
char* data,
int64_t cur_dim,
int64_t num_tensor_dims,
TypePtr ty,
at::ScalarType scalar_ty,
at::IntArrayRef sizes,
at::IntArrayRef strides,
size_t element_size) {
// If ty is a ListType, get the element type.
if (auto list_type = ty->cast<ListType>()) {
ty = list_type->getElementType();
} else {
// If the output type is a scalar, read and push one scalar of
// the right type onto the stack.
if (ty == IntType::get()) {
int64_t scalar = *(int64_t*)data;
return IValue(scalar);
} else if (ty == FloatType::get()) {
TORCH_INTERNAL_ASSERT(
scalar_ty == at::ScalarType::Float ||
scalar_ty == at::ScalarType::Double,
"Unexpected scalar type for Tensor");
double scalar =
scalar_ty == at::ScalarType::Float ? *(float*)data : *(double*)data;
return IValue(scalar);
} else if (ty == BoolType::get()) {
bool scalar = *(bool*)data;
return IValue(scalar);
} else {
TORCH_CHECK(
false,
ty->repr_str(),
" is not one of the supported types for tolist: int, float, bool");
}
}
// Make the result list consisting of elements of type ty. Since this
// invocation is processing dimension cur_dim, there will be sizes[cur_dim]
// output elements.
auto result = c10::impl::GenericList(ty);
result.reserve(sizes[cur_dim]);
// Since ty was a list type, tensorToListRecursive needs to be called
// recursively on each slice of the tensor in the current dimension.
for (int64_t i = 0, e = sizes[cur_dim]; i < e; ++i) {
auto inner_result = tensorToListRecursive(
data,
cur_dim + 1,
num_tensor_dims,
ty,
scalar_ty,
sizes,
strides,
element_size);
if (inner_result.isList()) {
result.emplace_back(inner_result.toList());
} else if (inner_result.isDouble()) {
result.emplace_back(inner_result.toDouble());
} else if (inner_result.isInt()) {
result.emplace_back(inner_result.toInt());
} else if (inner_result.isBool()) {
result.emplace_back(inner_result.toBool());
} else {
TORCH_INTERNAL_ASSERT("Unknown return type for tensorToListRecursive");
}
data += strides[cur_dim] * element_size;
}
return result;
}
void checkDoubleInRange(double a) {
if (std::isnan(a) || std::isinf(a) ||
a > double(std::numeric_limits<int64_t>::max()) ||
a < double(std::numeric_limits<int64_t>::min())) {
throw c10::Error(
"Cannot convert float " + c10::to_string(a) + " to integer", "");
return;
}
}
int64_t partProduct(int n, int m) {
if (m <= (n + 1))
return (int64_t)n;
if (m == (n + 2))
return (int64_t)n * m;
int k = (n + m) / 2;
if ((k & 1) != 1)
k = k - 1;
return partProduct(n, k) * partProduct(k + 2, m);
}
void loop(int n, int64_t& p, int64_t& r) {
if (n <= 2)
return;
loop(n / 2, p, r);
p = p * partProduct(n / 2 + 1 + ((n / 2) & 1), n - 1 + (n & 1));
r = r * p;
}
int nminussumofbits(int v) {
long w = (long)v;
w -= (0xaaaaaaaa & w) >> 1; // NOLINT
w = (w & 0x33333333) + ((w >> 2) & 0x33333333); // NOLINT
w = (w + (w >> 4)) & 0x0f0f0f0f; // NOLINT
w += w >> 8; // NOLINT
w += w >> 16; // NOLINT
return v - (int)(w & 0xff); // NOLINT
}
int64_t factorial(int n) {
if (n < 0) {
throw std::runtime_error("factorial() not defined for negative values");
}
int64_t p = 1, r = 1;
loop(n, p, r);
return r << nminussumofbits(n);
}
double degrees(double x) {
return x * radToDeg;
}
double radians(double x) {
return x * degToRad;
}
int64_t normalizeIndex(int64_t idx, int64_t list_size) {
if (idx < 0) {
// Handle negative indexing
idx = list_size + idx;
}
return idx;
}
int listAppend(Stack& stack) {
IValue el = pop(stack).to<IValue>();
c10::List<IValue> list = pop(stack).to<c10::List<IValue>>();
list.push_back(std::move(el));
push(stack, std::move(list));
return 0;
}
int listReverse(Stack& stack) {
c10::List<IValue> list = pop(stack).to<c10::List<IValue>>();
std::reverse(list.begin(), list.end());
return 0;
}
int listPopImpl(Stack& stack, const char* empty_message) {
int64_t idx = pop(stack).to<int64_t>();
c10::List<IValue> list = pop(stack).to<c10::List<IValue>>();
const int64_t list_size = list.size();
const int64_t normalized_idx = normalizeIndex(idx, list_size);
if (list_size == 0) {
AT_ERROR(empty_message);
}
push(stack, getItem(list, idx));
list.erase(list.begin() + normalized_idx);
return 0;
}
int listPop(Stack& stack) {
return listPopImpl(stack, "pop from empty list");
}
int listClear(Stack& stack) {
c10::List<IValue> list = pop(stack).to<c10::List<IValue>>();
list.clear();
return 0;
}
int listDelete(Stack& stack) {
listPopImpl(stack, "pop index out of range");
pop(stack);
return 0;
}
int listInsert(Stack& stack) {
IValue elem = pop(stack).to<IValue>();
int64_t idx = pop(stack).to<int64_t>();
c10::List<IValue> list = pop(stack).to<c10::List<IValue>>();
const int64_t list_size = list.size();
const int64_t normalized_idx = normalizeIndex(idx, list_size);
if (normalized_idx < 0 || normalized_idx >= list_size) {
if (normalized_idx < 0) {
list.insert(list.begin(), elem);
} else {
list.push_back(elem);
}
} else {
list.insert(list.begin() + normalized_idx, elem);
}
return 0;
}
int listExtend(Stack& stack) {
c10::List<IValue> b = pop(stack).to<c10::List<IValue>>();
c10::List<IValue> a = pop(stack).to<c10::List<IValue>>();
a.reserve(a.size() + b.size());
for (size_t i = 0; i < b.size(); ++i) {
a.push_back(b.get(i));
}
return 0;
}
int listCopy(Stack& stack) {
c10::List<IValue> list = pop(stack).to<c10::List<IValue>>();
push(stack, list.copy());
return 0;
}
int listSelect(Stack& stack) {
int64_t idx = pop(stack).to<int64_t>();
c10::List<IValue> list = pop(stack).to<c10::List<IValue>>();
auto element = getItem(list, idx);
push(stack, std::move(element));
return 0;
}
int listLen(Stack& stack) {
c10::List<IValue> a = pop(stack).to<c10::List<IValue>>();
const int64_t size = a.size();
push(stack, size);
return 0;
}
int listList(Stack& stack) {
c10::List<IValue> a = pop(stack).to<c10::List<IValue>>();
push(stack, a.copy());
return 0;
}
int listAdd(Stack& stack) {
c10::List<IValue> b = pop(stack).to<c10::List<IValue>>();
c10::List<IValue> a = pop(stack).to<c10::List<IValue>>();
c10::List<IValue> ret = make_result_list<IValue>(a.elementType());
if (a.use_count() == 1) {
ret = std::move(a);
} else {
ret = a.copy();
}
ret.append(std::move(b));
push(stack, std::move(ret));
return 0;
}
int listInplaceAdd(Stack& stack) {
c10::List<IValue> b = pop(stack).to<List<IValue>>();
c10::List<IValue> a = pop(stack).to<List<IValue>>();
a.append(std::move(b));
push(stack, std::move(a));
return 0;
}
int listMulIntLeftInPlace(Stack& stack) {
int64_t n = pop(stack).to<int64_t>();
c10::List<IValue> list = pop(stack).to<c10::List<IValue>>();
if (n <= 0) {
list.clear();
} else if (n > 1) {
size_t list_size = list.size();
for (int64_t i = 1; i < n; i++) {
for (size_t j = 0; j < list_size; j++) {
list.push_back(list.get(j));
}
}
}
push(stack, std::move(list));
return 0;
}
int listMulIntLeft(Stack& stack) {
int64_t n = pop(stack).to<int64_t>();
c10::List<IValue> list = pop(stack).to<c10::List<IValue>>();
c10::List<IValue> ret = make_result_list<IValue>(list.elementType());
const auto size = list.size() * n;
ret.reserve(size);
for (int64_t i = 0; i < n; i++) {
for (IValue e : list) {
ret.push_back(std::move(e));
}
}
push(stack, std::move(ret));
return 0;
}
int listMulIntRight(Stack& stack) {
c10::List<IValue> list = pop(stack).to<c10::List<IValue>>();
int64_t n = pop(stack).to<int64_t>();
c10::List<IValue> ret = make_result_list<IValue>(list.elementType());
const auto size = list.size() * n;
ret.reserve(size);
for (int64_t i = 0; i < n; i++) {
for (IValue e : list) {
ret.push_back(std::move(e));
}
}
push(stack, std::move(ret));
return 0;
}
int listSlice(Stack& stack) {
int64_t step = pop(stack).to<int64_t>();
int64_t end = pop(stack).to<int64_t>();
int64_t start = pop(stack).to<int64_t>();
c10::List<IValue> list = pop(stack).to<c10::List<IValue>>();
const int64_t list_size = list.size();
// clamp start and end to the bounds of the list
const auto normalized_start =
std::max((int64_t)0, normalizeIndex(start, list_size));
const auto normalized_end =
std::min(list_size, normalizeIndex(end, list_size));
c10::List<IValue> sliced_list = make_result_list<IValue>(list.elementType());
if (normalized_end <= normalized_start) {
// early exit if the slice is trivially empty
push(stack, std::move(sliced_list));
return 0;
}
sliced_list.reserve(normalized_end - normalized_start);
for (auto i = normalized_start; i < normalized_end;) {
sliced_list.push_back(list.get(i));
i += step;
}
push(stack, std::move(sliced_list));
return 0;
}
int listSetItem(Stack& stack) {
IValue value = pop(stack).to<IValue>();
int64_t idx = pop(stack).to<int64_t>();
c10::List<IValue> list = pop(stack).to<c10::List<IValue>>();
setItem(list, idx, std::move(value));
push(stack, std::move(list));
return 0;
}
} // namespace jit
} // namespace torch