mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Differential Revision: D80656182 Pull Request resolved: https://github.com/pytorch/pytorch/pull/161114 Approved by: https://github.com/henryoier
		
			
				
	
	
		
			628 lines
		
	
	
		
			23 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			628 lines
		
	
	
		
			23 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| #include <ATen/record_function.h>
 | |
| #include <torch/nativert/detail/ITree.h>
 | |
| 
 | |
| #include <iterator>
 | |
| #include <string_view>
 | |
| 
 | |
| #include <ATen/core/ivalue.h>
 | |
| #include <c10/util/Synchronized.h>
 | |
| #include <nlohmann/json.hpp>
 | |
| 
 | |
| namespace torch::nativert::detail {
 | |
| 
 | |
| namespace {
 | |
| inline constexpr int kDefaultTreeSpecSerializationProtocol = 1;
 | |
| 
 | |
| c10::IValue dynamicToIValue(const nlohmann::json& obj) {
 | |
|   if (obj.is_string()) {
 | |
|     return obj.get<std::string>();
 | |
|   } else if (obj.is_number_integer()) {
 | |
|     return obj.get<int64_t>();
 | |
|   } else {
 | |
|     TORCH_CHECK(false, "Unsupported dynamic type: ", obj);
 | |
|   }
 | |
| }
 | |
| 
 | |
| void itreeFlatten(
 | |
|     const c10::IValue& nested,
 | |
|     const ITreeSpec& spec,
 | |
|     std::vector<c10::IValue>& ivalues) {
 | |
|   if (spec.isIValue()) {
 | |
|     ivalues.push_back(nested);
 | |
|     return;
 | |
|   }
 | |
|   auto flattenFn = spec.nodeDefCache().flattenFn;
 | |
|   flattenFn(nested, spec, ivalues);
 | |
| }
 | |
| 
 | |
| class PytreeNodeRegistry {
 | |
|  public:
 | |
|   PytreeNodeRegistry() {
 | |
|     // Add some law of physics here.
 | |
|     registerNode(
 | |
|         "builtins.tuple",
 | |
|         NodeDef{
 | |
|             [](const c10::IValue& nested,
 | |
|                const ITreeSpec& spec,
 | |
|                std::vector<c10::IValue>& ivalues) {
 | |
|               const auto& tuple = nested.toTupleRef().elements();
 | |
|               TORCH_CHECK(tuple.size() == spec.children().size());
 | |
|               for (size_t i = 0; i < tuple.size(); i++) {
 | |
|                 itreeFlatten(tuple[i], spec.children(i), ivalues);
 | |
|               }
 | |
|             },
 | |
|             [](std::vector<c10::IValue> flats,
 | |
|                const nlohmann::json& obj) -> c10::IValue {
 | |
|               TORCH_INTERNAL_ASSERT_DEBUG_ONLY(obj.is_null());
 | |
|               return c10::ivalue::Tuple::create(std::move(flats));
 | |
|             },
 | |
|             [](ITreeMapNoReturnFn fn,
 | |
|                const c10::IValue& nested,
 | |
|                const ITreeSpec& spec) {
 | |
|               const auto& tuple = nested.toTupleRef().elements();
 | |
|               TORCH_CHECK(tuple.size() == spec.children().size());
 | |
|               for (size_t i = 0; i < tuple.size(); i++) {
 | |
|                 ivalueApply(fn, tuple[i], spec.children(i));
 | |
|               }
 | |
|             }});
 | |
|     const auto& tupleNodeDef = getNodeDef("builtins.tuple");
 | |
|     registerNode(
 | |
|         "collections.namedtuple",
 | |
|         NodeDef{
 | |
|             tupleNodeDef.flattenFn,
 | |
|             [](std::vector<c10::IValue> flats,
 | |
|                const nlohmann::json& obj) -> c10::IValue {
 | |
|               TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!obj.is_null());
 | |
|               return c10::ivalue::Tuple::create(std::move(flats));
 | |
|             },
 | |
|             tupleNodeDef.ivalueApplyFn,
 | |
|             [](std::string_view context) { return nlohmann::json{context}; }});
 | |
|     registerNode(
 | |
|         "builtins.list",
 | |
|         NodeDef{
 | |
|             [](const c10::IValue& nested,
 | |
|                const ITreeSpec& spec,
 | |
|                std::vector<c10::IValue>& ivalues) {
 | |
|               auto list = nested.toListRef();
 | |
|               for (size_t i = 0; i < list.size(); i++) {
 | |
|                 itreeFlatten(list[i], spec.children(i), ivalues);
 | |
|               }
 | |
|             },
 | |
|             [](std::vector<c10::IValue> flats,
 | |
|                const nlohmann::json& obj) -> c10::IValue {
 | |
|               TORCH_INTERNAL_ASSERT_DEBUG_ONLY(obj.is_null());
 | |
|               c10::List<c10::IValue> list(c10::AnyType::get());
 | |
|               list.reserve(flats.size());
 | |
|               for (auto& flat : flats) {
 | |
|                 list.push_back(std::move(flat));
 | |
|               }
 | |
|               return list;
 | |
|             },
 | |
|             [](ITreeMapNoReturnFn fn,
 | |
|                const c10::IValue& nested,
 | |
|                const ITreeSpec& spec) {
 | |
|               auto list = nested.toListRef();
 | |
|               for (size_t i = 0; i < list.size(); i++) {
 | |
|                 ivalueApply(fn, list[i], spec.children(i));
 | |
|               }
 | |
|             }});
 | |
|     registerNode(
 | |
|         "torch.fx.immutable_collections.immutable_list",
 | |
|         getNodeDef("builtins.list"));
 | |
|     registerNode(
 | |
|         "builtins.dict",
 | |
|         NodeDef{
 | |
|             [](const c10::IValue& nested,
 | |
|                const ITreeSpec& spec,
 | |
|                std::vector<c10::IValue>& ivalues) {
 | |
|               auto dict = nested.toGenericDict();
 | |
|               const auto& contextKeys = spec.contextKeys();
 | |
|               // allow the dict size less than the spec, missing key will be
 | |
|               // filled with empty tensor
 | |
|               TORCH_CHECK(dict.size() <= contextKeys.size());
 | |
|               size_t i = 0;
 | |
|               for (const auto& key : contextKeys) {
 | |
|                 auto it = dict.find(key);
 | |
| 
 | |
|                 if (it != dict.end()) {
 | |
|                   itreeFlatten(it->value(), spec.children(i), ivalues);
 | |
|                 } else {
 | |
|                   // when we have a dict with missing keys, we fill the missing
 | |
|                   // ivalues with an empty tensor which is required for
 | |
|                   // validation
 | |
|                   for (size_t j = 0; j < spec.children(i).numIValues(); ++j) {
 | |
|                     at::Tensor empty_tensor;
 | |
|                     ivalues.emplace_back(std::move(empty_tensor));
 | |
|                   }
 | |
|                 }
 | |
|                 i++;
 | |
|               }
 | |
|             },
 | |
|             [](std::vector<c10::IValue> flats,
 | |
|                const nlohmann::json& obj) -> c10::IValue {
 | |
|               c10::Dict<c10::IValue, c10::IValue> dict(
 | |
|                   c10::AnyType::get(), c10::AnyType::get());
 | |
|               TORCH_CHECK(obj.is_array());
 | |
|               TORCH_CHECK(obj.size() == flats.size());
 | |
|               dict.reserve(flats.size());
 | |
|               for (size_t i = 0; i < flats.size(); i++) {
 | |
|                 dict.insert(dynamicToIValue(obj[i]), std::move(flats[i]));
 | |
|               }
 | |
|               return dict;
 | |
|             },
 | |
|             [](ITreeMapNoReturnFn fn,
 | |
|                const c10::IValue& nested,
 | |
|                const ITreeSpec& spec) {
 | |
|               auto dict = nested.toGenericDict();
 | |
|               const auto& contextKeys = spec.contextKeys();
 | |
| 
 | |
|               size_t i = 0;
 | |
|               for (const auto& key : contextKeys) {
 | |
|                 if (spec.children(i).isUsed()) {
 | |
|                   auto it = dict.find(key);
 | |
|                   if (it != dict.end()) {
 | |
|                     ivalueApply(fn, it->value(), spec.children(i));
 | |
|                   } else {
 | |
|                     TORCH_CHECK(false, "input arg is missing key ", key);
 | |
|                   }
 | |
|                 }
 | |
|                 i++;
 | |
|               }
 | |
|             }});
 | |
|     registerNode(
 | |
|         "torch.fx.immutable_collections.immutable_dict",
 | |
|         getNodeDef("builtins.dict"));
 | |
|     // Register JaggedTensor pytree node
 | |
|     registerNode(
 | |
|         "torchrec.sparse.jagged_tensor.JaggedTensor",
 | |
|         NodeDef{
 | |
|             [](const c10::IValue& nested,
 | |
|                const ITreeSpec& spec,
 | |
|                std::vector<c10::IValue>& ivalues) {
 | |
|               // JaggedTensor has 4 fields: _values, _weights, _lengths,
 | |
|               // _offsets All fields are optional torch.Tensor except _values
 | |
|               TORCH_CHECK(nested.isObject(), "Expected JaggedTensor object");
 | |
|               const auto& obj = nested.toObjectRef();
 | |
| 
 | |
|               // Extract the tensor fields in order: _values, _weights,
 | |
|               // _lengths, _offsets
 | |
|               TORCH_CHECK(
 | |
|                   spec.children().size() == 4,
 | |
|                   "JaggedTensor should have 4 children");
 | |
| 
 | |
|               // Flatten each tensor field
 | |
|               itreeFlatten(obj.getAttr("_values"), spec.children(0), ivalues);
 | |
|               itreeFlatten(obj.getAttr("_weights"), spec.children(1), ivalues);
 | |
|               itreeFlatten(obj.getAttr("_lengths"), spec.children(2), ivalues);
 | |
|               itreeFlatten(obj.getAttr("_offsets"), spec.children(3), ivalues);
 | |
|             },
 | |
|             [](std::vector<c10::IValue> flats,
 | |
|                const nlohmann::json& obj) -> c10::IValue {
 | |
|               // Reconstruct JaggedTensor from flattened tensors
 | |
|               // This is a simplified reconstruction - in practice would need
 | |
|               // to call the actual JaggedTensor constructor
 | |
|               TORCH_INTERNAL_ASSERT_DEBUG_ONLY(obj.is_null());
 | |
|               TORCH_CHECK(
 | |
|                   flats.size() == 4, "JaggedTensor expects 4 tensor fields");
 | |
| 
 | |
|               // Return a generic tuple for now - actual implementation would
 | |
|               // need to construct the JaggedTensor custom class
 | |
|               return c10::ivalue::Tuple::create(std::move(flats));
 | |
|             },
 | |
|             [](ITreeMapNoReturnFn fn,
 | |
|                const c10::IValue& nested,
 | |
|                const ITreeSpec& spec) {
 | |
|               TORCH_CHECK(nested.isObject(), "Expected JaggedTensor object");
 | |
|               const auto& obj = nested.toObjectRef();
 | |
| 
 | |
|               TORCH_CHECK(
 | |
|                   spec.children().size() == 4,
 | |
|                   "JaggedTensor should have 4 children");
 | |
| 
 | |
|               // Apply function to each tensor field
 | |
|               ivalueApply(fn, obj.getAttr("_values"), spec.children(0));
 | |
|               ivalueApply(fn, obj.getAttr("_weights"), spec.children(1));
 | |
|               ivalueApply(fn, obj.getAttr("_lengths"), spec.children(2));
 | |
|               ivalueApply(fn, obj.getAttr("_offsets"), spec.children(3));
 | |
|             }});
 | |
| 
 | |
|     // Register KeyedJaggedTensor pytree node
 | |
|     registerNode(
 | |
|         "torchrec.sparse.jagged_tensor.KeyedJaggedTensor",
 | |
|         NodeDef{
 | |
|             [](const c10::IValue& nested,
 | |
|                const ITreeSpec& spec,
 | |
|                std::vector<c10::IValue>& ivalues) {
 | |
|               // KeyedJaggedTensor has 6 tensor fields plus keys context
 | |
|               // Fields: _values, _weights, _lengths, _offsets,
 | |
|               // _stride_per_key_per_rank, _inverse_indices tensor
 | |
|               TORCH_CHECK(
 | |
|                   nested.isObject(), "Expected KeyedJaggedTensor object");
 | |
|               const auto& obj = nested.toObjectRef();
 | |
| 
 | |
|               // Extract the tensor fields in order
 | |
|               TORCH_CHECK(
 | |
|                   spec.children().size() == 6,
 | |
|                   "KeyedJaggedTensor should have 6 children");
 | |
| 
 | |
|               // Flatten each tensor field
 | |
|               itreeFlatten(obj.getAttr("_values"), spec.children(0), ivalues);
 | |
|               itreeFlatten(obj.getAttr("_weights"), spec.children(1), ivalues);
 | |
|               itreeFlatten(obj.getAttr("_lengths"), spec.children(2), ivalues);
 | |
|               itreeFlatten(obj.getAttr("_offsets"), spec.children(3), ivalues);
 | |
|               itreeFlatten(
 | |
|                   obj.getAttr("_stride_per_key_per_rank"),
 | |
|                   spec.children(4),
 | |
|                   ivalues);
 | |
|               // For _inverse_indices, we need to extract the tensor part
 | |
|               // (second element of tuple)
 | |
|               auto inverse_indices = obj.getAttr("_inverse_indices");
 | |
|               if (!inverse_indices.isNone()) {
 | |
|                 auto tuple = inverse_indices.toTuple();
 | |
|                 itreeFlatten(tuple->elements()[1], spec.children(5), ivalues);
 | |
|               } else {
 | |
|                 // Handle None case by adding a null tensor
 | |
|                 itreeFlatten(c10::IValue(), spec.children(5), ivalues);
 | |
|               }
 | |
|             },
 | |
|             [](std::vector<c10::IValue> flats,
 | |
|                const nlohmann::json& obj) -> c10::IValue {
 | |
|               // Reconstruct KeyedJaggedTensor from flattened tensors and keys
 | |
|               // context
 | |
|               TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!obj.is_null());
 | |
|               TORCH_CHECK(
 | |
|                   flats.size() == 6,
 | |
|                   "KeyedJaggedTensor expects 6 tensor fields");
 | |
| 
 | |
|               // The context should contain the keys list
 | |
|               // Return a generic tuple for now - actual implementation would
 | |
|               // need to construct the KeyedJaggedTensor custom class
 | |
|               return c10::ivalue::Tuple::create(std::move(flats));
 | |
|             },
 | |
|             [](ITreeMapNoReturnFn fn,
 | |
|                const c10::IValue& nested,
 | |
|                const ITreeSpec& spec) {
 | |
|               TORCH_CHECK(
 | |
|                   nested.isObject(), "Expected KeyedJaggedTensor object");
 | |
|               const auto& obj = nested.toObjectRef();
 | |
| 
 | |
|               TORCH_CHECK(
 | |
|                   spec.children().size() == 6,
 | |
|                   "KeyedJaggedTensor should have 6 children");
 | |
| 
 | |
|               // Apply function to each tensor field
 | |
|               ivalueApply(fn, obj.getAttr("_values"), spec.children(0));
 | |
|               ivalueApply(fn, obj.getAttr("_weights"), spec.children(1));
 | |
|               ivalueApply(fn, obj.getAttr("_lengths"), spec.children(2));
 | |
|               ivalueApply(fn, obj.getAttr("_offsets"), spec.children(3));
 | |
|               ivalueApply(
 | |
|                   fn,
 | |
|                   obj.getAttr("_stride_per_key_per_rank"),
 | |
|                   spec.children(4));
 | |
|               // For _inverse_indices, we need to apply to the tensor part
 | |
|               // (second element of tuple)
 | |
|               auto inverse_indices = obj.getAttr("_inverse_indices");
 | |
|               if (!inverse_indices.isNone()) {
 | |
|                 auto tuple = inverse_indices.toTuple();
 | |
|                 ivalueApply(fn, tuple->elements()[1], spec.children(5));
 | |
|               } else {
 | |
|                 // Handle None case
 | |
|                 ivalueApply(fn, c10::IValue(), spec.children(5));
 | |
|               }
 | |
|             },
 | |
|             [](std::string_view context) {
 | |
|               // Context contains the keys list as JSON
 | |
|               return nlohmann::json::parse(context);
 | |
|             }});
 | |
|   }
 | |
|   bool hasNodeDef(std::string_view typeName) const {
 | |
|     return registry_.find(std::string{typeName}) != registry_.end();
 | |
|   }
 | |
|   const NodeDef& getNodeDef(std::string_view typeName) const {
 | |
|     return registry_.at(std::string{typeName});
 | |
|   }
 | |
|   void registerNode(std::string_view typeName, NodeDef nodeDef) {
 | |
|     TORCH_CHECK(!hasNodeDef(typeName));
 | |
|     registry_.emplace(typeName, nodeDef);
 | |
|   }
 | |
| 
 | |
|  private:
 | |
|   std::unordered_map<std::string, NodeDef> registry_;
 | |
| };
 | |
