mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[PyTorch] Remove device transfers from JNI (#105583)
Summary: If a model was exported for Vulkan backend without (automatic or manual) device transfers, then the export is incorrect, and the JNI need not correct that. (If this assumption is incorrect, please give feedback.) Undo the changes from - D23763771: automatic device transfers in JNI - D39519168: `"requires_backend_transfers"` logic in JNI Test Plan: Verify CUNET+ hybrid model from D47488843 works. Reviewed By: SS-JIA Differential Revision: D47527244 Pull Request resolved: https://github.com/pytorch/pytorch/pull/105583 Approved by: https://github.com/SS-JIA
This commit is contained in:
committed by
PyTorch MergeBot
parent
0b524343be
commit
93e6fc54fa
@ -348,7 +348,7 @@ facebook::jni::local_ref<JIValue> JIValue::newJIValueFromAtIValue(
|
|||||||
const auto& tensor = ivalue.toTensor();
|
const auto& tensor = ivalue.toTensor();
|
||||||
return jMethodTensor(
|
return jMethodTensor(
|
||||||
JIValue::javaClassStatic(),
|
JIValue::javaClassStatic(),
|
||||||
TensorHybrid::newJTensorFromAtTensor(tensor.cpu()));
|
TensorHybrid::newJTensorFromAtTensor(tensor));
|
||||||
} else if (ivalue.isBool()) {
|
} else if (ivalue.isBool()) {
|
||||||
static auto jMethodBool =
|
static auto jMethodBool =
|
||||||
JIValue::javaClassStatic()
|
JIValue::javaClassStatic()
|
||||||
|
@ -195,18 +195,9 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
|
|||||||
std::vector<at::IValue> inputs{};
|
std::vector<at::IValue> inputs{};
|
||||||
size_t n = jinputs->size();
|
size_t n = jinputs->size();
|
||||||
inputs.reserve(n);
|
inputs.reserve(n);
|
||||||
const bool requires_backend_transfers =
|
for (const auto i : c10::irange(n)) {
|
||||||
module_.attr("requires_backend_transfers", at::IValue(true)).toBool();
|
|
||||||
for (size_t i = 0; i < n; i++) {
|
|
||||||
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
|
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
|
||||||
if (at::kVulkan == deviceType_ && requires_backend_transfers) {
|
inputs.push_back(std::move(atIValue));
|
||||||
inputs.push_back(
|
|
||||||
atIValue.isTensor() ? at::IValue{atIValue.toTensor().vulkan()}
|
|
||||||
: std::move(atIValue));
|
|
||||||
} else {
|
|
||||||
TORCH_CHECK(at::kCPU == deviceType_ || !requires_backend_transfers);
|
|
||||||
inputs.push_back(std::move(atIValue));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
auto output = [&]() {
|
auto output = [&]() {
|
||||||
JITCallGuard guard;
|
JITCallGuard guard;
|
||||||
@ -225,18 +216,9 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
|
|||||||
std::vector<at::IValue> inputs{};
|
std::vector<at::IValue> inputs{};
|
||||||
size_t n = jinputs->size();
|
size_t n = jinputs->size();
|
||||||
inputs.reserve(n);
|
inputs.reserve(n);
|
||||||
const bool requires_backend_transfers =
|
for (const auto i : c10::irange(n)) {
|
||||||
module_.attr("requires_backend_transfers", at::IValue(true)).toBool();
|
|
||||||
for (size_t i = 0; i < n; i++) {
|
|
||||||
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
|
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
|
||||||
if (at::kVulkan == deviceType_ && requires_backend_transfers) {
|
inputs.push_back(std::move(atIValue));
|
||||||
inputs.push_back(
|
|
||||||
atIValue.isTensor() ? at::IValue{atIValue.toTensor().vulkan()}
|
|
||||||
: std::move(atIValue));
|
|
||||||
} else {
|
|
||||||
TORCH_CHECK(at::kCPU == deviceType_ || !requires_backend_transfers);
|
|
||||||
inputs.push_back(std::move(atIValue));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if (auto method = module_.find_method(methodName)) {
|
if (auto method = module_.find_method(methodName)) {
|
||||||
auto output = [&]() {
|
auto output = [&]() {
|
||||||
|
@ -158,18 +158,9 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
|
|||||||
std::vector<at::IValue> inputs{};
|
std::vector<at::IValue> inputs{};
|
||||||
size_t n = jinputs->size();
|
size_t n = jinputs->size();
|
||||||
inputs.reserve(n);
|
inputs.reserve(n);
|
||||||
const bool requires_backend_transfers =
|
|
||||||
module_.attr("requires_backend_transfers", at::IValue(true)).toBool();
|
|
||||||
for (const auto i : c10::irange(n)) {
|
for (const auto i : c10::irange(n)) {
|
||||||
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
|
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
|
||||||
if (at::kVulkan == deviceType_ && requires_backend_transfers) {
|
inputs.push_back(std::move(atIValue));
|
||||||
inputs.push_back(
|
|
||||||
atIValue.isTensor() ? at::IValue{atIValue.toTensor().vulkan()}
|
|
||||||
: std::move(atIValue));
|
|
||||||
} else {
|
|
||||||
TORCH_CHECK(at::kCPU == deviceType_ || !requires_backend_transfers);
|
|
||||||
inputs.push_back(std::move(atIValue));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auto output = [&]() {
|
auto output = [&]() {
|
||||||
@ -189,18 +180,9 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
|
|||||||
std::vector<at::IValue> inputs{};
|
std::vector<at::IValue> inputs{};
|
||||||
size_t n = jinputs->size();
|
size_t n = jinputs->size();
|
||||||
inputs.reserve(n);
|
inputs.reserve(n);
|
||||||
const bool requires_backend_transfers =
|
|
||||||
module_.attr("requires_backend_transfers", at::IValue(true)).toBool();
|
|
||||||
for (const auto i : c10::irange(n)) {
|
for (const auto i : c10::irange(n)) {
|
||||||
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
|
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
|
||||||
if (at::kVulkan == deviceType_ && requires_backend_transfers) {
|
inputs.push_back(std::move(atIValue));
|
||||||
inputs.push_back(
|
|
||||||
atIValue.isTensor() ? at::IValue{atIValue.toTensor().vulkan()}
|
|
||||||
: std::move(atIValue));
|
|
||||||
} else {
|
|
||||||
TORCH_CHECK(at::kCPU == deviceType_ || !requires_backend_transfers);
|
|
||||||
inputs.push_back(std::move(atIValue));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if (auto method = module_.find_method(methodName)) {
|
if (auto method = module_.find_method(methodName)) {
|
||||||
auto output = [&]() {
|
auto output = [&]() {
|
||||||
|
Reference in New Issue
Block a user