#include #include #include #include #include #include #include namespace torch::nativert::detail { namespace { inline constexpr int kDefaultTreeSpecSerializationProtocol = 1; c10::IValue dynamicToIValue(const nlohmann::json& obj) { if (obj.is_string()) { return obj.get(); } else if (obj.is_number_integer()) { return obj.get(); } else { TORCH_CHECK(false, "Unsupported dynamic type: ", obj); } } void itreeFlatten( const c10::IValue& nested, const ITreeSpec& spec, std::vector& ivalues) { if (spec.isIValue()) { ivalues.push_back(nested); return; } auto flattenFn = spec.nodeDefCache().flattenFn; flattenFn(nested, spec, ivalues); } class PytreeNodeRegistry { public: PytreeNodeRegistry() { // Add some law of physics here. registerNode( "builtins.tuple", NodeDef{ [](const c10::IValue& nested, const ITreeSpec& spec, std::vector& ivalues) { const auto& tuple = nested.toTupleRef().elements(); TORCH_CHECK(tuple.size() == spec.children().size()); for (size_t i = 0; i < tuple.size(); i++) { itreeFlatten(tuple[i], spec.children(i), ivalues); } }, [](std::vector flats, const nlohmann::json& obj) -> c10::IValue { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(obj.is_null()); return c10::ivalue::Tuple::create(std::move(flats)); }, [](ITreeMapNoReturnFn fn, const c10::IValue& nested, const ITreeSpec& spec) { const auto& tuple = nested.toTupleRef().elements(); TORCH_CHECK(tuple.size() == spec.children().size()); for (size_t i = 0; i < tuple.size(); i++) { ivalueApply(fn, tuple[i], spec.children(i)); } }}); const auto& tupleNodeDef = getNodeDef("builtins.tuple"); registerNode( "collections.namedtuple", NodeDef{ tupleNodeDef.flattenFn, [](std::vector flats, const nlohmann::json& obj) -> c10::IValue { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!obj.is_null()); return c10::ivalue::Tuple::create(std::move(flats)); }, tupleNodeDef.ivalueApplyFn, [](std::string_view context) { return nlohmann::json{context}; }}); registerNode( "builtins.list", NodeDef{ [](const c10::IValue& nested, const ITreeSpec& spec, std::vector& ivalues) { auto list = nested.toListRef(); for (size_t i = 0; i < list.size(); i++) { itreeFlatten(list[i], spec.children(i), ivalues); } }, [](std::vector flats, const nlohmann::json& obj) -> c10::IValue { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(obj.is_null()); c10::List list(c10::AnyType::get()); list.reserve(flats.size()); for (auto& flat : flats) { list.push_back(std::move(flat)); } return list; }, [](ITreeMapNoReturnFn fn, const c10::IValue& nested, const ITreeSpec& spec) { auto list = nested.toListRef(); for (size_t i = 0; i < list.size(); i++) { ivalueApply(fn, list[i], spec.children(i)); } }}); registerNode( "torch.fx.immutable_collections.immutable_list", getNodeDef("builtins.list")); registerNode( "builtins.dict", NodeDef{ [](const c10::IValue& nested, const ITreeSpec& spec, std::vector& ivalues) { auto dict = nested.toGenericDict(); const auto& contextKeys = spec.contextKeys(); // allow the dict size less than the spec, missing key will be // filled with empty tensor TORCH_CHECK(dict.size() <= contextKeys.size()); size_t i = 0; for (const auto& key : contextKeys) { auto it = dict.find(key); if (it != dict.end()) { itreeFlatten(it->value(), spec.children(i), ivalues); } else { // when we have a dict with missing keys, we fill the missing // ivalues with an empty tensor which is required for // validation for (size_t j = 0; j < spec.children(i).numIValues(); ++j) { at::Tensor empty_tensor; ivalues.emplace_back(std::move(empty_tensor)); } } i++; } }, [](std::vector flats, const nlohmann::json& obj) -> c10::IValue { c10::Dict dict( c10::AnyType::get(), c10::AnyType::get()); TORCH_CHECK(obj.is_array()); TORCH_CHECK(obj.size() == flats.size()); dict.reserve(flats.size()); for (size_t i = 0; i < flats.size(); i++) { dict.insert(dynamicToIValue(obj[i]), std::move(flats[i])); } return dict; }, [](ITreeMapNoReturnFn fn, const c10::IValue& nested, const ITreeSpec& spec) { auto dict = nested.toGenericDict(); const auto& contextKeys = spec.contextKeys(); size_t i = 0; for (const auto& key : contextKeys) { if (spec.children(i).isUsed()) { auto it = dict.find(key); if (it != dict.end()) { ivalueApply(fn, it->value(), spec.children(i)); } else { TORCH_CHECK(false, "input arg is missing key ", key); } } i++; } }}); registerNode( "torch.fx.immutable_collections.immutable_dict", getNodeDef("builtins.dict")); // Register JaggedTensor pytree node registerNode( "torchrec.sparse.jagged_tensor.JaggedTensor", NodeDef{ [](const c10::IValue& nested, const ITreeSpec& spec, std::vector& ivalues) { // JaggedTensor has 4 fields: _values, _weights, _lengths, // _offsets All fields are optional torch.Tensor except _values TORCH_CHECK(nested.isObject(), "Expected JaggedTensor object"); const auto& obj = nested.toObjectRef(); // Extract the tensor fields in order: _values, _weights, // _lengths, _offsets TORCH_CHECK( spec.children().size() == 4, "JaggedTensor should have 4 children"); // Flatten each tensor field itreeFlatten(obj.getAttr("_values"), spec.children(0), ivalues); itreeFlatten(obj.getAttr("_weights"), spec.children(1), ivalues); itreeFlatten(obj.getAttr("_lengths"), spec.children(2), ivalues); itreeFlatten(obj.getAttr("_offsets"), spec.children(3), ivalues); }, [](std::vector flats, const nlohmann::json& obj) -> c10::IValue { // Reconstruct JaggedTensor from flattened tensors // This is a simplified reconstruction - in practice would need // to call the actual JaggedTensor constructor TORCH_INTERNAL_ASSERT_DEBUG_ONLY(obj.is_null()); TORCH_CHECK( flats.size() == 4, "JaggedTensor expects 4 tensor fields"); // Return a generic tuple for now - actual implementation would // need to construct the JaggedTensor custom class return c10::ivalue::Tuple::create(std::move(flats)); }, [](ITreeMapNoReturnFn fn, const c10::IValue& nested, const ITreeSpec& spec) { TORCH_CHECK(nested.isObject(), "Expected JaggedTensor object"); const auto& obj = nested.toObjectRef(); TORCH_CHECK( spec.children().size() == 4, "JaggedTensor should have 4 children"); // Apply function to each tensor field ivalueApply(fn, obj.getAttr("_values"), spec.children(0)); ivalueApply(fn, obj.getAttr("_weights"), spec.children(1)); ivalueApply(fn, obj.getAttr("_lengths"), spec.children(2)); ivalueApply(fn, obj.getAttr("_offsets"), spec.children(3)); }}); // Register KeyedJaggedTensor pytree node registerNode( "torchrec.sparse.jagged_tensor.KeyedJaggedTensor", NodeDef{ [](const c10::IValue& nested, const ITreeSpec& spec, std::vector& ivalues) { // KeyedJaggedTensor has 6 tensor fields plus keys context // Fields: _values, _weights, _lengths, _offsets, // _stride_per_key_per_rank, _inverse_indices tensor TORCH_CHECK( nested.isObject(), "Expected KeyedJaggedTensor object"); const auto& obj = nested.toObjectRef(); // Extract the tensor fields in order TORCH_CHECK( spec.children().size() == 6, "KeyedJaggedTensor should have 6 children"); // Flatten each tensor field itreeFlatten(obj.getAttr("_values"), spec.children(0), ivalues); itreeFlatten(obj.getAttr("_weights"), spec.children(1), ivalues); itreeFlatten(obj.getAttr("_lengths"), spec.children(2), ivalues); itreeFlatten(obj.getAttr("_offsets"), spec.children(3), ivalues); itreeFlatten( obj.getAttr("_stride_per_key_per_rank"), spec.children(4), ivalues); // For _inverse_indices, we need to extract the tensor part // (second element of tuple) auto inverse_indices = obj.getAttr("_inverse_indices"); if (!inverse_indices.isNone()) { auto tuple = inverse_indices.toTuple(); itreeFlatten(tuple->elements()[1], spec.children(5), ivalues); } else { // Handle None case by adding a null tensor itreeFlatten(c10::IValue(), spec.children(5), ivalues); } }, [](std::vector flats, const nlohmann::json& obj) -> c10::IValue { // Reconstruct KeyedJaggedTensor from flattened tensors and keys // context TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!obj.is_null()); TORCH_CHECK( flats.size() == 6, "KeyedJaggedTensor expects 6 tensor fields"); // The context should contain the keys list // Return a generic tuple for now - actual implementation would // need to construct the KeyedJaggedTensor custom class return c10::ivalue::Tuple::create(std::move(flats)); }, [](ITreeMapNoReturnFn fn, const c10::IValue& nested, const ITreeSpec& spec) { TORCH_CHECK( nested.isObject(), "Expected KeyedJaggedTensor object"); const auto& obj = nested.toObjectRef(); TORCH_CHECK( spec.children().size() == 6, "KeyedJaggedTensor should have 6 children"); // Apply function to each tensor field ivalueApply(fn, obj.getAttr("_values"), spec.children(0)); ivalueApply(fn, obj.getAttr("_weights"), spec.children(1)); ivalueApply(fn, obj.getAttr("_lengths"), spec.children(2)); ivalueApply(fn, obj.getAttr("_offsets"), spec.children(3)); ivalueApply( fn, obj.getAttr("_stride_per_key_per_rank"), spec.children(4)); // For _inverse_indices, we need to apply to the tensor part // (second element of tuple) auto inverse_indices = obj.getAttr("_inverse_indices"); if (!inverse_indices.isNone()) { auto tuple = inverse_indices.toTuple(); ivalueApply(fn, tuple->elements()[1], spec.children(5)); } else { // Handle None case ivalueApply(fn, c10::IValue(), spec.children(5)); } }, [](std::string_view context) { // Context contains the keys list as JSON return nlohmann::json::parse(context); }}); } bool hasNodeDef(std::string_view typeName) const { return registry_.find(std::string{typeName}) != registry_.end(); } const NodeDef& getNodeDef(std::string_view typeName) const { return registry_.at(std::string{typeName}); } void registerNode(std::string_view typeName, NodeDef nodeDef) { TORCH_CHECK(!hasNodeDef(typeName)); registry_.emplace(typeName, nodeDef); } private: std::unordered_map registry_; }; c10::Synchronized& getPytreeNodeRegistry() { static auto* registry = new c10::Synchronized(); return *registry; } ITreeSpec makeITreeSpec( const nlohmann::json& obj, const std::vector& values, int start) { TORCH_CHECK(obj.is_object()); TORCH_CHECK(obj.find("type") != obj.end()); if (obj["type"].is_null()) { TORCH_CHECK(obj["children_spec"].empty()); TORCH_CHECK(obj["context"].is_null()); const Value* value = values[start]; if (value) { bool isUsed = !value->users().empty(); return ITreeSpec(value, isUsed); } else { return ITreeSpec(value, false); } } const auto& name = obj["type"].get(); NodeDef nodeDefCache; getPytreeNodeRegistry().withLock([&](auto& registry) { TORCH_CHECK(registry.hasNodeDef(name), "Unknown pytree node type: ", name); nodeDefCache = registry.getNodeDef(name); }); auto context = nodeDefCache.contextLoadFn(obj["context"].get()); const auto& childrenSpec = obj["children_spec"]; TORCH_CHECK(childrenSpec.is_array()); std::vector children; int offset = 0; for (const auto& child : childrenSpec) { children.push_back(makeITreeSpec(child, values, start + offset)); // NOLINTNEXTLINE(*-narrowing-conversions) offset += children.back().numIValues(); } return ITreeSpec(name, context, std::move(children), nodeDefCache); } } // namespace void registerPytreeNode(std::string_view typeName, NodeDef nodeDef) { getPytreeNodeRegistry().withLock([&](auto& registry) { registry.registerNode(typeName, std::move(nodeDef)); }); } ITreeSpec itreeSpecLoads( std::string_view json, const std::vector& values) { const auto obj = nlohmann::json::parse(json); TORCH_CHECK(obj.is_array()); TORCH_CHECK(obj.size() == 2); TORCH_CHECK(obj[0].get() == kDefaultTreeSpecSerializationProtocol); auto result = makeITreeSpec(obj[1], values, 0); TORCH_CHECK(result.numIValues() == values.size()); return result; } c10::IValue itreeUnflatten( std::vector ivalues, const ITreeSpec& spec) { RECORD_USER_SCOPE("nativert::itreeUnflatten"); TORCH_CHECK(ivalues.size() == spec.numIValues()); if (spec.isIValue()) { return std::move(ivalues[0]); } auto unflattenFn = spec.nodeDefCache().unflattenFn; if (spec.allIValues()) { return unflattenFn(std::move(ivalues), spec.context()); } size_t start = 0; std::vector childrenPytrees; for (const auto& child : spec.children()) { if (child.isIValue()) { childrenPytrees.push_back(std::move(ivalues[start])); start++; continue; } size_t numIValues = child.numIValues(); std::vector slice( // NOLINTNEXTLINE(*-narrowing-conversions) std::make_move_iterator(ivalues.begin() + start), // NOLINTNEXTLINE(*-narrowing-conversions) std::make_move_iterator(ivalues.begin() + start + numIValues)); childrenPytrees.push_back(itreeUnflatten(std::move(slice), child)); start += numIValues; } return unflattenFn(std::move(childrenPytrees), spec.context()); } std::vector itreeFlatten( const c10::IValue& nested, const ITreeSpec& spec) { std::vector ivalues; ivalues.reserve(spec.numIValues()); itreeFlatten(nested, spec, ivalues); return ivalues; } std::vector itreeFlattenFromArgs( const std::vector& args, const std::unordered_map& kwargs, const ITreeSpec& spec) { RECORD_USER_SCOPE("nativert::itreeFlattenFromArgs"); TORCH_CHECK(!spec.isIValue()); TORCH_CHECK(spec.children().size() == 2); std::vector ivalues; ivalues.reserve(spec.numIValues()); const auto& specArgs = spec.children(0); TORCH_CHECK(!specArgs.isIValue()); TORCH_CHECK(specArgs.children().size() == args.size()); for (size_t i = 0; i < args.size(); i++) { itreeFlatten(args[i], specArgs.children(i), ivalues); } const auto& specKwargs = spec.children(1); TORCH_CHECK(!specKwargs.isIValue()); TORCH_CHECK(specKwargs.context().size() == kwargs.size()); for (size_t i = 0; i < specKwargs.context().size(); i++) { itreeFlatten( kwargs.at(specKwargs.context()[i].get_ref()), specKwargs.children(i), ivalues); } return ivalues; } void ivalueApplyFromArgs( ITreeMapNoReturnFn fn, const std::vector& args, const std::unordered_map& kwargs, const ITreeSpec& spec) { RECORD_USER_SCOPE("nativert::ivalueApplyFromArgs"); TORCH_CHECK(!spec.isIValue()); TORCH_CHECK(spec.children().size() == 2); const auto& specArgs = spec.children(0); TORCH_CHECK(!specArgs.isIValue()); TORCH_CHECK(specArgs.children().size() == args.size()); for (size_t i = 0; i < args.size(); i++) { ivalueApply(fn, args[i], specArgs.children(i)); } const auto& specKwargs = spec.children(1); TORCH_CHECK(!specKwargs.isIValue()); const auto& ctx = specKwargs.context(); TORCH_CHECK(ctx.size() == kwargs.size()); for (size_t i = 0; i < ctx.size(); i++) { ivalueApply( fn, kwargs.at(ctx[i].get_ref()), specKwargs.children(i)); } } std::vector itreeFlattenToTensorList( const c10::IValue& nested, const ITreeSpec& spec) { auto flats = itreeFlatten(nested, spec); std::vector tensors; tensors.reserve(flats.size()); for (const auto& flat : flats) { tensors.push_back(flat.toTensor()); } return tensors; } c10::IValue itreeMap( ITreeMapFn f, const c10::IValue& nested, const ITreeSpec& spec) { const auto flats = itreeFlatten(nested, spec); std::vector mapped; mapped.reserve(flats.size()); for (const auto& flat : flats) { mapped.push_back(f(flat)); } return itreeUnflatten(std::move(mapped), spec); } c10::IValue argsToIValue( const std::vector& args, const std::unordered_map& kwargs) { c10::Dict dict( c10::StringType::get(), c10::AnyType::get()); for (const auto& [key, arg] : kwargs) { dict.insert(key, arg); } return c10::ivalue::Tuple::create({c10::ivalue::Tuple::create(args), dict}); } std:: pair, std::unordered_map> itreeMapArgs( ITreeMapFn f, const std::vector& args, const std::unordered_map& kwargs, const ITreeSpec& spec) { const auto val = argsToIValue(args, kwargs); const auto mapVal = itreeMap(f, val, spec); auto mapArgs = mapVal.toTupleRef().elements()[0].toTupleRef().elements().vec(); std::unordered_map mapKwargs; for (const auto& entry : mapVal.toTupleRef().elements()[1].toGenericDict()) { mapKwargs.emplace(entry.key().toStringRef(), entry.value()); } return {std::move(mapArgs), std::move(mapKwargs)}; } void ivalueApply( ITreeMapNoReturnFn fn, const c10::IValue& nested, const ITreeSpec& spec) { if (spec.isIValue()) { if (spec.isUsed()) { fn(nested, spec.value()); } return; } auto ivalueApplyFn = spec.nodeDefCache().ivalueApplyFn; ivalueApplyFn(fn, nested, spec); } nlohmann::json defaultContextLoadFn(std::string_view context) { return nlohmann::json::parse(context); } ITreeSpec::ITreeSpec( std::string_view uniformName, nlohmann::json context, std::vector children, NodeDef nodeDefCache) : uniformName_(uniformName), context_(std::move(context)), children_(std::move(children)), nodeDefCache_(nodeDefCache), numIValues_(0), value_(nullptr), isUsed_(false) { for (auto& child : children_) { numIValues_ += child.numIValues(); allIValues_ &= child.isIValue(); isUsed_ |= child.isUsed(); } if (uniformName_ == "builtins.dict" || uniformName_ == "torch.fx.immutable_collections.immutable_dict") { for (const auto& keyObj : context_) { contextKeys_.push_back(dynamicToIValue(keyObj)); } } } c10::TypePtr ITreeSpec::toAtenType() const { if (isIValue()) { return c10::AnyType::get(); } else if (uniformName_ == "builtins.tuple") { std::vector childrenType; childrenType.reserve(children_.size()); for (const auto& childrenSpec : children_) { childrenType.emplace_back(childrenSpec.toAtenType()); } return c10::TupleType::create(std::move(childrenType)); } else if ( uniformName_ == "builtins.list" || uniformName_ == "torch.fx.immutable_collections.immutable_list") { if (children_.empty()) { return c10::ListType::create(c10::AnyType::get()); } else { return c10::ListType::create(children_[0].toAtenType()); } } else if ( uniformName_ == "builtins.dict" || uniformName_ == "torch.fx.immutable_collections.immutable_dict") { if (children_.empty()) { return c10::DictType::create(c10::AnyType::get(), c10::AnyType::get()); } else { return c10::DictType::create( dynamicToIValue(context_[0]).type(), children_[0].toAtenType()); } } else { TORCH_CHECK(false, "Unsupported uniform name: ", uniformName()); } } } // namespace torch::nativert::detail