| 
 | |
| c10::Synchronized<PytreeNodeRegistry>& getPytreeNodeRegistry() {
 | |
|   static auto* registry = new c10::Synchronized<PytreeNodeRegistry>();
 | |
|   return *registry;
 | |
| }
 | |
| 
 | |
| ITreeSpec makeITreeSpec(
 | |
|     const nlohmann::json& obj,
 | |
|     const std::vector<const Value*>& values,
 | |
|     int start) {
 | |
|   TORCH_CHECK(obj.is_object());
 | |
|   TORCH_CHECK(obj.find("type") != obj.end());
 | |
|   if (obj["type"].is_null()) {
 | |
|     TORCH_CHECK(obj["children_spec"].empty());
 | |
|     TORCH_CHECK(obj["context"].is_null());
 | |
| 
 | |
|     const Value* value = values[start];
 | |
|     if (value) {
 | |
|       bool isUsed = !value->users().empty();
 | |
|       return ITreeSpec(value, isUsed);
 | |
|     } else {
 | |
|       return ITreeSpec(value, false);
 | |
|     }
 | |
|   }
 | |
|   const auto& name = obj["type"].get<std::string>();
 | |
|   NodeDef nodeDefCache;
 | |
|   getPytreeNodeRegistry().withLock([&](auto& registry) {
 | |
|     TORCH_CHECK(registry.hasNodeDef(name), "Unknown pytree node type: ", name);
 | |
|     nodeDefCache = registry.getNodeDef(name);
 | |
|   });
 | |
|   auto context = nodeDefCache.contextLoadFn(obj["context"].get<std::string>());
 | |
|   const auto& childrenSpec = obj["children_spec"];
 | |
|   TORCH_CHECK(childrenSpec.is_array());
 | |
|   std::vector<ITreeSpec> children;
 | |
|   int offset = 0;
 | |
|   for (const auto& child : childrenSpec) {
 | |
|     children.push_back(makeITreeSpec(child, values, start + offset));
 | |
|     // NOLINTNEXTLINE(*-narrowing-conversions)
 | |
|     offset += children.back().numIValues();
 | |
|   }
 | |
| 
 | |
|   return ITreeSpec(name, context, std::move(children), nodeDefCache);
 | |
| }
 | |
