mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Support tensor.__getitem__() in TorchScript compilation (#73952)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/73952 Reviewed By: tugsbayasgalan Differential Revision: D34743346 Pulled By: gmagogsfm fbshipit-source-id: 2273c289c2224166cb1eed10a138d4ac7043ed83 (cherry picked from commit 37aefb9a95e0df4586bb623a1aaa974fbe799687)
This commit is contained in:
committed by
PyTorch MergeBot
parent
766eba60f7
commit
fdd12a9f4c
39
test/jit/test_tensor_methods.py
Normal file
39
test/jit/test_tensor_methods.py
Normal file
@ -0,0 +1,39 @@
|
||||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
from torch.testing import FileCheck
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
class TestTensorMethods(JitTestCase):
|
||||
def test_getitem(self):
|
||||
def tensor_getitem(inp: torch.Tensor):
|
||||
indices = torch.tensor([0, 2], dtype=torch.long)
|
||||
return inp.__getitem__(indices)
|
||||
|
||||
inp = torch.rand(3, 4)
|
||||
self.checkScript(tensor_getitem, (inp, ))
|
||||
|
||||
scripted = torch.jit.script(tensor_getitem)
|
||||
FileCheck().check("aten::index").run(scripted.graph)
|
||||
|
||||
def test_getitem_invalid(self):
|
||||
def tensor_getitem_invalid(inp: torch.Tensor):
|
||||
return inp.__getitem__()
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, "expected exactly 1 argument", "inp.__getitem__"):
|
||||
torch.jit.script(tensor_getitem_invalid)
|
@ -76,6 +76,7 @@ from jit.test_dtype_analysis import TestDtypeAnalysis, TestDtypeCustomRulesCPU
|
||||
from jit.test_device_analysis import TestDeviceAnalysis # noqa: F401
|
||||
from jit.test_dce import TestDCE # noqa: F401
|
||||
from jit.test_sparse import TestSparse # noqa: F401
|
||||
from jit.test_tensor_methods import TestTensorMethods # noqa: F401
|
||||
|
||||
# Torch
|
||||
from torch import Tensor
|
||||
|
@ -3555,6 +3555,22 @@ struct to_ir {
|
||||
case prim::dict: {
|
||||
return emitApplySpecialFormForDict(apply, type_hint);
|
||||
}
|
||||
case aten::index: {
|
||||
const SourceRange& loc = apply.range();
|
||||
auto select = Select(apply.callee());
|
||||
auto self = emitSugaredExpr(select.value(), 1)->asValue(loc, method);
|
||||
|
||||
auto inputs = apply.inputs();
|
||||
if (inputs.size() != 1) {
|
||||
throw ErrorReport(apply)
|
||||
<< "__getitem__ expected exactly 1 arguments, got "
|
||||
<< inputs.size();
|
||||
}
|
||||
auto input =
|
||||
emitSugaredExpr(apply.inputs()[0], 1)->asValue(loc, method);
|
||||
|
||||
return std::make_shared<SimpleValue>(emitIndex(loc, self, {input}));
|
||||
}
|
||||
default:
|
||||
TORCH_INTERNAL_ASSERT(false, "unknown special form: ", form);
|
||||
}
|
||||
|
@ -221,6 +221,14 @@ std::shared_ptr<SugaredValue> SimpleValue::attr(
|
||||
return SpecialFormValue::create(prim::tolist);
|
||||
}
|
||||
|
||||
// Handle calling __getitem__() directly on a Tensor, it needs special
|
||||
// handling because desired method name (`__getitem__`) doesn't match `aten`
|
||||
// operator name of `aten::index`.
|
||||
if (value_->type()->isSubtypeOf(*TensorType::get()) &&
|
||||
field == "__getitem__") {
|
||||
return SpecialFormValue::create(aten::index);
|
||||
}
|
||||
|
||||
ErrorReport report(loc);
|
||||
report << "'" << value_->type()->repr_str()
|
||||
<< "' object has no attribute or method '" << field << "'.";
|
||||
|
Reference in New Issue
Block a user