enable cat for cuda bits types (#115044)

It was already working for cpu, so bring parity.
Also, slightly reduce number of compiled kernels by using OpaqueType.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115044
Approved by: https://github.com/malfet
This commit is contained in:
Natalia Gimelshein
2023-12-06 00:05:18 +00:00
committed by PyTorch MergeBot
parent b9c4fb68c5
commit b8ce05456c
3 changed files with 71 additions and 18 deletions

View File

@ -153,7 +153,14 @@ except ImportError as e:
logging.warning(e)
# Experimental functionality
from quantization.core.experimental.test_bits import TestBits # noqa: F401
try:
from quantization.core.experimental.test_bits import TestBitsCPU # noqa: F401
except ImportError as e:
logging.warning(e)
try:
from quantization.core.experimental.test_bits import TestBitsCUDA # noqa: F401
except ImportError as e:
logging.warning(e)
try:
from quantization.core.experimental.test_float8 import TestFloat8DtypeCPU # noqa: F401
except ImportError as e: