[PyTorch] Remove device transfers from JNI (#105583)

Summary:
If a model was exported for Vulkan backend without (automatic or manual) device transfers, then the export is incorrect, and the JNI need not correct that.
(If this assumption is incorrect, please give feedback.)

Undo the changes from
- D23763771: automatic device transfers in JNI
- D39519168: `"requires_backend_transfers"` logic in JNI

Test Plan: Verify CUNET+ hybrid model from D47488843 works.

Reviewed By: SS-JIA

Differential Revision: D47527244

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105583
Approved by: https://github.com/SS-JIA
This commit is contained in:
Jorge Pineda
2023-07-20 00:26:21 +00:00
committed by PyTorch MergeBot
parent 0b524343be
commit 93e6fc54fa
3 changed files with 7 additions and 43 deletions

View File

@ -348,7 +348,7 @@ facebook::jni::local_ref<JIValue> JIValue::newJIValueFromAtIValue(
const auto& tensor = ivalue.toTensor();
return jMethodTensor(
JIValue::javaClassStatic(),
TensorHybrid::newJTensorFromAtTensor(tensor.cpu()));
TensorHybrid::newJTensorFromAtTensor(tensor));
} else if (ivalue.isBool()) {
static auto jMethodBool =
JIValue::javaClassStatic()

View File

@ -195,19 +195,10 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
std::vector<at::IValue> inputs{};
size_t n = jinputs->size();
inputs.reserve(n);
const bool requires_backend_transfers =
module_.attr("requires_backend_transfers", at::IValue(true)).toBool();
for (size_t i = 0; i < n; i++) {
for (const auto i : c10::irange(n)) {
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
if (at::kVulkan == deviceType_ && requires_backend_transfers) {
inputs.push_back(
atIValue.isTensor() ? at::IValue{atIValue.toTensor().vulkan()}
: std::move(atIValue));
} else {
TORCH_CHECK(at::kCPU == deviceType_ || !requires_backend_transfers);
inputs.push_back(std::move(atIValue));
}
}
auto output = [&]() {
JITCallGuard guard;
return module_.forward(std::move(inputs));
@ -225,19 +216,10 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
std::vector<at::IValue> inputs{};
size_t n = jinputs->size();
inputs.reserve(n);
const bool requires_backend_transfers =
module_.attr("requires_backend_transfers", at::IValue(true)).toBool();
for (size_t i = 0; i < n; i++) {
for (const auto i : c10::irange(n)) {
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
if (at::kVulkan == deviceType_ && requires_backend_transfers) {
inputs.push_back(
atIValue.isTensor() ? at::IValue{atIValue.toTensor().vulkan()}
: std::move(atIValue));
} else {
TORCH_CHECK(at::kCPU == deviceType_ || !requires_backend_transfers);
inputs.push_back(std::move(atIValue));
}
}
if (auto method = module_.find_method(methodName)) {
auto output = [&]() {
JITCallGuard guard;

View File

@ -158,19 +158,10 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
std::vector<at::IValue> inputs{};
size_t n = jinputs->size();
inputs.reserve(n);
const bool requires_backend_transfers =
module_.attr("requires_backend_transfers", at::IValue(true)).toBool();
for (const auto i : c10::irange(n)) {
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
if (at::kVulkan == deviceType_ && requires_backend_transfers) {
inputs.push_back(
atIValue.isTensor() ? at::IValue{atIValue.toTensor().vulkan()}
: std::move(atIValue));
} else {
TORCH_CHECK(at::kCPU == deviceType_ || !requires_backend_transfers);
inputs.push_back(std::move(atIValue));
}
}
auto output = [&]() {
LiteJITCallGuard guard;
@ -189,19 +180,10 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
std::vector<at::IValue> inputs{};
size_t n = jinputs->size();
inputs.reserve(n);
const bool requires_backend_transfers =
module_.attr("requires_backend_transfers", at::IValue(true)).toBool();
for (const auto i : c10::irange(n)) {
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
if (at::kVulkan == deviceType_ && requires_backend_transfers) {
inputs.push_back(
atIValue.isTensor() ? at::IValue{atIValue.toTensor().vulkan()}
: std::move(atIValue));
} else {
TORCH_CHECK(at::kCPU == deviceType_ || !requires_backend_transfers);
inputs.push_back(std::move(atIValue));
}
}
if (auto method = module_.find_method(methodName)) {
auto output = [&]() {
LiteJITCallGuard guard;