| 
 | |
| } // namespace
 | |
| 
 | |
| void registerPytreeNode(std::string_view typeName, NodeDef nodeDef) {
 | |
|   getPytreeNodeRegistry().withLock([&](auto& registry) {
 | |
|     registry.registerNode(typeName, std::move(nodeDef));
 | |
|   });
 | |
| }
 | |
| 
 | |
| ITreeSpec itreeSpecLoads(
 | |
|     std::string_view json,
 | |
|     const std::vector<const Value*>& values) {
 | |
|   const auto obj = nlohmann::json::parse(json);
 | |
|   TORCH_CHECK(obj.is_array());
 | |
|   TORCH_CHECK(obj.size() == 2);
 | |
|   TORCH_CHECK(obj[0].get<int64_t>() == kDefaultTreeSpecSerializationProtocol);
 | |
|   auto result = makeITreeSpec(obj[1], values, 0);
 | |
| 
 | |
|   TORCH_CHECK(result.numIValues() == values.size());
 | |
|   return result;
 | |
| }
 | |
| 
 | |
| c10::IValue itreeUnflatten(
 | |
|     std::vector<c10::IValue> ivalues,
 | |
|     const ITreeSpec& spec) {
 | |
|   RECORD_USER_SCOPE("nativert::itreeUnflatten");
 | |
|   TORCH_CHECK(ivalues.size() == spec.numIValues());
 | |
|   if (spec.isIValue()) {
 | |
|     return std::move(ivalues[0]);
 | |
|   }
 | |
|   auto unflattenFn = spec.nodeDefCache().unflattenFn;
 | |
|   if (spec.allIValues()) {
 | |
|     return unflattenFn(std::move(ivalues), spec.context());
 | |
|   }
 | |
|   size_t start = 0;
 | |
|   std::vector<c10::IValue> childrenPytrees;
 | |
|   for (const auto& child : spec.children()) {
 | |
|     if (child.isIValue()) {
 | |
|       childrenPytrees.push_back(std::move(ivalues[start]));
 | |
|       start++;
 | |
|       continue;
 | |
|     }
 | |
|     size_t numIValues = child.numIValues();
 | |
|     std::vector<c10::IValue> slice(
 | |
|         // NOLINTNEXTLINE(*-narrowing-conversions)
 | |
|         std::make_move_iterator(ivalues.begin() + start),
 | |
|         // NOLINTNEXTLINE(*-narrowing-conversions)
 | |
|         std::make_move_iterator(ivalues.begin() + start + numIValues));
 | |
|     childrenPytrees.push_back(itreeUnflatten(std::move(slice), child));
 | |
|     start += numIValues;
 | |
|   }
 | |
|   return unflattenFn(std::move(childrenPytrees), spec.context());
 | |
| }
 | |
