mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix undefined tensor error in _copy_from_and_resize when fallback to cpu. (#130237)
1) Add skip undefined tensor in cpu fallback when call _copy_from_and_resize; 2) Modify to_cpu function support optional tensor; 3) Add copy back to origin optional tensor when alias_info isWrite is true. @ezyang @bdhirsh Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/130237 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
13283fb4bc
commit
d57af32e63
@ -21,29 +21,39 @@ namespace at::native {
|
||||
|
||||
// convenience helper for converting tensors to cpu
|
||||
|
||||
static std::vector<at::Tensor> to_cpu(const at::TensorList& tensors) {
|
||||
template<typename T, std::enable_if_t<std::is_same_v<T, at::Tensor> || std::is_same_v<T, std::optional<at::Tensor>>, int> = 1>
|
||||
static std::vector<T> to_cpu(const std::vector<T>& tensors) {
|
||||
// We can't just call at::to_cpu() on the entire list of Tensors
|
||||
// Because it will break on undefined tensors. Separate out undefined tensors first.
|
||||
std::vector<at::Tensor> cpu_tensors(tensors.size());
|
||||
const int num = tensors.size();
|
||||
std::vector<T> cpu_tensors(num);
|
||||
std::vector<at::Tensor> valid_tensors;
|
||||
std::vector<bool> to_translate(tensors.size());
|
||||
for (const auto i : c10::irange(tensors.size())) {
|
||||
const at::Tensor& tensor = tensors[i];
|
||||
// Explicitly handling undefined tensors here instead of letting `at::_to_cpu` handle it.
|
||||
// Otherwise, we'd need to require all backends with their own implementation of _to_cpu
|
||||
// to properly handle undefined tensors.
|
||||
if (tensor.defined()) {
|
||||
to_translate[i] = true;
|
||||
valid_tensors.push_back(tensor);
|
||||
std::vector<bool> to_translate(num);
|
||||
for (const auto i : c10::irange(num)) {
|
||||
// Explicitly handling undefined tensors here instead of letting `at::_to_cpu` handle it.
|
||||
// Otherwise, we'd need to require all backends with their own implementation of _to_cpu
|
||||
// to properly handle undefined tensors.
|
||||
if constexpr(std::is_same<T, std::optional<at::Tensor>>::value) {
|
||||
if (tensors[i].has_value() && tensors[i].value().defined()) {
|
||||
to_translate[i] = true;
|
||||
valid_tensors.push_back(tensors[i].value());
|
||||
} else {
|
||||
cpu_tensors[i] = tensor;
|
||||
cpu_tensors[i] = tensors[i];
|
||||
}
|
||||
} else {
|
||||
if (tensors[i].defined()) {
|
||||
to_translate[i] = true;
|
||||
valid_tensors.push_back(tensors[i]);
|
||||
} else {
|
||||
cpu_tensors[i] = tensors[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
auto cpu_valid_tensors = at::_to_cpu(valid_tensors);
|
||||
for (size_t i = 0, defined_pos = 0; i < tensors.size(); ++i) {
|
||||
if (to_translate[i]) {
|
||||
cpu_tensors[i] = std::move(cpu_valid_tensors[defined_pos++]);
|
||||
}
|
||||
for (int i = 0, defined_pos = 0; i < num; ++i) {
|
||||
if (to_translate[i]) {
|
||||
cpu_tensors[i] = std::move(cpu_valid_tensors[defined_pos++]);
|
||||
}
|
||||
}
|
||||
return cpu_tensors;
|
||||
}
|
||||
@ -89,9 +99,13 @@ void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool
|
||||
std::vector<c10::List<at::Tensor>> tensorlist_args;
|
||||
std::vector<int> tensorlist_args_indices;
|
||||
|
||||
std::vector<c10::List<std::optional<at::Tensor>>> optional_tensorlist_args;
|
||||
std::vector<int> optional_tensorlist_args_indices;
|
||||
|
||||
std::optional<c10::Device> tgt_device = std::nullopt;
|
||||
// save converted cpu tensor for TensorList
|
||||
// save converted cpu tensor for TensorList and optional TensorList
|
||||
std::vector<c10::IValue> tensorlist_cpu_args;
|
||||
std::vector<c10::IValue> optional_tensorlist_cpu_args;
|
||||
|
||||
// Step 1: Convert all non-CPU tensor inputs into CPU tensors
|
||||
// and put them on the stack at the correct indices.
|
||||
@ -106,25 +120,15 @@ void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool
|
||||
// We can improve this if we need better perf for XLA's CPU fallbacks.
|
||||
tensorlist_args.push_back(ivalue.toTensorList());
|
||||
tensorlist_args_indices.push_back(idx);
|
||||
auto cpu_ivalue = c10::IValue(c10::List<at::Tensor>(to_cpu(ivalue.toTensorList().vec())));
|
||||
auto cpu_ivalue = c10::IValue(c10::List<at::Tensor>(to_cpu(ivalue.toTensorVector())));
|
||||
tensorlist_cpu_args.push_back(cpu_ivalue);
|
||||
(*stack)[arguments_begin + idx] = std::move(cpu_ivalue);
|
||||
tensorlist_args.push_back(ivalue.toTensorList());
|
||||
} else if (ivalue.isOptionalTensorList()) {
|
||||
auto opt_tensors = ivalue.toOptionalTensorList().vec();
|
||||
std::vector<at::Tensor> need_convert_tensors;
|
||||
std::vector<int> need_convert_tensors_index;
|
||||
for (auto i : c10::irange(opt_tensors.size())) {
|
||||
if (!opt_tensors[i].has_value() || !opt_tensors[i]->defined()) continue;
|
||||
need_convert_tensors.push_back(opt_tensors[i].value());
|
||||
need_convert_tensors_index.push_back(i);
|
||||
}
|
||||
auto cpu_tensors = to_cpu(need_convert_tensors);
|
||||
for (const auto i : c10::irange(need_convert_tensors_index.size())) {
|
||||
auto idx = need_convert_tensors_index[i];
|
||||
opt_tensors[idx] = cpu_tensors[i];
|
||||
}
|
||||
(*stack)[arguments_begin + idx] = c10::IValue(opt_tensors);
|
||||
optional_tensorlist_args.push_back(ivalue.toOptionalTensorList());
|
||||
optional_tensorlist_args_indices.push_back(idx);
|
||||
auto cpu_ivalue = c10::IValue(c10::List<std::optional<at::Tensor>>(to_cpu(ivalue.toOptionalTensorVector())));
|
||||
optional_tensorlist_cpu_args.push_back(cpu_ivalue);
|
||||
(*stack)[arguments_begin + idx] = c10::IValue(cpu_ivalue);
|
||||
} else if (ivalue.isDevice()) {
|
||||
tgt_device = ivalue.toDevice();
|
||||
(*stack)[arguments_begin + idx] = c10::IValue(c10::Device(kCPU));
|
||||
@ -148,6 +152,7 @@ void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool
|
||||
auto tensor_idx = tensor_args_indices[i];
|
||||
const AliasInfo* alias_info = schema_args[tensor_idx].alias_info();
|
||||
if (alias_info != nullptr && alias_info->isWrite()) {
|
||||
if (!tensor_args[i].defined()) continue;
|
||||
at::_copy_from_and_resize(cpu_tensors[i], tensor_args[i]);
|
||||
}
|
||||
}
|
||||
@ -158,13 +163,30 @@ void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool
|
||||
auto tensorlist_idx = tensorlist_args_indices[i];
|
||||
const AliasInfo* alias_info = schema_args[tensorlist_idx].alias_info();
|
||||
if (alias_info != nullptr && alias_info->isWrite()) {
|
||||
const auto& cpu_tensors = tensorlist_cpu_args[i].toTensorList().vec();
|
||||
const auto& cpu_tensors = tensorlist_cpu_args[i].toTensorVector();
|
||||
for (const auto idx : c10::irange(tensorlist_args[i].size())) {
|
||||
if (!cpu_tensors[idx].defined()) continue;
|
||||
at::_copy_from_and_resize(cpu_tensors[idx], tensorlist_args[i][idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// We also need to explicit reapply input mutations to inputs that are lists
|
||||
// of optional tensors
|
||||
for (const auto i : c10::irange(optional_tensorlist_args_indices.size())) {
|
||||
auto tensorlist_idx = optional_tensorlist_args_indices[i];
|
||||
const AliasInfo* alias_info = schema_args[tensorlist_idx].alias_info();
|
||||
if (alias_info != nullptr && alias_info->isWrite()) {
|
||||
const auto& cpu_tensors = optional_tensorlist_cpu_args[i].toOptionalTensorList();
|
||||
for (const auto idx : c10::irange(optional_tensorlist_args[i].size())) {
|
||||
if (cpu_tensors[idx].has_value() && cpu_tensors[idx].value().defined()) {
|
||||
const std::optional<at::Tensor>& optional_tensor = optional_tensorlist_args[i][idx];
|
||||
at::_copy_from_and_resize(cpu_tensors[idx].value(), optional_tensor.value());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Step 4: Convert any CPU output tensors back to the original input device.
|
||||
// For mutable alias'd outputs, we also need to take special care
|
||||
// to move the ORIGINAL input tensor back onto the stack, in place of
|
||||
|
@ -338,8 +338,8 @@ Tensor _to_copy(
|
||||
options);
|
||||
}
|
||||
|
||||
bool pin_out = (non_blocking && self.is_cuda() && options.device().is_cpu() &&
|
||||
(options.layout() == c10::kStrided));
|
||||
bool pin_out = (non_blocking && (self.is_cuda() || self.is_privateuseone())
|
||||
&& options.device().is_cpu() && (options.layout() == c10::kStrided));
|
||||
|
||||
if (memory_format == MemoryFormat::Preserve) {
|
||||
if (options.device().supports_as_strided()) {
|
||||
|
@ -362,8 +362,15 @@ at::Tensor custom__copy_from(const at::Tensor& self, const at::Tensor& dst, bool
|
||||
self.storage().nbytes());
|
||||
} else {
|
||||
// Using cpu tensor to accomplishment stride copy.
|
||||
at::Tensor cpu_self = unsafe_create_cpu_tensor_from_dummy_tensor(self);
|
||||
at::Tensor cpu_dst = unsafe_create_cpu_tensor_from_dummy_tensor(dst);
|
||||
auto convert_to_cpu_tensor = [](const at::Tensor& src) -> at::Tensor {
|
||||
if (src.device().type() == c10::DeviceType::PrivateUse1) {
|
||||
return unsafe_create_cpu_tensor_from_dummy_tensor(src);
|
||||
} else {
|
||||
return src;
|
||||
}
|
||||
};
|
||||
at::Tensor cpu_self = convert_to_cpu_tensor(self);
|
||||
at::Tensor cpu_dst = convert_to_cpu_tensor(dst);
|
||||
cpu_dst.copy_(cpu_self);
|
||||
}
|
||||
|
||||
@ -555,6 +562,7 @@ void custom_cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack
|
||||
TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
|
||||
m.impl("sub.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
|
||||
m.impl("_foreach_add.List", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
|
||||
m.impl("_fused_adamw_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
|
||||
m.impl("index.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
|
||||
m.impl("triu_indices", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
|
||||
}
|
||||
@ -604,7 +612,10 @@ void set_custom_device_index(c10::DeviceIndex device_index) {
|
||||
custom_device_index = device_index;
|
||||
}
|
||||
|
||||
struct FooHooksArgs : public at::PrivateUse1HooksArgs {};
|
||||
|
||||
struct FooHooksInterface : public at::PrivateUse1HooksInterface {
|
||||
FooHooksInterface(FooHooksArgs) {}
|
||||
~FooHooksInterface() override = default;
|
||||
const at::Generator& getDefaultGenerator(c10::DeviceIndex device_index) override {
|
||||
static auto device_gen = make_generator_privateuse1(device_index);
|
||||
@ -612,34 +623,49 @@ struct FooHooksInterface : public at::PrivateUse1HooksInterface {
|
||||
}
|
||||
};
|
||||
|
||||
struct FooHooksArgs : public at::PrivateUse1HooksArgs {};
|
||||
|
||||
TORCH_DECLARE_REGISTRY(PrivateUse1HooksRegistry, FooHooksInterface, FooHooksArgs);
|
||||
#define REGISTER_PRIVATEUSE1_HOOKS(clsname) \
|
||||
C10_REGISTER_CLASS(PrivateUse1HooksRegistry, clsname, clsname)
|
||||
|
||||
C10_DEFINE_REGISTRY(PrivateUse1HooksRegistry, FooHooksInterface, FooHooksArgs)
|
||||
// Using Create function to get PrivateUse1HooksInterface point from PrivateUse1HooksRegistry class.
|
||||
C10_REGISTER_TYPED_CLASS(PrivateUse1HooksRegistry, "FooHooks", FooHooksInterface)
|
||||
|
||||
static at::PrivateUse1HooksInterface* privateuse1_hooks_local = nullptr;
|
||||
static at::PrivateUse1HooksInterface* get_private_hooks() {
|
||||
static at::PrivateUse1HooksInterface* privateuse1_hooks;
|
||||
static c10::once_flag once;
|
||||
c10::call_once(once, [] {
|
||||
privateuse1_hooks = PrivateUse1HooksRegistry()->Create("PrivateUse1Hooks", {}).release();
|
||||
if (!privateuse1_hooks) {
|
||||
privateuse1_hooks = new FooHooksInterface();
|
||||
privateuse1_hooks_local = PrivateUse1HooksRegistry()->Create("FooHooks", {}).release();
|
||||
if (!privateuse1_hooks_local) {
|
||||
privateuse1_hooks_local = new FooHooksInterface(FooHooksArgs{});
|
||||
}
|
||||
});
|
||||
return privateuse1_hooks;
|
||||
return privateuse1_hooks_local;
|
||||
}
|
||||
|
||||
void register_hook() {
|
||||
at::RegisterPrivateUse1HooksInterface(get_private_hooks());
|
||||
}
|
||||
|
||||
bool is_register_hook() {
|
||||
return privateuse1_hooks_local != nullptr;
|
||||
}
|
||||
|
||||
const at::Generator& default_generator(c10::DeviceIndex device_index) {
|
||||
return at::globalContext().defaultGenerator(at::Device(c10::DeviceType::PrivateUse1, device_index));;
|
||||
}
|
||||
|
||||
void fallback_with_undefined_tensor() {
|
||||
at::Tensor first = at::empty((2,3)).to(at::DeviceType::PrivateUse1);
|
||||
at::Tensor second = at::Tensor();
|
||||
at::Tensor step = at::empty({}).fill_(2).to(at::DeviceType::PrivateUse1);
|
||||
at::Tensor grad_scale = at::empty({}).fill_(0.00001).to(at::DeviceType::PrivateUse1);
|
||||
at::Tensor found_inf = at::empty({}).fill_(1).to(at::DeviceType::PrivateUse1);
|
||||
at::TensorList tensors = {first, first};
|
||||
at::TensorList undefined_tensors = {first, second};
|
||||
at::TensorList steps = {step, step};
|
||||
return at::_fused_adamw_(tensors, tensors, tensors, tensors, undefined_tensors,
|
||||
steps, 0.001, 0.9, 0.999, 1e-2, 1e-8, false, false,
|
||||
grad_scale, found_inf);
|
||||
}
|
||||
|
||||
struct CustomAutogradFnReturnsSelf : public torch::autograd::Function<CustomAutogradFnReturnsSelf> {
|
||||
|
||||
static at::Tensor forward(torch::autograd::AutogradContext* ctx, at::Tensor self) {
|
||||
@ -686,7 +712,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("check_backend_meta", &check_backend_meta, "check if BackendMeta serialization correctly");
|
||||
m.def("custom_serialization_registry", &custom_serialization_registry, "register custom serialization function");
|
||||
m.def("register_hook", ®ister_hook, "register_hook for privateuse1");
|
||||
m.def("is_register_hook", &is_register_hook, "is_register_hook for privateuse1");
|
||||
m.def("default_generator", &default_generator, "default_generator for privateuse1");
|
||||
m.def("fallback_with_undefined_tensor", &fallback_with_undefined_tensor, "fallback_with_undefined_tensor for privateuse1");
|
||||
|
||||
// Co-opting this file to more easily test torch.compile'ing of custom autograd functions in C++
|
||||
m.def("custom_autograd_fn_returns_self", &custom_autograd_fn_returns_self);
|
||||
|
@ -191,7 +191,8 @@ class TestCppExtensionOpenRgistration(common.TestCase):
|
||||
):
|
||||
self.module.register_generator_second()
|
||||
|
||||
self.module.register_hook()
|
||||
if self.module.is_register_hook() is False:
|
||||
self.module.register_hook()
|
||||
default_gen = self.module.default_generator(0)
|
||||
self.assertTrue(
|
||||
default_gen.device.type == torch._C._get_privateuse1_backend_name()
|
||||
@ -577,6 +578,9 @@ class TestCppExtensionOpenRgistration(common.TestCase):
|
||||
self.assertEqual(z_cpu, z[0])
|
||||
self.assertEqual(z_cpu, z[1])
|
||||
|
||||
# call _fused_adamw_ with undefined tensor.
|
||||
self.module.fallback_with_undefined_tensor()
|
||||
|
||||
def test_open_device_numpy_serialization_map_location(self):
|
||||
torch.utils.rename_privateuse1_backend("foo")
|
||||
device = self.module.custom_device()
|
||||
|
Reference in New Issue
Block a user