mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[export] Add device and dtype fields to assert_tensor_metadata (#141071)
Differential Revision: [D66321128](https://our.internmc.facebook.com/intern/diff/D66321128) Pull Request resolved: https://github.com/pytorch/pytorch/pull/141071 Approved by: https://github.com/yushangdi, https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
45d62d6fc5
commit
0fbc0830ba
@ -19,15 +19,27 @@ void _assert_match(const O& original, const C& compared, const std::string& name
|
||||
if (!equal) {
|
||||
std::stringstream msg;
|
||||
msg << "Tensor " << name << " mismatch!";
|
||||
AT_ASSERT(equal, msg.str());
|
||||
if (!equal) {
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void _assert_tensor_metadata(at::Tensor const& tensor, at::OptionalIntArrayRef sizes, at::OptionalIntArrayRef strides, std::optional<c10::ScalarType> dtype) {
|
||||
void _assert_tensor_metadata_meta_symint(at::Tensor const& tensor, at::OptionalSymIntArrayRef sizes, at::OptionalSymIntArrayRef strides, std::optional<c10::ScalarType> dtype, std::optional<c10::Device> device, std::optional<c10::Layout> layout) {
|
||||
_assert_match(tensor.sym_sizes(), sizes, "sizes");
|
||||
_assert_match(tensor.sym_strides(), strides, "strides");
|
||||
_assert_match(tensor.dtype(), dtype, "dtype");
|
||||
_assert_match(tensor.device(), device, "device");
|
||||
_assert_match(tensor.layout(), layout, "layout");
|
||||
}
|
||||
|
||||
void _assert_tensor_metadata(at::Tensor const& tensor, at::OptionalIntArrayRef sizes, at::OptionalIntArrayRef strides, std::optional<c10::ScalarType> dtype, std::optional<c10::Device> device, std::optional<c10::Layout> layout) {
|
||||
_assert_match(tensor.sizes(), sizes, "sizes");
|
||||
_assert_match(tensor.strides(), strides, "strides");
|
||||
_assert_match(tensor.dtype(), dtype, "dtype");
|
||||
_assert_match(tensor.device(), device, "device");
|
||||
_assert_match(tensor.layout(), layout, "layout");
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -187,7 +187,10 @@
|
||||
dispatch:
|
||||
CPU: _functional_assert_async_msg_cpu
|
||||
|
||||
- func: _assert_tensor_metadata(Tensor a, SymInt[]? size=None, SymInt[]? stride=None, ScalarType? dtype=None) -> ()
|
||||
- func: _assert_tensor_metadata(Tensor a, SymInt[]? size=None, SymInt[]? stride=None, ScalarType? dtype=None, *, Device? device=None, Layout? layout=None) -> ()
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: _assert_tensor_metadata
|
||||
Meta: _assert_tensor_metadata_meta_symint
|
||||
|
||||
- func: _print(str s) -> ()
|
||||
dispatch:
|
||||
|
@ -30,6 +30,7 @@ aten::_amp_update_scale.out
|
||||
aten::_amp_update_scale_
|
||||
aten::_assert_async
|
||||
aten::_assert_async.msg
|
||||
aten::_assert_tensor_metadata
|
||||
aten::_batch_norm_no_update.out
|
||||
aten::_batch_norm_with_update.out
|
||||
aten::_cdist_backward
|
||||
|
@ -1,6 +1,8 @@
|
||||
#!/usr/bin/env python3
|
||||
# Owner(s): ["module: internals"]
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
@ -32,6 +34,19 @@ class TestComparisonUtils(TestCase):
|
||||
with self.assertRaises(RuntimeError):
|
||||
torch._assert_tensor_metadata(t, [3], [1], torch.float)
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "Requires cuda")
|
||||
def test_assert_device(self):
|
||||
t = torch.tensor([0.5], device="cpu")
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
torch._assert_tensor_metadata(t, device="cuda")
|
||||
|
||||
def test_assert_layout(self):
|
||||
t = torch.tensor([0.5])
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
torch._assert_tensor_metadata(t, layout=torch.sparse_coo)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -80,6 +80,7 @@ _side_effectful_functions: Set[Callable] = {
|
||||
torch._assert_async,
|
||||
_ops.aten._assert_async.msg,
|
||||
_ops.aten._assert_scalar.default,
|
||||
_ops.aten._assert_tensor_metadata.default,
|
||||
_ops.aten.sym_constrain_range.default,
|
||||
_ops.aten.sym_constrain_range_for_size.default,
|
||||
_ops.profiler._record_function_enter,
|
||||
|
@ -53,6 +53,7 @@ MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT = [
|
||||
FUNCTIONAL_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT = [
|
||||
"_assert_async", # no return
|
||||
"_assert_async.msg", # no return
|
||||
"_assert_tensor_metadata", # no return
|
||||
"_cslt_sparse_mm_search", # returns an int
|
||||
"_assert_scalar", # no return
|
||||
"_dimI", # returns an int
|
||||
|
Reference in New Issue
Block a user