[PyTorch] Adopt IValue::toTupleRef() where obvious (#65505)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65505

Generated with

`fastmod -m 'toTuple\(\)(\s*)->' 'toTupleRef()${1}.'`

, followed by

`fastmod '(std::move\(.*)toTupleRef\(\).' '${1}toTuple()->'`

to unbreak 2 callsites.
ghstack-source-id: 142065835

Test Plan: CI

Reviewed By: gchanan

Differential Revision: D31131025

fbshipit-source-id: 54457ae5bbeb38db9c7f196d469b98521c3d3f34
This commit is contained in:
Scott Wolchok
2021-11-02 10:13:02 -07:00
committed by Facebook GitHub Bot
parent eb1b8a2160
commit 82f7f8d471
45 changed files with 96 additions and 95 deletions

View File

@ -332,7 +332,7 @@ facebook::jni::local_ref<JIValue> JIValue::newJIValueFromAtIValue(
JIValue::javaClassStatic(),
facebook::jni::make_jstring(ivalue.toStringRef()));
} else if (ivalue.isTuple()) {
auto elementsVec = ivalue.toTuple()->elements();
auto elementsVec = ivalue.toTupleRef().elements();
static auto jMethodTupleArr =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(

View File

@ -125,7 +125,7 @@ TypePtr IValue::type() const {
case Tag::Capsule:
return CapsuleType::get();
case Tag::Tuple:
return toTuple()->type();
return toTupleRef().type();
case Tag::Generator:
return GeneratorType::get();
case Tag::Quantizer:
@ -147,7 +147,7 @@ void IValue::visit(const std::function<bool (const IValue &)>& visitor) const {
case Tag::GenericList: {
c10::ArrayRef<IValue> elems;
if (isTuple()) {
elems = this->toTuple()->elements();
elems = this->toTupleRef().elements();
} else {
elems = this->toListRef();
}
@ -196,7 +196,7 @@ void IValue::getSubValues(HashAliasedIValues& subValues) const {
subValues.insert(*this);
c10::ArrayRef<IValue> elems;
if (isTuple()) {
elems = this->toTuple()->elements();
elems = this->toTupleRef().elements();
} else {
elems = this->toListRef();
}
@ -542,7 +542,7 @@ std::ostream& IValue::repr(
case IValue::Tag::Bool:
return out << (v.toBool() ? "True" : "False");
case IValue::Tag::Tuple: {
const auto& elements = v.toTuple()->elements();
const auto& elements = v.toTupleRef().elements();
const auto& finish = elements.size() == 1 ? ",)" : ")";
return printList(out, elements, "(", finish, formatter);
}
@ -633,7 +633,7 @@ IValueComparator getLessThanComparator(const IValue& v) {
}
if (v.isTuple()) {
const auto& elements = v.toTuple()->elements();
const auto& elements = v.toTupleRef().elements();
size_t n = elements.size();
std::vector<IValueComparator> elements_lts;
@ -643,8 +643,8 @@ IValueComparator getLessThanComparator(const IValue& v) {
}
return [elements_lts=std::move(elements_lts), n](const IValue& a, const IValue& b) {
const auto& a_elements = a.toTuple()->elements();
const auto& b_elements = b.toTuple()->elements();
const auto& a_elements = a.toTupleRef().elements();
const auto& b_elements = b.toTupleRef().elements();
for (const auto i : c10::irange(n)) {
if (elements_lts[i](a_elements[i], b_elements[i])) {
@ -728,7 +728,7 @@ std::ostream& operator<<(std::ostream & out, const IValue & v) {
case IValue::Tag::Bool:
return out << (v.toBool() ? "True" : "False");
case IValue::Tag::Tuple: {
const auto& elements = v.toTuple()->elements();
const auto& elements = v.toTupleRef().elements();
const auto& finish = elements.size() == 1 ? ",)" : ")";
return printList(out, elements, "(", finish, formatter);
}
@ -803,7 +803,7 @@ IValue IValue::deepcopy(
break;
case IValue::Tag::Tuple: {
std::vector<IValue> copied_tuple;
for (const auto& e : toTuple()->elements()) {
for (const auto& e : toTupleRef().elements()) {
copied_tuple.push_back(e.deepcopy(memo));
}
copy = IValue(ivalue::Tuple::create(copied_tuple));

View File

@ -332,7 +332,7 @@ struct TORCH_API TupleElements {
// It would be nice to make this noncopyable to prevent people from
// writing code like `auto output =
// forward(...).toTuple()->elements()` (which does refcount bumps on
// forward(...).toTupleRef().elements()` (which does refcount bumps on
// each element, unlike the more efficient but verbose
// ```
// auto outputIntrusivePtr = forward(...).toTuple();
@ -1716,7 +1716,7 @@ template <
guts::negation<std::is_constructible<IValue, Args>>...>::value,
std::nullptr_t> = nullptr>
std::tuple<Args...> generic_to(IValue ivalue, _fake_type<std::tuple<Args...>>) {
const auto& vals = ivalue.toTuple()->elements();
const auto& vals = ivalue.toTupleRef().elements();
TORCH_CHECK(vals.size() == sizeof...(Args));
return detail::generic_to_tuple_impl<std::tuple<Args...>>(vals, Indices{});
}

View File

@ -80,7 +80,7 @@ ConvParamsSerializationTypeV3 parse_conv_serialized_state(c10::IValue v) {
// determine the version based on IValue contents
int version = -1;
if (v.isTuple()) {
const auto& elements = v.toTuple()->elements();
const auto& elements = v.toTupleRef().elements();
if (elements.size() > 0) {
auto firstElement = elements[0];
if (firstElement.isTensor()) {
@ -105,7 +105,7 @@ ConvParamsSerializationTypeV3 parse_conv_serialized_state(c10::IValue v) {
if (version == 1) {
// version 1 - convert to version 3 manually
const auto& elements = v.toTuple()->elements();
const auto& elements = v.toTupleRef().elements();
at::Tensor weight = elements[0].toTensor();
c10::optional<at::Tensor> bias = elements[1].toOptional<at::Tensor>();
@ -149,7 +149,7 @@ ConvParamsSerializationTypeV3 parse_conv_serialized_state(c10::IValue v) {
return std::tie(version, config_vals, tensors);
} else if (version == 2) {
// version 2
const auto& elements = v.toTuple()->elements();
const auto& elements = v.toTupleRef().elements();
std::vector<at::Tensor> non_optional = elements[1].toTensorList().vec();
std::vector<c10::optional<at::Tensor>> optional;

View File

@ -51,7 +51,7 @@ TEST(IValueTest, Basic) {
at::ivalue::Tuple::create({IValue(3.4), IValue(4), IValue(foo)}));
ASSERT_EQ(foo.use_count(), 3);
ASSERT_TRUE(the_list.isTuple());
auto first = the_list.toTuple()->elements()[1];
auto first = the_list.toTupleRef().elements()[1];
ASSERT_EQ(first.toInt(), 4);
// Make sure toTupleRef has test coverage too.
first = the_list.toTupleRef().elements()[1];
@ -89,8 +89,8 @@ TEST(IValueTest, Basic) {
IValue complex_tuple(
at::ivalue::Tuple::create({IValue(c10::complex<double>(3.4, 4.7)), IValue(foo1)}));
ASSERT_TRUE(complex_tuple.isTuple());
ASSERT_EQ(complex_tuple.toTuple()->elements()[0].toComplexDouble(), c10::complex<double>(3.4, 4.7));
ASSERT_EQ(complex_tuple.toTuple()->elements()[1], foo1);
ASSERT_EQ(complex_tuple.toTupleRef().elements()[0].toComplexDouble(), c10::complex<double>(3.4, 4.7));
ASSERT_EQ(complex_tuple.toTupleRef().elements()[1], foo1);
}
TEST(IValueTest, BasicStorage) {

View File

@ -866,7 +866,7 @@ TEST(StaticRuntime, DeepWide) {
// run static runtime
std::vector<c10::IValue> input_tensors({ad_emb_packed, user_emb, wide});
auto outputs = smod(input_tensors, {}).toTuple()->elements();
auto outputs = smod(input_tensors, {}).toTupleRef().elements();
ASSERT_TRUE(outputs.size() > 0);
at::Tensor output_2 = outputs[0].toTensor();
smod.runtime().check_for_memory_leak();
@ -1002,7 +1002,7 @@ TEST(StaticRuntime, CleanUpMemory) {
// run static runtime
std::vector<c10::IValue> input_tensors(
{ad_emb_packed, user_emb, wide});
auto outputs = runtime(input_tensors, {}).toTuple()->elements();
auto outputs = runtime(input_tensors, {}).toTupleRef().elements();
ASSERT_TRUE(outputs.size() > 0);
auto output_2 = outputs[0].toTensor();
runtime.check_for_memory_leak();
@ -1062,12 +1062,12 @@ TEST(
{
IValue tuple = runtime(args, {});
ASSERT_TRUE(tuple.isTuple());
ASSERT_EQ(tuple.toTuple()->elements().size(), 1);
ASSERT_EQ(tuple.toTupleRef().elements().size(), 1);
// Do not manage intput value.
EXPECT_FALSE(runtime.isManagedOutputTensor(args[0]));
// Do not manage direct output value.
EXPECT_FALSE(runtime.isManagedOutputTensor(tuple));
IValue element = tuple.toTuple()->elements()[0];
IValue element = tuple.toTupleRef().elements()[0];
// Tensor to be managed, but not yet from the profile run.
EXPECT_FALSE(runtime.isManagedOutputTensor(element));
tuple = IValue();
@ -1078,12 +1078,12 @@ TEST(
{
IValue tuple = runtime(args, {});
ASSERT_TRUE(tuple.isTuple());
ASSERT_EQ(tuple.toTuple()->elements().size(), 1);
ASSERT_EQ(tuple.toTupleRef().elements().size(), 1);
// Do not manage intput value.
EXPECT_FALSE(runtime.isManagedOutputTensor(args[0]));
// Do not manage direct output value.
EXPECT_FALSE(runtime.isManagedOutputTensor(tuple));
IValue element = tuple.toTuple()->elements()[0];
IValue element = tuple.toTupleRef().elements()[0];
// Tensor to be managed, but not yet from the profile run.
EXPECT_TRUE(runtime.isManagedOutputTensor(element));
tuple = IValue();

View File

@ -142,8 +142,8 @@ void compareResults(
return;
} else if (expect.isTuple()) {
EXPECT_TRUE(actual.isTuple());
auto lhs = expect.toTuple()->elements();
auto rhs = actual.toTuple()->elements();
auto lhs = expect.toTupleRef().elements();
auto rhs = actual.toTupleRef().elements();
EXPECT_TRUE(lhs.size() == rhs.size());
for (size_t i = 0; i < lhs.size(); i++) {
compareResults(lhs[i], rhs[i]);

View File

@ -256,7 +256,7 @@ int main(int argc, char** argv) {
std::cerr << "Model has only " << all_inputs.size() << " bundled inputs." << std::endl;
return 1;
}
inputs = all_inputs.get(FLAGS_use_bundled_input).toTuple()->elements();
inputs = all_inputs.get(FLAGS_use_bundled_input).toTupleRef().elements();
}
#ifdef BUILD_LITE_INTERPRETER

View File

@ -25,7 +25,7 @@ TEST(BackendTest, ToBackend) {
std::vector<IValue> inputs;
inputs.emplace_back(2.0 * torch::ones({}));
inputs.emplace_back(1.0 * torch::ones({}));
auto ref = m.forward(inputs).toTuple()->elements().vec();
auto ref = m.forward(inputs).toTupleRef().elements().vec();
c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get());
c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get());
@ -75,7 +75,7 @@ TEST(BackendTest, ToBackend) {
return (_10, _12)
*/
auto res = lm.forward(inputs).toTuple()->elements().vec();
auto res = lm.forward(inputs).toTupleRef().elements().vec();
AT_ASSERT(res[0].toTensor().equal(ref[0].toTensor()));
AT_ASSERT(res[1].toTensor().equal(ref[1].toTensor()));
}
@ -96,7 +96,7 @@ TEST(BackendTest, ToBackendNotAvailable) {
std::vector<IValue> inputs;
inputs.emplace_back(2.0 * torch::ones({}));
inputs.emplace_back(1.0 * torch::ones({}));
auto ref = m.forward(inputs).toTuple()->elements().vec();
auto ref = m.forward(inputs).toTupleRef().elements().vec();
c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get());
c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get());
@ -110,7 +110,7 @@ TEST(BackendTest, ToBackendNotAvailable) {
// Validate exception is thrown when trying to execute and
// the backend is not available.
ASSERT_THROWS_WITH_MESSAGE(
lm.forward(inputs).toTuple()->elements(), "Backend is not available.");
lm.forward(inputs).toTupleRef().elements(), "Backend is not available.");
}
TEST(BackendTest, TestCompiler) {

View File

@ -95,8 +95,8 @@ class BackendWithCompiler : public PyTorchBackendInterface {
auto start_us = autograd::profiler::getTime() / 1000;
for (const auto& token : handle.toList()) {
IValue val = token;
auto instruction = val.toTuple()->elements()[0].toStringRef();
auto debug_handle = val.toTuple()->elements()[1].toInt();
auto instruction = val.toTupleRef().elements()[0].toStringRef();
auto debug_handle = val.toTupleRef().elements()[1].toInt();
double const_val = 1.0;
auto start_time_us = autograd::profiler::getTime() / 1000;
try {

View File

@ -181,7 +181,7 @@ TEST(LiteInterpreterTest, Tuple) {
mobile::Module bc = _load_for_mobile(ss);
std::vector<torch::jit::IValue> inputs({torch::ones({})});
auto output = bc.get_method("forward")(inputs);
AT_ASSERT(output.toTuple()->elements()[1].toInt() == 2);
AT_ASSERT(output.toTupleRef().elements()[1].toInt() == 2);
}
TEST(LiteInterpreterTest, Dict) {
@ -535,7 +535,7 @@ void runAndCheckTorchScriptModel(
Module m_mobile = load(input_model_stream);
auto actual_result = m_mobile.forward(input_data);
const auto& actual_result_list = actual_result.toTuple()->elements();
const auto& actual_result_list = actual_result.toTupleRef().elements();
compareModelOutput(actual_result_list, expect_result_list);
}
@ -552,7 +552,7 @@ void runAndCheckBytecodeModel(
Module m_mobile = load(input_model_stream);
auto actual_result = m_mobile.forward(input_data);
const auto& actual_result_list = actual_result.toTuple()->elements();
const auto& actual_result_list = actual_result.toTupleRef().elements();
compareModelOutput(actual_result_list, expect_result_list);
}

View File

@ -130,7 +130,7 @@ TEST(SerializationTest, TestJitStream_CUDA) {
model = torch::jit::load("saved_stream_model.pt");
auto output = model.forward(inputs);
const auto& list_of_elements = output.toTuple()->elements();
const auto& list_of_elements = output.toTupleRef().elements();
auto is_stream_s = list_of_elements[0].toBool();
// a,b: These are the two input tensors

View File

@ -49,7 +49,7 @@ void testSerializationInterop() {
std::istream_iterator<char>());
IValue ivalue = pickle_load(input);
auto elements = ivalue.toTuple()->elements();
auto elements = ivalue.toTupleRef().elements();
auto ones = torch::ones({2, 2});
AT_ASSERT(ones.equal(elements.at(0).toTensor()));

View File

@ -14,7 +14,7 @@ void load_serialized_lowered_module_and_execute(const std::string& path) {
std::vector<torch::jit::IValue> inputs{tensor, tensor};
auto output = module.forward(inputs);
AT_ASSERT(output.isTuple());
auto output_elements = output.toTuple()->elements();
auto output_elements = output.toTupleRef().elements();
for (auto& e : output_elements) {
AT_ASSERT(e.isTensor());
}

View File

@ -199,8 +199,8 @@ struct Benchmark {
eg = I.global("builtins", "tuple")(
I.self.attr("load_pickle")({"model", "example.pkl"}))
.toIValue()
.toTuple()
->elements();
.toTupleRef()
.elements();
}
// NOLINTNEXTLINE(bugprone-branch-clone)

View File

@ -29,12 +29,12 @@ void compare_torchpy_jit(const char* model_filename, const char* jit_filename) {
eg = I.self.attr("load_pickle")({"model", "example.pkl"}).toIValue();
}
at::Tensor output = model(eg.toTuple()->elements()).toTensor();
at::Tensor output = model(eg.toTupleRef().elements()).toTensor();
// Reference
auto ref_model = torch::jit::load(jit_filename);
at::Tensor ref_output =
ref_model.forward(eg.toTuple()->elements()).toTensor();
ref_model.forward(eg.toTupleRef().elements()).toTensor();
ASSERT_TRUE(ref_output.allclose(output, 1e-03, 1e-05));
}

View File

@ -41,7 +41,7 @@ TEST(TorchDeployGPUTest, SimpleModel) {
{
auto I = p.acquireSession();
auto eg = I.self.attr("load_pickle")({"model", "example.pkl"}).toIValue();
inputs = eg.toTuple()->elements();
inputs = eg.toTupleRef().elements();
inputs[0] = inputs[0].toTensor().to("cuda");
}
at::Tensor output = model(inputs).toTensor();

View File

@ -55,7 +55,7 @@ std::unique_ptr<PropagateGradientsReq> PropagateGradientsReq::fromMessage(
payload_size,
*rpc::RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
message.tensors());
const auto& tupleElements = tuple.toTuple()->elements();
const auto& tupleElements = tuple.toTupleRef().elements();
// Build PropagateGradientsReq.
TORCH_INTERNAL_ASSERT(tupleElements.size() >= 3);

View File

@ -191,7 +191,7 @@ py::object PyRRef::toHere(const float timeoutSeconds) const {
if (rref_->isPyObj()) {
// python_rpc_handler deserialization will acquires GIL.
auto rfr_values = value.toTuple()->elements().vec();
auto rfr_values = value.toTupleRef().elements().vec();
auto& pythonRpcHandler = PythonRpcHandler::getInstance();
auto ret = pythonRpcHandler.deserialize(
SerializedPyObj::fromIValues(std::move(rfr_values)));

View File

@ -42,7 +42,7 @@ std::unique_ptr<PythonRemoteCall> PythonRemoteCall::fromMessage(
payload_size,
*RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
message.tensors());
auto values = value.toTuple()->elements().vec();
auto values = value.toTupleRef().elements().vec();
// remove the last elements from values and convert it back to an RRef
TORCH_INTERNAL_ASSERT(

View File

@ -129,7 +129,7 @@ std::unique_ptr<ScriptCall> ScriptCall::fromMessage(const Message& message) {
*RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
message.tensors());
auto values = value.toTuple()->elements().vec();
auto values = value.toTupleRef().elements().vec();
return fromIValues(values);
}

View File

@ -76,7 +76,7 @@ std::unique_ptr<ScriptRemoteCall> ScriptRemoteCall::fromMessage(
payload_size,
*RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
message.tensors());
auto values = value.toTuple()->elements().vec();
auto values = value.toTupleRef().elements().vec();
return fromIValues(values);
}

View File

@ -60,7 +60,7 @@ GloballyUniqueId GloballyUniqueId::fromIValue(const at::IValue& ivalue) {
TORCH_INTERNAL_ASSERT(
ivalue.isTuple(),
"GloballyUniqueId::fromIValue expected ivalue to be a tuple.");
const auto& ivalues = ivalue.toTuple()->elements();
const auto& ivalues = ivalue.toTupleRef().elements();
TORCH_CHECK(
ivalues.size() == 2,
"Constructing GloballyUniqueId from ivalue "

View File

@ -517,7 +517,7 @@ std::vector<at::IValue> readWrappedPayload(
additionalPayloadSize,
*rpc::RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
tensorTable);
std::vector<at::IValue> tupleElements = tuple.toTuple()->elements().vec();
std::vector<at::IValue> tupleElements = tuple.toTupleRef().elements().vec();
payload.resize(payload.size() - additionalPayloadSize);
return tupleElements;
}

View File

@ -242,7 +242,7 @@ IValue Module::operator()(std::vector<IValue> inputs) {
IValue result = Method(_ivalue(), pre_hook)({tuple_input});
if (!result.isNone()) {
if (result.isTuple()) {
inputs = result.toTuple()->elements().vec();
inputs = result.toTupleRef().elements().vec();
} else {
inputs = {result};
}

View File

@ -143,7 +143,7 @@ struct API_AVAILABLE(ios(11.0), macos(10.13)) CoreMLExecutorWrapper
for (int i = 0; i < inputs.size(); ++i) {
auto val = inputs.get(i);
if (val.isTuple()) {
auto tuples = val.toTuple()->elements();
auto tuples = val.toTupleRef().elements();
for (auto& ival : tuples) {
TORCH_CHECK(ival.isTensor());
auto tensor = ival.toTensor();

View File

@ -342,7 +342,7 @@ std::vector<IValue> ScriptTypeParser::evaluateDefaults(
// recursively initialize stuff in DecomposeOps.
GraphOptimizerEnabledGuard guard(false);
cu.get_function(def.name().name()).run(stack);
return stack.at(0).toTuple()->elements().vec();
return stack.at(0).toTupleRef().elements().vec();
}
std::vector<Argument> ScriptTypeParser::parseArgsFromDecl(

View File

@ -122,7 +122,7 @@ Value* TracingState::getValue(const IValue& var) {
} else if (var.isTuple()) {
return graph
->insertNode(graph->createTuple(fmap(
var.toTuple()->elements(),
var.toTupleRef().elements(),
[&](const IValue& val) { return getValue(val); })))
->output();
} else if (var.isGenericDict()) {
@ -270,7 +270,7 @@ Value* TracingState::getOutput(const IValue& iv, size_t i) {
[&](const IValue& ival) { return getOutput(ival, i); })))
->output();
} else if (iv.isTuple()) {
const auto& tuple = iv.toTuple()->elements();
const auto& tuple = iv.toTupleRef().elements();
auto tuple_node = graph->createTuple(
fmap(tuple, [&](const IValue& ival) { return getOutput(ival, i); }));
graph->insertNode(tuple_node);
@ -547,7 +547,7 @@ void TracingState::setValue(const IValue& v, Value* value) {
setValue(outputs.get(i), unpack_node->outputs()[i]);
}
} else if (v.isTuple()) {
const auto& outputs = v.toTuple()->elements();
const auto& outputs = v.toTupleRef().elements();
Node* unpack_node = graph->insertNode(graph->createTupleUnpack(value));
for (const auto i : c10::irange(outputs.size())) {
setValue(outputs[i], unpack_node->outputs()[i]);

View File

@ -27,7 +27,7 @@ bool insertableIValue(const IValue& ivalue) {
if (ivalue.isList() || ivalue.isTuple()) {
c10::ArrayRef<IValue> elems;
if (ivalue.isTuple()) {
elems = ivalue.toTuple()->elements();
elems = ivalue.toTupleRef().elements();
} else {
elems = ivalue.toListRef();
}

View File

@ -107,8 +107,8 @@ bool ivaluesEqual(const IValue& a1, const IValue& a2) {
return attributesEqual(a1.toListRef(), a2.toListRef());
}
if (a1.isTuple()) {
at::ArrayRef<IValue> a1_elem = a1.toTuple()->elements();
at::ArrayRef<IValue> a2_elem = a2.toTuple()->elements();
at::ArrayRef<IValue> a1_elem = a1.toTupleRef().elements();
at::ArrayRef<IValue> a2_elem = a2.toTupleRef().elements();
return attributesEqual(a1_elem, a2_elem);
}
if (a1.isGenericDict()) {

View File

@ -159,11 +159,11 @@ std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info(
// loop over all the functions in the bytecode
for (const auto i : c10::irange(1, bytecode_ivalues.size())) {
// descend to the operators list
const auto& method_tuple = bytecode_ivalues.at(i).toTuple()->elements();
auto operators_tuple = method_tuple.at(1).toTuple()->elements()[1];
auto operators = operators_tuple.toTuple()->elements()[1];
for (auto& op_tuple : operators.toTuple()->elements()) {
const auto& op = op_tuple.toTuple()->elements();
const auto& method_tuple = bytecode_ivalues.at(i).toTupleRef().elements();
auto operators_tuple = method_tuple.at(1).toTupleRef().elements()[1];
auto operators = operators_tuple.toTupleRef().elements()[1];
for (auto& op_tuple : operators.toTupleRef().elements()) {
const auto& op = op_tuple.toTupleRef().elements();
// grab name
std::string op_name = op.at(0).toStringRef();
@ -226,11 +226,11 @@ std::unordered_set<std::string> _get_mobile_model_contained_types(
// the hash to record which types are parsed.
std::unordered_set<std::string> parsed_type_names_records;
for (const auto i : c10::irange(1, bytecode_ivalues.size())) {
const auto& method_tuple = bytecode_ivalues.at(i).toTuple()->elements();
const auto& method_tuple = bytecode_ivalues.at(i).toTupleRef().elements();
auto type_table_tuple =
method_tuple.at(1).toTuple()->elements()[BYTECODE_INDEX_TYPE];
method_tuple.at(1).toTupleRef().elements()[BYTECODE_INDEX_TYPE];
const auto& type_table =
type_table_tuple.toTuple()->elements()[1].toTuple()->elements();
type_table_tuple.toTupleRef().elements()[1].toTupleRef().elements();
// type_table is a list of IValue, and each IValue is a string,
// for example: "Dict[int, Tuple[Tensor, Tensor, Tensor]]"

View File

@ -29,7 +29,7 @@ std::vector<std::vector<at::IValue>> MobileModelRunner::
"but got a ",
input.tagKind(),
" instead");
ret.push_back(input.toTuple()->elements());
ret.push_back(input.toTupleRef().elements());
}
return ret;

View File

@ -97,13 +97,14 @@ Function::Function(const c10::IValue& value) {
parameters_ = dict.at("parameters").toList();
// input_specs_
for (const auto& input_value : dict.at("input_specs").toTuple()->elements()) {
for (const auto& input_value :
dict.at("input_specs").toTupleRef().elements()) {
input_specs_.emplace_back(input_value);
}
// output_specs_
for (const auto& output_value :
dict.at("output_specs").toTuple()->elements()) {
dict.at("output_specs").toTupleRef().elements()) {
output_specs_.emplace_back(output_value);
}
@ -228,8 +229,8 @@ c10::impl::GenericList Function::run(
}
CompilationUnit::CompilationUnit(const c10::IValue& value) {
const auto& root = value.toTuple()->elements();
const auto& functions = root[1].toTuple()->elements();
const auto& root = value.toTupleRef().elements();
const auto& functions = root[1].toTupleRef().elements();
for (const auto& function : functions) {
register_function(std::make_unique<Function>(function));
}

View File

@ -85,8 +85,8 @@ void parseInstructions(
debugHandlesTableElements,
"function_debug_handles",
BYTECODE_INDEX_MODULE_DEBUG_HANDLES)
.toTuple()
->elements())[0]
.toTupleRef()
.elements())[0]
.toIntList();
TORCH_CHECK(
debug_handles_list.size() == ins_list.size(),

View File

@ -551,7 +551,7 @@ inline Stack toTraceableStack(const py::tuple& inputs) {
info.type()->repr_str(),
"' cannot be traced. Only Tensors and (possibly nested) Lists, Dicts, and"
" Tuples of Tensors can be traced");
return info.toTuple()->elements().vec();
return info.toTupleRef().elements().vec();
}
inline IValue createGenericList(py::handle obj, const TypePtr& elem_type) {

View File

@ -77,7 +77,7 @@ void prepare_and_call_rpc_op(
userCallableStack.reserve(functionSchema.arguments().size());
// Move args from Tuple IValue to Stack.
for (auto& elem : argsTupleIValue.toTuple()->elements()) {
for (auto& elem : argsTupleIValue.toTupleRef().elements()) {
push(userCallableStack, std::move(elem));
}

View File

@ -1317,7 +1317,7 @@ void dictConstructFromList(Stack& stack) {
tup_type->elements().at(0), tup_type->elements().at(1));
dict.reserve(list.size());
for (IValue input : list) {
const auto& tup = input.toTuple()->elements();
const auto& tup = input.toTupleRef().elements();
dict.insert_or_assign(tup[0], tup[1]);
}
push(stack, dict);

View File

@ -47,7 +47,7 @@ Operation createStaticSubgraphRuntime(const Node* node) {
torch::jit::drop(stack, num_inputs);
if (module->num_outputs() > 1) {
for (auto& o : outputs.toTuple()->elements()) {
for (auto& o : outputs.toTupleRef().elements()) {
push_one(stack, std::move(o));
}
} else {

View File

@ -1181,7 +1181,7 @@ bool display_ivalue(const IValue& iv) {
std::cout << "Dict {" << iv.toGenericDict().size() << "}\n";
return true;
} else if (iv.isTuple()) {
std::cout << "Tuple {" << iv.toTuple()->elements().size() << "}\n";
std::cout << "Tuple {" << iv.toTupleRef().elements().size() << "}\n";
return true;
} else if (iv.isInt()) {
std::cout << "int {" << iv.toInt() << "}\n";

View File

@ -60,7 +60,7 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
prim_TupleUnpack,
[](Node* n) -> SROperator {
return [](ProcessedNode* p_node) {
const auto& elems = p_node->Input(0).toTuple()->elements();
const auto& elems = p_node->Input(0).toTupleRef().elements();
const size_t num_outputs = p_node->outputs().size();
TORCH_CHECK(
num_outputs == elems.size(),

View File

@ -218,7 +218,7 @@ void percentFormat(Stack& stack, size_t num_inputs) {
auto args = last(stack, num_inputs - 1)[0];
auto args_size = 1; // assumed size
if (args.isTuple()) {
args_size = args.toTuple()->elements().size();
args_size = args.toTupleRef().elements().size();
}
std::stringstream ss;
size_t used_args = 0;
@ -243,7 +243,7 @@ void percentFormat(Stack& stack, size_t num_inputs) {
char key = format_str.at(format_idx);
IValue arg;
if (args.isTuple()) {
arg = args.toTuple()->elements()[used_args];
arg = args.toTupleRef().elements()[used_args];
} else {
arg = args;
}

View File

@ -181,7 +181,7 @@ c10::optional<ModuleInstanceInfo> InlinedCallStackDeserializer::
if (it != cached_module_instance_info_.end()) {
return it->second;
}
const auto& tup_elems = iv.toTuple()->elements();
const auto& tup_elems = iv.toTupleRef().elements();
TORCH_CHECK(tup_elems.size() == 2);
std::string type_name = tup_elems[0].toString()->string();
std::string instance_name = tup_elems[1].toString()->string();
@ -218,7 +218,7 @@ ska::flat_hash_map<int64_t, DebugInfoTuple> CallStackDebugInfoUnpickler::
ska::flat_hash_map<int64_t, DebugInfoTuple> callstack_ptrs;
auto ivalues = std::move(*std::move(ival).toTuple()).elements();
for (auto& val : ivalues) {
const auto& tup_elems = val.toTuple()->elements();
const auto& tup_elems = val.toTupleRef().elements();
TORCH_CHECK(
tup_elems.size() == 4,
"Pickled map must have four elements: "

View File

@ -967,16 +967,16 @@ void export_opnames(const script::Module& m, std::set<std::string>& opnames) {
BytecodeExportSet exportSet = moduleMethodsTuple(m, dummy_uniquer);
pushFunctionToIValues(std::move(exportSet), elements, dummy, dummy_uniquer);
for (const auto& element : elements) {
auto table = element.toTuple()->elements()[1];
auto table = element.toTupleRef().elements()[1];
auto row =
table.toTuple()->elements().at(BYTECODE_INDEX_OPERATOR).toTuple();
table.toTupleRef().elements().at(BYTECODE_INDEX_OPERATOR).toTuple();
TORCH_INTERNAL_ASSERT(
row->elements().at(0).toStringRef() == "operators",
"Expected operators but found ",
row->elements().at(0).toStringRef());
const auto& ops_list = row->elements().at(1).toTuple()->elements();
const auto& ops_list = row->elements().at(1).toTupleRef().elements();
for (const auto& op : ops_list) {
const auto& op_item = op.toTuple()->elements();
const auto& op_item = op.toTupleRef().elements();
TORCH_CHECK(
op_item.size() >= 2,
"There should be either two parts (name and overload name), ",

View File

@ -24,7 +24,7 @@ class SourceRangeSerializer {
};
SourceRange SourceRangeDeserializer::deserialize(const c10::IValue& iv) {
const auto& tup_elems = iv.toTuple()->elements();
const auto& tup_elems = iv.toTupleRef().elements();
TORCH_INTERNAL_ASSERT(tup_elems.size() == 3);
std::shared_ptr<SourceView> source_ = deserialize_source(tup_elems[0]);
int64_t start_ = tup_elems[1].toInt();
@ -117,7 +117,7 @@ void ConcreteSourceRangeUnpickler::unpickle() {
unpickled_records = std::make_shared<SourceRangeRecords>();
for (auto& val : ivalues) {
const auto& tup_elems = val.toTuple()->elements();
const auto& tup_elems = val.toTupleRef().elements();
int64_t offset = tup_elems[kByteOffsetIndex].toInt();
auto source_range = deserializer->deserialize(tup_elems[kSourceRangeIndex]);
unpickled_records->emplace_back(offset, std::move(source_range));

View File

@ -248,7 +248,7 @@ inline void append<bool>(std::vector<bool>& a, bool&& e) {
}
static std::vector<int64_t> tupleToIntList(const IValue& v) {
return fmap(v.toTuple()->elements(), [](const IValue& v) -> int64_t {
return fmap(v.toTupleRef().elements(), [](const IValue& v) -> int64_t {
return v.toInt();
});
}
@ -534,7 +534,7 @@ void Unpickler::readGlobal(
if (class_name == "build_tensor_from_id") {
globals_.emplace_back([this] {
// Pop reduce arg off the stack
auto data = stack_.back().toTuple()->elements().at(0);
auto data = stack_.back().toTupleRef().elements().at(0);
stack_.pop_back();
TORCH_CHECK(
!tensor_table_.empty(),
@ -583,7 +583,7 @@ void Unpickler::readGlobal(
// Unpickle a list specialization (e.g. List[Tensor], List[int], ...)
globals_.emplace_back([this, elem_type] {
// Pop reduce arg off the stack
auto data = stack_.back().toTuple()->elements().at(0).toList();
auto data = stack_.back().toTupleRef().elements().at(0).toList();
stack_.pop_back();
data.unsafeSetElementType(elem_type);
stack_.emplace_back(std::move(data));
@ -620,7 +620,7 @@ void Unpickler::readGlobal(
});
} else if (module_name == "torch" && class_name == "device") {
globals_.emplace_back([this] {
auto device_string = stack_.back().toTuple()->elements().at(0);
auto device_string = stack_.back().toTupleRef().elements().at(0);
stack_.pop_back();
stack_.emplace_back(c10::Device(device_string.toStringRef()));
});