| 
 | |
| std::vector<c10::IValue> itreeFlatten(
 | |
|     const c10::IValue& nested,
 | |
|     const ITreeSpec& spec) {
 | |
|   std::vector<c10::IValue> ivalues;
 | |
|   ivalues.reserve(spec.numIValues());
 | |
|   itreeFlatten(nested, spec, ivalues);
 | |
|   return ivalues;
 | |
| }
 | |
| 
 | |
| std::vector<c10::IValue> itreeFlattenFromArgs(
 | |
|     const std::vector<c10::IValue>& args,
 | |
|     const std::unordered_map<std::string, c10::IValue>& kwargs,
 | |
|     const ITreeSpec& spec) {
 | |
|   RECORD_USER_SCOPE("nativert::itreeFlattenFromArgs");
 | |
|   TORCH_CHECK(!spec.isIValue());
 | |
|   TORCH_CHECK(spec.children().size() == 2);
 | |
| 
 | |
|   std::vector<c10::IValue> ivalues;
 | |
|   ivalues.reserve(spec.numIValues());
 | |
|   const auto& specArgs = spec.children(0);
 | |
|   TORCH_CHECK(!specArgs.isIValue());
 | |
|   TORCH_CHECK(specArgs.children().size() == args.size());
 | |
|   for (size_t i = 0; i < args.size(); i++) {
 | |
|     itreeFlatten(args[i], specArgs.children(i), ivalues);
 | |
|   }
 | |
| 
 | |
|   const auto& specKwargs = spec.children(1);
 | |
|   TORCH_CHECK(!specKwargs.isIValue());
 | |
|   TORCH_CHECK(specKwargs.context().size() == kwargs.size());
 | |
|   for (size_t i = 0; i < specKwargs.context().size(); i++) {
 | |
|     itreeFlatten(
 | |
|         kwargs.at(specKwargs.context()[i].get_ref<const std::string&>()),
 | |
|         specKwargs.children(i),
 | |
|         ivalues);
 | |
|   }
 | |
|   return ivalues;
 | |
| }
 | |
