mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[tensorexpr] check for index out of bounds in ir_eval (#68858)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/68858 when executing with ir_eval, check for index out of bounds. Test Plan: Imported from OSS Reviewed By: ZolotukhinM Differential Revision: D32657881 Pulled By: davidberard98 fbshipit-source-id: 62dd0f85bb182b34e9c9f795ff761081290f6922
This commit is contained in:
committed by
Facebook GitHub Bot
parent
76d282d447
commit
8c7f4a0d0b
@ -556,6 +556,65 @@ TEST(Expr, DynamicShapeAdd) {
|
||||
testWithSize(37);
|
||||
}
|
||||
|
||||
TEST(Expr, OutOfBounds) {
|
||||
ExprHandle N(10);
|
||||
ExprHandle start(0);
|
||||
ExprHandle stop(15);
|
||||
VarHandle i("i", kInt);
|
||||
|
||||
BufHandle X("X", {N}, kInt);
|
||||
|
||||
auto body = Store::make(X, {i}, i);
|
||||
auto stmt = For::make(i, start, stop, body);
|
||||
|
||||
PaddedBuffer<int> data(20);
|
||||
|
||||
EXPECT_ANY_THROW(SimpleIREvaluator(stmt, {X})(data));
|
||||
}
|
||||
|
||||
TEST(Expr, OutOfBounds2d) {
|
||||
std::vector<std::pair<int, int>> size_options = {{10, 15}, {15, 10}};
|
||||
for (auto sizes : size_options) {
|
||||
ExprHandle N(sizes.first);
|
||||
ExprHandle M(sizes.second);
|
||||
ExprHandle start(0);
|
||||
ExprHandle stopInner(15);
|
||||
ExprHandle stopOuter(15);
|
||||
VarHandle i("i", kInt);
|
||||
VarHandle j("j", kInt);
|
||||
|
||||
BufHandle X("X", {N, M}, kInt);
|
||||
|
||||
auto body = Store::make(X, {i, j}, i);
|
||||
auto inner = For::make(j, start, stopInner, body);
|
||||
auto stmt = For::make(i, start, stopOuter, inner);
|
||||
|
||||
PaddedBuffer<int> data(400);
|
||||
|
||||
EXPECT_ANY_THROW(SimpleIREvaluator(stmt, {X})(data));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Expr, OutOfBounds2dFlattenedIndex) {
|
||||
ExprHandle buf_size(149);
|
||||
ExprHandle start(0);
|
||||
ExprHandle stopInner(15);
|
||||
ExprHandle stopOuter(10);
|
||||
VarHandle i("i", kInt);
|
||||
VarHandle j("j", kInt);
|
||||
|
||||
BufHandle X("X", {buf_size}, kInt);
|
||||
|
||||
auto idx = Add::make(Mul::make(i, stopInner), j);
|
||||
auto body = Store::make(X, {idx}, i);
|
||||
auto inner = For::make(j, start, stopInner, body);
|
||||
auto stmt = For::make(i, start, stopOuter, inner);
|
||||
|
||||
PaddedBuffer<int> data(400);
|
||||
|
||||
EXPECT_ANY_THROW(SimpleIREvaluator(stmt, {X})(data));
|
||||
}
|
||||
|
||||
void testCond01() {
|
||||
const int N = 16;
|
||||
PaddedBuffer<float> a_v(N);
|
||||
|
@ -1677,5 +1677,34 @@ TEST_F(Kernel, DISABLED_FlattenVectorize) {
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST_F(Kernel, Strided1dWithinBounds) {
|
||||
auto ir = R"IR(
|
||||
graph(%0 : Float(3, strides=[1], device=cpu),
|
||||
%1 : Float(3, strides=[2], device=cpu)):
|
||||
%2 : int = prim::Constant[value=1]()
|
||||
%3 : Float(3, strides=[1]) = aten::add(%0, %1, %2)
|
||||
return (%3))IR";
|
||||
auto graph = std::make_shared<Graph>();
|
||||
std::unordered_map<std::string, Value*> vmap;
|
||||
parseIR(ir, graph.get(), vmap);
|
||||
TensorExprKernel k(graph);
|
||||
|
||||
auto a = at::rand({3}, TensorOptions(kCPU).dtype(at::kFloat));
|
||||
auto b = at::rand({6}, TensorOptions(kCPU).dtype(at::kFloat))
|
||||
.index({Slice(None, None, 2)});
|
||||
auto expect = a + b;
|
||||
|
||||
std::vector<at::Tensor> inputs = {a, b};
|
||||
|
||||
std::vector<IValue> stack = fmap<IValue>(inputs);
|
||||
k.run(stack);
|
||||
|
||||
auto output = stack[0].toTensor();
|
||||
|
||||
for (size_t i = 0; i < 3; ++i) {
|
||||
CHECK_EQ(((float*)output.data_ptr())[i], ((float*)expect.data_ptr())[i]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -69,8 +69,8 @@ class TestTensorExprPyBind(JitTestCase):
|
||||
|
||||
def test_dynamic_shape(self):
|
||||
dN = te.VarHandle(torch.int32)
|
||||
A = te.BufHandle(torch.float64)
|
||||
B = te.BufHandle(torch.float64)
|
||||
A = te.BufHandle([dN], torch.float64)
|
||||
B = te.BufHandle([dN], torch.float64)
|
||||
|
||||
def compute(i):
|
||||
return A.load(i) - B.load(i)
|
||||
@ -92,6 +92,32 @@ class TestTensorExprPyBind(JitTestCase):
|
||||
test_with_shape(8)
|
||||
test_with_shape(31)
|
||||
|
||||
def test_dynamic_shape_2d(self):
|
||||
dN = te.VarHandle(torch.int32)
|
||||
dM = te.VarHandle(torch.int32)
|
||||
A = te.BufHandle([dN, dM], torch.float64)
|
||||
B = te.BufHandle([dN, dM], torch.float64)
|
||||
|
||||
def compute(i, j):
|
||||
return A.load([i, j]) - B.load([i, j])
|
||||
|
||||
C = te.Compute("C", [dN, dM], compute)
|
||||
|
||||
loopnest = te.LoopNest([C])
|
||||
loopnest.prepare_for_codegen()
|
||||
|
||||
cg = te.construct_codegen("ir_eval", loopnest.simplify(), [A, B, C, dN, dM])
|
||||
|
||||
def test_with_shape(n, m):
|
||||
tA = torch.randn(n, m, dtype=torch.double)
|
||||
tB = torch.randn(n, m, dtype=torch.double)
|
||||
tC = torch.empty(n, m, dtype=torch.double)
|
||||
cg.call([tA, tB, tC, n, m])
|
||||
torch.testing.assert_close(tA - tB, tC)
|
||||
|
||||
test_with_shape(2, 4)
|
||||
test_with_shape(5, 3)
|
||||
|
||||
def test_dtype_error(self):
|
||||
te.BufHandle("a", [1], torch.float32) # ok
|
||||
self.assertRaises(TypeError, lambda: te.BufHandle("a", [1], "float55"))
|
||||
|
@ -672,6 +672,65 @@ class SimpleIREvaluatorImpl : public IRVisitor {
|
||||
return {};
|
||||
}
|
||||
|
||||
void check_bounds_throw(int64_t idx, int64_t bound, const BufPtr& buf) {
|
||||
std::stringstream ss;
|
||||
ss << "Index out of bounds in check_bounds. Index: " << idx
|
||||
<< "; bounds: [0, " << bound << ").";
|
||||
throw malformed_input(ss.str(), buf);
|
||||
}
|
||||
|
||||
void check_bounds(const BufPtr& buf, const std::vector<ExprPtr>& indices) {
|
||||
const std::vector<ExprPtr>& dims = buf->dims();
|
||||
if (dims.size() != indices.size()) {
|
||||
// indices are flattened, but not buffer
|
||||
if (indices.size() == 1) {
|
||||
if (dims.size() != buf->strides().size()) {
|
||||
throw malformed_input(
|
||||
"Number of dimensions did not match number of strides", buf);
|
||||
}
|
||||
size_t buf_size = 1;
|
||||
if (dims.size() > 0) {
|
||||
ExprHandle buf_size_expr = ExprHandle(immLike(dims[0], 1));
|
||||
ExprHandle negative_one = ExprHandle(immLike(dims[0], -1));
|
||||
for (const auto& i : c10::irange(dims.size())) {
|
||||
buf_size_expr = buf_size_expr +
|
||||
((negative_one + ExprHandle(dims[i])) *
|
||||
ExprHandle(buf->strides()[i]));
|
||||
}
|
||||
buf_size_expr.node()->accept(this);
|
||||
buf_size = value().intValue();
|
||||
}
|
||||
indices[0]->accept(this);
|
||||
const auto& index_values = indexVec(value());
|
||||
for (auto& j : index_values) {
|
||||
if (j < 0 || j >= buf_size) {
|
||||
check_bounds_throw(j, buf_size, buf);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
throw malformed_input(
|
||||
"dimensions and indices mismatch in check_bounds. Buf has " +
|
||||
std::to_string(dims.size()) + " dimensions and indices has " +
|
||||
std::to_string(indices.size()) + " dimensions.",
|
||||
buf);
|
||||
}
|
||||
for (const auto& i : c10::irange(dims.size())) {
|
||||
auto opt_dim = intValue(dims[i]);
|
||||
if (!opt_dim) {
|
||||
continue;
|
||||
}
|
||||
auto dim_bound = *opt_dim;
|
||||
indices[i]->accept(this);
|
||||
const auto& ithDimIndices = indexVec(value());
|
||||
for (auto& j : ithDimIndices) {
|
||||
if (j < 0 || j >= dim_bound) {
|
||||
check_bounds_throw(j, dim_bound, buf);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_API void visit(LoadPtr v) override {
|
||||
auto iter = buffer_mapping_.find(v->buf());
|
||||
if (iter == buffer_mapping_.end()) {
|
||||
@ -679,6 +738,8 @@ class SimpleIREvaluatorImpl : public IRVisitor {
|
||||
}
|
||||
void* ptr = iter->second;
|
||||
|
||||
check_bounds(v->buf(), v->indices());
|
||||
|
||||
ExprPtr flat_idx =
|
||||
flatten_index(v->buf()->dims(), v->indices(), v->buf()->strides());
|
||||
flat_idx->accept(this);
|
||||
@ -722,6 +783,8 @@ class SimpleIREvaluatorImpl : public IRVisitor {
|
||||
|
||||
void* ptr = iter->second;
|
||||
|
||||
check_bounds(v->buf(), v->indices());
|
||||
|
||||
ExprPtr flat_idx =
|
||||
flatten_index(v->buf()->dims(), v->indices(), v->buf()->strides());
|
||||
flat_idx->accept(this);
|
||||
|
Reference in New Issue
Block a user