Fixed print tensor.type() issue. (#96381)

Fixes #95954
Updating the cpp printing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96381
Approved by: https://github.com/albanD
This commit is contained in:
dujinhang
2023-03-17 20:26:43 +00:00
committed by PyTorch MergeBot
parent 57bb5b159d
commit 1983b31711
2 changed files with 23 additions and 1 deletions

View File

@ -115,5 +115,22 @@ class TestCppExtensionOpenRgistration(common.TestCase):
"Only can register a generator to the PrivateUse1 dispatch key once"):
module.register_genertor()
# check whether print tensor.type() meets the expectation
torch.utils.rename_privateuse1_backend('foo')
dtypes = {
torch.bool: 'torch.foo.BoolTensor',
torch.double: 'torch.foo.DoubleTensor',
torch.float32: 'torch.foo.FloatTensor',
torch.half: 'torch.foo.HalfTensor',
torch.int32: 'torch.foo.IntTensor',
torch.int64: 'torch.foo.LongTensor',
torch.int8: 'torch.foo.CharTensor',
torch.short: 'torch.foo.ShortTensor',
torch.uint8: 'torch.foo.ByteTensor',
}
for tt, dt in dtypes.items():
test_tensor = torch.empty(4, 4, dtype=tt, device=device)
self.assertTrue(test_tensor.type() == dt)
if __name__ == "__main__":
common.run_tests()

View File

@ -19,6 +19,11 @@ using namespace at;
namespace torch {
namespace utils {
const char* parse_privateuseone_backend() {
static std::string backend_name = "torch." + get_privateuse1_backend();
return backend_name.c_str();
}
static const char* backend_to_string(const at::Backend& backend) {
switch (backend) {
case at::Backend::CPU:
@ -42,7 +47,7 @@ static const char* backend_to_string(const at::Backend& backend) {
case at::Backend::MPS:
return "torch.mps";
case at::Backend::PrivateUse1:
return "torch.privateuseone";
return parse_privateuseone_backend();
case at::Backend::Lazy:
return "torch.lazy";
case at::Backend::XLA: