Files
pytorch/torch/csrc/jit/mobile/promoted_prim_ops.cpp
Pavithran Ramachandran 8d164a36fb Use at::native::is_nonzero in promoted ops to improve portability (#67097)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67097

all delegated models have `is_nonzero` ops by default, by making the op native and consumable without dispatch eases the portability of such models
ghstack-source-id: 141375082

Test Plan:
`buck test caffe2/test/cpp/jit:jit -- BackendTest.TestComposite`

```
~/fbsource/fbcode] cd ~/fbsource/fbcode/ && buck test caffe2/test:jit -- test_trace_arange
Parsing buck files: finished in 0.5 sec
Building: finished in 9.4 sec (100%) 16035/16035 jobs, 0/16035 updated
  Total time: 10.0 sec
More details at https://www.internalfb.com/intern/buck/build/1e55eea5-2adb-41d1-96ae-cbf4b446d6c6
BUILD SUCCEEDED
Tpx test run coordinator for Facebook. See https://fburl.com/tpx for details.
Running with tpx session id: 46eedba2-ae17-4e88-b205-93bd1332665d
Trace available for this run at /tmp/tpx-20211015-113905.235421/trace.log
Started reporting to test run: https://www.internalfb.com/intern/testinfra/testrun/1970324912349177
    ✓ ListingSuccess: caffe2/test:jit - main (12.372)
    ✓ Pass: caffe2/test:jit - test_trace_arange (jit.test_tracer.TestTracer) (13.748)
    ✓ Pass: caffe2/test:jit - test_trace_arange_with_grad (jit.test_tracer.TestTracer) (13.892)
Summary
  Pass: 2
  ListingSuccess: 1
If you need help understanding your runs, please follow the wiki: https://fburl.com/posting_in_tpx_users
Finished test run: https://www.internalfb.com/intern/testinfra/testrun/1970324912349177
```

Reviewed By: iseeyuan

Differential Revision: D31656842

fbshipit-source-id: c0e6c798478a2783c0e17e6e9100ba5ce044da78
2021-10-25 10:18:31 -07:00

202 lines
5.9 KiB
C++

#include <torch/csrc/jit/mobile/promoted_prim_ops.h>
namespace torch {
namespace jit {
void tupleIndex(Stack& stack) {
int64_t index = pop(stack).toInt();
auto tuple = pop(stack).toTuple();
auto norm_index = normalizeIndex(index, tuple->elements().size());
if (norm_index < 0 ||
norm_index > static_cast<int64_t>(tuple->elements().size())) {
throw std::out_of_range("Tuple list index out of range");
}
stack.emplace_back(tuple->elements()[norm_index]);
}
void raiseException(Stack& stack) {
throw JITException(pop(stack).toStringRef());
}
void is(Stack& stack) {
IValue self, obj;
pop(stack, self, obj);
push(stack, self.is(obj));
}
void unInitialized(Stack& stack) {
push(stack, IValue::uninitialized());
}
void isNot(Stack& stack) {
IValue self, obj;
pop(stack, self, obj);
push(stack, !self.is(obj));
}
void aten_format(Stack& stack) {
size_t num_inputs = pop(stack).toInt();
format(stack, num_inputs);
}
void size(Stack& stack) {
auto t = std::move(pop(stack)).toTensor();
pack(stack, t.sizes().vec());
}
void device(Stack& stack) {
push(stack, pop(stack).toTensor().device());
}
void dtype(Stack& stack) {
at::Tensor a;
pop(stack, a);
push(stack, static_cast<int64_t>(a.scalar_type()));
}
void toPrimDType(Stack& stack) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
bool non_blocking;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
bool copy;
pop(stack, non_blocking, copy);
c10::optional<at::ScalarType> scalarType =
pop(stack).toOptional<at::ScalarType>();
c10::optional<c10::Device> device = c10::nullopt;
at::Tensor self = pop(stack).toTensor();
push(stack, to_dispatch(self, device, scalarType, non_blocking, copy));
}
void dim(Stack& stack) {
at::Tensor arg = pop(stack).toTensor();
push(stack, arg.dim());
}
void _not(Stack& stack) {
push(stack, !pop(stack).toBool());
}
void boolTensor(Stack& stack) {
at::Tensor a;
pop(stack, a);
push(stack, at::native::is_nonzero(a));
}
void toList(Stack& stack) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int elem_ty_val;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int dim_val;
at::Tensor t;
pop(stack, elem_ty_val);
pop(stack, dim_val);
pop(stack, t);
// If the Tensor is not on the CPU, transfer it.
if (!t.device().is_cpu()) {
t = t.cpu();
}
// Rebuild the output type using elem_ty_val and dim_val. Start
// with the element type corresponding to elem_ty_val.
TypePtr out_ty;
if (elem_ty_val == 0) {
out_ty = IntType::get();
} else if (elem_ty_val == 1) {
out_ty = FloatType::get();
} else if (elem_ty_val == 2) {
out_ty = BoolType::get();
} else if (elem_ty_val == 3) {
out_ty = ComplexType::get();
} else {
TORCH_CHECK(
false,
"Unsupported element type for tolist; only int, float, complex and bool are supported");
}
// Check that type of the Tensor matches that of the annotation.
// Make an exception for the case in which the annotated type is
// float/complex and the Tensor data type is also float/complex;
// the elements will be casted to double/c10::complex<double>
// later.
TORCH_CHECK(
(out_ty == FloatType::get() && t.is_floating_point()) ||
(out_ty == ComplexType::get() && t.is_complex()) ||
tryScalarTypeFromJitType(*out_ty) == t.scalar_type(),
"Output annotation element type and runtime tensor element type must match for tolist()");
// Check that the dimension of the Tensor matches that of the
// annotation.
TORCH_CHECK(
dim_val == t.dim(),
"Output annotation list dimension and runtime tensor dimension must match for tolist()");
// Wrap out_ty in a ListType dim times.
for (const auto i : c10::irange(dim_val)) {
(void)i; // Suppress unused variable warning
out_ty = ListType::create(out_ty);
}
int64_t dim = t.dim();
auto sizes = t.sizes();
auto strides = t.strides();
size_t element_size = t.element_size();
char* data = static_cast<char*>(t.data_ptr());
auto result = tensorToListRecursive(
data, 0, dim, out_ty, t.scalar_type(), sizes, strides, element_size);
push(stack, std::move(result));
}
void numToTensorScalar(Stack& stack) {
at::Scalar s;
pop(stack, s);
push(stack, at::scalar_to_tensor(s));
}
void isCuda(Stack& stack) {
at::Tensor a;
pop(stack, a);
push(stack, a.is_cuda());
}
void numToTensorBool(Stack& stack) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
bool b;
pop(stack, b);
push(stack, at::scalar_to_tensor(b));
}
void dictIndex(Stack& stack) {
auto key = pop(stack);
auto dict = pop(stack).toGenericDict();
auto value = dict.find(key);
if (value == dict.end()) {
AT_ERROR("KeyError: ", key);
}
push(stack, value->value());
}
static const C10_UNUSED std::array<mobile::prim_op_fn_register, 15> op_reg = {
mobile::prim_op_fn_register("prim::TupleIndex", tupleIndex),
mobile::prim_op_fn_register("aten::Bool.Tensor", boolTensor),
mobile::prim_op_fn_register("aten::format", aten_format),
mobile::prim_op_fn_register("prim::NumToTensor.Scalar", numToTensorScalar),
mobile::prim_op_fn_register("prim::RaiseException", raiseException),
mobile::prim_op_fn_register("prim::device", device),
mobile::prim_op_fn_register("prim::dtype", dtype),
mobile::prim_op_fn_register("aten::__not__", _not),
mobile::prim_op_fn_register("aten::__is__", is),
mobile::prim_op_fn_register("aten::__isnot__", isNot),
mobile::prim_op_fn_register("aten::dim", dim),
mobile::prim_op_fn_register("prim::Uninitialized", unInitialized),
mobile::prim_op_fn_register("prim::is_cuda", isCuda),
mobile::prim_op_fn_register("aten::__getitem__.Dict_str", dictIndex),
mobile::prim_op_fn_register("prim::unchecked_cast", noop),
// TODO: (@pavithran) size is overloaded with int[] and Tensor
// so this throws error expecting int not Tensor
// mobile::prim_op_fn_register("aten::size", size)
};
} // namespace jit
} // namespace torch