mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Partially addresses #123062 Ran lintrunner on - test/_test_bazel.py - test/ao - test/autograd test/backends test/benchmark_uitls test/conftest.py test/bottleneck_test test/cpp Pull Request resolved: https://github.com/pytorch/pytorch/pull/123369 Approved by: https://github.com/huydhn
195 lines
7.0 KiB
Python
195 lines
7.0 KiB
Python
# Owner(s): ["module: unknown"]
|
|
|
|
import warnings
|
|
|
|
from torch import nn
|
|
|
|
from torch.ao.pruning import BaseScheduler, CubicSL, LambdaSL, WeightNormSparsifier
|
|
|
|
from torch.testing._internal.common_utils import TestCase
|
|
|
|
|
|
class ImplementedScheduler(BaseScheduler):
|
|
def get_sl(self):
|
|
if self.last_epoch > 0:
|
|
return [group["sparsity_level"] * 0.5 for group in self.sparsifier.groups]
|
|
else:
|
|
return list(self.base_sl)
|
|
|
|
|
|
class TestScheduler(TestCase):
|
|
def test_constructor(self):
|
|
model = nn.Sequential(nn.Linear(16, 16))
|
|
sparsifier = WeightNormSparsifier()
|
|
sparsifier.prepare(model, config=None)
|
|
scheduler = ImplementedScheduler(sparsifier)
|
|
|
|
assert scheduler.sparsifier is sparsifier
|
|
assert scheduler._step_count == 1
|
|
assert scheduler.base_sl == [sparsifier.groups[0]["sparsity_level"]]
|
|
|
|
def test_order_of_steps(self):
|
|
"""Checks if the warning is thrown if the scheduler step is called
|
|
before the sparsifier step"""
|
|
|
|
model = nn.Sequential(nn.Linear(16, 16))
|
|
sparsifier = WeightNormSparsifier()
|
|
sparsifier.prepare(model, config=None)
|
|
scheduler = ImplementedScheduler(sparsifier)
|
|
|
|
# Sparsifier step is not called
|
|
with self.assertWarns(UserWarning):
|
|
scheduler.step()
|
|
|
|
# Correct order has no warnings
|
|
# Note: This will trigger if other warnings are present.
|
|
with warnings.catch_warnings(record=True) as w:
|
|
sparsifier.step()
|
|
scheduler.step()
|
|
# Make sure there is no warning related to the base_scheduler
|
|
for warning in w:
|
|
fname = warning.filename
|
|
fname = "/".join(fname.split("/")[-5:])
|
|
assert fname != "torch/ao/sparsity/scheduler/base_scheduler.py"
|
|
|
|
def test_step(self):
|
|
model = nn.Sequential(nn.Linear(16, 16))
|
|
sparsifier = WeightNormSparsifier()
|
|
sparsifier.prepare(model, config=None)
|
|
assert sparsifier.groups[0]["sparsity_level"] == 0.5
|
|
scheduler = ImplementedScheduler(sparsifier)
|
|
assert sparsifier.groups[0]["sparsity_level"] == 0.5
|
|
|
|
sparsifier.step()
|
|
scheduler.step()
|
|
assert sparsifier.groups[0]["sparsity_level"] == 0.25
|
|
|
|
def test_lambda_scheduler(self):
|
|
model = nn.Sequential(nn.Linear(16, 16))
|
|
sparsifier = WeightNormSparsifier()
|
|
sparsifier.prepare(model, config=None)
|
|
assert sparsifier.groups[0]["sparsity_level"] == 0.5
|
|
scheduler = LambdaSL(sparsifier, lambda epoch: epoch * 10)
|
|
assert sparsifier.groups[0]["sparsity_level"] == 0.0 # Epoch 0
|
|
scheduler.step()
|
|
assert sparsifier.groups[0]["sparsity_level"] == 5.0 # Epoch 1
|
|
|
|
|
|
class TestCubicScheduler(TestCase):
|
|
def setUp(self):
|
|
self.model_sparse_config = [
|
|
{"tensor_fqn": "0.weight", "sparsity_level": 0.8},
|
|
{"tensor_fqn": "2.weight", "sparsity_level": 0.4},
|
|
]
|
|
self.sorted_sparse_levels = [
|
|
conf["sparsity_level"] for conf in self.model_sparse_config
|
|
]
|
|
self.initial_sparsity = 0.1
|
|
self.initial_step = 3
|
|
|
|
def _make_model(self, **kwargs):
|
|
model = nn.Sequential(
|
|
nn.Linear(13, 17),
|
|
nn.Dropout(0.5),
|
|
nn.Linear(17, 3),
|
|
)
|
|
return model
|
|
|
|
def _make_scheduler(self, model, **kwargs):
|
|
sparsifier = WeightNormSparsifier()
|
|
sparsifier.prepare(model, config=self.model_sparse_config)
|
|
|
|
scheduler_args = {
|
|
"init_sl": self.initial_sparsity,
|
|
"init_t": self.initial_step,
|
|
}
|
|
scheduler_args.update(kwargs)
|
|
|
|
scheduler = CubicSL(sparsifier, **scheduler_args)
|
|
return sparsifier, scheduler
|
|
|
|
@staticmethod
|
|
def _get_sparsity_levels(sparsifier, precision=32):
|
|
r"""Gets the current levels of sparsity in a sparsifier."""
|
|
return [
|
|
round(group["sparsity_level"], precision) for group in sparsifier.groups
|
|
]
|
|
|
|
def test_constructor(self):
|
|
model = self._make_model()
|
|
sparsifier, scheduler = self._make_scheduler(model=model, initially_zero=True)
|
|
self.assertIs(
|
|
scheduler.sparsifier, sparsifier, msg="Sparsifier is not properly attached"
|
|
)
|
|
self.assertEqual(
|
|
scheduler._step_count,
|
|
1,
|
|
msg="Scheduler is initialized with incorrect step count",
|
|
)
|
|
self.assertEqual(
|
|
scheduler.base_sl,
|
|
self.sorted_sparse_levels,
|
|
msg="Scheduler did not store the target sparsity levels correctly",
|
|
)
|
|
|
|
# Value before t_0 is 0
|
|
self.assertEqual(
|
|
self._get_sparsity_levels(sparsifier),
|
|
scheduler._make_sure_a_list(0.0),
|
|
msg="Sparsifier is not reset correctly after attaching to the Scheduler",
|
|
)
|
|
|
|
# Value before t_0 is s_0
|
|
model = self._make_model()
|
|
sparsifier, scheduler = self._make_scheduler(model=model, initially_zero=False)
|
|
self.assertEqual(
|
|
self._get_sparsity_levels(sparsifier),
|
|
scheduler._make_sure_a_list(self.initial_sparsity),
|
|
msg="Sparsifier is not reset correctly after attaching to the Scheduler",
|
|
)
|
|
|
|
def test_step(self):
|
|
# For n=5, dt=2, there will be totally 10 steps between s_0 and s_f, starting from t_0
|
|
model = self._make_model()
|
|
sparsifier, scheduler = self._make_scheduler(
|
|
model=model, initially_zero=True, init_t=3, delta_t=2, total_t=5
|
|
)
|
|
|
|
scheduler.step()
|
|
scheduler.step()
|
|
self.assertEqual(
|
|
scheduler._step_count,
|
|
3,
|
|
msg="Scheduler step_count is expected to increment",
|
|
)
|
|
# Value before t_0 is supposed to be 0
|
|
self.assertEqual(
|
|
self._get_sparsity_levels(sparsifier),
|
|
scheduler._make_sure_a_list(0.0),
|
|
msg="Scheduler step updating the sparsity level before t_0",
|
|
)
|
|
|
|
scheduler.step() # Step = 3 => sparsity = initial_sparsity
|
|
self.assertEqual(
|
|
self._get_sparsity_levels(sparsifier),
|
|
scheduler._make_sure_a_list(self.initial_sparsity),
|
|
msg="Sparsifier is not reset to initial sparsity at the first step",
|
|
)
|
|
|
|
scheduler.step() # Step = 4 => sparsity ~ [0.3, 0.2]
|
|
self.assertEqual(
|
|
self._get_sparsity_levels(sparsifier, 1),
|
|
[0.3, 0.2],
|
|
msg="Sparsity level is not set correctly after the first step",
|
|
)
|
|
|
|
current_step = scheduler._step_count - scheduler.init_t[0] - 1
|
|
more_steps_needed = scheduler.delta_t[0] * scheduler.total_t[0] - current_step
|
|
for _ in range(more_steps_needed): # More steps needed to final sparsity level
|
|
scheduler.step()
|
|
self.assertEqual(
|
|
self._get_sparsity_levels(sparsifier),
|
|
self.sorted_sparse_levels,
|
|
msg="Sparsity level is not reaching the target level afer delta_t * n steps ",
|
|
)
|