[fix] nn c++ : segfault in modulelist and moduledict (#93074)

Fixes https://github.com/pytorch/pytorch/issues/73565

Pull Request resolved: https://github.com/pytorch/pytorch/pull/93074
Approved by: https://github.com/albanD
This commit is contained in:
Kshiteej K
2023-01-27 12:20:19 +00:00
committed by PyTorch MergeBot
parent 219e9533f0
commit 68a98537d5
4 changed files with 46 additions and 4 deletions

View File

@ -300,3 +300,9 @@ TEST_F(ModuleListTest, RangeBasedForLoop) {
module->pretty_print(buffer);
}
}
TEST_F(ModuleListTest, InvalidAt) {
torch::nn::ModuleList m(torch::nn::Linear(1, 2));
ASSERT_THROWS_WITH(
m->at<torch::nn::Dropout2dImpl>(0), "Unable to cast module");
}