Revert "Add int1 to int7 dtypes (#136301)"

This reverts commit bfa16a161d5089a9ba008f5e665f29b58dc16526.

Reverted https://github.com/pytorch/pytorch/pull/136301 on behalf of https://github.com/PaliC due to causing internal failures ([comment](https://github.com/pytorch/pytorch/pull/136301#issuecomment-2384119600))
This commit is contained in:
PyTorch MergeBot
2024-09-30 20:50:49 +00:00
parent 0ccd39a64b
commit 2ef1454189
5 changed files with 26 additions and 84 deletions

View File

@ -192,31 +192,30 @@ class TestUtils(TestCase):
assert quantized_tensor.int_repr().max().item() == q8_max
assert quantized_tensor.int_repr().min().item() == q8_min
def test_uint4_int4_dtype(self):
def test_uint1_7_dtype(self):
def up_size(size):
return (*size[:-1], size[-1] * 2)
for dtype in [torch.uint4, torch.int4]:
class UInt4OrInt4Tensor(torch.Tensor):
@staticmethod
def __new__(cls, elem, **kwargs):
assert elem.dtype is torch.uint8
assert not kwargs.get("requires_grad", False)
kwargs["requires_grad"] = False
return torch.Tensor._make_wrapper_subclass(cls, up_size(elem.shape), dtype=dtype, **kwargs)
class UInt4Tensor(torch.Tensor):
@staticmethod
def __new__(cls, elem, **kwargs):
assert elem.dtype is torch.uint8
assert not kwargs.get("requires_grad", False)
kwargs["requires_grad"] = False
return torch.Tensor._make_wrapper_subclass(cls, up_size(elem.shape), dtype=torch.uint4, **kwargs)
def __init__(self, elem):
self.elem = elem
def __init__(self, elem):
self.elem = elem
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
pass
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
pass
# make sure it runs
x = UInt4OrInt4Tensor(torch.tensor([
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
], dtype=torch.uint8))
assert x.dtype == dtype
# make sure it runs
x = UInt4Tensor(torch.tensor([
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
], dtype=torch.uint8))
assert x.dtype == torch.uint4