mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-13 11:15:20 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/67097 all delegated models have `is_nonzero` ops by default, by making the op native and consumable without dispatch eases the portability of such models ghstack-source-id: 141375082 Test Plan: `buck test caffe2/test/cpp/jit:jit -- BackendTest.TestComposite` ``` ~/fbsource/fbcode] cd ~/fbsource/fbcode/ && buck test caffe2/test:jit -- test_trace_arange Parsing buck files: finished in 0.5 sec Building: finished in 9.4 sec (100%) 16035/16035 jobs, 0/16035 updated Total time: 10.0 sec More details at https://www.internalfb.com/intern/buck/build/1e55eea5-2adb-41d1-96ae-cbf4b446d6c6 BUILD SUCCEEDED Tpx test run coordinator for Facebook. See https://fburl.com/tpx for details. Running with tpx session id: 46eedba2-ae17-4e88-b205-93bd1332665d Trace available for this run at /tmp/tpx-20211015-113905.235421/trace.log Started reporting to test run: https://www.internalfb.com/intern/testinfra/testrun/1970324912349177 ✓ ListingSuccess: caffe2/test:jit - main (12.372) ✓ Pass: caffe2/test:jit - test_trace_arange (jit.test_tracer.TestTracer) (13.748) ✓ Pass: caffe2/test:jit - test_trace_arange_with_grad (jit.test_tracer.TestTracer) (13.892) Summary Pass: 2 ListingSuccess: 1 If you need help understanding your runs, please follow the wiki: https://fburl.com/posting_in_tpx_users Finished test run: https://www.internalfb.com/intern/testinfra/testrun/1970324912349177 ``` Reviewed By: iseeyuan Differential Revision: D31656842 fbshipit-source-id: c0e6c798478a2783c0e17e6e9100ba5ce044da78
202 lines
5.9 KiB
C++
202 lines
5.9 KiB
C++
#include <torch/csrc/jit/mobile/promoted_prim_ops.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
void tupleIndex(Stack& stack) {
|
|
int64_t index = pop(stack).toInt();
|
|
auto tuple = pop(stack).toTuple();
|
|
auto norm_index = normalizeIndex(index, tuple->elements().size());
|
|
if (norm_index < 0 ||
|
|
norm_index > static_cast<int64_t>(tuple->elements().size())) {
|
|
throw std::out_of_range("Tuple list index out of range");
|
|
}
|
|
stack.emplace_back(tuple->elements()[norm_index]);
|
|
}
|
|
|
|
void raiseException(Stack& stack) {
|
|
throw JITException(pop(stack).toStringRef());
|
|
}
|
|
|
|
void is(Stack& stack) {
|
|
IValue self, obj;
|
|
pop(stack, self, obj);
|
|
push(stack, self.is(obj));
|
|
}
|
|
|
|
void unInitialized(Stack& stack) {
|
|
push(stack, IValue::uninitialized());
|
|
}
|
|
|
|
void isNot(Stack& stack) {
|
|
IValue self, obj;
|
|
pop(stack, self, obj);
|
|
push(stack, !self.is(obj));
|
|
}
|
|
|
|
void aten_format(Stack& stack) {
|
|
size_t num_inputs = pop(stack).toInt();
|
|
format(stack, num_inputs);
|
|
}
|
|
|
|
void size(Stack& stack) {
|
|
auto t = std::move(pop(stack)).toTensor();
|
|
pack(stack, t.sizes().vec());
|
|
}
|
|
|
|
void device(Stack& stack) {
|
|
push(stack, pop(stack).toTensor().device());
|
|
}
|
|
|
|
void dtype(Stack& stack) {
|
|
at::Tensor a;
|
|
pop(stack, a);
|
|
push(stack, static_cast<int64_t>(a.scalar_type()));
|
|
}
|
|
|
|
void toPrimDType(Stack& stack) {
|
|
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
|
bool non_blocking;
|
|
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
|
bool copy;
|
|
pop(stack, non_blocking, copy);
|
|
c10::optional<at::ScalarType> scalarType =
|
|
pop(stack).toOptional<at::ScalarType>();
|
|
c10::optional<c10::Device> device = c10::nullopt;
|
|
at::Tensor self = pop(stack).toTensor();
|
|
push(stack, to_dispatch(self, device, scalarType, non_blocking, copy));
|
|
}
|
|
|
|
void dim(Stack& stack) {
|
|
at::Tensor arg = pop(stack).toTensor();
|
|
push(stack, arg.dim());
|
|
}
|
|
|
|
void _not(Stack& stack) {
|
|
push(stack, !pop(stack).toBool());
|
|
}
|
|
|
|
void boolTensor(Stack& stack) {
|
|
at::Tensor a;
|
|
pop(stack, a);
|
|
push(stack, at::native::is_nonzero(a));
|
|
}
|
|
|
|
void toList(Stack& stack) {
|
|
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
|
int elem_ty_val;
|
|
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
|
int dim_val;
|
|
at::Tensor t;
|
|
|
|
pop(stack, elem_ty_val);
|
|
pop(stack, dim_val);
|
|
pop(stack, t);
|
|
|
|
// If the Tensor is not on the CPU, transfer it.
|
|
if (!t.device().is_cpu()) {
|
|
t = t.cpu();
|
|
}
|
|
|
|
// Rebuild the output type using elem_ty_val and dim_val. Start
|
|
// with the element type corresponding to elem_ty_val.
|
|
TypePtr out_ty;
|
|
if (elem_ty_val == 0) {
|
|
out_ty = IntType::get();
|
|
} else if (elem_ty_val == 1) {
|
|
out_ty = FloatType::get();
|
|
} else if (elem_ty_val == 2) {
|
|
out_ty = BoolType::get();
|
|
} else if (elem_ty_val == 3) {
|
|
out_ty = ComplexType::get();
|
|
} else {
|
|
TORCH_CHECK(
|
|
false,
|
|
"Unsupported element type for tolist; only int, float, complex and bool are supported");
|
|
}
|
|
|
|
// Check that type of the Tensor matches that of the annotation.
|
|
// Make an exception for the case in which the annotated type is
|
|
// float/complex and the Tensor data type is also float/complex;
|
|
// the elements will be casted to double/c10::complex<double>
|
|
// later.
|
|
TORCH_CHECK(
|
|
(out_ty == FloatType::get() && t.is_floating_point()) ||
|
|
(out_ty == ComplexType::get() && t.is_complex()) ||
|
|
tryScalarTypeFromJitType(*out_ty) == t.scalar_type(),
|
|
"Output annotation element type and runtime tensor element type must match for tolist()");
|
|
|
|
// Check that the dimension of the Tensor matches that of the
|
|
// annotation.
|
|
TORCH_CHECK(
|
|
dim_val == t.dim(),
|
|
"Output annotation list dimension and runtime tensor dimension must match for tolist()");
|
|
|
|
// Wrap out_ty in a ListType dim times.
|
|
for (const auto i : c10::irange(dim_val)) {
|
|
(void)i; // Suppress unused variable warning
|
|
out_ty = ListType::create(out_ty);
|
|
}
|
|
|
|
int64_t dim = t.dim();
|
|
auto sizes = t.sizes();
|
|
auto strides = t.strides();
|
|
size_t element_size = t.element_size();
|
|
char* data = static_cast<char*>(t.data_ptr());
|
|
auto result = tensorToListRecursive(
|
|
data, 0, dim, out_ty, t.scalar_type(), sizes, strides, element_size);
|
|
push(stack, std::move(result));
|
|
}
|
|
|
|
void numToTensorScalar(Stack& stack) {
|
|
at::Scalar s;
|
|
pop(stack, s);
|
|
push(stack, at::scalar_to_tensor(s));
|
|
}
|
|
|
|
void isCuda(Stack& stack) {
|
|
at::Tensor a;
|
|
pop(stack, a);
|
|
push(stack, a.is_cuda());
|
|
}
|
|
|
|
void numToTensorBool(Stack& stack) {
|
|
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
|
bool b;
|
|
pop(stack, b);
|
|
push(stack, at::scalar_to_tensor(b));
|
|
}
|
|
|
|
void dictIndex(Stack& stack) {
|
|
auto key = pop(stack);
|
|
auto dict = pop(stack).toGenericDict();
|
|
auto value = dict.find(key);
|
|
if (value == dict.end()) {
|
|
AT_ERROR("KeyError: ", key);
|
|
}
|
|
push(stack, value->value());
|
|
}
|
|
|
|
static const C10_UNUSED std::array<mobile::prim_op_fn_register, 15> op_reg = {
|
|
mobile::prim_op_fn_register("prim::TupleIndex", tupleIndex),
|
|
mobile::prim_op_fn_register("aten::Bool.Tensor", boolTensor),
|
|
mobile::prim_op_fn_register("aten::format", aten_format),
|
|
mobile::prim_op_fn_register("prim::NumToTensor.Scalar", numToTensorScalar),
|
|
mobile::prim_op_fn_register("prim::RaiseException", raiseException),
|
|
mobile::prim_op_fn_register("prim::device", device),
|
|
mobile::prim_op_fn_register("prim::dtype", dtype),
|
|
mobile::prim_op_fn_register("aten::__not__", _not),
|
|
mobile::prim_op_fn_register("aten::__is__", is),
|
|
mobile::prim_op_fn_register("aten::__isnot__", isNot),
|
|
mobile::prim_op_fn_register("aten::dim", dim),
|
|
mobile::prim_op_fn_register("prim::Uninitialized", unInitialized),
|
|
mobile::prim_op_fn_register("prim::is_cuda", isCuda),
|
|
mobile::prim_op_fn_register("aten::__getitem__.Dict_str", dictIndex),
|
|
mobile::prim_op_fn_register("prim::unchecked_cast", noop),
|
|
// TODO: (@pavithran) size is overloaded with int[] and Tensor
|
|
// so this throws error expecting int not Tensor
|
|
// mobile::prim_op_fn_register("aten::size", size)
|
|
};
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|