mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Add int1 to int7 dtypes (#136301)"
This reverts commit bfa16a161d5089a9ba008f5e665f29b58dc16526. Reverted https://github.com/pytorch/pytorch/pull/136301 on behalf of https://github.com/PaliC due to causing internal failures ([comment](https://github.com/pytorch/pytorch/pull/136301#issuecomment-2384119600))
This commit is contained in:
@ -192,31 +192,30 @@ class TestUtils(TestCase):
|
||||
assert quantized_tensor.int_repr().max().item() == q8_max
|
||||
assert quantized_tensor.int_repr().min().item() == q8_min
|
||||
|
||||
def test_uint4_int4_dtype(self):
|
||||
def test_uint1_7_dtype(self):
|
||||
|
||||
def up_size(size):
|
||||
return (*size[:-1], size[-1] * 2)
|
||||
|
||||
for dtype in [torch.uint4, torch.int4]:
|
||||
class UInt4OrInt4Tensor(torch.Tensor):
|
||||
@staticmethod
|
||||
def __new__(cls, elem, **kwargs):
|
||||
assert elem.dtype is torch.uint8
|
||||
assert not kwargs.get("requires_grad", False)
|
||||
kwargs["requires_grad"] = False
|
||||
return torch.Tensor._make_wrapper_subclass(cls, up_size(elem.shape), dtype=dtype, **kwargs)
|
||||
class UInt4Tensor(torch.Tensor):
|
||||
@staticmethod
|
||||
def __new__(cls, elem, **kwargs):
|
||||
assert elem.dtype is torch.uint8
|
||||
assert not kwargs.get("requires_grad", False)
|
||||
kwargs["requires_grad"] = False
|
||||
return torch.Tensor._make_wrapper_subclass(cls, up_size(elem.shape), dtype=torch.uint4, **kwargs)
|
||||
|
||||
def __init__(self, elem):
|
||||
self.elem = elem
|
||||
def __init__(self, elem):
|
||||
self.elem = elem
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args, kwargs=None):
|
||||
pass
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args, kwargs=None):
|
||||
pass
|
||||
|
||||
# make sure it runs
|
||||
x = UInt4OrInt4Tensor(torch.tensor([
|
||||
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
|
||||
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
|
||||
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
|
||||
], dtype=torch.uint8))
|
||||
assert x.dtype == dtype
|
||||
# make sure it runs
|
||||
x = UInt4Tensor(torch.tensor([
|
||||
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
|
||||
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
|
||||
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
|
||||
], dtype=torch.uint8))
|
||||
assert x.dtype == torch.uint4
|
||||
|
Reference in New Issue
Block a user