[ghstack-poisoned]
This commit is contained in:
Pearu Peterson
2025-09-29 13:53:20 +03:00
2 changed files with 30 additions and 4 deletions

View File

@ -512,9 +512,36 @@ if not IS_WINDOWS:
# tests tensor accessor
import libtorch_agnostic
for dtype in [torch.float16, torch.float32, torch.float64]:
for shape in [(3,), (3, 4)]:
t = torch.empty(shape, device=device, dtype=dtype)
for dtype in get_supported_dtypes():
for shape in [(5,), (3, 4)]:
if dtype in {
torch.float16,
torch.float32,
torch.float64,
torch.complex32,
torch.complex64,
torch.complex128,
}:
t = torch.randn(shape, device=device, dtype=dtype)
elif dtype in {
torch.float8_e5m2,
torch.float8_e4m3fn,
torch.float8_e5m2fnuz,
torch.float8_e4m3fnuz,
}:
t = torch.randn(shape, device=device, dtype=torch.float16).to(
dtype=dtype
)
elif dtype is torch.bool:
t = (
torch.randint(
0, 127, shape, device=device, dtype=torch.int8
)
% 2
== 0
)
else:
t = torch.randint(0, 127, shape, device=device, dtype=dtype)
result = libtorch_agnostic.ops.my_element_wise_clone(t)
expected = t.clone()
self.assertEqual(result, expected)

View File

@ -1,7 +1,6 @@
#pragma once
#include <torch/csrc/stable/TensorAccessor.h>
#include <torch/csrc/stable/stableivalue_conversions.h>
#include <array>
#include <cstdint>