[Static Runtime] Schema checks for index_put (#84152)

Summary:
`index_put` can take a list of tensors, but Static Runtime always tries to convert its argument to a list of optional tensors. This was causing crashes for some users. Add some schema checks to prevent this, and add a new overload for the new case.

Also, I found a clear bug in the JIT interpreter (mutating the argument when its not supposed to), so I fixed that too.

Test Plan: New unit test

Differential Revision: D39072214

Pull Request resolved: https://github.com/pytorch/pytorch/pull/84152
Approved by: https://github.com/tenpercent
This commit is contained in:
Mike Iovine
2022-08-31 01:20:14 +00:00
committed by PyTorch MergeBot
parent 7532d5b125
commit db7784e722
4 changed files with 47 additions and 24 deletions

View File

@ -2509,11 +2509,20 @@ TEST(StaticRuntime, Index_Put) {
)JIT";
auto a = at::randn({2});
auto indicies_a = std::make_tuple(torch::tensor({0}, at::kLong));
auto indices_a = std::make_tuple(torch::tensor({0}, at::kLong));
auto values_a = at::randn({1});
std::vector<IValue> args0{a, indicies_a, values_a, false};
std::vector<IValue> args0{a, indices_a, values_a, false};
testStaticRuntime(index_put_str, args0);
const auto index_put_non_optional_str = R"JIT(
def forward(self, a: Tensor, indices: List[Tensor], values: Tensor, accumulate: bool):
return torch.index_put(a, indices, values, accumulate).clone()
)JIT";
auto indices_b = c10::List<at::Tensor>{torch::tensor({0}, at::kLong)};
std::vector<IValue> args1{a, indices_b, values_a, false};
testStaticRuntime(index_put_non_optional_str, args1);
}
TEST(StaticRuntime, Item) {

View File

@ -1090,7 +1090,7 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
}
auto self = pop(stack).toTensor();
auto result =
at::index_put_(self, opt_list_indices, values, accumulate);
at::index_put(self, opt_list_indices, values, accumulate);
push(stack, std::move(result));
},
aliasAnalysisFromSchema()),

View File

@ -269,10 +269,9 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
};
});
REGISTER_NATIVE_OPERATOR_FUNCTOR(
aten::index_put,
aten_index_put,
[](Node* n) -> SROperator {
REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::index_put, aten_index_put, [](Node* n) -> SROperator {
if (n->matches(torch::schema(
"aten::index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor"))) {
return [](ProcessedNode* p_node) {
const auto& self = p_node->Input(0).toTensor();
const auto& indices = p_node->Input(1).toOptionalTensorList();
@ -281,7 +280,30 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
p_node->Output(0) =
at::native::index_put(self, indices, values, accumulate);
};
});
}
if (n->matches(torch::schema(
"aten::index_put(Tensor self, Tensor[] indices, Tensor values, bool accumulate=False) -> Tensor"))) {
return [](ProcessedNode* p_node) {
const auto& self = p_node->Input(0).toTensor();
const auto indices = p_node->Input(1).toTensorList();
c10::List<c10::optional<at::Tensor>> opt_list_indices;
opt_list_indices.reserve(indices.size());
for (const auto& ten : indices) {
opt_list_indices.push_back(ten);
}
const auto& values = p_node->Input(2).toTensor();
const auto accumulate = p_node->Input(3).toBool();
p_node->Output(0) =
at::native::index_put(self, opt_list_indices, values, accumulate);
};
}
LogAndDumpSchema(n);
return nullptr;
});
REGISTER_NATIVE_OPERATOR_FUNCTOR(
aten::item,

View File

@ -13278,14 +13278,6 @@ op_db: List[OpInfo] = [
sample_inputs_func=sample_inputs_index_put,
skips=(
DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
# RuntimeError: The following operation failed in the TorchScript interpreter.
# Traceback of TorchScript (most recent call last):
# File "<string>", line 3, in forward
# def the_method(i0, i1: List[torch.Tensor], i2):
# return torch.index_put(i0, i1, i2, accumulate=False)
# ~~~~~~~~~~~~~~~ <--- HERE
# RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
)),
OpInfo('sort',
dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),