mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[Static Runtime] Schema checks for index_put (#84152)
Summary: `index_put` can take a list of tensors, but Static Runtime always tries to convert its argument to a list of optional tensors. This was causing crashes for some users. Add some schema checks to prevent this, and add a new overload for the new case. Also, I found a clear bug in the JIT interpreter (mutating the argument when its not supposed to), so I fixed that too. Test Plan: New unit test Differential Revision: D39072214 Pull Request resolved: https://github.com/pytorch/pytorch/pull/84152 Approved by: https://github.com/tenpercent
This commit is contained in:
committed by
PyTorch MergeBot
parent
7532d5b125
commit
db7784e722
@ -2509,11 +2509,20 @@ TEST(StaticRuntime, Index_Put) {
|
||||
)JIT";
|
||||
|
||||
auto a = at::randn({2});
|
||||
auto indicies_a = std::make_tuple(torch::tensor({0}, at::kLong));
|
||||
auto indices_a = std::make_tuple(torch::tensor({0}, at::kLong));
|
||||
auto values_a = at::randn({1});
|
||||
|
||||
std::vector<IValue> args0{a, indicies_a, values_a, false};
|
||||
std::vector<IValue> args0{a, indices_a, values_a, false};
|
||||
testStaticRuntime(index_put_str, args0);
|
||||
|
||||
const auto index_put_non_optional_str = R"JIT(
|
||||
def forward(self, a: Tensor, indices: List[Tensor], values: Tensor, accumulate: bool):
|
||||
return torch.index_put(a, indices, values, accumulate).clone()
|
||||
)JIT";
|
||||
|
||||
auto indices_b = c10::List<at::Tensor>{torch::tensor({0}, at::kLong)};
|
||||
std::vector<IValue> args1{a, indices_b, values_a, false};
|
||||
testStaticRuntime(index_put_non_optional_str, args1);
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, Item) {
|
||||
|
@ -1090,7 +1090,7 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
|
||||
}
|
||||
auto self = pop(stack).toTensor();
|
||||
auto result =
|
||||
at::index_put_(self, opt_list_indices, values, accumulate);
|
||||
at::index_put(self, opt_list_indices, values, accumulate);
|
||||
push(stack, std::move(result));
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
|
@ -269,10 +269,9 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
||||
};
|
||||
});
|
||||
|
||||
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
||||
aten::index_put,
|
||||
aten_index_put,
|
||||
[](Node* n) -> SROperator {
|
||||
REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::index_put, aten_index_put, [](Node* n) -> SROperator {
|
||||
if (n->matches(torch::schema(
|
||||
"aten::index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor"))) {
|
||||
return [](ProcessedNode* p_node) {
|
||||
const auto& self = p_node->Input(0).toTensor();
|
||||
const auto& indices = p_node->Input(1).toOptionalTensorList();
|
||||
@ -281,7 +280,30 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
||||
p_node->Output(0) =
|
||||
at::native::index_put(self, indices, values, accumulate);
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
if (n->matches(torch::schema(
|
||||
"aten::index_put(Tensor self, Tensor[] indices, Tensor values, bool accumulate=False) -> Tensor"))) {
|
||||
return [](ProcessedNode* p_node) {
|
||||
const auto& self = p_node->Input(0).toTensor();
|
||||
const auto indices = p_node->Input(1).toTensorList();
|
||||
|
||||
c10::List<c10::optional<at::Tensor>> opt_list_indices;
|
||||
opt_list_indices.reserve(indices.size());
|
||||
for (const auto& ten : indices) {
|
||||
opt_list_indices.push_back(ten);
|
||||
}
|
||||
|
||||
const auto& values = p_node->Input(2).toTensor();
|
||||
const auto accumulate = p_node->Input(3).toBool();
|
||||
p_node->Output(0) =
|
||||
at::native::index_put(self, opt_list_indices, values, accumulate);
|
||||
};
|
||||
}
|
||||
|
||||
LogAndDumpSchema(n);
|
||||
return nullptr;
|
||||
});
|
||||
|
||||
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
||||
aten::item,
|
||||
|
@ -13278,14 +13278,6 @@ op_db: List[OpInfo] = [
|
||||
sample_inputs_func=sample_inputs_index_put,
|
||||
skips=(
|
||||
DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
|
||||
# RuntimeError: The following operation failed in the TorchScript interpreter.
|
||||
# Traceback of TorchScript (most recent call last):
|
||||
# File "<string>", line 3, in forward
|
||||
# def the_method(i0, i1: List[torch.Tensor], i2):
|
||||
# return torch.index_put(i0, i1, i2, accumulate=False)
|
||||
# ~~~~~~~~~~~~~~~ <--- HERE
|
||||
# RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.
|
||||
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
|
||||
)),
|
||||
OpInfo('sort',
|
||||
dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),
|
||||
|
Reference in New Issue
Block a user