[tensorexpr] check for index out of bounds in ir_eval (#68858)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68858

when executing with ir_eval, check for index out of bounds.

Test Plan: Imported from OSS

Reviewed By: ZolotukhinM

Differential Revision: D32657881

Pulled By: davidberard98

fbshipit-source-id: 62dd0f85bb182b34e9c9f795ff761081290f6922
This commit is contained in:
David Berard
2021-12-16 09:25:35 -08:00
committed by Facebook GitHub Bot
parent 76d282d447
commit 8c7f4a0d0b
4 changed files with 179 additions and 2 deletions

View File

@ -556,6 +556,65 @@ TEST(Expr, DynamicShapeAdd) {
testWithSize(37);
}
TEST(Expr, OutOfBounds) {
ExprHandle N(10);
ExprHandle start(0);
ExprHandle stop(15);
VarHandle i("i", kInt);
BufHandle X("X", {N}, kInt);
auto body = Store::make(X, {i}, i);
auto stmt = For::make(i, start, stop, body);
PaddedBuffer<int> data(20);
EXPECT_ANY_THROW(SimpleIREvaluator(stmt, {X})(data));
}
TEST(Expr, OutOfBounds2d) {
std::vector<std::pair<int, int>> size_options = {{10, 15}, {15, 10}};
for (auto sizes : size_options) {
ExprHandle N(sizes.first);
ExprHandle M(sizes.second);
ExprHandle start(0);
ExprHandle stopInner(15);
ExprHandle stopOuter(15);
VarHandle i("i", kInt);
VarHandle j("j", kInt);
BufHandle X("X", {N, M}, kInt);
auto body = Store::make(X, {i, j}, i);
auto inner = For::make(j, start, stopInner, body);
auto stmt = For::make(i, start, stopOuter, inner);
PaddedBuffer<int> data(400);
EXPECT_ANY_THROW(SimpleIREvaluator(stmt, {X})(data));
}
}
TEST(Expr, OutOfBounds2dFlattenedIndex) {
ExprHandle buf_size(149);
ExprHandle start(0);
ExprHandle stopInner(15);
ExprHandle stopOuter(10);
VarHandle i("i", kInt);
VarHandle j("j", kInt);
BufHandle X("X", {buf_size}, kInt);
auto idx = Add::make(Mul::make(i, stopInner), j);
auto body = Store::make(X, {idx}, i);
auto inner = For::make(j, start, stopInner, body);
auto stmt = For::make(i, start, stopOuter, inner);
PaddedBuffer<int> data(400);
EXPECT_ANY_THROW(SimpleIREvaluator(stmt, {X})(data));
}
void testCond01() {
const int N = 16;
PaddedBuffer<float> a_v(N);

View File

@ -1677,5 +1677,34 @@ TEST_F(Kernel, DISABLED_FlattenVectorize) {
#endif
}
TEST_F(Kernel, Strided1dWithinBounds) {
auto ir = R"IR(
graph(%0 : Float(3, strides=[1], device=cpu),
%1 : Float(3, strides=[2], device=cpu)):
%2 : int = prim::Constant[value=1]()
%3 : Float(3, strides=[1]) = aten::add(%0, %1, %2)
return (%3))IR";
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
parseIR(ir, graph.get(), vmap);
TensorExprKernel k(graph);
auto a = at::rand({3}, TensorOptions(kCPU).dtype(at::kFloat));
auto b = at::rand({6}, TensorOptions(kCPU).dtype(at::kFloat))
.index({Slice(None, None, 2)});
auto expect = a + b;
std::vector<at::Tensor> inputs = {a, b};
std::vector<IValue> stack = fmap<IValue>(inputs);
k.run(stack);
auto output = stack[0].toTensor();
for (size_t i = 0; i < 3; ++i) {
CHECK_EQ(((float*)output.data_ptr())[i], ((float*)expect.data_ptr())[i]);
}
}
} // namespace jit
} // namespace torch

View File

