mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This reverts commit a6c29b722772816804d54eed070fbb38450d3e6f. Reverted https://github.com/pytorch/pytorch/pull/106914 on behalf of https://github.com/izaitsevfb due to Causing metal breakage internally, see D48709279 ([comment](https://github.com/pytorch/pytorch/pull/106914#issuecomment-1696670027))
290 lines
10 KiB
C++
290 lines
10 KiB
C++
#include <c10/util/irange.h>
|
|
#include <torch/csrc/jit/runtime/argument_spec.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
void ArgumentSpecCreator::scan(
|
|
const TypePtr& typ,
|
|
size_t depth,
|
|
const WrittenSlots& written_slots) {
|
|
auto finishAggregate = [&](size_t pos) {
|
|
// it is possible after all the work we did to scan this aggregate,
|
|
// we found no tensors or optionals to specialize. In this case, just
|
|
// generate a skip for the whole aggregate.
|
|
bool any_spec = std::any_of(
|
|
instructions_.begin() + pos, instructions_.end(), [](Inst i) {
|
|
return i == SPECIALIZE_TENSOR || i == SPECIALIZE_OPTIONAL ||
|
|
i == SPECIALIZE_OPTIONAL_TENSOR;
|
|
});
|
|
if (!any_spec) {
|
|
instructions_[pos] = SKIP;
|
|
instructions_.resize(pos + 1);
|
|
} else {
|
|
instructions_.emplace_back(LEAVE);
|
|
}
|
|
};
|
|
// the simple vm that scans instructions_ has a limited stack depth,
|
|
// this prevents going deeper than that.
|
|
if (depth >= ARG_SPEC_DEPTH_LIMIT) {
|
|
instructions_.emplace_back(SKIP);
|
|
}
|
|
if (typ->isSubtypeOf(*TensorType::get())) {
|
|
num_tensors_++;
|
|
instructions_.emplace_back(SPECIALIZE_TENSOR);
|
|
} else if (typ->isSubtypeOf(*OptionalType::ofTensor())) {
|
|
num_tensors_++;
|
|
num_optionals_++;
|
|
instructions_.emplace_back(SPECIALIZE_OPTIONAL_TENSOR);
|
|
} else if (typ->kind() == TypeKind::OptionalType) {
|
|
// note that Optional[Tuple] or Optional[Class] will just register
|
|
// as optional (previously they didn't at all, so it's not a regression).
|
|
num_optionals_++;
|
|
instructions_.emplace_back(SPECIALIZE_OPTIONAL);
|
|
} else if (auto tup = typ->cast<TupleType>()) {
|
|
size_t pos = instructions_.size();
|
|
instructions_.emplace_back(ENTER_TUPLE);
|
|
for (const auto& elem : tup->containedTypes()) {
|
|
scan(elem, depth + 1, written_slots);
|
|
}
|
|
finishAggregate(pos);
|
|
} else if (auto cls = typ->cast<ClassType>()) {
|
|
size_t pos = instructions_.size();
|
|
instructions_.emplace_back(ENTER_OBJECT);
|
|
for (size_t i = 0; i < cls->numAttributes(); ++i) {
|
|
auto key =
|
|
cls->name()->qualifiedName() + cls->getAttributes().at(i).getName();
|
|
// it is only safe to specialize because someone might have written to it
|
|
if (!written_slots.count(key)) {
|
|
scan(cls->containedTypes().at(i), depth + 1, written_slots);
|
|
} else {
|
|
instructions_.emplace_back(SKIP);
|
|
}
|
|
}
|
|
finishAggregate(pos);
|
|
} else {
|
|
instructions_.emplace_back(SKIP);
|
|
}
|
|
};
|
|
|
|
// this is a coarse-grained guarantee that the slots of a class will not be
|
|
// modified by the function. It works fine for things that used be read-only
|
|
// modules, but will be overly conservative when some classes are written to.
|
|
// Doing alias analysis and looking for writes to the class would be more
|
|
// accurate.
|
|
static void scanWrittenSlots(
|
|
Block* block,
|
|
ArgumentSpecCreator::WrittenSlots& written_slots) {
|
|
for (Node* n : block->nodes()) {
|
|
if (n->kind() == prim::SetAttr) {
|
|
if (auto cls = n->inputs().at(0)->type()->cast<ClassType>()) {
|
|
written_slots.insert(cls->name()->qualifiedName() + n->s(attr::name));
|
|
}
|
|
}
|
|
for (Block* subblock : n->blocks()) {
|
|
scanWrittenSlots(subblock, written_slots);
|
|
}
|
|
if (n->hasAttribute(attr::Subgraph)) {
|
|
scanWrittenSlots(n->g(attr::Subgraph)->block(), written_slots);
|
|
}
|
|
}
|
|
}
|
|
|
|
ArgumentSpecCreator::ArgumentSpecCreator(Graph& graph)
|
|
: num_inputs_(graph.inputs().size()) {
|
|
WrittenSlots written_slots;
|
|
scanWrittenSlots(graph.block(), written_slots);
|
|
for (Value* input : graph.inputs()) {
|
|
scan(input->type(), 0, written_slots);
|
|
}
|
|
}
|
|
|
|
void ArgumentSpecCreator::dump() const {
|
|
for (Inst inst : instructions_) {
|
|
switch (inst) {
|
|
case LEAVE:
|
|
std::cout << "] ";
|
|
break;
|
|
case ENTER_TUPLE:
|
|
std::cout << "Tuple[";
|
|
break;
|
|
case ENTER_OBJECT:
|
|
std::cout << "Object[";
|
|
break;
|
|
case SKIP:
|
|
std::cout << "Skip ";
|
|
break;
|
|
case SPECIALIZE_TENSOR:
|
|
std::cout << "SpecializeTensor ";
|
|
break;
|
|
case SPECIALIZE_OPTIONAL_TENSOR:
|
|
std::cout << "SpecializeOptionalTensor ";
|
|
break;
|
|
case SPECIALIZE_OPTIONAL:
|
|
std::cout << "SpecializeOptional ";
|
|
break;
|
|
}
|
|
}
|
|
std::cout << "\n";
|
|
}
|
|
|
|
ArgumentSpec ArgumentSpecCreator::create(bool with_grad, const Stack& input)
|
|
const {
|
|
ArgumentSpec spec(num_tensors_, num_optionals_);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
|
const IValue* stack[ARG_SPEC_DEPTH_LIMIT]; // The stack of IValue lists
|
|
// The stack gets initialized with the input list
|
|
stack[0] = last(input, num_inputs_).begin();
|
|
size_t stack_top = 0; // offset to the top of the stack
|
|
for (Inst inst : instructions_) {
|
|
switch (inst) {
|
|
case SPECIALIZE_OPTIONAL_TENSOR: {
|
|
// consume a tensor optional and add to the argspec
|
|
// NOLINTNEXTLINE(clang-analyzer-core.uninitialized.Assign)
|
|
auto& arg = *stack[stack_top]++;
|
|
spec.addOptional(arg);
|
|
if (!arg.isNone()) {
|
|
spec.addTensor(arg, with_grad);
|
|
}
|
|
} break;
|
|
case SPECIALIZE_TENSOR:
|
|
// consume a tensor and add to the argspec
|
|
// NOLINTNEXTLINE(clang-analyzer-core.uninitialized.Assign)
|
|
spec.addTensor(*stack[stack_top]++, with_grad);
|
|
break;
|
|
case SPECIALIZE_OPTIONAL:
|
|
// consume a non-tensor optional and add to the argspec
|
|
// NOLINTNEXTLINE(clang-analyzer-core.uninitialized.Assign)
|
|
spec.addOptional(*stack[stack_top]++);
|
|
break;
|
|
case ENTER_TUPLE: {
|
|
// consume tuple
|
|
// NOLINTNEXTLINE(clang-analyzer-core.uninitialized.Assign)
|
|
const IValue* iv = stack[stack_top]++;
|
|
AT_ASSERT(iv->isTuple(), "Expected Tuple but got ", iv->tagKind());
|
|
auto p = *reinterpret_cast<const at::ivalue::Tuple* const*>(iv);
|
|
auto tup_ptr = &p->elements()[0];
|
|
// push list of tuple elements to the stack
|
|
stack[++stack_top] = tup_ptr;
|
|
} break;
|
|
case ENTER_OBJECT: {
|
|
// consume object
|
|
// NOLINTNEXTLINE(clang-analyzer-core.uninitialized.Assign)
|
|
const IValue* iv = stack[stack_top]++;
|
|
AT_ASSERT(iv->isObject(), "Expected Object but got ", iv->tagKind());
|
|
auto obj_ptr = &iv->toObjectRef().slots()[0];
|
|
// push list of object elements to the stack
|
|
stack[++stack_top] = obj_ptr;
|
|
} break;
|
|
case SKIP:
|
|
// consume and skip an element
|
|
// NOLINTNEXTLINE(clang-analyzer-core.uninitialized.Assign)
|
|
stack[stack_top]++;
|
|
break;
|
|
case LEAVE:
|
|
--stack_top;
|
|
break;
|
|
}
|
|
}
|
|
return spec;
|
|
}
|
|
|
|
// For every input of a given graph, returns a most detailed type that can be
|
|
// inferred for it based on this ArgumentSpec.
|
|
void ArgumentSpecCreator::specializeTypes(
|
|
Graph& graph,
|
|
const ArgumentSpec& spec) const {
|
|
auto input_types =
|
|
fmap(graph.inputs(), [](Value* input) { return input->type(); });
|
|
std::vector<std::vector<TypePtr>> result_stack;
|
|
result_stack.emplace_back();
|
|
std::vector<const TypePtr*> input_stack = {input_types.data()};
|
|
std::vector<std::function<TypePtr()>> aggregate_creators;
|
|
|
|
size_t tensor_arg_spec_offset =
|
|
0; // number of specialized tensors seen so far
|
|
size_t optional_arg_spec_offset =
|
|
0; // number of specialized optionals seen so far
|
|
|
|
for (Inst inst : instructions_) {
|
|
switch (inst) {
|
|
case SPECIALIZE_OPTIONAL_TENSOR: {
|
|
auto& input_type = *input_stack.back()++;
|
|
auto is_present = spec.isPresent(optional_arg_spec_offset++);
|
|
if (!is_present) {
|
|
result_stack.back().emplace_back(input_type);
|
|
break;
|
|
}
|
|
auto& arg = spec.tensorAt(tensor_arg_spec_offset++);
|
|
AT_ASSERT(arg.defined());
|
|
result_stack.back().emplace_back(arg.toType());
|
|
} break;
|
|
case SPECIALIZE_TENSOR: {
|
|
input_stack.back()++;
|
|
auto& arg = spec.tensorAt(tensor_arg_spec_offset++);
|
|
if (!arg.defined()) {
|
|
result_stack.back().emplace_back(TensorType::get()->withUndefined());
|
|
} else {
|
|
result_stack.back().emplace_back(arg.toType());
|
|
}
|
|
} break;
|
|
case SPECIALIZE_OPTIONAL: {
|
|
auto is_present = spec.isPresent(optional_arg_spec_offset++);
|
|
auto ot = (*input_stack.back()++)->expect<OptionalType>();
|
|
if (!is_present) {
|
|
result_stack.back().emplace_back(ot);
|
|
} else {
|
|
result_stack.back().emplace_back(ot->getElementType());
|
|
}
|
|
} break;
|
|
case ENTER_TUPLE: {
|
|
auto tup = (*input_stack.back()++)->expect<TupleType>();
|
|
input_stack.emplace_back(tup->elements().data());
|
|
result_stack.emplace_back();
|
|
aggregate_creators.emplace_back(
|
|
[&] { return TupleType::create(result_stack.back()); });
|
|
} break;
|
|
case ENTER_OBJECT: {
|
|
auto cls = (*input_stack.back()++)->expect<ClassType>();
|
|
input_stack.emplace_back(cls->containedTypes().data());
|
|
result_stack.emplace_back();
|
|
aggregate_creators.emplace_back(
|
|
[&result_stack, cls] { return cls->refine(result_stack.back()); });
|
|
} break;
|
|
case SKIP:
|
|
result_stack.back().emplace_back(*input_stack.back()++);
|
|
break;
|
|
case LEAVE:
|
|
TypePtr result = aggregate_creators.back()();
|
|
result_stack.pop_back();
|
|
aggregate_creators.pop_back();
|
|
input_stack.pop_back();
|
|
result_stack.back().emplace_back(std::move(result));
|
|
break;
|
|
}
|
|
}
|
|
AT_ASSERT(result_stack.size() == 1);
|
|
// FIXME: by doing this only on the inputs, we only capture graph inputs and
|
|
// not
|
|
// optionals in tuples or objects. For that to work, we would have
|
|
// to investigate the uses of the inputs in detail to change the
|
|
// accesses/ unwrapping
|
|
auto inputs = graph.inputs();
|
|
for (const auto i : c10::irange(inputs.size())) {
|
|
auto t = result_stack.back()[i];
|
|
if (auto ot = t->cast<OptionalType>()) {
|
|
// if an optional input hasn't been specialized above, it is None
|
|
// so we disconnect the input here and replace its uses with
|
|
// a constant
|
|
WithInsertPoint guard(*graph.nodes().begin());
|
|
auto c = graph.insertConstant({});
|
|
inputs[i]->replaceAllUsesWith(c);
|
|
} else {
|
|
inputs[i]->setType(t);
|
|
}
|
|
}
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|