[pytorch-vulkan] fix zero-dim test (#113116)

Summary:
Fix zero-dim test. Use `at::zeros` instead of `at::empty` as the init value inside a `at::empty` tensor is undefined. Likely to be the cause of test flakiness.

 {F1142344469}

Test Plan:
Run on devserver

```
$ LD_LIBRARY_PATH=third-party/swiftshader/lib/linux-x64/ buck2 run fbcode/mode/dev-nosan    //xplat/caffe2:pt_vulkan_api_test_bin
...
[       OK ] VulkanAPITest.linear_4d_large (2 ms)
[ RUN      ] VulkanAPITest.lstm_success
[       OK ] VulkanAPITest.lstm_success (4 ms)
[ RUN      ] VulkanAPITest.lstm_mclareninputs_success
[       OK ] VulkanAPITest.lstm_mclareninputs_success (45 ms)
[ RUN      ] VulkanAPITest.lstm_prepack_success
[       OK ] VulkanAPITest.lstm_prepack_success (2 ms)
[ RUN      ] VulkanAPITest.querypool_flushed_shader_log
xplat/caffe2/aten/src/ATen/test/vulkan_api_test.cpp:7773: Skipped
QueryPool is not available

[  SKIPPED ] VulkanAPITest.querypool_flushed_shader_log (0 ms)
[----------] 402 tests from VulkanAPITest (24598 ms total)

[----------] Global test environment tear-down
[==========] 402 tests from 1 test suite ran. (24598 ms total)
[  PASSED  ] 399 tests.
[  SKIPPED ] 1 test, listed below:
[  SKIPPED ] VulkanAPITest.querypool_flushed_shader_log
[  FAILED  ] 2 tests, listed below:
[  FAILED  ] VulkanAPITest.conv2d_pw_prepack
[  FAILED  ] VulkanAPITest.conv2d_pw_prepack_bc

 2 FAILED TESTS
  YOU HAVE 7 DISABLED TESTS

```

Last two are known failures on devserver.

Full output: P875058890

Differential Revision: D51055623

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113116
Approved by: https://github.com/manuelcandales
This commit is contained in:
Justin Yip
2023-11-07 21:32:03 +00:00
committed by PyTorch MergeBot
parent ff1ae35205
commit 2da062da51

View File

@ -313,8 +313,8 @@ TEST_F(VulkanAPITest, zero_dim_tensor_1) {
TEST_F(VulkanAPITest, zero_dim_tensor_2) {
float v = 3.14f;
auto cpu = at::empty({}, at::device(at::kCPU).dtype(at::kFloat)) + v;
auto vk = at::empty({}, at::device(at::kVulkan).dtype(at::kFloat)) + v;
auto cpu = at::zeros({}, at::device(at::kCPU).dtype(at::kFloat)) + v;
auto vk = at::zeros({}, at::device(at::kVulkan).dtype(at::kFloat)) + v;
ASSERT_TRUE(almostEqual(cpu, vk.cpu()));
}