| 
 | |
| void ivalueApplyFromArgs(
 | |
|     ITreeMapNoReturnFn fn,
 | |
|     const std::vector<c10::IValue>& args,
 | |
|     const std::unordered_map<std::string, c10::IValue>& kwargs,
 | |
|     const ITreeSpec& spec) {
 | |
|   RECORD_USER_SCOPE("nativert::ivalueApplyFromArgs");
 | |
|   TORCH_CHECK(!spec.isIValue());
 | |
|   TORCH_CHECK(spec.children().size() == 2);
 | |
| 
 | |
|   const auto& specArgs = spec.children(0);
 | |
|   TORCH_CHECK(!specArgs.isIValue());
 | |
|   TORCH_CHECK(specArgs.children().size() == args.size());
 | |
|   for (size_t i = 0; i < args.size(); i++) {
 | |
|     ivalueApply(fn, args[i], specArgs.children(i));
 | |
|   }
 | |
| 
 | |
|   const auto& specKwargs = spec.children(1);
 | |
|   TORCH_CHECK(!specKwargs.isIValue());
 | |
| 
 | |
|   const auto& ctx = specKwargs.context();
 | |
|   TORCH_CHECK(ctx.size() == kwargs.size());
 | |
| 
 | |
|   for (size_t i = 0; i < ctx.size(); i++) {
 | |
|     ivalueApply(
 | |
|         fn,
 | |
|         kwargs.at(ctx[i].get_ref<const std::string&>()),
 | |
|         specKwargs.children(i));
 | |
|   }
 | |
| }
 | |
