fix the problem that cpu_fallback for aten::triu_indices on custom device crashed (#121306)

Fixes #121289

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121306
Approved by: https://github.com/ezyang
This commit is contained in:
chentianyi16
2024-03-26 01:29:45 +00:00
committed by PyTorch MergeBot
parent 5e66bf5f42
commit 83ad8e01b1
3 changed files with 15 additions and 2 deletions

View File

@ -89,6 +89,7 @@ void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool
std::vector<c10::List<at::Tensor>> tensorlist_args;
std::vector<int> tensorlist_args_indices;
c10::optional<c10::Device> tgt_device = c10::nullopt;
// save converted cpu tensor for TensorList
std::vector<c10::IValue> tensorlist_cpu_args;
@ -124,6 +125,9 @@ void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool
opt_tensors[idx] = cpu_tensors[i];
}
(*stack)[arguments_begin + idx] = c10::IValue(opt_tensors);
} else if (ivalue.isDevice()) {
tgt_device = ivalue.toDevice();
(*stack)[arguments_begin + idx] = c10::IValue(c10::Device(kCPU));
}
}
// XLA requires all of the tensor arguments to be gathered up and converted to CPU together.
@ -184,8 +188,9 @@ void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool
auto returns = torch::jit::last(stack, num_returns);
const auto returns_begin = stack->size() - num_returns;
c10::optional<c10::Device> tgt_device =
compute_target_device(tensor_args, tensorlist_args);
if (tgt_device == c10::nullopt) {
tgt_device = compute_target_device(tensor_args, tensorlist_args);
}
for (const auto idx : c10::irange(returns.size())) {
const AliasInfo* alias_info = schema_returns[idx].alias_info();

View File

@ -492,6 +492,7 @@ TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
m.impl("sub.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
m.impl("_foreach_add.List", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
m.impl("index.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
m.impl("triu.indices", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
}
// This basic implementation doesn't bother dealing with different device indices

View File

@ -469,6 +469,12 @@ class TestCppExtensionOpenRgistration(common.TestCase):
self.assertEqual(out_ref, out_test)
self.assertEqual(x_ref.grad, x_test.grad)
def test_open_device_scalar_type_fallback():
torch.utils.rename_privateuse1_backend('foo')
z_cpu = torch.Tensor([[0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2]]).to(torch.int64)
z = torch.triu_indices(3, 3, device='foo')
self.assertEqual(z_cpu, z)
def test_open_device_tensor_type_fallback():
torch.utils.rename_privateuse1_backend('foo')
# create tensors located in custom device
@ -528,6 +534,7 @@ class TestCppExtensionOpenRgistration(common.TestCase):
test_compile_autograd_function_returns_self()
test_compile_autograd_function_aliasing()
test_open_device_scalar_type_fallback()
test_open_device_tensor_type_fallback()
test_open_device_tensorlist_type_fallback()