mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Differential Revision: D80656182 Pull Request resolved: https://github.com/pytorch/pytorch/pull/161114 Approved by: https://github.com/henryoier
628 lines
23 KiB
C++
628 lines
23 KiB
C++
#include <ATen/record_function.h>
|
|
#include <torch/nativert/detail/ITree.h>
|
|
|
|
#include <iterator>
|
|
#include <string_view>
|
|
|
|
#include <ATen/core/ivalue.h>
|
|
#include <c10/util/Synchronized.h>
|
|
#include <nlohmann/json.hpp>
|
|
|
|
namespace torch::nativert::detail {
|
|
|
|
namespace {
|
|
inline constexpr int kDefaultTreeSpecSerializationProtocol = 1;
|
|
|
|
c10::IValue dynamicToIValue(const nlohmann::json& obj) {
|
|
if (obj.is_string()) {
|
|
return obj.get<std::string>();
|
|
} else if (obj.is_number_integer()) {
|
|
return obj.get<int64_t>();
|
|
} else {
|
|
TORCH_CHECK(false, "Unsupported dynamic type: ", obj);
|
|
}
|
|
}
|
|
|
|
void itreeFlatten(
|
|
const c10::IValue& nested,
|
|
const ITreeSpec& spec,
|
|
std::vector<c10::IValue>& ivalues) {
|
|
if (spec.isIValue()) {
|
|
ivalues.push_back(nested);
|
|
return;
|
|
}
|
|
auto flattenFn = spec.nodeDefCache().flattenFn;
|
|
flattenFn(nested, spec, ivalues);
|
|
}
|
|
|
|
class PytreeNodeRegistry {
|
|
public:
|
|
PytreeNodeRegistry() {
|
|
// Add some law of physics here.
|
|
registerNode(
|
|
"builtins.tuple",
|
|
NodeDef{
|
|
[](const c10::IValue& nested,
|
|
const ITreeSpec& spec,
|
|
std::vector<c10::IValue>& ivalues) {
|
|
const auto& tuple = nested.toTupleRef().elements();
|
|
TORCH_CHECK(tuple.size() == spec.children().size());
|
|
for (size_t i = 0; i < tuple.size(); i++) {
|
|
itreeFlatten(tuple[i], spec.children(i), ivalues);
|
|
}
|
|
},
|
|
[](std::vector<c10::IValue> flats,
|
|
const nlohmann::json& obj) -> c10::IValue {
|
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(obj.is_null());
|
|
return c10::ivalue::Tuple::create(std::move(flats));
|
|
},
|
|
[](ITreeMapNoReturnFn fn,
|
|
const c10::IValue& nested,
|
|
const ITreeSpec& spec) {
|
|
const auto& tuple = nested.toTupleRef().elements();
|
|
TORCH_CHECK(tuple.size() == spec.children().size());
|
|
for (size_t i = 0; i < tuple.size(); i++) {
|
|
ivalueApply(fn, tuple[i], spec.children(i));
|
|
}
|
|
}});
|
|
const auto& tupleNodeDef = getNodeDef("builtins.tuple");
|
|
registerNode(
|
|
"collections.namedtuple",
|
|
NodeDef{
|
|
tupleNodeDef.flattenFn,
|
|
[](std::vector<c10::IValue> flats,
|
|
const nlohmann::json& obj) -> c10::IValue {
|
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!obj.is_null());
|
|
return c10::ivalue::Tuple::create(std::move(flats));
|
|
},
|
|
tupleNodeDef.ivalueApplyFn,
|
|
[](std::string_view context) { return nlohmann::json{context}; }});
|
|
registerNode(
|
|
"builtins.list",
|
|
NodeDef{
|
|
[](const c10::IValue& nested,
|
|
const ITreeSpec& spec,
|
|
std::vector<c10::IValue>& ivalues) {
|
|
auto list = nested.toListRef();
|
|
for (size_t i = 0; i < list.size(); i++) {
|
|
itreeFlatten(list[i], spec.children(i), ivalues);
|
|
}
|
|
},
|
|
[](std::vector<c10::IValue> flats,
|
|
const nlohmann::json& obj) -> c10::IValue {
|
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(obj.is_null());
|
|
c10::List<c10::IValue> list(c10::AnyType::get());
|
|
list.reserve(flats.size());
|
|
for (auto& flat : flats) {
|
|
list.push_back(std::move(flat));
|
|
}
|
|
return list;
|
|
},
|
|
[](ITreeMapNoReturnFn fn,
|
|
const c10::IValue& nested,
|
|
const ITreeSpec& spec) {
|
|
auto list = nested.toListRef();
|
|
for (size_t i = 0; i < list.size(); i++) {
|
|
ivalueApply(fn, list[i], spec.children(i));
|
|
}
|
|
}});
|
|
registerNode(
|
|
"torch.fx.immutable_collections.immutable_list",
|
|
getNodeDef("builtins.list"));
|
|
registerNode(
|
|
"builtins.dict",
|
|
NodeDef{
|
|
[](const c10::IValue& nested,
|
|
const ITreeSpec& spec,
|
|
std::vector<c10::IValue>& ivalues) {
|
|
auto dict = nested.toGenericDict();
|
|
const auto& contextKeys = spec.contextKeys();
|
|
// allow the dict size less than the spec, missing key will be
|
|
// filled with empty tensor
|
|
TORCH_CHECK(dict.size() <= contextKeys.size());
|
|
size_t i = 0;
|
|
for (const auto& key : contextKeys) {
|
|
auto it = dict.find(key);
|
|
|
|
if (it != dict.end()) {
|
|
itreeFlatten(it->value(), spec.children(i), ivalues);
|
|
} else {
|
|
// when we have a dict with missing keys, we fill the missing
|
|
// ivalues with an empty tensor which is required for
|
|
// validation
|
|
for (size_t j = 0; j < spec.children(i).numIValues(); ++j) {
|
|
at::Tensor empty_tensor;
|
|
ivalues.emplace_back(std::move(empty_tensor));
|
|
}
|
|
}
|
|
i++;
|
|
}
|
|
},
|
|
[](std::vector<c10::IValue> flats,
|
|
const nlohmann::json& obj) -> c10::IValue {
|
|
c10::Dict<c10::IValue, c10::IValue> dict(
|
|
c10::AnyType::get(), c10::AnyType::get());
|
|
TORCH_CHECK(obj.is_array());
|
|
TORCH_CHECK(obj.size() == flats.size());
|
|
dict.reserve(flats.size());
|
|
for (size_t i = 0; i < flats.size(); i++) {
|
|
dict.insert(dynamicToIValue(obj[i]), std::move(flats[i]));
|
|
}
|
|
return dict;
|
|
},
|
|
[](ITreeMapNoReturnFn fn,
|
|
const c10::IValue& nested,
|
|
const ITreeSpec& spec) {
|
|
auto dict = nested.toGenericDict();
|
|
const auto& contextKeys = spec.contextKeys();
|
|
|
|
size_t i = 0;
|
|
for (const auto& key : contextKeys) {
|
|
if (spec.children(i).isUsed()) {
|
|
auto it = dict.find(key);
|
|
if (it != dict.end()) {
|
|
ivalueApply(fn, it->value(), spec.children(i));
|
|
} else {
|
|
TORCH_CHECK(false, "input arg is missing key ", key);
|
|
}
|
|
}
|
|
i++;
|
|
}
|
|
}});
|
|
registerNode(
|
|
"torch.fx.immutable_collections.immutable_dict",
|
|
getNodeDef("builtins.dict"));
|
|
// Register JaggedTensor pytree node
|
|
registerNode(
|
|
"torchrec.sparse.jagged_tensor.JaggedTensor",
|
|
NodeDef{
|
|
[](const c10::IValue& nested,
|
|
const ITreeSpec& spec,
|
|
std::vector<c10::IValue>& ivalues) {
|
|
// JaggedTensor has 4 fields: _values, _weights, _lengths,
|
|
// _offsets All fields are optional torch.Tensor except _values
|
|
TORCH_CHECK(nested.isObject(), "Expected JaggedTensor object");
|
|
const auto& obj = nested.toObjectRef();
|
|
|
|
// Extract the tensor fields in order: _values, _weights,
|
|
// _lengths, _offsets
|
|
TORCH_CHECK(
|
|
spec.children().size() == 4,
|
|
"JaggedTensor should have 4 children");
|
|
|
|
// Flatten each tensor field
|
|
itreeFlatten(obj.getAttr("_values"), spec.children(0), ivalues);
|
|
itreeFlatten(obj.getAttr("_weights"), spec.children(1), ivalues);
|
|
itreeFlatten(obj.getAttr("_lengths"), spec.children(2), ivalues);
|
|
itreeFlatten(obj.getAttr("_offsets"), spec.children(3), ivalues);
|
|
},
|
|
[](std::vector<c10::IValue> flats,
|
|
const nlohmann::json& obj) -> c10::IValue {
|
|
// Reconstruct JaggedTensor from flattened tensors
|
|
// This is a simplified reconstruction - in practice would need
|
|
// to call the actual JaggedTensor constructor
|
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(obj.is_null());
|
|
TORCH_CHECK(
|
|
flats.size() == 4, "JaggedTensor expects 4 tensor fields");
|
|
|
|
// Return a generic tuple for now - actual implementation would
|
|
// need to construct the JaggedTensor custom class
|
|
return c10::ivalue::Tuple::create(std::move(flats));
|
|
},
|
|
[](ITreeMapNoReturnFn fn,
|
|
const c10::IValue& nested,
|
|
const ITreeSpec& spec) {
|
|
TORCH_CHECK(nested.isObject(), "Expected JaggedTensor object");
|
|
const auto& obj = nested.toObjectRef();
|
|
|
|
TORCH_CHECK(
|
|
spec.children().size() == 4,
|
|
"JaggedTensor should have 4 children");
|
|
|
|
// Apply function to each tensor field
|
|
ivalueApply(fn, obj.getAttr("_values"), spec.children(0));
|
|
ivalueApply(fn, obj.getAttr("_weights"), spec.children(1));
|
|
ivalueApply(fn, obj.getAttr("_lengths"), spec.children(2));
|
|
ivalueApply(fn, obj.getAttr("_offsets"), spec.children(3));
|
|
}});
|
|
|
|
// Register KeyedJaggedTensor pytree node
|
|
registerNode(
|
|
"torchrec.sparse.jagged_tensor.KeyedJaggedTensor",
|
|
NodeDef{
|
|
[](const c10::IValue& nested,
|
|
const ITreeSpec& spec,
|
|
std::vector<c10::IValue>& ivalues) {
|
|
// KeyedJaggedTensor has 6 tensor fields plus keys context
|
|
// Fields: _values, _weights, _lengths, _offsets,
|
|
// _stride_per_key_per_rank, _inverse_indices tensor
|
|
TORCH_CHECK(
|
|
nested.isObject(), "Expected KeyedJaggedTensor object");
|
|
const auto& obj = nested.toObjectRef();
|
|
|
|
// Extract the tensor fields in order
|
|
TORCH_CHECK(
|
|
spec.children().size() == 6,
|
|
"KeyedJaggedTensor should have 6 children");
|
|
|
|
// Flatten each tensor field
|
|
itreeFlatten(obj.getAttr("_values"), spec.children(0), ivalues);
|
|
itreeFlatten(obj.getAttr("_weights"), spec.children(1), ivalues);
|
|
itreeFlatten(obj.getAttr("_lengths"), spec.children(2), ivalues);
|
|
itreeFlatten(obj.getAttr("_offsets"), spec.children(3), ivalues);
|
|
itreeFlatten(
|
|
obj.getAttr("_stride_per_key_per_rank"),
|
|
spec.children(4),
|
|
ivalues);
|
|
// For _inverse_indices, we need to extract the tensor part
|
|
// (second element of tuple)
|
|
auto inverse_indices = obj.getAttr("_inverse_indices");
|
|
if (!inverse_indices.isNone()) {
|
|
auto tuple = inverse_indices.toTuple();
|
|
itreeFlatten(tuple->elements()[1], spec.children(5), ivalues);
|
|
} else {
|
|
// Handle None case by adding a null tensor
|
|
itreeFlatten(c10::IValue(), spec.children(5), ivalues);
|
|
}
|
|
},
|
|
[](std::vector<c10::IValue> flats,
|
|
const nlohmann::json& obj) -> c10::IValue {
|
|
// Reconstruct KeyedJaggedTensor from flattened tensors and keys
|
|
// context
|
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!obj.is_null());
|
|
TORCH_CHECK(
|
|
flats.size() == 6,
|
|
"KeyedJaggedTensor expects 6 tensor fields");
|
|
|
|
// The context should contain the keys list
|
|
// Return a generic tuple for now - actual implementation would
|
|
// need to construct the KeyedJaggedTensor custom class
|
|
return c10::ivalue::Tuple::create(std::move(flats));
|
|
},
|
|
[](ITreeMapNoReturnFn fn,
|
|
const c10::IValue& nested,
|
|
const ITreeSpec& spec) {
|
|
TORCH_CHECK(
|
|
nested.isObject(), "Expected KeyedJaggedTensor object");
|
|
const auto& obj = nested.toObjectRef();
|
|
|
|
TORCH_CHECK(
|
|
spec.children().size() == 6,
|
|
"KeyedJaggedTensor should have 6 children");
|
|
|
|
// Apply function to each tensor field
|
|
ivalueApply(fn, obj.getAttr("_values"), spec.children(0));
|
|
ivalueApply(fn, obj.getAttr("_weights"), spec.children(1));
|
|
ivalueApply(fn, obj.getAttr("_lengths"), spec.children(2));
|
|
ivalueApply(fn, obj.getAttr("_offsets"), spec.children(3));
|
|
ivalueApply(
|
|
fn,
|
|
obj.getAttr("_stride_per_key_per_rank"),
|
|
spec.children(4));
|
|
// For _inverse_indices, we need to apply to the tensor part
|
|
// (second element of tuple)
|
|
auto inverse_indices = obj.getAttr("_inverse_indices");
|
|
if (!inverse_indices.isNone()) {
|
|
auto tuple = inverse_indices.toTuple();
|
|
ivalueApply(fn, tuple->elements()[1], spec.children(5));
|
|
} else {
|
|
// Handle None case
|
|
ivalueApply(fn, c10::IValue(), spec.children(5));
|
|
}
|
|
},
|
|
[](std::string_view context) {
|
|
// Context contains the keys list as JSON
|
|
return nlohmann::json::parse(context);
|
|
}});
|
|
}
|
|
bool hasNodeDef(std::string_view typeName) const {
|
|
return registry_.find(std::string{typeName}) != registry_.end();
|
|
}
|
|
const NodeDef& getNodeDef(std::string_view typeName) const {
|
|
return registry_.at(std::string{typeName});
|
|
}
|
|
void registerNode(std::string_view typeName, NodeDef nodeDef) {
|
|
TORCH_CHECK(!hasNodeDef(typeName));
|
|
registry_.emplace(typeName, nodeDef);
|
|
}
|
|
|
|
private:
|
|
std::unordered_map<std::string, NodeDef> registry_;
|
|
};
|
|
|
|
c10::Synchronized<PytreeNodeRegistry>& getPytreeNodeRegistry() {
|
|
static auto* registry = new c10::Synchronized<PytreeNodeRegistry>();
|
|
return *registry;
|
|
}
|
|
|
|
ITreeSpec makeITreeSpec(
|
|
const nlohmann::json& obj,
|
|
const std::vector<const Value*>& values,
|
|
int start) {
|
|
TORCH_CHECK(obj.is_object());
|
|
TORCH_CHECK(obj.find("type") != obj.end());
|
|
if (obj["type"].is_null()) {
|
|
TORCH_CHECK(obj["children_spec"].empty());
|
|
TORCH_CHECK(obj["context"].is_null());
|
|
|
|
const Value* value = values[start];
|
|
if (value) {
|
|
bool isUsed = !value->users().empty();
|
|
return ITreeSpec(value, isUsed);
|
|
} else {
|
|
return ITreeSpec(value, false);
|
|
}
|
|
}
|
|
const auto& name = obj["type"].get<std::string>();
|
|
NodeDef nodeDefCache;
|
|
getPytreeNodeRegistry().withLock([&](auto& registry) {
|
|
TORCH_CHECK(registry.hasNodeDef(name), "Unknown pytree node type: ", name);
|
|
nodeDefCache = registry.getNodeDef(name);
|
|
});
|
|
auto context = nodeDefCache.contextLoadFn(obj["context"].get<std::string>());
|
|
const auto& childrenSpec = obj["children_spec"];
|
|
TORCH_CHECK(childrenSpec.is_array());
|
|
std::vector<ITreeSpec> children;
|
|
int offset = 0;
|
|
for (const auto& child : childrenSpec) {
|
|
children.push_back(makeITreeSpec(child, values, start + offset));
|
|
// NOLINTNEXTLINE(*-narrowing-conversions)
|
|
offset += children.back().numIValues();
|
|
}
|
|
|
|
return ITreeSpec(name, context, std::move(children), nodeDefCache);
|
|
}
|
|
|
|
} // namespace
|
|
|
|
void registerPytreeNode(std::string_view typeName, NodeDef nodeDef) {
|
|
getPytreeNodeRegistry().withLock([&](auto& registry) {
|
|
registry.registerNode(typeName, std::move(nodeDef));
|
|
});
|
|
}
|
|
|
|
ITreeSpec itreeSpecLoads(
|
|
std::string_view json,
|
|
const std::vector<const Value*>& values) {
|
|
const auto obj = nlohmann::json::parse(json);
|
|
TORCH_CHECK(obj.is_array());
|
|
TORCH_CHECK(obj.size() == 2);
|
|
TORCH_CHECK(obj[0].get<int64_t>() == kDefaultTreeSpecSerializationProtocol);
|
|
auto result = makeITreeSpec(obj[1], values, 0);
|
|
|
|
TORCH_CHECK(result.numIValues() == values.size());
|
|
return result;
|
|
}
|
|
|
|
c10::IValue itreeUnflatten(
|
|
std::vector<c10::IValue> ivalues,
|
|
const ITreeSpec& spec) {
|
|
RECORD_USER_SCOPE("nativert::itreeUnflatten");
|
|
TORCH_CHECK(ivalues.size() == spec.numIValues());
|
|
if (spec.isIValue()) {
|
|
return std::move(ivalues[0]);
|
|
}
|
|
auto unflattenFn = spec.nodeDefCache().unflattenFn;
|
|
if (spec.allIValues()) {
|
|
return unflattenFn(std::move(ivalues), spec.context());
|
|
}
|
|
size_t start = 0;
|
|
std::vector<c10::IValue> childrenPytrees;
|
|
for (const auto& child : spec.children()) {
|
|
if (child.isIValue()) {
|
|
childrenPytrees.push_back(std::move(ivalues[start]));
|
|
start++;
|
|
continue;
|
|
}
|
|
size_t numIValues = child.numIValues();
|
|
std::vector<c10::IValue> slice(
|
|
// NOLINTNEXTLINE(*-narrowing-conversions)
|
|
std::make_move_iterator(ivalues.begin() + start),
|
|
// NOLINTNEXTLINE(*-narrowing-conversions)
|
|
std::make_move_iterator(ivalues.begin() + start + numIValues));
|
|
childrenPytrees.push_back(itreeUnflatten(std::move(slice), child));
|
|
start += numIValues;
|
|
}
|
|
return unflattenFn(std::move(childrenPytrees), spec.context());
|
|
}
|
|
|
|
std::vector<c10::IValue> itreeFlatten(
|
|
const c10::IValue& nested,
|
|
const ITreeSpec& spec) {
|
|
std::vector<c10::IValue> ivalues;
|
|
ivalues.reserve(spec.numIValues());
|
|
itreeFlatten(nested, spec, ivalues);
|
|
return ivalues;
|
|
}
|
|
|
|
std::vector<c10::IValue> itreeFlattenFromArgs(
|
|
const std::vector<c10::IValue>& args,
|
|
const std::unordered_map<std::string, c10::IValue>& kwargs,
|
|
const ITreeSpec& spec) {
|
|
RECORD_USER_SCOPE("nativert::itreeFlattenFromArgs");
|
|
TORCH_CHECK(!spec.isIValue());
|
|
TORCH_CHECK(spec.children().size() == 2);
|
|
|
|
std::vector<c10::IValue> ivalues;
|
|
ivalues.reserve(spec.numIValues());
|
|
const auto& specArgs = spec.children(0);
|
|
TORCH_CHECK(!specArgs.isIValue());
|
|
TORCH_CHECK(specArgs.children().size() == args.size());
|
|
for (size_t i = 0; i < args.size(); i++) {
|
|
itreeFlatten(args[i], specArgs.children(i), ivalues);
|
|
}
|
|
|
|
const auto& specKwargs = spec.children(1);
|
|
TORCH_CHECK(!specKwargs.isIValue());
|
|
TORCH_CHECK(specKwargs.context().size() == kwargs.size());
|
|
for (size_t i = 0; i < specKwargs.context().size(); i++) {
|
|
itreeFlatten(
|
|
kwargs.at(specKwargs.context()[i].get_ref<const std::string&>()),
|
|
specKwargs.children(i),
|
|
ivalues);
|
|
}
|
|
return ivalues;
|
|
}
|
|
|
|
void ivalueApplyFromArgs(
|
|
ITreeMapNoReturnFn fn,
|
|
const std::vector<c10::IValue>& args,
|
|
const std::unordered_map<std::string, c10::IValue>& kwargs,
|
|
const ITreeSpec& spec) {
|
|
RECORD_USER_SCOPE("nativert::ivalueApplyFromArgs");
|
|
TORCH_CHECK(!spec.isIValue());
|
|
TORCH_CHECK(spec.children().size() == 2);
|
|
|
|
const auto& specArgs = spec.children(0);
|
|
TORCH_CHECK(!specArgs.isIValue());
|
|
TORCH_CHECK(specArgs.children().size() == args.size());
|
|
for (size_t i = 0; i < args.size(); i++) {
|
|
ivalueApply(fn, args[i], specArgs.children(i));
|
|
}
|
|
|
|
const auto& specKwargs = spec.children(1);
|
|
TORCH_CHECK(!specKwargs.isIValue());
|
|
|
|
const auto& ctx = specKwargs.context();
|
|
TORCH_CHECK(ctx.size() == kwargs.size());
|
|
|
|
for (size_t i = 0; i < ctx.size(); i++) {
|
|
ivalueApply(
|
|
fn,
|
|
kwargs.at(ctx[i].get_ref<const std::string&>()),
|
|
specKwargs.children(i));
|
|
}
|
|
}
|
|
|
|
std::vector<at::Tensor> itreeFlattenToTensorList(
|
|
const c10::IValue& nested,
|
|
const ITreeSpec& spec) {
|
|
auto flats = itreeFlatten(nested, spec);
|
|
std::vector<at::Tensor> tensors;
|
|
tensors.reserve(flats.size());
|
|
for (const auto& flat : flats) {
|
|
tensors.push_back(flat.toTensor());
|
|
}
|
|
return tensors;
|
|
}
|
|
|
|
c10::IValue itreeMap(
|
|
ITreeMapFn f,
|
|
const c10::IValue& nested,
|
|
const ITreeSpec& spec) {
|
|
const auto flats = itreeFlatten(nested, spec);
|
|
std::vector<c10::IValue> mapped;
|
|
mapped.reserve(flats.size());
|
|
for (const auto& flat : flats) {
|
|
mapped.push_back(f(flat));
|
|
}
|
|
return itreeUnflatten(std::move(mapped), spec);
|
|
}
|
|
|
|
c10::IValue argsToIValue(
|
|
const std::vector<c10::IValue>& args,
|
|
const std::unordered_map<std::string, c10::IValue>& kwargs) {
|
|
c10::Dict<c10::IValue, c10::IValue> dict(
|
|
c10::StringType::get(), c10::AnyType::get());
|
|
for (const auto& [key, arg] : kwargs) {
|
|
dict.insert(key, arg);
|
|
}
|
|
return c10::ivalue::Tuple::create({c10::ivalue::Tuple::create(args), dict});
|
|
}
|
|
|
|
std::
|
|
pair<std::vector<c10::IValue>, std::unordered_map<std::string, c10::IValue>>
|
|
itreeMapArgs(
|
|
ITreeMapFn f,
|
|
const std::vector<c10::IValue>& args,
|
|
const std::unordered_map<std::string, c10::IValue>& kwargs,
|
|
const ITreeSpec& spec) {
|
|
const auto val = argsToIValue(args, kwargs);
|
|
const auto mapVal = itreeMap(f, val, spec);
|
|
auto mapArgs =
|
|
mapVal.toTupleRef().elements()[0].toTupleRef().elements().vec();
|
|
std::unordered_map<std::string, c10::IValue> mapKwargs;
|
|
for (const auto& entry : mapVal.toTupleRef().elements()[1].toGenericDict()) {
|
|
mapKwargs.emplace(entry.key().toStringRef(), entry.value());
|
|
}
|
|
return {std::move(mapArgs), std::move(mapKwargs)};
|
|
}
|
|
|
|
void ivalueApply(
|
|
ITreeMapNoReturnFn fn,
|
|
const c10::IValue& nested,
|
|
const ITreeSpec& spec) {
|
|
if (spec.isIValue()) {
|
|
if (spec.isUsed()) {
|
|
fn(nested, spec.value());
|
|
}
|
|
return;
|
|
}
|
|
auto ivalueApplyFn = spec.nodeDefCache().ivalueApplyFn;
|
|
ivalueApplyFn(fn, nested, spec);
|
|
}
|
|
|
|
nlohmann::json defaultContextLoadFn(std::string_view context) {
|
|
return nlohmann::json::parse(context);
|
|
}
|
|
|
|
ITreeSpec::ITreeSpec(
|
|
std::string_view uniformName,
|
|
nlohmann::json context,
|
|
std::vector<ITreeSpec> children,
|
|
NodeDef nodeDefCache)
|
|
: uniformName_(uniformName),
|
|
context_(std::move(context)),
|
|
children_(std::move(children)),
|
|
nodeDefCache_(nodeDefCache),
|
|
numIValues_(0),
|
|
value_(nullptr),
|
|
isUsed_(false) {
|
|
for (auto& child : children_) {
|
|
numIValues_ += child.numIValues();
|
|
allIValues_ &= child.isIValue();
|
|
isUsed_ |= child.isUsed();
|
|
}
|
|
|
|
if (uniformName_ == "builtins.dict" ||
|
|
uniformName_ == "torch.fx.immutable_collections.immutable_dict") {
|
|
for (const auto& keyObj : context_) {
|
|
contextKeys_.push_back(dynamicToIValue(keyObj));
|
|
}
|
|
}
|
|
}
|
|
|
|
c10::TypePtr ITreeSpec::toAtenType() const {
|
|
if (isIValue()) {
|
|
return c10::AnyType::get();
|
|
} else if (uniformName_ == "builtins.tuple") {
|
|
std::vector<c10::TypePtr> childrenType;
|
|
childrenType.reserve(children_.size());
|
|
for (const auto& childrenSpec : children_) {
|
|
childrenType.emplace_back(childrenSpec.toAtenType());
|
|
}
|
|
return c10::TupleType::create(std::move(childrenType));
|
|
} else if (
|
|
uniformName_ == "builtins.list" ||
|
|
uniformName_ == "torch.fx.immutable_collections.immutable_list") {
|
|
if (children_.empty()) {
|
|
return c10::ListType::create(c10::AnyType::get());
|
|
} else {
|
|
return c10::ListType::create(children_[0].toAtenType());
|
|
}
|
|
} else if (
|
|
uniformName_ == "builtins.dict" ||
|
|
uniformName_ == "torch.fx.immutable_collections.immutable_dict") {
|
|
if (children_.empty()) {
|
|
return c10::DictType::create(c10::AnyType::get(), c10::AnyType::get());
|
|
} else {
|
|
return c10::DictType::create(
|
|
dynamicToIValue(context_[0]).type(), children_[0].toAtenType());
|
|
}
|
|
} else {
|
|
TORCH_CHECK(false, "Unsupported uniform name: ", uniformName());
|
|
}
|
|
}
|
|
|
|
} // namespace torch::nativert::detail
|