| 
 | |
| std::vector<at::Tensor> itreeFlattenToTensorList(
 | |
|     const c10::IValue& nested,
 | |
|     const ITreeSpec& spec) {
 | |
|   auto flats = itreeFlatten(nested, spec);
 | |
|   std::vector<at::Tensor> tensors;
 | |
|   tensors.reserve(flats.size());
 | |
|   for (const auto& flat : flats) {
 | |
|     tensors.push_back(flat.toTensor());
 | |
|   }
 | |
|   return tensors;
 | |
| }
 | |
| 
 | |
| c10::IValue itreeMap(
 | |
|     ITreeMapFn f,
 | |
|     const c10::IValue& nested,
 | |
|     const ITreeSpec& spec) {
 | |
|   const auto flats = itreeFlatten(nested, spec);
 | |
|   std::vector<c10::IValue> mapped;
 | |
|   mapped.reserve(flats.size());
 | |
|   for (const auto& flat : flats) {
 | |
|     mapped.push_back(f(flat));
 | |
|   }
 | |
|   return itreeUnflatten(std::move(mapped), spec);
 | |
| }
 | |
| 
 | |
| c10::IValue argsToIValue(
 | |
|     const std::vector<c10::IValue>& args,
 | |
|     const std::unordered_map<std::string, c10::IValue>& kwargs) {
 | |
|   c10::Dict<c10::IValue, c10::IValue> dict(
 | |
|       c10::StringType::get(), c10::AnyType::get());
 | |
|   for (const auto& [key, arg] : kwargs) {
 | |
|     dict.insert(key, arg);
 | |
|   }
 | |
|   return c10::ivalue::Tuple::create({c10::ivalue::Tuple::create(args), dict});
 | |
| }
 | |
