mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[PyTorch] Remove device transfers from JNI (#105583)
Summary: If a model was exported for Vulkan backend without (automatic or manual) device transfers, then the export is incorrect, and the JNI need not correct that. (If this assumption is incorrect, please give feedback.) Undo the changes from - D23763771: automatic device transfers in JNI - D39519168: `"requires_backend_transfers"` logic in JNI Test Plan: Verify CUNET+ hybrid model from D47488843 works. Reviewed By: SS-JIA Differential Revision: D47527244 Pull Request resolved: https://github.com/pytorch/pytorch/pull/105583 Approved by: https://github.com/SS-JIA
This commit is contained in:
committed by
PyTorch MergeBot
parent
0b524343be
commit
93e6fc54fa
@ -348,7 +348,7 @@ facebook::jni::local_ref<JIValue> JIValue::newJIValueFromAtIValue(
|
||||
const auto& tensor = ivalue.toTensor();
|
||||
return jMethodTensor(
|
||||
JIValue::javaClassStatic(),
|
||||
TensorHybrid::newJTensorFromAtTensor(tensor.cpu()));
|
||||
TensorHybrid::newJTensorFromAtTensor(tensor));
|
||||
} else if (ivalue.isBool()) {
|
||||
static auto jMethodBool =
|
||||
JIValue::javaClassStatic()
|
||||
|
@ -195,18 +195,9 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
|
||||
std::vector<at::IValue> inputs{};
|
||||
size_t n = jinputs->size();
|
||||
inputs.reserve(n);
|
||||
const bool requires_backend_transfers =
|
||||
module_.attr("requires_backend_transfers", at::IValue(true)).toBool();
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
for (const auto i : c10::irange(n)) {
|
||||
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
|
||||
if (at::kVulkan == deviceType_ && requires_backend_transfers) {
|
||||
inputs.push_back(
|
||||
atIValue.isTensor() ? at::IValue{atIValue.toTensor().vulkan()}
|
||||
: std::move(atIValue));
|
||||
} else {
|
||||
TORCH_CHECK(at::kCPU == deviceType_ || !requires_backend_transfers);
|
||||
inputs.push_back(std::move(atIValue));
|
||||
}
|
||||
inputs.push_back(std::move(atIValue));
|
||||
}
|
||||
auto output = [&]() {
|
||||
JITCallGuard guard;
|
||||
@ -225,18 +216,9 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
|
||||
std::vector<at::IValue> inputs{};
|
||||
size_t n = jinputs->size();
|
||||
inputs.reserve(n);
|
||||
const bool requires_backend_transfers =
|
||||
module_.attr("requires_backend_transfers", at::IValue(true)).toBool();
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
for (const auto i : c10::irange(n)) {
|
||||
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
|
||||
if (at::kVulkan == deviceType_ && requires_backend_transfers) {
|
||||
inputs.push_back(
|
||||
atIValue.isTensor() ? at::IValue{atIValue.toTensor().vulkan()}
|
||||
: std::move(atIValue));
|
||||
} else {
|
||||
TORCH_CHECK(at::kCPU == deviceType_ || !requires_backend_transfers);
|
||||
inputs.push_back(std::move(atIValue));
|
||||
}
|
||||
inputs.push_back(std::move(atIValue));
|
||||
}
|
||||
if (auto method = module_.find_method(methodName)) {
|
||||
auto output = [&]() {
|
||||
|
@ -158,18 +158,9 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
|
||||
std::vector<at::IValue> inputs{};
|
||||
size_t n = jinputs->size();
|
||||
inputs.reserve(n);
|
||||
const bool requires_backend_transfers =
|
||||
module_.attr("requires_backend_transfers", at::IValue(true)).toBool();
|
||||
for (const auto i : c10::irange(n)) {
|
||||
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
|
||||
if (at::kVulkan == deviceType_ && requires_backend_transfers) {
|
||||
inputs.push_back(
|
||||
atIValue.isTensor() ? at::IValue{atIValue.toTensor().vulkan()}
|
||||
: std::move(atIValue));
|
||||
} else {
|
||||
TORCH_CHECK(at::kCPU == deviceType_ || !requires_backend_transfers);
|
||||
inputs.push_back(std::move(atIValue));
|
||||
}
|
||||
inputs.push_back(std::move(atIValue));
|
||||
}
|
||||
|
||||
auto output = [&]() {
|
||||
@ -189,18 +180,9 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
|
||||
std::vector<at::IValue> inputs{};
|
||||
size_t n = jinputs->size();
|
||||
inputs.reserve(n);
|
||||
const bool requires_backend_transfers =
|
||||
module_.attr("requires_backend_transfers", at::IValue(true)).toBool();
|
||||
for (const auto i : c10::irange(n)) {
|
||||
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
|
||||
if (at::kVulkan == deviceType_ && requires_backend_transfers) {
|
||||
inputs.push_back(
|
||||
atIValue.isTensor() ? at::IValue{atIValue.toTensor().vulkan()}
|
||||
: std::move(atIValue));
|
||||
} else {
|
||||
TORCH_CHECK(at::kCPU == deviceType_ || !requires_backend_transfers);
|
||||
inputs.push_back(std::move(atIValue));
|
||||
}
|
||||
inputs.push_back(std::move(atIValue));
|
||||
}
|
||||
if (auto method = module_.find_method(methodName)) {
|
||||
auto output = [&]() {
|
||||
|
Reference in New Issue
Block a user