mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 00:14:54 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/68161 Continuing rollout of borrowing outputs for native ops. ghstack-source-id: 143424920 Test Plan: Compare CMF local_ro perf again. Previous diff: ``` I1110 22:05:23.245435 113949 PyTorchPredictorBenchLib.cpp:274] PyTorch run finished. Milliseconds per iter: 1.03272. Iters per second: 968.313 I1110 22:05:23.822196 113949 PyTorchPredictorBenchLib.cpp:274] PyTorch run finished. Milliseconds per iter: 1.06478. Iters per second: 939.163 I1110 22:05:24.395256 113949 PyTorchPredictorBenchLib.cpp:274] PyTorch run finished. Milliseconds per iter: 1.035. Iters per second: 966.186 I1110 22:05:24.964169 113949 PyTorchPredictorBenchLib.cpp:274] PyTorch run finished. Milliseconds per iter: 1.02786. Iters per second: 972.898 I1110 22:05:25.536558 113949 PyTorchPredictorBenchLib.cpp:274] PyTorch run finished. Milliseconds per iter: 1.03205. Iters per second: 968.946 I1110 22:05:26.109027 113949 PyTorchPredictorBenchLib.cpp:274] PyTorch run finished. Milliseconds per iter: 1.04256. Iters per second: 959.174 I1110 22:05:26.679611 113949 PyTorchPredictorBenchLib.cpp:274] PyTorch run finished. Milliseconds per iter: 1.03245. Iters per second: 968.567 I1110 22:05:27.253048 113949 PyTorchPredictorBenchLib.cpp:274] PyTorch run finished. Milliseconds per iter: 1.04493. Iters per second: 957.005 I1110 22:05:27.822629 113949 PyTorchPredictorBenchLib.cpp:274] PyTorch run finished. Milliseconds per iter: 1.0299. Iters per second: 970.971 I1110 22:05:28.393326 113949 PyTorchPredictorBenchLib.cpp:274] PyTorch run finished. Milliseconds per iter: 1.03039. Iters per second: 970.509 I1110 22:05:28.393368 113949 PyTorchPredictorBenchLib.cpp:285] Mean milliseconds per iter: 1.03726, standard deviation: 0.0111053 ``` This diff: ``` I1110 22:18:48.453075 191647 PyTorchPredictorBenchLib.cpp:274] PyTorch run finished. Milliseconds per iter: 0.931188. Iters per second: 1073.9 I1110 22:18:48.967614 191647 PyTorchPredictorBenchLib.cpp:274] PyTorch run finished. Milliseconds per iter: 0.933196. Iters per second: 1071.59 I1110 22:18:49.483338 191647 PyTorchPredictorBenchLib.cpp:274] PyTorch run finished. Milliseconds per iter: 0.932087. Iters per second: 1072.86 I1110 22:18:49.997144 191647 PyTorchPredictorBenchLib.cpp:274] PyTorch run finished. Milliseconds per iter: 0.930877. Iters per second: 1074.26 I1110 22:18:50.529383 191647 PyTorchPredictorBenchLib.cpp:274] PyTorch run finished. Milliseconds per iter: 0.936981. Iters per second: 1067.26 I1110 22:18:51.085038 191647 PyTorchPredictorBenchLib.cpp:274] PyTorch run finished. Milliseconds per iter: 0.953214. Iters per second: 1049.08 I1110 22:18:51.607192 191647 PyTorchPredictorBenchLib.cpp:274] PyTorch run finished. Milliseconds per iter: 0.940719. Iters per second: 1063.02 I1110 22:18:52.126169 191647 PyTorchPredictorBenchLib.cpp:274] PyTorch run finished. Milliseconds per iter: 0.942638. Iters per second: 1060.85 I1110 22:18:52.644445 191647 PyTorchPredictorBenchLib.cpp:274] PyTorch run finished. Milliseconds per iter: 0.937574. Iters per second: 1066.58 I1110 22:18:53.163486 191647 PyTorchPredictorBenchLib.cpp:274] PyTorch run finished. Milliseconds per iter: 0.941636. Iters per second: 1061.98 I1110 22:18:53.163537 191647 PyTorchPredictorBenchLib.cpp:285] Mean milliseconds per iter: 0.938011, standard deviation: 0.00691196 ``` 0.099 (9.5%!) usec/iter improvement over previous diff Reviewed By: hlu1 Differential Revision: D32347900 fbshipit-source-id: 8169ebcadf1248e555a18bbffa99eef6cac1ba85
566 lines
19 KiB
C++
566 lines
19 KiB
C++
#include <torch/csrc/jit/runtime/static/ops.h>
|
|
|
|
#include <ATen/CPUFunctions.h>
|
|
#include <ATen/NativeFunctions.h>
|
|
#include <ATen/ScalarOps.h>
|
|
#include <ATen/TensorUtils.h>
|
|
#include <ATen/native/IndexingUtils.h>
|
|
#include <ATen/native/Resize.h>
|
|
#include <ATen/native/TensorAdvancedIndexing.h>
|
|
#include <c10/util/irange.h>
|
|
#include <torch/csrc/jit/ir/ir.h>
|
|
#include <torch/csrc/jit/runtime/register_ops_utils.h>
|
|
#include <torch/csrc/jit/runtime/vararg_functions.h>
|
|
|
|
namespace {
|
|
constexpr auto createBorrowedIValue =
|
|
c10::MaybeOwnedTraits<c10::IValue>::createBorrow;
|
|
} // namespace
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
C10_DEFINE_REGISTRY(SRNativeOperatorRegistry, SROperatorFunctor);
|
|
|
|
bool nativeOpIsRegistered(const c10::Symbol& op_name) {
|
|
const std::string name(op_name.toQualString());
|
|
return SRNativeOperatorRegistry()->Has(name);
|
|
}
|
|
|
|
std::function<void(ProcessedNode*)> getNativeOperation(Node* n) {
|
|
auto op_name = n->kind().toQualString();
|
|
if (SRNativeOperatorRegistry()->Has(op_name)) {
|
|
return SRNativeOperatorRegistry()->Create(op_name)->Generate(n);
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
prim::TupleConstruct,
|
|
prim_TupleConstruct,
|
|
[](Node* n) -> SROperator {
|
|
return [](ProcessedNode* p_node) {
|
|
// prepare inputs
|
|
std::vector<IValue> stack;
|
|
const size_t size = p_node->num_inputs();
|
|
stack.reserve(size);
|
|
for (const auto i : c10::irange(size)) {
|
|
stack.emplace_back(p_node->Input(i));
|
|
}
|
|
// run op
|
|
auto* node = p_node->node();
|
|
const auto& type = node->output()->type()->expect<TupleType>();
|
|
if (type->name().has_value()) {
|
|
namedTupleConstruct(stack, type, node->inputs().size());
|
|
} else {
|
|
tupleConstruct(stack, node->inputs().size());
|
|
}
|
|
// put output back
|
|
p_node->Output(0) = std::move(stack[0]);
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
prim::TupleUnpack,
|
|
prim_TupleUnpack,
|
|
[](Node* n) -> SROperator {
|
|
return [](ProcessedNode* p_node) {
|
|
const auto& elems = p_node->Input(0).toTupleRef().elements();
|
|
const size_t num_outputs = p_node->outputs().size();
|
|
TORCH_CHECK(
|
|
num_outputs == elems.size(),
|
|
"Number of outputs must match number of tuple elements.")
|
|
for (size_t i = 0; i < num_outputs; ++i) {
|
|
p_node->Output(i) = elems[i];
|
|
}
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
prim::DictConstruct,
|
|
prim_DictConstruct,
|
|
[](Node* n) -> SROperator {
|
|
return [](ProcessedNode* p_node) {
|
|
// prepare inputs
|
|
std::vector<IValue> stack;
|
|
const size_t size = p_node->num_inputs();
|
|
stack.reserve(size);
|
|
for (const auto i : c10::irange(size)) {
|
|
stack.emplace_back(p_node->Input(i));
|
|
}
|
|
// run op
|
|
auto* node = p_node->node();
|
|
dictConstruct(
|
|
stack,
|
|
node->output()->type()->expectRef<DictType>(),
|
|
node->inputs().size());
|
|
// put output back
|
|
p_node->Output(0) = std::move(stack[0]);
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
static_runtime::dict_unpack,
|
|
static_runtime_dict_unpack,
|
|
[](Node*) -> SROperator {
|
|
return [](ProcessedNode* p_node) {
|
|
DCHECK(p_node->num_inputs() - 1 == p_node->outputs().size());
|
|
auto dict = p_node->Input(0).toGenericDict();
|
|
for (size_t i = 1; i < p_node->num_inputs(); ++i) {
|
|
const auto& key = p_node->Input(i);
|
|
auto value = dict.find(key);
|
|
TORCH_CHECK(value != dict.end(), "Key not in dict: ", key);
|
|
p_node->Output(i - 1) = createBorrowedIValue(value->value());
|
|
}
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
aten::__getitem__,
|
|
aten_getitem,
|
|
[](Node* n) -> SROperator {
|
|
if (n->inputs().size() != 2) {
|
|
return nullptr;
|
|
}
|
|
|
|
if (n->input(0)->type()->castRaw<DictType>()) {
|
|
return [](ProcessedNode* p_node) {
|
|
auto dict = p_node->Input(0).toGenericDict();
|
|
const auto& key = p_node->Input(1);
|
|
auto value = dict.find(key);
|
|
TORCH_CHECK(value != dict.end(), "Key not in dict: ", key);
|
|
p_node->Output(0) = value->value();
|
|
};
|
|
} else if (n->input(0)->type()->castRaw<ListType>()) {
|
|
return [](ProcessedNode* p_node) {
|
|
const auto& list = p_node->Input(0).toList();
|
|
auto idx = p_node->Input(1).toInt();
|
|
p_node->Output(0) = getItem(list, idx);
|
|
};
|
|
}
|
|
|
|
// TODO(T98581096): make __getitem__ work for other container types
|
|
return nullptr;
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
prim::ListConstruct,
|
|
prim_ListConstruct,
|
|
[](Node* n) -> SROperator {
|
|
return [](ProcessedNode* p_node) {
|
|
// prepare inputs
|
|
std::vector<IValue> stack;
|
|
const size_t size = p_node->num_inputs();
|
|
stack.reserve(size);
|
|
for (const auto i : c10::irange(size)) {
|
|
stack.emplace_back(p_node->Input(i));
|
|
}
|
|
// run op
|
|
listConstruct(
|
|
stack,
|
|
p_node->node()->output()->type()->expectRef<ListType>(),
|
|
p_node->num_inputs());
|
|
// put output back
|
|
p_node->Output(0) = std::move(stack[0]);
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
prim::ListUnpack,
|
|
prim_ListUnpack,
|
|
[](Node* n) -> SROperator {
|
|
return [](ProcessedNode* p_node) {
|
|
// prepare inputs
|
|
std::vector<IValue> stack;
|
|
const size_t size = p_node->num_inputs();
|
|
stack.reserve(size);
|
|
for (const auto i : c10::irange(size)) {
|
|
stack.emplace_back(p_node->Input(i));
|
|
}
|
|
// run op
|
|
size_t num_outputs = p_node->outputs().size();
|
|
listUnpack(stack, num_outputs);
|
|
// put output back
|
|
DCHECK_EQ(stack.size(), num_outputs);
|
|
for (const auto i : c10::irange(num_outputs)) {
|
|
p_node->Output(i) = std::move(stack[i]);
|
|
}
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
aten::append,
|
|
aten_append,
|
|
[](Node* n) -> SROperator {
|
|
return [](ProcessedNode* p_node) {
|
|
auto list = p_node->Input(0).toList();
|
|
list.push_back(p_node->Input(1));
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
prim::GetAttr,
|
|
prim_GetAttr,
|
|
[](Node* n) -> SROperator {
|
|
return [](ProcessedNode* p_node) {
|
|
auto module = p_node->Input(0).toObject();
|
|
Node* node = p_node->node();
|
|
const auto& type = node->input()->type()->expectRef<ClassType>();
|
|
const auto& field = node->s(attr::name);
|
|
const auto slot = type.getAttributeSlot(field);
|
|
p_node->Output(0) = module->getSlot(slot);
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
prim::SetAttr,
|
|
prim_SetAttr,
|
|
[](Node* n) -> SROperator {
|
|
return [](ProcessedNode* p_node) {
|
|
auto module = p_node->Input(0).toObject();
|
|
Node* node = p_node->node();
|
|
const auto& type = node->inputs()[0]->type()->expectRef<ClassType>();
|
|
const auto& field = node->s(attr::name);
|
|
const auto slot = type.getAttributeSlot(field);
|
|
module->setSlot(slot, p_node->Input(1));
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
aten::transpose,
|
|
aten_transpose,
|
|
[](Node* n) -> SROperator {
|
|
if (!n->matches(torch::schema(
|
|
"aten::transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)"))) {
|
|
LogAndDumpSchema(n);
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* p_node) {
|
|
const auto& in0_t = p_node->Input(0).toTensor();
|
|
const auto in1_i = p_node->Input(1).toInt();
|
|
const auto in2_i = p_node->Input(2).toInt();
|
|
p_node->Output(0) = at::native::transpose(in0_t, in1_i, in2_i);
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::flatten, aten_flatten, [](Node* n) -> SROperator {
|
|
if (!n->matches(torch::schema(
|
|
"aten::flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a)"))) {
|
|
LogAndDumpSchema(n);
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* p_node) {
|
|
const auto& in0_t = p_node->Input(0).toTensor();
|
|
const auto in1_i = p_node->Input(1).toInt();
|
|
const auto in2_i = p_node->Input(2).toInt();
|
|
p_node->Output(0) = at::native::flatten(in0_t, in1_i, in2_i);
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
aten::permute,
|
|
aten_permute,
|
|
[](Node* n) -> SROperator {
|
|
if (!n->matches(torch::schema(
|
|
"aten::permute(Tensor(a) self, int[] dims) -> Tensor(a)"))) {
|
|
LogAndDumpSchema(n);
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* p_node) {
|
|
const auto& in0_t = p_node->Input(0).toTensor();
|
|
const auto in1_iv = p_node->Input(1).toIntVector();
|
|
p_node->Output(0) = at::native::permute(in0_t, in1_iv);
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
aten::reshape,
|
|
aten_reshape,
|
|
[](Node* n) -> SROperator {
|
|
if (!n->matches(torch::schema(
|
|
"aten::reshape(Tensor(a) self, int[] shape) -> Tensor(a)"))) {
|
|
LogAndDumpSchema(n);
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* p_node) {
|
|
const auto& in0_t = p_node->Input(0).toTensor();
|
|
const auto in1_iv = p_node->Input(1).toIntVector();
|
|
p_node->Output(0) = at::native::reshape(in0_t, in1_iv);
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::slice, aten_slice, [](Node* n) -> SROperator {
|
|
if (!n->matches(torch::schema(
|
|
"aten::slice.Tensor(Tensor(a) self, int dim=0, int? start=0, int? end=9223372036854775807, int step=1) -> Tensor(a)"))) {
|
|
LogAndDumpSchema(n);
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* p_node) {
|
|
const auto& in0_t = p_node->Input(0).toTensor();
|
|
const auto in1_i = p_node->Input(1).toInt();
|
|
const auto in2_i = p_node->Input(2).toInt();
|
|
const auto in3_i = p_node->Input(3).toInt();
|
|
const auto in4_i = p_node->Input(4).toInt();
|
|
p_node->Output(0) = at::native::slice(in0_t, in1_i, in2_i, in3_i, in4_i);
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::narrow, aten_narrow, [](Node* n) -> SROperator {
|
|
if (!n->matches(torch::schema(
|
|
"aten::narrow(Tensor(a) self, int dim, int start, int length) -> Tensor(a)")) &&
|
|
!n->matches(torch::schema(
|
|
"aten::narrow.Tensor(Tensor(a) self, int dim, Tensor start, int length) -> Tensor(a)"))) {
|
|
LogAndDumpSchema(n);
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* p_node) {
|
|
const auto& self = p_node->Input(0).toTensor(); // self
|
|
const auto dim = p_node->Input(1).toInt(); // dim
|
|
int64_t start = 0;
|
|
if (p_node->Input(2).isScalar()) {
|
|
start = p_node->Input(2).toInt();
|
|
} else {
|
|
auto& t = p_node->Input(2).toTensor();
|
|
start = t.item<int64_t>();
|
|
}
|
|
const auto length = p_node->Input(3).toInt(); // length
|
|
TORCH_CHECK(
|
|
self.dim() > 0, "narrow() cannot be applied to a 0-dim tensor.");
|
|
auto cur_size = self.sizes()[dim];
|
|
if (start != cur_size && start < 0) { // start being the end is valid, but
|
|
// not a valid dim specification.
|
|
start = at::maybe_wrap_dim(start, cur_size);
|
|
}
|
|
TORCH_CHECK(
|
|
length >= 0 && start <= cur_size - length,
|
|
"start (",
|
|
start,
|
|
") + length (",
|
|
length,
|
|
") exceeds dimension size (",
|
|
cur_size,
|
|
").");
|
|
p_node->Output(0) = at::native::slice(self, dim, start, start + length, 1);
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::to, aten_to, [](Node* n) -> SROperator {
|
|
if (n->matches(torch::schema(
|
|
"aten::to.other(Tensor(a) self, Tensor other, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)"))) {
|
|
return [](ProcessedNode* p_node) {
|
|
const auto& in0_t = p_node->Input(0).toTensor();
|
|
const auto& in1_t = p_node->Input(1).toTensor();
|
|
const auto in2_i = p_node->Input(2).toBool();
|
|
const auto in3_i = p_node->Input(3).toBool();
|
|
const auto in4_o = p_node->Input(4).toOptional<at::MemoryFormat>();
|
|
p_node->Output(0) = at::native::to(in0_t, in1_t, in2_i, in3_i, in4_o);
|
|
};
|
|
}
|
|
if (n->matches(torch::schema(
|
|
"aten::to.dtype(Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)"))) {
|
|
return [](ProcessedNode* p_node) {
|
|
const auto& in0_t = p_node->Input(0).toTensor();
|
|
const auto in1_i = p_node->Input(1).toScalarType();
|
|
const auto in2_i = p_node->Input(2).toBool();
|
|
const auto in3_i = p_node->Input(3).toBool();
|
|
const auto in4_o = p_node->Input(4).toOptional<at::MemoryFormat>();
|
|
p_node->Output(0) = at::native::to(in0_t, in1_i, in2_i, in3_i, in4_o);
|
|
};
|
|
}
|
|
if (n->matches(torch::schema(
|
|
"aten::to.prim_dtype(Tensor(a) self, int? dtype, bool non_blocking=False, bool copy=False) -> Tensor(a|b)"))) {
|
|
return [](ProcessedNode* p_node) {
|
|
const auto& in0_t = p_node->Input(0).toTensor();
|
|
const auto in1_i = p_node->Input(1).toOptional<at::ScalarType>();
|
|
const auto in2_i = p_node->Input(2).toBool();
|
|
const auto in3_i = p_node->Input(3).toBool();
|
|
// To mimick the behavior of the JIT interpreter, if both dtype
|
|
// and copy are not set, we return self. Otherwise, we assume
|
|
// that dtype is set.
|
|
if (!in1_i && !in3_i) {
|
|
p_node->Output(0) = in0_t;
|
|
} else {
|
|
TORCH_CHECK(
|
|
in1_i,
|
|
"dytpe cannot be None when copy is True for aten::to.prim_dtype");
|
|
p_node->Output(0) = at::native::to(in0_t, *in1_i, in2_i, in3_i);
|
|
}
|
|
};
|
|
}
|
|
LogAndDumpSchema(n);
|
|
return nullptr;
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
aten::detach,
|
|
aten_detach,
|
|
[](Node* n) -> SROperator {
|
|
if (!n->matches(
|
|
torch::schema("aten::detach(Tensor(a) self) -> Tensor(a)"))) {
|
|
LogAndDumpSchema(n);
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* p_node) {
|
|
const auto& in0_t = p_node->Input(0).toTensor();
|
|
p_node->Output(0) = at::native::alias(in0_t);
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
aten::expand_as,
|
|
aten_expand_as,
|
|
[](Node* n) -> SROperator {
|
|
if (!n->matches(torch::schema(
|
|
"aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)"))) {
|
|
LogAndDumpSchema(n);
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* p_node) {
|
|
const auto& self = p_node->Input(0).toTensor();
|
|
const auto& other = p_node->Input(1).toTensor();
|
|
p_node->Output(0) = self.expand(other.sizes());
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
prim::isinstance,
|
|
prim_isinstance,
|
|
[](Node* n) -> SROperator {
|
|
if (!n->matches(
|
|
torch::schema("prim::isinstance(Any to_check) -> bool"))) {
|
|
LogAndDumpSchema(n);
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* p_node) {
|
|
auto input_type = p_node->Input(0).type();
|
|
|
|
auto* node = p_node->node();
|
|
const std::vector<TypePtr>& candidates = node->tys(attr::types);
|
|
for (const auto& candidate_type : candidates) {
|
|
if (input_type->isSubtypeOf(*candidate_type)) {
|
|
p_node->Output(0) = true;
|
|
return;
|
|
}
|
|
}
|
|
|
|
p_node->Output(0) = false;
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
prim::TypeCheck,
|
|
prim_TypeCheck,
|
|
[](Node* n) -> SROperator {
|
|
return [](ProcessedNode* p_node) {
|
|
auto* node = p_node->node();
|
|
const size_t num_inputs = node->inputs().size();
|
|
TORCH_INTERNAL_ASSERT(
|
|
num_inputs && num_inputs + 1 == node->outputs().size());
|
|
|
|
const auto& expected_types = node->tys(attr::types);
|
|
|
|
for (size_t i = 0; i < num_inputs; i++) {
|
|
p_node->Output(i) = p_node->Input(i);
|
|
}
|
|
|
|
for (size_t i = 0; i < num_inputs; i++) {
|
|
auto& input_tensor = p_node->Input(i).toTensor();
|
|
auto* expected_type = expected_types[i]->castRaw<TensorType>();
|
|
if (input_tensor.defined() &&
|
|
!expected_type->matchTensor(input_tensor)) {
|
|
p_node->Output(num_inputs) = false;
|
|
return;
|
|
}
|
|
}
|
|
|
|
p_node->Output(num_inputs) = true;
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
static_runtime::VarTupleUnpack,
|
|
static_runtime_VarTupleUnpack,
|
|
[](Node*) -> SROperator {
|
|
return [](ProcessedNode* pnode) {
|
|
size_t output_idx = 0;
|
|
for (const auto idx : c10::irange(pnode->num_inputs())) {
|
|
const auto& tuple = pnode->Input(idx);
|
|
for (auto& elem : tuple.toTupleRef().elements()) {
|
|
pnode->Output(output_idx) = createBorrowedIValue(elem);
|
|
++output_idx;
|
|
}
|
|
}
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
aten::view,
|
|
aten_view,
|
|
[](Node* n) -> SROperator {
|
|
if (!n->matches(torch::schema(
|
|
"aten::view(Tensor(a) self, int[] size) -> (Tensor(a))"))) {
|
|
LogAndDumpSchema(n);
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* p_node) {
|
|
const auto& input = p_node->Input(0).toTensor();
|
|
const auto size = p_node->Input(1).toIntList();
|
|
p_node->Output(0) = at::native::view(input, size.vec());
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
aten::size,
|
|
aten_size,
|
|
[](Node* n) -> SROperator {
|
|
if (!n->matches(
|
|
torch::schema("aten::size(Tensor self, int dim) -> int"))) {
|
|
LogAndDumpSchema(n);
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* p_node) {
|
|
const auto& input = p_node->Input(0).toTensor();
|
|
auto dim = p_node->Input(1).toInt();
|
|
const auto ndim = input.dim();
|
|
|
|
if (dim < 0 || dim >= ndim) {
|
|
dim = c10::maybe_wrap_dim(dim, ndim);
|
|
}
|
|
p_node->Output(0) = input.sizes()[dim];
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
aten::squeeze,
|
|
aten_squeeze,
|
|
[](Node* n) -> SROperator {
|
|
if (!n->matches(torch::schema(
|
|
"aten::squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)"))) {
|
|
LogAndDumpSchema(n);
|
|
return nullptr;
|
|
}
|
|
|
|
return [](ProcessedNode* p_node) {
|
|
const auto& self = p_node->Input(0).toTensor();
|
|
const auto dim = p_node->Input(1).toInt();
|
|
p_node->Output(0) = at::native::squeeze(self, dim);
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::split, aten_split, [](Node* n) -> SROperator {
|
|
if (!n->matches(torch::schema(
|
|
"aten::split(Tensor(a -> *) self, int split_size, int dim=0) -> Tensor(a)[]"))) {
|
|
LogAndDumpSchema(n);
|
|
return nullptr;
|
|
}
|
|
|
|
return [](ProcessedNode* p_node) {
|
|
const auto& self = p_node->Input(0).toTensor();
|
|
const auto split_size = p_node->Input(1).toInt();
|
|
const auto dim = p_node->Input(2).toInt();
|
|
p_node->Output(0) = at::native::split(self, split_size, dim);
|
|
};
|
|
});
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|