| 
 | |
| std::
 | |
|     pair<std::vector<c10::IValue>, std::unordered_map<std::string, c10::IValue>>
 | |
|     itreeMapArgs(
 | |
|         ITreeMapFn f,
 | |
|         const std::vector<c10::IValue>& args,
 | |
|         const std::unordered_map<std::string, c10::IValue>& kwargs,
 | |
|         const ITreeSpec& spec) {
 | |
|   const auto val = argsToIValue(args, kwargs);
 | |
|   const auto mapVal = itreeMap(f, val, spec);
 | |
|   auto mapArgs =
 | |
|       mapVal.toTupleRef().elements()[0].toTupleRef().elements().vec();
 | |
|   std::unordered_map<std::string, c10::IValue> mapKwargs;
 | |
|   for (const auto& entry : mapVal.toTupleRef().elements()[1].toGenericDict()) {
 | |
|     mapKwargs.emplace(entry.key().toStringRef(), entry.value());
 | |
|   }
 | |
|   return {std::move(mapArgs), std::move(mapKwargs)};
 | |
| }
 | |
| 
 | |
| void ivalueApply(
 | |
|     ITreeMapNoReturnFn fn,
 | |
|     const c10::IValue& nested,
 | |
|     const ITreeSpec& spec) {
 | |
|   if (spec.isIValue()) {
 | |
|     if (spec.isUsed()) {
 | |
|       fn(nested, spec.value());
 | |
|     }
 | |
|     return;
 | |
|   }
 | |
|   auto ivalueApplyFn = spec.nodeDefCache().ivalueApplyFn;
 | |
|   ivalueApplyFn(fn, nested, spec);
 | |
| }
 | |
