mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
update test_quantization tests to run weekly (#163077)
Fixes #162854 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163077 Approved by: https://github.com/huydhn
This commit is contained in:
committed by
PyTorch MergeBot
parent
141fc7276e
commit
3b73841f43
@ -4,6 +4,8 @@ import copy
|
||||
import unittest
|
||||
from collections import Counter
|
||||
|
||||
from packaging import version
|
||||
|
||||
import torch
|
||||
from torch.ao.quantization import (
|
||||
compare_results,
|
||||
@ -29,6 +31,10 @@ from torch.testing._internal.common_utils import (
|
||||
)
|
||||
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse("2.8.0"):
|
||||
torch._dynamo.config.cache_size_limit = 128
|
||||
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile")
|
||||
class TestNumericDebugger(TestCase):
|
||||
def _assert_each_node_has_debug_handle(self, model) -> None:
|
||||
|
@ -2121,14 +2121,9 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
m(*example_inputs)
|
||||
|
||||
def test_observer_callback(self):
|
||||
from torch.library import impl, Library
|
||||
from torch.library import custom_op
|
||||
|
||||
test_lib = Library("test_int4", "DEF") # noqa: TOR901
|
||||
test_lib.define(
|
||||
"quantize_per_tensor_int4(Tensor input, float scale, int zero_point) -> Tensor"
|
||||
)
|
||||
|
||||
@impl(test_lib, "quantize_per_tensor_int4", "CompositeExplicitAutograd")
|
||||
@custom_op("test_int4::quantize_per_tensor_int4", mutates_args=())
|
||||
def quantize_per_tensor_int4(
|
||||
input: torch.Tensor,
|
||||
scale: float,
|
||||
@ -2141,11 +2136,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
.view(torch.bits8)
|
||||
)
|
||||
|
||||
test_lib.define(
|
||||
"dequantize_per_tensor_int4(Tensor input, float scale, int zero_point) -> Tensor"
|
||||
)
|
||||
|
||||
@impl(test_lib, "dequantize_per_tensor_int4", "CompositeExplicitAutograd")
|
||||
@custom_op("test_int4::dequantize_per_tensor_int4", mutates_args=())
|
||||
def dequantize_per_tensor_int4(
|
||||
input: torch.Tensor,
|
||||
scale: float,
|
||||
|
Reference in New Issue
Block a user