[export] Add device and dtype fields to assert_tensor_metadata (#141071)

Differential Revision: [D66321128](https://our.internmc.facebook.com/intern/diff/D66321128)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141071
Approved by: https://github.com/yushangdi, https://github.com/zou3519
This commit is contained in:
angelayi
2024-11-21 16:15:09 -08:00
committed by PyTorch MergeBot
parent 45d62d6fc5
commit 0fbc0830ba
6 changed files with 36 additions and 3 deletions

View File

@ -19,15 +19,27 @@ void _assert_match(const O& original, const C& compared, const std::string& name
if (!equal) {
std::stringstream msg;
msg << "Tensor " << name << " mismatch!";
AT_ASSERT(equal, msg.str());
if (!equal) {
throw std::runtime_error(msg.str());
}
}
}
}
void _assert_tensor_metadata(at::Tensor const& tensor, at::OptionalIntArrayRef sizes, at::OptionalIntArrayRef strides, std::optional<c10::ScalarType> dtype) {
void _assert_tensor_metadata_meta_symint(at::Tensor const& tensor, at::OptionalSymIntArrayRef sizes, at::OptionalSymIntArrayRef strides, std::optional<c10::ScalarType> dtype, std::optional<c10::Device> device, std::optional<c10::Layout> layout) {
_assert_match(tensor.sym_sizes(), sizes, "sizes");
_assert_match(tensor.sym_strides(), strides, "strides");
_assert_match(tensor.dtype(), dtype, "dtype");
_assert_match(tensor.device(), device, "device");
_assert_match(tensor.layout(), layout, "layout");
}
void _assert_tensor_metadata(at::Tensor const& tensor, at::OptionalIntArrayRef sizes, at::OptionalIntArrayRef strides, std::optional<c10::ScalarType> dtype, std::optional<c10::Device> device, std::optional<c10::Layout> layout) {
_assert_match(tensor.sizes(), sizes, "sizes");
_assert_match(tensor.strides(), strides, "strides");
_assert_match(tensor.dtype(), dtype, "dtype");
_assert_match(tensor.device(), device, "device");
_assert_match(tensor.layout(), layout, "layout");
}
}

View File

@ -187,7 +187,10 @@
dispatch:
CPU: _functional_assert_async_msg_cpu
- func: _assert_tensor_metadata(Tensor a, SymInt[]? size=None, SymInt[]? stride=None, ScalarType? dtype=None) -> ()
- func: _assert_tensor_metadata(Tensor a, SymInt[]? size=None, SymInt[]? stride=None, ScalarType? dtype=None, *, Device? device=None, Layout? layout=None) -> ()
dispatch:
CompositeExplicitAutograd: _assert_tensor_metadata
Meta: _assert_tensor_metadata_meta_symint
- func: _print(str s) -> ()
dispatch:

View File

@ -30,6 +30,7 @@ aten::_amp_update_scale.out
aten::_amp_update_scale_
aten::_assert_async
aten::_assert_async.msg
aten::_assert_tensor_metadata
aten::_batch_norm_no_update.out
aten::_batch_norm_with_update.out
aten::_cdist_backward

View File

@ -1,6 +1,8 @@
#!/usr/bin/env python3
# Owner(s): ["module: internals"]
import unittest
import torch
from torch.testing._internal.common_utils import run_tests, TestCase
@ -32,6 +34,19 @@ class TestComparisonUtils(TestCase):
with self.assertRaises(RuntimeError):
torch._assert_tensor_metadata(t, [3], [1], torch.float)
@unittest.skipIf(not torch.cuda.is_available(), "Requires cuda")
def test_assert_device(self):
t = torch.tensor([0.5], device="cpu")
with self.assertRaises(RuntimeError):
torch._assert_tensor_metadata(t, device="cuda")
def test_assert_layout(self):
t = torch.tensor([0.5])
with self.assertRaises(RuntimeError):
torch._assert_tensor_metadata(t, layout=torch.sparse_coo)
if __name__ == "__main__":
run_tests()

View File

@ -80,6 +80,7 @@ _side_effectful_functions: Set[Callable] = {
torch._assert_async,
_ops.aten._assert_async.msg,
_ops.aten._assert_scalar.default,
_ops.aten._assert_tensor_metadata.default,
_ops.aten.sym_constrain_range.default,
_ops.aten.sym_constrain_range_for_size.default,
_ops.profiler._record_function_enter,

View File

@ -53,6 +53,7 @@ MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT = [
FUNCTIONAL_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT = [
"_assert_async", # no return
"_assert_async.msg", # no return
"_assert_tensor_metadata", # no return
"_cslt_sparse_mm_search", # returns an int
"_assert_scalar", # no return
"_dimI", # returns an int