mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
I am trying to give some test files better owner labels than `module: unknown`. I am not sure them, but they seem pretty reasonable Pull Request resolved: https://github.com/pytorch/pytorch/pull/163203 Approved by: https://github.com/jcaip
181 lines
6.6 KiB
Python
181 lines
6.6 KiB
Python
# Owner(s): ["module: sparse"]
|
|
|
|
import copy
|
|
import warnings
|
|
|
|
import torch
|
|
from torch import nn
|
|
from torch.ao.pruning._experimental.data_scheduler import BaseDataScheduler
|
|
from torch.ao.pruning._experimental.data_sparsifier import DataNormSparsifier
|
|
from torch.testing._internal.common_utils import raise_on_run_directly, TestCase
|
|
|
|
|
|
class ImplementedDataScheduler(BaseDataScheduler):
|
|
def __init__(self, sparsifier, sparsifier_hyperparam, last_epoch=-1, verbose=False):
|
|
super().__init__(sparsifier, sparsifier_hyperparam, last_epoch, verbose)
|
|
|
|
def get_schedule_param(self):
|
|
if self.last_epoch > 0:
|
|
return {
|
|
name: config["sparsity_level"] * 0.5
|
|
for name, config in self.data_sparsifier.data_groups.items()
|
|
}
|
|
else:
|
|
return self.base_param
|
|
|
|
|
|
class TestBaseDataScheduler(TestCase):
|
|
def _get_data(self):
|
|
tensor1, param1, emb1 = (
|
|
torch.randn(5, 5),
|
|
nn.Parameter(torch.randn(10, 10)),
|
|
nn.Embedding(50, 5),
|
|
)
|
|
data_list = [("tensor1", tensor1), ("param1", param1), ("emb1", emb1)]
|
|
defaults = {
|
|
"sparsity_level": 0.7,
|
|
"sparse_block_shape": (1, 4),
|
|
"zeros_per_block": 2,
|
|
}
|
|
data_with_config = [
|
|
{
|
|
"name": "tensor2",
|
|
"data": torch.randn(4, 4),
|
|
"config": {"sparsity_level": 0.3},
|
|
}
|
|
]
|
|
return data_list, data_with_config, defaults
|
|
|
|
def _get_sparsifier(self, data_list, data_with_config, defaults):
|
|
sparsifier = DataNormSparsifier(data_list, **defaults)
|
|
for data_config_dict in data_with_config:
|
|
name, data, config = (
|
|
data_config_dict["name"],
|
|
data_config_dict["data"],
|
|
data_config_dict["config"],
|
|
)
|
|
sparsifier.add_data(name=name, data=data, **config)
|
|
return sparsifier
|
|
|
|
def _get_scheduler(self, sparsifier, schedule_param):
|
|
scheduler = ImplementedDataScheduler(sparsifier, schedule_param)
|
|
return scheduler
|
|
|
|
def _get_schedule_param(self):
|
|
return "sparsity_level"
|
|
|
|
def _get_name_data_config(self, some_data, defaults):
|
|
config = copy.deepcopy(defaults)
|
|
if isinstance(some_data, tuple):
|
|
# dealing with data_list
|
|
name, data = some_data
|
|
else:
|
|
# dealing with data_with_config
|
|
name, data, new_config = (
|
|
some_data["name"],
|
|
some_data["data"],
|
|
some_data["config"],
|
|
)
|
|
config.update(new_config)
|
|
return name, data, config
|
|
|
|
def test_constructor(self):
|
|
"""Checks if the warning is thrown if the scheduler step is called
|
|
before the sparsifier step"""
|
|
data_list, data_with_config, defaults = self._get_data()
|
|
sparsifier = self._get_sparsifier(data_list, data_with_config, defaults)
|
|
schedule_param = self._get_schedule_param()
|
|
scheduler = self._get_scheduler(sparsifier, schedule_param)
|
|
|
|
assert scheduler.data_sparsifier == sparsifier
|
|
assert scheduler._step_count == 1
|
|
|
|
for name, config in sparsifier.data_groups.items():
|
|
assert scheduler.base_param[name] == config.get(schedule_param, None)
|
|
|
|
def test_order_of_steps(self):
|
|
data_list, data_with_config, defaults = self._get_data()
|
|
sparsifier = self._get_sparsifier(data_list, data_with_config, defaults)
|
|
schedule_param = self._get_schedule_param()
|
|
scheduler = self._get_scheduler(sparsifier, schedule_param)
|
|
|
|
# Sparsifier step is not called
|
|
with self.assertWarns(UserWarning):
|
|
scheduler.step()
|
|
|
|
# Correct order has no warnings
|
|
# Note: This will trigger if other warnings are present.
|
|
with warnings.catch_warnings(record=True) as w:
|
|
sparsifier.step()
|
|
scheduler.step()
|
|
# Make sure there is no warning related to the base_data_scheduler
|
|
for warning in w:
|
|
fname = warning.filename
|
|
fname = "/".join(fname.split("/")[-5:])
|
|
assert (
|
|
fname
|
|
!= "torch/ao/sparsity/experimental/scheduler/data_scheduler/base_data_scheduler.py"
|
|
)
|
|
|
|
def test_step(self):
|
|
data_list, data_with_config, defaults = self._get_data()
|
|
sparsifier = self._get_sparsifier(data_list, data_with_config, defaults)
|
|
schedule_param = self._get_schedule_param()
|
|
scheduler = self._get_scheduler(sparsifier, schedule_param)
|
|
|
|
all_data = data_list + data_with_config
|
|
|
|
for some_data in all_data:
|
|
name, _, config = self._get_name_data_config(some_data, defaults)
|
|
assert (
|
|
sparsifier.data_groups[name][schedule_param] == config[schedule_param]
|
|
)
|
|
|
|
sparsifier.step()
|
|
scheduler.step()
|
|
|
|
for some_data in all_data:
|
|
name, _, config = self._get_name_data_config(some_data, defaults)
|
|
assert (
|
|
sparsifier.data_groups[name][schedule_param]
|
|
== config[schedule_param] * 0.5
|
|
)
|
|
|
|
# checking step count
|
|
step_cnt = 5
|
|
for _ in range(0, step_cnt):
|
|
sparsifier.step()
|
|
scheduler.step()
|
|
|
|
assert (
|
|
scheduler._step_count == step_cnt + 2
|
|
) # step_cnt + step above + 1 step in constructor
|
|
|
|
def test_state_dict(self):
|
|
data_list, data_with_config, defaults = self._get_data()
|
|
sparsifier = self._get_sparsifier(data_list, data_with_config, defaults)
|
|
schedule_param = self._get_schedule_param()
|
|
scheduler1 = self._get_scheduler(sparsifier, schedule_param)
|
|
|
|
sparsifier.step()
|
|
scheduler1.step()
|
|
|
|
scheduler2 = self._get_scheduler(sparsifier, schedule_param)
|
|
all_data = data_list + data_with_config
|
|
for some_data in all_data:
|
|
name, _, _ = self._get_name_data_config(some_data, defaults)
|
|
assert scheduler1.base_param[name] != scheduler2.base_param[name]
|
|
assert scheduler1._last_param[name] == scheduler2.base_param[name]
|
|
|
|
scheduler1_state = scheduler1.state_dict()
|
|
scheduler2.load_state_dict(scheduler1_state)
|
|
|
|
for some_data in all_data:
|
|
name, _, _ = self._get_name_data_config(some_data, defaults)
|
|
assert scheduler1.base_param[name] == scheduler2.base_param[name]
|
|
assert scheduler1._last_param[name] == scheduler2._last_param[name]
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise_on_run_directly("test/test_ao_sparsity.py")
|