@ -69,8 +69,8 @@ class TestTensorExprPyBind(JitTestCase):
def test_dynamic_shape(self):
dN = te.VarHandle(torch.int32)
A = te.BufHandle(torch.float64)
B = te.BufHandle(torch.float64)
A = te.BufHandle([dN], torch.float64)
B = te.BufHandle([dN], torch.float64)
def compute(i):
return A.load(i) - B.load(i)
@ -92,6 +92,32 @@ class TestTensorExprPyBind(JitTestCase):
test_with_shape(8)
test_with_shape(31)
def test_dynamic_shape_2d(self):
dN = te.VarHandle(torch.int32)
dM = te.VarHandle(torch.int32)
A = te.BufHandle([dN, dM], torch.float64)
B = te.BufHandle([dN, dM], torch.float64)
def compute(i, j):
return A.load([i, j]) - B.load([i, j])
C = te.Compute("C", [dN, dM], compute)
loopnest = te.LoopNest([C])
loopnest.prepare_for_codegen()
cg = te.construct_codegen("ir_eval", loopnest.simplify(), [A, B, C, dN, dM])
def test_with_shape(n, m):
tA = torch.randn(n, m, dtype=torch.double)
tB = torch.randn(n, m, dtype=torch.double)
tC = torch.empty(n, m, dtype=torch.double)
cg.call([tA, tB, tC, n, m])
torch.testing.assert_close(tA - tB, tC)
test_with_shape(2, 4)
test_with_shape(5, 3)
def test_dtype_error(self):
te.BufHandle("a", [1], torch.float32) # ok
self.assertRaises(TypeError, lambda: te.BufHandle("a", [1], "float55"))

View File

@ -672,6 +672,65 @@ class SimpleIREvaluatorImpl : public IRVisitor {
return {};
}
void check_bounds_throw(int64_t idx, int64_t bound, const BufPtr& buf) {
std::stringstream ss;
ss << "Index out of bounds in check_bounds. Index: " << idx
<< "; bounds: [0, " << bound << ").";
throw malformed_input(ss.str(), buf);
}
void check_bounds(const BufPtr& buf, const std::vector<ExprPtr>& indices) {
const std::vector<ExprPtr>& dims = buf->dims();
if (dims.size() != indices.size()) {
// indices are flattened, but not buffer
if (indices.size() == 1) {
if (dims.size() != buf->strides().size()) {
throw malformed_input(
"Number of dimensions did not match number of strides", buf);
}
size_t buf_size = 1;
if (dims.size() > 0) {
ExprHandle buf_size_expr = ExprHandle(immLike(dims[0], 1));
ExprHandle negative_one = ExprHandle(immLike(dims[0], -1));
for (const auto& i : c10::irange(dims.size())) {
buf_size_expr = buf_size_expr +
((negative_one + ExprHandle(dims[i])) *
ExprHandle(buf->strides()[i]));
}
buf_size_expr.node()->accept(this);
buf_size = value().intValue();
}
indices[0]->accept(this);
const auto& index_values = indexVec(value());
for (auto& j : index_values) {
if (j < 0 || j >= buf_size) {
check_bounds_throw(j, buf_size, buf);
}
}
return;
}
throw malformed_input(
"dimensions and indices mismatch in check_bounds. Buf has " +
std::to_string(dims.size()) + " dimensions and indices has " +
std::to_string(indices.size()) + " dimensions.",
buf);
}
for (const auto& i : c10::irange(dims.size())) {
auto opt_dim = intValue(dims[i]);
if (!opt_dim) {
continue;
}
auto dim_bound = *opt_dim;
indices[i]->accept(this);
const auto& ithDimIndices = indexVec(value());
for (auto& j : ithDimIndices) {
if (j < 0 || j >= dim_bound) {
check_bounds_throw(j, dim_bound, buf);
}
}
}
}
TORCH_API void visit(LoadPtr v) override {
auto iter = buffer_mapping_.find(v->buf());
if (iter == buffer_mapping_.end()) {
@ -679,6 +738,8 @@ class SimpleIREvaluatorImpl : public IRVisitor {
}
void* ptr = iter->second;
check_bounds(v->buf(), v->indices());
ExprPtr flat_idx =
flatten_index(v->buf()->dims(), v->indices(), v->buf()->strides());
flat_idx->accept(this);
@ -722,6 +783,8 @@ class SimpleIREvaluatorImpl : public IRVisitor {
void* ptr = iter->second;
check_bounds(v->buf(), v->indices());
ExprPtr flat_idx =
flatten_index(v->buf()->dims(), v->indices(), v->buf()->strides());
flat_idx->accept(this);