| 
 | |
| nlohmann::json defaultContextLoadFn(std::string_view context) {
 | |
|   return nlohmann::json::parse(context);
 | |
| }
 | |
| 
 | |
| ITreeSpec::ITreeSpec(
 | |
|     std::string_view uniformName,
 | |
|     nlohmann::json context,
 | |
|     std::vector<ITreeSpec> children,
 | |
|     NodeDef nodeDefCache)
 | |
|     : uniformName_(uniformName),
 | |
|       context_(std::move(context)),
 | |
|       children_(std::move(children)),
 | |
|       nodeDefCache_(nodeDefCache),
 | |
|       numIValues_(0),
 | |
|       value_(nullptr),
 | |
|       isUsed_(false) {
 | |
|   for (auto& child : children_) {
 | |
|     numIValues_ += child.numIValues();
 | |
|     allIValues_ &= child.isIValue();
 | |
|     isUsed_ |= child.isUsed();
 | |
|   }
 | |
| 
 | |
|   if (uniformName_ == "builtins.dict" ||
 | |
|       uniformName_ == "torch.fx.immutable_collections.immutable_dict") {
 | |
|     for (const auto& keyObj : context_) {
 | |
|       contextKeys_.push_back(dynamicToIValue(keyObj));
 | |
|     }
 | |
|   }
 | |
| }
 | |
| 
 | |
| c10::TypePtr ITreeSpec::toAtenType() const {
 | |
|   if (isIValue()) {
 | |
|     return c10::AnyType::get();
 | |
|   } else if (uniformName_ == "builtins.tuple") {
 | |
|     std::vector<c10::TypePtr> childrenType;
 | |
|     childrenType.reserve(children_.size());
 | |
|     for (const auto& childrenSpec : children_) {
 | |
|       childrenType.emplace_back(childrenSpec.toAtenType());
 | |
|     }
 | |
|     return c10::TupleType::create(std::move(childrenType));
 | |
|   } else if (
 | |
|       uniformName_ == "builtins.list" ||
 | |
|       uniformName_ == "torch.fx.immutable_collections.immutable_list") {
 | |
|     if (children_.empty()) {
 | |
|       return c10::ListType::create(c10::AnyType::get());
 | |
|     } else {
 | |
|       return c10::ListType::create(children_[0].toAtenType());
 | |
|     }
 | |
|   } else if (
 | |
|       uniformName_ == "builtins.dict" ||
 | |
|       uniformName_ == "torch.fx.immutable_collections.immutable_dict") {
 | |
|     if (children_.empty()) {
 | |
|       return c10::DictType::create(c10::AnyType::get(), c10::AnyType::get());
 | |
|     } else {
 | |
|       return c10::DictType::create(
 | |
|           dynamicToIValue(context_[0]).type(), children_[0].toAtenType());
 | |
|     }
 | |
|   } else {
 | |
|     TORCH_CHECK(false, "Unsupported uniform name: ", uniformName());
 | |
|   }
 | |
| }
 | |
| 
 | |
| } // namespace torch::nativert::detail
 |