Fix out_tensor device in diag_test.py (#134020)

This benchmark fails if device='cuda' but out_tensor is on cpu

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134020
Approved by: https://github.com/soulitzer
This commit is contained in:
Pavel Belevich
2024-08-21 20:43:36 +00:00
committed by PyTorch MergeBot
parent 6c1e2d2462
commit a3e1416c05

View File

@ -31,6 +31,7 @@ class DiagBenchmark(op_bench.TorchBenchmarkBase):
"out": out,
"out_tensor": torch.tensor(
(),
device=device,
),
}
self.set_module_name("diag")