mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-16 23:44:53 +08:00
Update
[ghstack-poisoned]
This commit is contained in:
@ -512,9 +512,36 @@ if not IS_WINDOWS:
|
||||
# tests tensor accessor
|
||||
import libtorch_agnostic
|
||||
|
||||
for dtype in [torch.float16, torch.float32, torch.float64]:
|
||||
for shape in [(3,), (3, 4)]:
|
||||
t = torch.empty(shape, device=device, dtype=dtype)
|
||||
for dtype in get_supported_dtypes():
|
||||
for shape in [(5,), (3, 4)]:
|
||||
if dtype in {
|
||||
torch.float16,
|
||||
torch.float32,
|
||||
torch.float64,
|
||||
torch.complex32,
|
||||
torch.complex64,
|
||||
torch.complex128,
|
||||
}:
|
||||
t = torch.randn(shape, device=device, dtype=dtype)
|
||||
elif dtype in {
|
||||
torch.float8_e5m2,
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e5m2fnuz,
|
||||
torch.float8_e4m3fnuz,
|
||||
}:
|
||||
t = torch.randn(shape, device=device, dtype=torch.float16).to(
|
||||
dtype=dtype
|
||||
)
|
||||
elif dtype is torch.bool:
|
||||
t = (
|
||||
torch.randint(
|
||||
0, 127, shape, device=device, dtype=torch.int8
|
||||
)
|
||||
% 2
|
||||
== 0
|
||||
)
|
||||
else:
|
||||
t = torch.randint(0, 127, shape, device=device, dtype=dtype)
|
||||
result = libtorch_agnostic.ops.my_element_wise_clone(t)
|
||||
expected = t.clone()
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/stable/TensorAccessor.h>
|
||||
|
||||
#include <torch/csrc/stable/stableivalue_conversions.h>
|
||||
#include <array>
|
||||
#include <cstdint>
|
||||
|
||||
Reference in New Issue
Block a user