mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
fix the problem that cpu_fallback for aten::triu_indices on custom device crashed (#121306)
Fixes #121289 Pull Request resolved: https://github.com/pytorch/pytorch/pull/121306 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
5e66bf5f42
commit
83ad8e01b1
@ -89,6 +89,7 @@ void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool
|
||||
std::vector<c10::List<at::Tensor>> tensorlist_args;
|
||||
std::vector<int> tensorlist_args_indices;
|
||||
|
||||
c10::optional<c10::Device> tgt_device = c10::nullopt;
|
||||
// save converted cpu tensor for TensorList
|
||||
std::vector<c10::IValue> tensorlist_cpu_args;
|
||||
|
||||
@ -124,6 +125,9 @@ void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool
|
||||
opt_tensors[idx] = cpu_tensors[i];
|
||||
}
|
||||
(*stack)[arguments_begin + idx] = c10::IValue(opt_tensors);
|
||||
} else if (ivalue.isDevice()) {
|
||||
tgt_device = ivalue.toDevice();
|
||||
(*stack)[arguments_begin + idx] = c10::IValue(c10::Device(kCPU));
|
||||
}
|
||||
}
|
||||
// XLA requires all of the tensor arguments to be gathered up and converted to CPU together.
|
||||
@ -184,8 +188,9 @@ void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool
|
||||
auto returns = torch::jit::last(stack, num_returns);
|
||||
const auto returns_begin = stack->size() - num_returns;
|
||||
|
||||
c10::optional<c10::Device> tgt_device =
|
||||
compute_target_device(tensor_args, tensorlist_args);
|
||||
if (tgt_device == c10::nullopt) {
|
||||
tgt_device = compute_target_device(tensor_args, tensorlist_args);
|
||||
}
|
||||
|
||||
for (const auto idx : c10::irange(returns.size())) {
|
||||
const AliasInfo* alias_info = schema_returns[idx].alias_info();
|
||||
|
||||
@ -492,6 +492,7 @@ TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
|
||||
m.impl("sub.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
|
||||
m.impl("_foreach_add.List", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
|
||||
m.impl("index.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
|
||||
m.impl("triu.indices", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
|
||||
}
|
||||
|
||||
// This basic implementation doesn't bother dealing with different device indices
|
||||
|
||||
@ -469,6 +469,12 @@ class TestCppExtensionOpenRgistration(common.TestCase):
|
||||
self.assertEqual(out_ref, out_test)
|
||||
self.assertEqual(x_ref.grad, x_test.grad)
|
||||
|
||||
def test_open_device_scalar_type_fallback():
|
||||
torch.utils.rename_privateuse1_backend('foo')
|
||||
z_cpu = torch.Tensor([[0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2]]).to(torch.int64)
|
||||
z = torch.triu_indices(3, 3, device='foo')
|
||||
self.assertEqual(z_cpu, z)
|
||||
|
||||
def test_open_device_tensor_type_fallback():
|
||||
torch.utils.rename_privateuse1_backend('foo')
|
||||
# create tensors located in custom device
|
||||
@ -528,6 +534,7 @@ class TestCppExtensionOpenRgistration(common.TestCase):
|
||||
test_compile_autograd_function_returns_self()
|
||||
test_compile_autograd_function_aliasing()
|
||||
|
||||
test_open_device_scalar_type_fallback()
|
||||
test_open_device_tensor_type_fallback()
|
||||
test_open_device_tensorlist_type_fallback()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user