[PyTorch] Remove device transfers from JNI (#105583)

Summary:
If a model was exported for Vulkan backend without (automatic or manual) device transfers, then the export is incorrect, and the JNI need not correct that.
(If this assumption is incorrect, please give feedback.)

Undo the changes from
- D23763771: automatic device transfers in JNI
- D39519168: `"requires_backend_transfers"` logic in JNI

Test Plan: Verify CUNET+ hybrid model from D47488843 works.

Reviewed By: SS-JIA

Differential Revision: D47527244

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105583
Approved by: https://github.com/SS-JIA
This commit is contained in:
Jorge Pineda
2023-07-20 00:26:21 +00:00
committed by PyTorch MergeBot
parent 0b524343be
commit 93e6fc54fa
3 changed files with 7 additions and 43 deletions

View File

@ -348,7 +348,7 @@ facebook::jni::local_ref<JIValue> JIValue::newJIValueFromAtIValue(
const auto& tensor = ivalue.toTensor(); const auto& tensor = ivalue.toTensor();
return jMethodTensor( return jMethodTensor(
JIValue::javaClassStatic(), JIValue::javaClassStatic(),
TensorHybrid::newJTensorFromAtTensor(tensor.cpu())); TensorHybrid::newJTensorFromAtTensor(tensor));
} else if (ivalue.isBool()) { } else if (ivalue.isBool()) {
static auto jMethodBool = static auto jMethodBool =
JIValue::javaClassStatic() JIValue::javaClassStatic()

View File

@ -195,18 +195,9 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
std::vector<at::IValue> inputs{}; std::vector<at::IValue> inputs{};
size_t n = jinputs->size(); size_t n = jinputs->size();
inputs.reserve(n); inputs.reserve(n);
const bool requires_backend_transfers = for (const auto i : c10::irange(n)) {
module_.attr("requires_backend_transfers", at::IValue(true)).toBool();
for (size_t i = 0; i < n; i++) {
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i)); at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
if (at::kVulkan == deviceType_ && requires_backend_transfers) { inputs.push_back(std::move(atIValue));
inputs.push_back(
atIValue.isTensor() ? at::IValue{atIValue.toTensor().vulkan()}
: std::move(atIValue));
} else {
TORCH_CHECK(at::kCPU == deviceType_ || !requires_backend_transfers);
inputs.push_back(std::move(atIValue));
}
} }
auto output = [&]() { auto output = [&]() {
JITCallGuard guard; JITCallGuard guard;
@ -225,18 +216,9 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
std::vector<at::IValue> inputs{}; std::vector<at::IValue> inputs{};
size_t n = jinputs->size(); size_t n = jinputs->size();
inputs.reserve(n); inputs.reserve(n);
const bool requires_backend_transfers = for (const auto i : c10::irange(n)) {
module_.attr("requires_backend_transfers", at::IValue(true)).toBool();
for (size_t i = 0; i < n; i++) {
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i)); at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
if (at::kVulkan == deviceType_ && requires_backend_transfers) { inputs.push_back(std::move(atIValue));
inputs.push_back(
atIValue.isTensor() ? at::IValue{atIValue.toTensor().vulkan()}
: std::move(atIValue));
} else {
TORCH_CHECK(at::kCPU == deviceType_ || !requires_backend_transfers);
inputs.push_back(std::move(atIValue));
}
} }
if (auto method = module_.find_method(methodName)) { if (auto method = module_.find_method(methodName)) {
auto output = [&]() { auto output = [&]() {

View File

@ -158,18 +158,9 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
std::vector<at::IValue> inputs{}; std::vector<at::IValue> inputs{};
size_t n = jinputs->size(); size_t n = jinputs->size();
inputs.reserve(n); inputs.reserve(n);
const bool requires_backend_transfers =
module_.attr("requires_backend_transfers", at::IValue(true)).toBool();
for (const auto i : c10::irange(n)) { for (const auto i : c10::irange(n)) {
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i)); at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
if (at::kVulkan == deviceType_ && requires_backend_transfers) { inputs.push_back(std::move(atIValue));
inputs.push_back(
atIValue.isTensor() ? at::IValue{atIValue.toTensor().vulkan()}
: std::move(atIValue));
} else {
TORCH_CHECK(at::kCPU == deviceType_ || !requires_backend_transfers);
inputs.push_back(std::move(atIValue));
}
} }
auto output = [&]() { auto output = [&]() {
@ -189,18 +180,9 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
std::vector<at::IValue> inputs{}; std::vector<at::IValue> inputs{};
size_t n = jinputs->size(); size_t n = jinputs->size();
inputs.reserve(n); inputs.reserve(n);
const bool requires_backend_transfers =
module_.attr("requires_backend_transfers", at::IValue(true)).toBool();
for (const auto i : c10::irange(n)) { for (const auto i : c10::irange(n)) {
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i)); at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
if (at::kVulkan == deviceType_ && requires_backend_transfers) { inputs.push_back(std::move(atIValue));
inputs.push_back(
atIValue.isTensor() ? at::IValue{atIValue.toTensor().vulkan()}
: std::move(atIValue));
} else {
TORCH_CHECK(at::kCPU == deviceType_ || !requires_backend_transfers);
inputs.push_back(std::move(atIValue));
}
} }
if (auto method = module_.find_method(methodName)) { if (auto method = module_.find_method(methodName)) {
auto output = [&]() { auto output = [&]() {