Support tensor.__getitem__() in TorchScript compilation (#73952)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/73952

Reviewed By: tugsbayasgalan

Differential Revision: D34743346

Pulled By: gmagogsfm

fbshipit-source-id: 2273c289c2224166cb1eed10a138d4ac7043ed83
(cherry picked from commit 37aefb9a95e0df4586bb623a1aaa974fbe799687)
This commit is contained in:
gmagogsfm
2022-03-10 17:40:30 -08:00
committed by PyTorch MergeBot
parent 766eba60f7
commit fdd12a9f4c
4 changed files with 64 additions and 0 deletions

View File

@ -0,0 +1,39 @@
# Owner(s): ["oncall: jit"]
import os
import sys
import torch
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing import FileCheck
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestTensorMethods(JitTestCase):
def test_getitem(self):
def tensor_getitem(inp: torch.Tensor):
indices = torch.tensor([0, 2], dtype=torch.long)
return inp.__getitem__(indices)
inp = torch.rand(3, 4)
self.checkScript(tensor_getitem, (inp, ))
scripted = torch.jit.script(tensor_getitem)
FileCheck().check("aten::index").run(scripted.graph)
def test_getitem_invalid(self):
def tensor_getitem_invalid(inp: torch.Tensor):
return inp.__getitem__()
with self.assertRaisesRegexWithHighlight(
RuntimeError, "expected exactly 1 argument", "inp.__getitem__"):
torch.jit.script(tensor_getitem_invalid)

View File

@ -76,6 +76,7 @@ from jit.test_dtype_analysis import TestDtypeAnalysis, TestDtypeCustomRulesCPU
from jit.test_device_analysis import TestDeviceAnalysis # noqa: F401
from jit.test_dce import TestDCE # noqa: F401
from jit.test_sparse import TestSparse # noqa: F401
from jit.test_tensor_methods import TestTensorMethods # noqa: F401
# Torch
from torch import Tensor

View File

@ -3555,6 +3555,22 @@ struct to_ir {
case prim::dict: {
return emitApplySpecialFormForDict(apply, type_hint);
}
case aten::index: {
const SourceRange& loc = apply.range();
auto select = Select(apply.callee());
auto self = emitSugaredExpr(select.value(), 1)->asValue(loc, method);
auto inputs = apply.inputs();
if (inputs.size() != 1) {
throw ErrorReport(apply)
<< "__getitem__ expected exactly 1 arguments, got "
<< inputs.size();
}
auto input =
emitSugaredExpr(apply.inputs()[0], 1)->asValue(loc, method);
return std::make_shared<SimpleValue>(emitIndex(loc, self, {input}));
}
default:
TORCH_INTERNAL_ASSERT(false, "unknown special form: ", form);
}

View File

@ -221,6 +221,14 @@ std::shared_ptr<SugaredValue> SimpleValue::attr(
return SpecialFormValue::create(prim::tolist);
}
// Handle calling __getitem__() directly on a Tensor, it needs special
// handling because desired method name (`__getitem__`) doesn't match `aten`
// operator name of `aten::index`.
if (value_->type()->isSubtypeOf(*TensorType::get()) &&
field == "__getitem__") {
return SpecialFormValue::create(aten::index);
}
ErrorReport report(loc);
report << "'" << value_->type()->repr_str()
<< "' object has no attribute or method '" << field << "'.";