mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
UFMT formatting on test/autograd test/ao test/cpp test/backends (#123369)
Partially addresses #123062 Ran lintrunner on - test/_test_bazel.py - test/ao - test/autograd test/backends test/benchmark_uitls test/conftest.py test/bottleneck_test test/cpp Pull Request resolved: https://github.com/pytorch/pytorch/pull/123369 Approved by: https://github.com/huydhn
This commit is contained in:
committed by
PyTorch MergeBot
parent
de7edeea25
commit
f71e368969
@ -1014,33 +1014,6 @@ exclude_patterns = [
|
||||
'test/_nvfuser/test_dynamo.py',
|
||||
'test/_nvfuser/test_python_frontend.py',
|
||||
'test/_nvfuser/test_torchscript.py',
|
||||
'test/_test_bazel.py',
|
||||
'test/ao/sparsity/test_activation_sparsifier.py',
|
||||
'test/ao/sparsity/test_composability.py',
|
||||
'test/ao/sparsity/test_data_scheduler.py',
|
||||
'test/ao/sparsity/test_data_sparsifier.py',
|
||||
'test/ao/sparsity/test_kernels.py',
|
||||
'test/ao/sparsity/test_parametrization.py',
|
||||
'test/ao/sparsity/test_qlinear_packed_params.py',
|
||||
'test/ao/sparsity/test_scheduler.py',
|
||||
'test/ao/sparsity/test_sparsifier.py',
|
||||
'test/ao/sparsity/test_sparsity_utils.py',
|
||||
'test/ao/sparsity/test_structured_sparsifier.py',
|
||||
'test/autograd/test_complex.py',
|
||||
'test/autograd/test_fallback.py',
|
||||
'test/autograd/test_functional.py',
|
||||
'test/backends/xeon/test_launch.py',
|
||||
'test/benchmark_utils/test_benchmark_utils.py',
|
||||
'test/bottleneck_test/test.py',
|
||||
'test/bottleneck_test/test_args.py',
|
||||
'test/bottleneck_test/test_cuda.py',
|
||||
'test/conftest.py',
|
||||
'test/cpp/__init__.py',
|
||||
'test/cpp/aot_inductor/test.py',
|
||||
'test/cpp/api/init_baseline.py',
|
||||
'test/cpp/api/optim_baseline.py',
|
||||
'test/cpp/jit/__init__.py',
|
||||
'test/cpp/jit/tests_setup.py',
|
||||
'test/cpp_api_parity/__init__.py',
|
||||
'test/cpp_api_parity/functional_impl_check.py',
|
||||
'test/cpp_api_parity/module_impl_check.py',
|
||||
|
@ -11,11 +11,14 @@ The name uses underscore `_test_bazel.py` to avoid globbing into other non-bazel
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def test_sum() -> None:
|
||||
assert torch.eq(torch.tensor([[1, 2, 3]]) + torch.tensor([[4, 5, 6]]), torch.tensor([[5, 7, 9]])).all()
|
||||
assert torch.eq(
|
||||
torch.tensor([[1, 2, 3]]) + torch.tensor([[4, 5, 6]]), torch.tensor([[5, 7, 9]])
|
||||
).all()
|
||||
|
||||
|
||||
def test_simple_compile_eager() -> None:
|
||||
|
||||
def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
a = torch.sin(x)
|
||||
b = torch.cos(y)
|
||||
|
@ -1,16 +1,21 @@
|
||||
# Owner(s): ["module: unknown"]
|
||||
|
||||
import copy
|
||||
from torch.testing._internal.common_utils import TestCase, skipIfTorchDynamo
|
||||
import logging
|
||||
import torch
|
||||
from torch.ao.pruning._experimental.activation_sparsifier.activation_sparsifier import ActivationSparsifier
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.ao.pruning.sparsifier.utils import module_to_fqn
|
||||
from typing import List
|
||||
|
||||
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO)
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.ao.pruning._experimental.activation_sparsifier.activation_sparsifier import (
|
||||
ActivationSparsifier,
|
||||
)
|
||||
from torch.ao.pruning.sparsifier.utils import module_to_fqn
|
||||
from torch.testing._internal.common_utils import skipIfTorchDynamo, TestCase
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
|
||||
)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
@ -45,7 +50,7 @@ class TestActivationSparsifier(TestCase):
|
||||
in the activation sparsifier
|
||||
"""
|
||||
sparsifier_defaults = activation_sparsifier.defaults
|
||||
combined_defaults = {**defaults, 'sparse_config': sparse_config}
|
||||
combined_defaults = {**defaults, "sparse_config": sparse_config}
|
||||
|
||||
# more keys are populated in activation sparsifier (eventhough they may be None)
|
||||
assert len(combined_defaults) <= len(activation_sparsifier.defaults)
|
||||
@ -54,7 +59,9 @@ class TestActivationSparsifier(TestCase):
|
||||
# all the keys in combined_defaults should be present in sparsifier defaults
|
||||
assert config == combined_defaults.get(key, None)
|
||||
|
||||
def _check_register_layer(self, activation_sparsifier, defaults, sparse_config, layer_args_list):
|
||||
def _check_register_layer(
|
||||
self, activation_sparsifier, defaults, sparse_config, layer_args_list
|
||||
):
|
||||
"""Checks if layers in the model are correctly mapped to it's arguments.
|
||||
|
||||
Args:
|
||||
@ -82,14 +89,14 @@ class TestActivationSparsifier(TestCase):
|
||||
sparse_config_actual = copy.deepcopy(sparse_config)
|
||||
sparse_config_actual.update(sparse_config_layer)
|
||||
|
||||
name = module_to_fqn(activation_sparsifier.model, layer_arg['layer'])
|
||||
name = module_to_fqn(activation_sparsifier.model, layer_arg["layer"])
|
||||
|
||||
assert data_groups[name]['sparse_config'] == sparse_config_actual
|
||||
assert data_groups[name]["sparse_config"] == sparse_config_actual
|
||||
|
||||
# assert the rest
|
||||
other_config_actual = copy.deepcopy(defaults)
|
||||
other_config_actual.update(layer_arg)
|
||||
other_config_actual.pop('layer')
|
||||
other_config_actual.pop("layer")
|
||||
|
||||
for key, value in other_config_actual.items():
|
||||
assert key in data_groups[name]
|
||||
@ -119,13 +126,15 @@ class TestActivationSparsifier(TestCase):
|
||||
data_agg_actual = data_list[0]
|
||||
model = activation_sparsifier.model
|
||||
layer_name = module_to_fqn(model, model.conv1)
|
||||
agg_fn = activation_sparsifier.data_groups[layer_name]['aggregate_fn']
|
||||
agg_fn = activation_sparsifier.data_groups[layer_name]["aggregate_fn"]
|
||||
|
||||
for i in range(1, len(data_list)):
|
||||
data_agg_actual = agg_fn(data_agg_actual, data_list[i])
|
||||
|
||||
assert 'data' in activation_sparsifier.data_groups[layer_name]
|
||||
assert torch.all(activation_sparsifier.data_groups[layer_name]['data'] == data_agg_actual)
|
||||
assert "data" in activation_sparsifier.data_groups[layer_name]
|
||||
assert torch.all(
|
||||
activation_sparsifier.data_groups[layer_name]["data"] == data_agg_actual
|
||||
)
|
||||
|
||||
return data_agg_actual
|
||||
|
||||
@ -144,11 +153,11 @@ class TestActivationSparsifier(TestCase):
|
||||
layer_name = module_to_fqn(model, model.conv1)
|
||||
assert layer_name is not None
|
||||
|
||||
reduce_fn = activation_sparsifier.data_groups[layer_name]['reduce_fn']
|
||||
reduce_fn = activation_sparsifier.data_groups[layer_name]["reduce_fn"]
|
||||
|
||||
data_reduce_actual = reduce_fn(data_agg_actual)
|
||||
mask_fn = activation_sparsifier.data_groups[layer_name]['mask_fn']
|
||||
sparse_config = activation_sparsifier.data_groups[layer_name]['sparse_config']
|
||||
mask_fn = activation_sparsifier.data_groups[layer_name]["mask_fn"]
|
||||
sparse_config = activation_sparsifier.data_groups[layer_name]["sparse_config"]
|
||||
mask_actual = mask_fn(data_reduce_actual, **sparse_config)
|
||||
|
||||
mask_model = activation_sparsifier.get_mask(layer_name)
|
||||
@ -156,8 +165,7 @@ class TestActivationSparsifier(TestCase):
|
||||
assert torch.all(mask_model == mask_actual)
|
||||
|
||||
for config in activation_sparsifier.data_groups.values():
|
||||
assert 'data' not in config
|
||||
|
||||
assert "data" not in config
|
||||
|
||||
def _check_squash_mask(self, activation_sparsifier, data):
|
||||
"""Makes sure that squash_mask() works as usual. Specifically, checks
|
||||
@ -172,11 +180,12 @@ class TestActivationSparsifier(TestCase):
|
||||
data (torch tensor)
|
||||
dummy batched data
|
||||
"""
|
||||
|
||||
# create a forward hook for checking output == layer(input * mask)
|
||||
def check_output(name):
|
||||
mask = activation_sparsifier.get_mask(name)
|
||||
features = activation_sparsifier.data_groups[name].get('features')
|
||||
feature_dim = activation_sparsifier.data_groups[name].get('feature_dim')
|
||||
features = activation_sparsifier.data_groups[name].get("features")
|
||||
feature_dim = activation_sparsifier.data_groups[name].get("feature_dim")
|
||||
|
||||
def hook(module, input, output):
|
||||
input_data = input[0]
|
||||
@ -184,20 +193,28 @@ class TestActivationSparsifier(TestCase):
|
||||
assert torch.all(mask * input_data == output)
|
||||
else:
|
||||
for feature_idx in range(0, len(features)):
|
||||
feature = torch.Tensor([features[feature_idx]], device=input_data.device).long()
|
||||
inp_data_feature = torch.index_select(input_data, feature_dim, feature)
|
||||
out_data_feature = torch.index_select(output, feature_dim, feature)
|
||||
feature = torch.Tensor(
|
||||
[features[feature_idx]], device=input_data.device
|
||||
).long()
|
||||
inp_data_feature = torch.index_select(
|
||||
input_data, feature_dim, feature
|
||||
)
|
||||
out_data_feature = torch.index_select(
|
||||
output, feature_dim, feature
|
||||
)
|
||||
|
||||
assert torch.all(
|
||||
mask[feature_idx] * inp_data_feature == out_data_feature
|
||||
)
|
||||
|
||||
assert torch.all(mask[feature_idx] * inp_data_feature == out_data_feature)
|
||||
return hook
|
||||
|
||||
for name, config in activation_sparsifier.data_groups.items():
|
||||
if 'identity' in name:
|
||||
config['layer'].register_forward_hook(check_output(name))
|
||||
if "identity" in name:
|
||||
config["layer"].register_forward_hook(check_output(name))
|
||||
|
||||
activation_sparsifier.model(data)
|
||||
|
||||
|
||||
def _check_state_dict(self, sparsifier1):
|
||||
"""Checks if loading and restoring of state_dict() works as expected.
|
||||
Basically, dumps the state of the sparsifier and loads it in the other sparsifier
|
||||
@ -222,8 +239,8 @@ class TestActivationSparsifier(TestCase):
|
||||
|
||||
for name, state in sparsifier2.state.items():
|
||||
assert name in sparsifier1.state
|
||||
mask1 = sparsifier1.state[name]['mask']
|
||||
mask2 = state['mask']
|
||||
mask1 = sparsifier1.state[name]["mask"]
|
||||
mask2 = state["mask"]
|
||||
|
||||
if mask1 is None:
|
||||
assert mask2 is None
|
||||
@ -237,8 +254,8 @@ class TestActivationSparsifier(TestCase):
|
||||
assert torch.all(mask1 == mask2)
|
||||
|
||||
# make sure that the state dict is stored as torch sparse
|
||||
for state in state_dict['state'].values():
|
||||
mask = state['mask']
|
||||
for state in state_dict["state"].values():
|
||||
mask = state["mask"]
|
||||
if mask is not None:
|
||||
if isinstance(mask, List):
|
||||
for idx in range(len(mask)):
|
||||
@ -252,8 +269,16 @@ class TestActivationSparsifier(TestCase):
|
||||
assert layer_name in dg2
|
||||
|
||||
# exclude hook and layer
|
||||
config1 = {key: value for key, value in config.items() if key not in ['hook', 'layer']}
|
||||
config2 = {key: value for key, value in dg2[layer_name].items() if key not in ['hook', 'layer']}
|
||||
config1 = {
|
||||
key: value
|
||||
for key, value in config.items()
|
||||
if key not in ["hook", "layer"]
|
||||
}
|
||||
config2 = {
|
||||
key: value
|
||||
for key, value in dg2[layer_name].items()
|
||||
if key not in ["hook", "layer"]
|
||||
}
|
||||
|
||||
assert config1 == config2
|
||||
|
||||
@ -263,6 +288,7 @@ class TestActivationSparsifier(TestCase):
|
||||
till squash_mask().
|
||||
The idea is to check that everything works as expected while in the workflow.
|
||||
"""
|
||||
|
||||
# defining aggregate, reduce and mask functions
|
||||
def agg_fn(x, y):
|
||||
return x + y
|
||||
@ -287,14 +313,9 @@ class TestActivationSparsifier(TestCase):
|
||||
|
||||
# Creating default function and sparse configs
|
||||
# default sparse_config
|
||||
sparse_config = {
|
||||
'sparsity_level': 0.5
|
||||
}
|
||||
sparse_config = {"sparsity_level": 0.5}
|
||||
|
||||
defaults = {
|
||||
'aggregate_fn': agg_fn,
|
||||
'reduce_fn': reduce_fn
|
||||
}
|
||||
defaults = {"aggregate_fn": agg_fn, "reduce_fn": reduce_fn}
|
||||
|
||||
# simulate the workflow
|
||||
# STEP 1: make data and activation sparsifier object
|
||||
@ -306,35 +327,41 @@ class TestActivationSparsifier(TestCase):
|
||||
|
||||
# STEP 2: Register some layers
|
||||
register_layer1_args = {
|
||||
'layer': model.conv1,
|
||||
'mask_fn': _vanilla_norm_sparsifier
|
||||
"layer": model.conv1,
|
||||
"mask_fn": _vanilla_norm_sparsifier,
|
||||
}
|
||||
sparse_config_layer1 = {'sparsity_level': 0.3}
|
||||
sparse_config_layer1 = {"sparsity_level": 0.3}
|
||||
|
||||
register_layer2_args = {
|
||||
'layer': model.linear1,
|
||||
'features': [0, 10, 234],
|
||||
'feature_dim': 1,
|
||||
'mask_fn': _vanilla_norm_sparsifier
|
||||
"layer": model.linear1,
|
||||
"features": [0, 10, 234],
|
||||
"feature_dim": 1,
|
||||
"mask_fn": _vanilla_norm_sparsifier,
|
||||
}
|
||||
sparse_config_layer2 = {'sparsity_level': 0.1}
|
||||
sparse_config_layer2 = {"sparsity_level": 0.1}
|
||||
|
||||
register_layer3_args = {
|
||||
'layer': model.identity1,
|
||||
'mask_fn': _vanilla_norm_sparsifier
|
||||
"layer": model.identity1,
|
||||
"mask_fn": _vanilla_norm_sparsifier,
|
||||
}
|
||||
sparse_config_layer3 = {'sparsity_level': 0.3}
|
||||
sparse_config_layer3 = {"sparsity_level": 0.3}
|
||||
|
||||
register_layer4_args = {
|
||||
'layer': model.identity2,
|
||||
'features': [0, 10, 20],
|
||||
'feature_dim': 1,
|
||||
'mask_fn': _vanilla_norm_sparsifier
|
||||
"layer": model.identity2,
|
||||
"features": [0, 10, 20],
|
||||
"feature_dim": 1,
|
||||
"mask_fn": _vanilla_norm_sparsifier,
|
||||
}
|
||||
sparse_config_layer4 = {'sparsity_level': 0.1}
|
||||
sparse_config_layer4 = {"sparsity_level": 0.1}
|
||||
|
||||
layer_args_list = [(register_layer1_args, sparse_config_layer1), (register_layer2_args, sparse_config_layer2)]
|
||||
layer_args_list += [(register_layer3_args, sparse_config_layer3), (register_layer4_args, sparse_config_layer4)]
|
||||
layer_args_list = [
|
||||
(register_layer1_args, sparse_config_layer1),
|
||||
(register_layer2_args, sparse_config_layer2),
|
||||
]
|
||||
layer_args_list += [
|
||||
(register_layer3_args, sparse_config_layer3),
|
||||
(register_layer4_args, sparse_config_layer4),
|
||||
]
|
||||
|
||||
# Registering..
|
||||
for layer_args in layer_args_list:
|
||||
@ -342,7 +369,9 @@ class TestActivationSparsifier(TestCase):
|
||||
activation_sparsifier.register_layer(**layer_arg, **sparse_config_layer)
|
||||
|
||||
# check if things are registered correctly
|
||||
self._check_register_layer(activation_sparsifier, defaults, sparse_config, layer_args_list)
|
||||
self._check_register_layer(
|
||||
activation_sparsifier, defaults, sparse_config, layer_args_list
|
||||
)
|
||||
|
||||
# check state_dict after registering and before model forward
|
||||
self._check_state_dict(activation_sparsifier)
|
||||
|
@ -7,9 +7,14 @@ import torch
|
||||
import torch.ao.quantization as tq
|
||||
from torch import nn
|
||||
from torch.ao import pruning
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx, convert_to_reference_fx, prepare_qat_fx
|
||||
from torch.ao.pruning import fqn_to_module
|
||||
from torch.ao.quantization.quantize_fx import (
|
||||
convert_fx,
|
||||
convert_to_reference_fx,
|
||||
prepare_fx,
|
||||
prepare_qat_fx,
|
||||
)
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
|
||||
@ -21,6 +26,7 @@ sparse_defaults = {
|
||||
"zeros_per_block": 4,
|
||||
}
|
||||
|
||||
|
||||
def _get_model_and_sparsifier_and_sparse_config(qconfig=None):
|
||||
model = nn.Sequential(
|
||||
nn.Linear(4, 4), # 0
|
||||
@ -40,7 +46,7 @@ def _get_model_and_sparsifier_and_sparse_config(qconfig=None):
|
||||
|
||||
sparse_config = [
|
||||
{
|
||||
"tensor_fqn": '5.weight',
|
||||
"tensor_fqn": "5.weight",
|
||||
"sparsity_level": 0.7,
|
||||
"sparse_block_shape": (1, 4),
|
||||
"zeros_per_block": 4,
|
||||
@ -49,15 +55,18 @@ def _get_model_and_sparsifier_and_sparse_config(qconfig=None):
|
||||
]
|
||||
return model, sparsifier, sparse_config
|
||||
|
||||
|
||||
def _squash_mask_calibrate_and_convert(model, sparsifier, input):
|
||||
sparsifier.step()
|
||||
sparsifier.squash_mask()
|
||||
model(input)
|
||||
tq.convert(model, inplace=True)
|
||||
|
||||
|
||||
def _calculate_sparsity(tensor):
|
||||
return ((tensor == 0).sum() / tensor.numel()).item()
|
||||
|
||||
|
||||
# This series of tests are to check the composability goals for sparsity and quantization. Namely
|
||||
# that performing quantization and sparsity model manipulations in various orderings
|
||||
# does not cause problems
|
||||
@ -70,7 +79,9 @@ class TestComposability(TestCase):
|
||||
mod,
|
||||
sparsifier,
|
||||
sparse_config,
|
||||
) = _get_model_and_sparsifier_and_sparse_config(tq.get_default_qconfig("fbgemm"))
|
||||
) = _get_model_and_sparsifier_and_sparse_config(
|
||||
tq.get_default_qconfig("fbgemm")
|
||||
)
|
||||
|
||||
tq.prepare(mod, inplace=True)
|
||||
sparsifier.prepare(mod, config=sparse_config)
|
||||
@ -81,9 +92,7 @@ class TestComposability(TestCase):
|
||||
# check that correct observers were inserted
|
||||
self.assertTrue(hasattr(mod[5], "activation_post_process"))
|
||||
|
||||
_squash_mask_calibrate_and_convert(
|
||||
mod, sparsifier, torch.randn(1, 4, 4, 4)
|
||||
)
|
||||
_squash_mask_calibrate_and_convert(mod, sparsifier, torch.randn(1, 4, 4, 4))
|
||||
|
||||
# check that final module is the expected quantized module and that the model runs
|
||||
self.assertTrue(isinstance(mod[5], torch.ao.nn.quantized.Linear))
|
||||
@ -99,7 +108,9 @@ class TestComposability(TestCase):
|
||||
mod,
|
||||
sparsifier,
|
||||
sparse_config,
|
||||
) = _get_model_and_sparsifier_and_sparse_config(tq.get_default_qconfig("fbgemm"))
|
||||
) = _get_model_and_sparsifier_and_sparse_config(
|
||||
tq.get_default_qconfig("fbgemm")
|
||||
)
|
||||
|
||||
sparsifier.prepare(mod, config=sparse_config)
|
||||
tq.prepare(mod, inplace=True)
|
||||
@ -113,9 +124,7 @@ class TestComposability(TestCase):
|
||||
# occurred successfully
|
||||
self.assertTrue(hasattr(mod[5], "activation_post_process"))
|
||||
|
||||
_squash_mask_calibrate_and_convert(
|
||||
mod, sparsifier, torch.randn(1, 4, 4, 4)
|
||||
)
|
||||
_squash_mask_calibrate_and_convert(mod, sparsifier, torch.randn(1, 4, 4, 4))
|
||||
|
||||
# check that final module is the expected quantized module and that the model runs
|
||||
self.assertTrue(isinstance(mod[5], torch.ao.nn.quantized.Linear))
|
||||
@ -130,7 +139,9 @@ class TestComposability(TestCase):
|
||||
mod,
|
||||
sparsifier,
|
||||
sparse_config,
|
||||
) = _get_model_and_sparsifier_and_sparse_config(tq.get_default_qconfig("fbgemm"))
|
||||
) = _get_model_and_sparsifier_and_sparse_config(
|
||||
tq.get_default_qconfig("fbgemm")
|
||||
)
|
||||
|
||||
sparsifier.prepare(mod, config=sparse_config)
|
||||
tq.prepare(mod, inplace=True)
|
||||
@ -168,7 +179,9 @@ class TestComposability(TestCase):
|
||||
mod,
|
||||
sparsifier,
|
||||
sparse_config,
|
||||
) = _get_model_and_sparsifier_and_sparse_config(tq.get_default_qconfig("fbgemm"))
|
||||
) = _get_model_and_sparsifier_and_sparse_config(
|
||||
tq.get_default_qconfig("fbgemm")
|
||||
)
|
||||
sparsifier.prepare(mod, config=sparse_config)
|
||||
tq.fuse_modules(mod, [["5", "6"]], inplace=True)
|
||||
mod[5].qconfig = tq.get_default_qconfig("fbgemm")
|
||||
@ -182,9 +195,7 @@ class TestComposability(TestCase):
|
||||
# check that correct observers were inserted and that matching
|
||||
# occurred successfully
|
||||
self.assertTrue(hasattr(mod[5], "activation_post_process"))
|
||||
_squash_mask_calibrate_and_convert(
|
||||
mod, sparsifier, torch.randn(1, 4, 4, 4)
|
||||
)
|
||||
_squash_mask_calibrate_and_convert(mod, sparsifier, torch.randn(1, 4, 4, 4))
|
||||
|
||||
# check that final module is the expected quantized module and that the model runs
|
||||
self.assertTrue(isinstance(mod[5], torch.ao.nn.intrinsic.quantized.LinearReLU))
|
||||
@ -197,7 +208,9 @@ class TestComposability(TestCase):
|
||||
mod,
|
||||
sparsifier,
|
||||
_,
|
||||
) = _get_model_and_sparsifier_and_sparse_config(tq.get_default_qconfig("fbgemm"))
|
||||
) = _get_model_and_sparsifier_and_sparse_config(
|
||||
tq.get_default_qconfig("fbgemm")
|
||||
)
|
||||
tq.fuse_modules(mod, [["5", "6"]], inplace=True)
|
||||
|
||||
# its absolutely broken by fusion but will still work if you put the correct fqn in
|
||||
@ -261,9 +274,7 @@ class TestComposability(TestCase):
|
||||
# occurred successfully
|
||||
self.assertTrue(hasattr(mod[5], "activation_post_process"))
|
||||
self.assertTrue(isinstance(mod[5], torch.ao.nn.qat.Linear))
|
||||
_squash_mask_calibrate_and_convert(
|
||||
mod, sparsifier, torch.randn(1, 4, 4, 4)
|
||||
)
|
||||
_squash_mask_calibrate_and_convert(mod, sparsifier, torch.randn(1, 4, 4, 4))
|
||||
# check that final module is the expected quantized module and that the model runs
|
||||
self.assertTrue(isinstance(mod[5], torch.ao.nn.quantized.Linear))
|
||||
self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4]))
|
||||
@ -301,9 +312,7 @@ class TestComposability(TestCase):
|
||||
self.assertTrue(hasattr(mod[5], "activation_post_process"))
|
||||
self.assertTrue(isinstance(mod[5], torch.ao.nn.qat.Linear))
|
||||
|
||||
_squash_mask_calibrate_and_convert(
|
||||
mod, sparsifier, torch.randn(1, 4, 4, 4)
|
||||
)
|
||||
_squash_mask_calibrate_and_convert(mod, sparsifier, torch.randn(1, 4, 4, 4))
|
||||
|
||||
# check that final module is the expected quantized module and that the model runs
|
||||
self.assertTrue(isinstance(mod[5], torch.ao.nn.quantized.Linear))
|
||||
@ -313,6 +322,7 @@ class TestComposability(TestCase):
|
||||
cur_sparsity = _calculate_sparsity(mod[5]._weight_bias()[0])
|
||||
self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
|
||||
|
||||
|
||||
def _module_has_activation_post_process(model, fqn_of_module):
|
||||
for node in model.graph.nodes:
|
||||
# look for an observer whose arg is the target module
|
||||
@ -321,10 +331,12 @@ def _module_has_activation_post_process(model, fqn_of_module):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class TestFxComposability(TestCase):
|
||||
r"""This series of tests checks that various steps of the quantization and sparsity flow
|
||||
compose cleanly despite variation in sequencing.
|
||||
"""
|
||||
|
||||
def test_q_prep_fx_before_s_prep(self):
|
||||
r"""
|
||||
This test checks that the ordering of prepare_fx -> sparse prepare -> convert_fx
|
||||
@ -340,10 +352,11 @@ class TestFxComposability(TestCase):
|
||||
|
||||
example = torch.randn(1, 4, 4, 4)
|
||||
qconfig = tq.get_default_qconfig("fbgemm")
|
||||
qconfig_mapping = tq.QConfigMapping() \
|
||||
.set_module_name("4", qconfig) \
|
||||
qconfig_mapping = (
|
||||
tq.QConfigMapping()
|
||||
.set_module_name("4", qconfig)
|
||||
.set_module_name("5", qconfig)
|
||||
|
||||
)
|
||||
|
||||
mod = prepare_fx(mod, qconfig_mapping, (example,))
|
||||
|
||||
@ -374,7 +387,11 @@ class TestFxComposability(TestCase):
|
||||
mod = convert_fx(mod)
|
||||
|
||||
# check that final module is the expected quantized module and that the model runs
|
||||
self.assertTrue(isinstance(fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.quantized.LinearReLU))
|
||||
self.assertTrue(
|
||||
isinstance(
|
||||
fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.quantized.LinearReLU
|
||||
)
|
||||
)
|
||||
self.assertEqual(mod(example).shape, torch.Size([1, 4, 4, 4]))
|
||||
|
||||
# check that module was actually sparsified
|
||||
@ -399,9 +416,11 @@ class TestFxComposability(TestCase):
|
||||
|
||||
example = torch.randn(1, 4, 4, 4)
|
||||
qconfig = tq.get_default_qconfig("fbgemm")
|
||||
qconfig_mapping = tq.QConfigMapping() \
|
||||
.set_module_name("4", qconfig) \
|
||||
qconfig_mapping = (
|
||||
tq.QConfigMapping()
|
||||
.set_module_name("4", qconfig)
|
||||
.set_module_name("5", qconfig)
|
||||
)
|
||||
|
||||
mod = prepare_fx(mod, qconfig_mapping, (example,))
|
||||
|
||||
@ -432,9 +451,15 @@ class TestFxComposability(TestCase):
|
||||
mod = convert_to_reference_fx(mod)
|
||||
|
||||
# check that final module is the expected quantized module and that the model runs
|
||||
self.assertTrue(isinstance(fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.LinearReLU))
|
||||
self.assertTrue(
|
||||
isinstance(fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.LinearReLU)
|
||||
)
|
||||
self.assertEqual(mod(example).shape, torch.Size([1, 4, 4, 4]))
|
||||
self.assertTrue(isinstance(fqn_to_module(mod, "5.0"), torch.ao.nn.quantized.reference.Linear))
|
||||
self.assertTrue(
|
||||
isinstance(
|
||||
fqn_to_module(mod, "5.0"), torch.ao.nn.quantized.reference.Linear
|
||||
)
|
||||
)
|
||||
|
||||
# check that module was actually sparsified
|
||||
cur_sparsity = _calculate_sparsity(fqn_to_module(mod, "5.0.weight"))
|
||||
@ -459,9 +484,11 @@ class TestFxComposability(TestCase):
|
||||
|
||||
example = torch.randn(1, 4, 4, 4)
|
||||
qconfig = tq.get_default_qconfig("fbgemm")
|
||||
qconfig_mapping = tq.QConfigMapping() \
|
||||
.set_module_name("4", qconfig) \
|
||||
qconfig_mapping = (
|
||||
tq.QConfigMapping()
|
||||
.set_module_name("4", qconfig)
|
||||
.set_module_name("5", qconfig)
|
||||
)
|
||||
mod = prepare_fx(mod, qconfig_mapping, (example,))
|
||||
|
||||
# check that correct modules had parametrizations added and
|
||||
@ -478,7 +505,11 @@ class TestFxComposability(TestCase):
|
||||
mod = convert_fx(mod)
|
||||
|
||||
# check that final module is the expected quantized module and that the model runs
|
||||
self.assertTrue(isinstance(fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.quantized.LinearReLU))
|
||||
self.assertTrue(
|
||||
isinstance(
|
||||
fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.quantized.LinearReLU
|
||||
)
|
||||
)
|
||||
self.assertEqual(mod(example).shape, torch.Size([1, 4, 4, 4]))
|
||||
|
||||
# check that module was actually sparsified
|
||||
@ -504,16 +535,20 @@ class TestFxComposability(TestCase):
|
||||
|
||||
example = torch.randn(1, 4, 4, 4)
|
||||
qconfig = tq.get_default_qat_qconfig("fbgemm")
|
||||
qconfig_mapping = tq.QConfigMapping() \
|
||||
.set_module_name("4", qconfig) \
|
||||
qconfig_mapping = (
|
||||
tq.QConfigMapping()
|
||||
.set_module_name("4", qconfig)
|
||||
.set_module_name("5", qconfig)
|
||||
)
|
||||
mod = prepare_qat_fx(mod, qconfig_mapping, (example,))
|
||||
|
||||
# check that correct modules had parametrizations added and
|
||||
# that none were lost during prepare
|
||||
self.assertTrue(hasattr(fqn_to_module(mod, "0.0"), "parametrizations"))
|
||||
self.assertTrue(hasattr(fqn_to_module(mod, "5"), "parametrizations"))
|
||||
self.assertTrue(isinstance(fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.qat.LinearReLU))
|
||||
self.assertTrue(
|
||||
isinstance(fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.qat.LinearReLU)
|
||||
)
|
||||
|
||||
# check that correct observers were inserted and that matching
|
||||
# occurred successfully
|
||||
@ -524,7 +559,11 @@ class TestFxComposability(TestCase):
|
||||
mod = convert_fx(mod)
|
||||
|
||||
# check that final module is the expected quantized module and that the model runs
|
||||
self.assertTrue(isinstance(fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.quantized.LinearReLU))
|
||||
self.assertTrue(
|
||||
isinstance(
|
||||
fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.quantized.LinearReLU
|
||||
)
|
||||
)
|
||||
self.assertEqual(mod(example).shape, torch.Size([1, 4, 4, 4]))
|
||||
|
||||
# check that module was actually sparsified
|
||||
@ -550,9 +589,11 @@ class TestFxComposability(TestCase):
|
||||
|
||||
example = torch.randn(1, 4, 4, 4)
|
||||
qconfig = tq.get_default_qconfig("fbgemm")
|
||||
qconfig_mapping = tq.QConfigMapping() \
|
||||
.set_module_name("4", qconfig) \
|
||||
qconfig_mapping = (
|
||||
tq.QConfigMapping()
|
||||
.set_module_name("4", qconfig)
|
||||
.set_module_name("5", qconfig)
|
||||
)
|
||||
mod = prepare_fx(mod, qconfig_mapping, (example,))
|
||||
|
||||
# check that correct modules had parametrizations added and
|
||||
@ -569,9 +610,15 @@ class TestFxComposability(TestCase):
|
||||
mod = convert_to_reference_fx(mod)
|
||||
|
||||
# check that final module is the expected quantized module and that the model runs
|
||||
self.assertTrue(isinstance(fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.LinearReLU))
|
||||
self.assertTrue(
|
||||
isinstance(fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.LinearReLU)
|
||||
)
|
||||
self.assertEqual(mod(example).shape, torch.Size([1, 4, 4, 4]))
|
||||
self.assertTrue(isinstance(fqn_to_module(mod, "5.0"), torch.ao.nn.quantized.reference.Linear))
|
||||
self.assertTrue(
|
||||
isinstance(
|
||||
fqn_to_module(mod, "5.0"), torch.ao.nn.quantized.reference.Linear
|
||||
)
|
||||
)
|
||||
|
||||
# check that module was actually sparsified
|
||||
cur_sparsity = _calculate_sparsity(fqn_to_module(mod, "5.0.weight"))
|
||||
|
@ -1,17 +1,20 @@
|
||||
# Owner(s): ["module: unknown"]
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import warnings
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
from torch import nn
|
||||
import torch
|
||||
from typing import Tuple
|
||||
import copy
|
||||
|
||||
from torch.ao.pruning._experimental.data_sparsifier import DataNormSparsifier
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.ao.pruning._experimental.data_scheduler import BaseDataScheduler
|
||||
|
||||
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO)
|
||||
from torch.ao.pruning._experimental.data_sparsifier import DataNormSparsifier
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
|
||||
)
|
||||
|
||||
|
||||
class ImplementedDataScheduler(BaseDataScheduler):
|
||||
@ -20,27 +23,32 @@ class ImplementedDataScheduler(BaseDataScheduler):
|
||||
|
||||
def get_schedule_param(self):
|
||||
if self.last_epoch > 0:
|
||||
return {name: config['sparsity_level'] * 0.5
|
||||
for name, config in self.data_sparsifier.data_groups.items()}
|
||||
return {
|
||||
name: config["sparsity_level"] * 0.5
|
||||
for name, config in self.data_sparsifier.data_groups.items()
|
||||
}
|
||||
else:
|
||||
return self.base_param
|
||||
|
||||
|
||||
class TestBaseDataScheduler(TestCase):
|
||||
def _get_data(self):
|
||||
tensor1, param1, emb1 = torch.randn(5, 5), nn.Parameter(torch.randn(10, 10)), nn.Embedding(50, 5)
|
||||
data_list = [
|
||||
('tensor1', tensor1), ('param1', param1), ('emb1', emb1)
|
||||
]
|
||||
tensor1, param1, emb1 = (
|
||||
torch.randn(5, 5),
|
||||
nn.Parameter(torch.randn(10, 10)),
|
||||
nn.Embedding(50, 5),
|
||||
)
|
||||
data_list = [("tensor1", tensor1), ("param1", param1), ("emb1", emb1)]
|
||||
defaults = {
|
||||
'sparsity_level': 0.7,
|
||||
'sparse_block_shape': (1, 4),
|
||||
'zeros_per_block': 2
|
||||
"sparsity_level": 0.7,
|
||||
"sparse_block_shape": (1, 4),
|
||||
"zeros_per_block": 2,
|
||||
}
|
||||
data_with_config = [
|
||||
{
|
||||
'name': 'tensor2', 'data': torch.randn(4, 4),
|
||||
'config': {'sparsity_level': 0.3}
|
||||
"name": "tensor2",
|
||||
"data": torch.randn(4, 4),
|
||||
"config": {"sparsity_level": 0.3},
|
||||
}
|
||||
]
|
||||
return data_list, data_with_config, defaults
|
||||
@ -48,7 +56,11 @@ class TestBaseDataScheduler(TestCase):
|
||||
def _get_sparsifier(self, data_list, data_with_config, defaults):
|
||||
sparsifier = DataNormSparsifier(data_list, **defaults)
|
||||
for data_config_dict in data_with_config:
|
||||
name, data, config = data_config_dict['name'], data_config_dict['data'], data_config_dict['config']
|
||||
name, data, config = (
|
||||
data_config_dict["name"],
|
||||
data_config_dict["data"],
|
||||
data_config_dict["config"],
|
||||
)
|
||||
sparsifier.add_data(name=name, data=data, **config)
|
||||
return sparsifier
|
||||
|
||||
@ -57,7 +69,7 @@ class TestBaseDataScheduler(TestCase):
|
||||
return scheduler
|
||||
|
||||
def _get_schedule_param(self):
|
||||
return 'sparsity_level'
|
||||
return "sparsity_level"
|
||||
|
||||
def _get_name_data_config(self, some_data, defaults):
|
||||
config = copy.deepcopy(defaults)
|
||||
@ -66,7 +78,11 @@ class TestBaseDataScheduler(TestCase):
|
||||
name, data = some_data
|
||||
else:
|
||||
# dealing with data_with_config
|
||||
name, data, new_config = some_data['name'], some_data['data'], some_data['config']
|
||||
name, data, new_config = (
|
||||
some_data["name"],
|
||||
some_data["data"],
|
||||
some_data["config"],
|
||||
)
|
||||
config.update(new_config)
|
||||
return name, data, config
|
||||
|
||||
@ -102,8 +118,11 @@ class TestBaseDataScheduler(TestCase):
|
||||
# Make sure there is no warning related to the base_data_scheduler
|
||||
for warning in w:
|
||||
fname = warning.filename
|
||||
fname = '/'.join(fname.split('/')[-5:])
|
||||
assert fname != 'torch/ao/sparsity/experimental/scheduler/data_scheduler/base_data_scheduler.py'
|
||||
fname = "/".join(fname.split("/")[-5:])
|
||||
assert (
|
||||
fname
|
||||
!= "torch/ao/sparsity/experimental/scheduler/data_scheduler/base_data_scheduler.py"
|
||||
)
|
||||
|
||||
def test_step(self):
|
||||
data_list, data_with_config, defaults = self._get_data()
|
||||
@ -115,14 +134,19 @@ class TestBaseDataScheduler(TestCase):
|
||||
|
||||
for some_data in all_data:
|
||||
name, _, config = self._get_name_data_config(some_data, defaults)
|
||||
assert sparsifier.data_groups[name][schedule_param] == config[schedule_param]
|
||||
assert (
|
||||
sparsifier.data_groups[name][schedule_param] == config[schedule_param]
|
||||
)
|
||||
|
||||
sparsifier.step()
|
||||
scheduler.step()
|
||||
|
||||
for some_data in all_data:
|
||||
name, _, config = self._get_name_data_config(some_data, defaults)
|
||||
assert sparsifier.data_groups[name][schedule_param] == config[schedule_param] * 0.5
|
||||
assert (
|
||||
sparsifier.data_groups[name][schedule_param]
|
||||
== config[schedule_param] * 0.5
|
||||
)
|
||||
|
||||
# checking step count
|
||||
step_cnt = 5
|
||||
@ -130,7 +154,9 @@ class TestBaseDataScheduler(TestCase):
|
||||
sparsifier.step()
|
||||
scheduler.step()
|
||||
|
||||
assert scheduler._step_count == step_cnt + 2 # step_cnt + step above + 1 step in constructor
|
||||
assert (
|
||||
scheduler._step_count == step_cnt + 2
|
||||
) # step_cnt + step above + 1 step in constructor
|
||||
|
||||
def test_state_dict(self):
|
||||
data_list, data_with_config, defaults = self._get_data()
|
||||
|
@ -1,20 +1,28 @@
|
||||
# Owner(s): ["module: unknown"]
|
||||
|
||||
import copy
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from torch.ao.pruning._experimental.data_sparsifier import (
|
||||
BaseDataSparsifier,
|
||||
DataNormSparsifier,
|
||||
)
|
||||
from torch.ao.pruning._experimental.data_sparsifier.quantization_utils import (
|
||||
post_training_sparse_quantize,
|
||||
)
|
||||
from torch.nn.utils.parametrize import is_parametrized
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
|
||||
from typing import Tuple
|
||||
from torch import nn
|
||||
import itertools
|
||||
import math
|
||||
import copy
|
||||
|
||||
from torch.ao.pruning._experimental.data_sparsifier import BaseDataSparsifier, DataNormSparsifier
|
||||
from torch.ao.pruning._experimental.data_sparsifier.quantization_utils import post_training_sparse_quantize
|
||||
|
||||
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO)
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
|
||||
)
|
||||
|
||||
|
||||
class ImplementedSparsifier(BaseDataSparsifier):
|
||||
@ -25,17 +33,18 @@ class ImplementedSparsifier(BaseDataSparsifier):
|
||||
mask = self.get_mask(name)
|
||||
mask[0] = 0
|
||||
linear_state = self.state[name]
|
||||
linear_state['step_count'] = linear_state.get('step_count', 0) + 1
|
||||
linear_state["step_count"] = linear_state.get("step_count", 0) + 1
|
||||
|
||||
|
||||
class _BaseDataSparsiferTestCase(TestCase):
|
||||
r"""This helper test class takes in any supported type of and runs some tests.
|
||||
The user is required to pass in the data that needs to sparsified and the
|
||||
runner will run some tests that needs to be passed in order for the data
|
||||
type to be supported.
|
||||
TODO: Change the structure by creating a separate test case class for each
|
||||
member function
|
||||
The user is required to pass in the data that needs to sparsified and the
|
||||
runner will run some tests that needs to be passed in order for the data
|
||||
type to be supported.
|
||||
TODO: Change the structure by creating a separate test case class for each
|
||||
member function
|
||||
"""
|
||||
|
||||
def run_all_checks(self, data_list, data_with_config, defaults):
|
||||
self.check_constructor(data_list, data_with_config, defaults)
|
||||
self.check_squash_mask(data_list, data_with_config, defaults)
|
||||
@ -52,32 +61,49 @@ class _BaseDataSparsiferTestCase(TestCase):
|
||||
config = defaults
|
||||
else:
|
||||
# dealing with data_with_config
|
||||
name, data, config = some_data['name'], some_data['data'], some_data['config']
|
||||
name, data, config = (
|
||||
some_data["name"],
|
||||
some_data["data"],
|
||||
some_data["config"],
|
||||
)
|
||||
return name, data, config
|
||||
|
||||
@staticmethod
|
||||
def _make_sparsifier(data_list, data_with_config, defaults,
|
||||
sparsifier_type=None, sparsifier_kwargs=None):
|
||||
def _make_sparsifier(
|
||||
data_list,
|
||||
data_with_config,
|
||||
defaults,
|
||||
sparsifier_type=None,
|
||||
sparsifier_kwargs=None,
|
||||
):
|
||||
if sparsifier_type is None:
|
||||
sparsifier = ImplementedSparsifier(data_list=data_list, **defaults)
|
||||
else:
|
||||
kwargs = copy.deepcopy(defaults)
|
||||
kwargs.update(sparsifier_kwargs)
|
||||
kwargs['data_list'] = data_list
|
||||
kwargs["data_list"] = data_list
|
||||
sparsifier = sparsifier_type(**kwargs)
|
||||
assert len(sparsifier.data_groups) == len(data_list)
|
||||
for data_config_dict in data_with_config:
|
||||
name, data, config = data_config_dict['name'], data_config_dict['data'], data_config_dict['config']
|
||||
name, data, config = (
|
||||
data_config_dict["name"],
|
||||
data_config_dict["data"],
|
||||
data_config_dict["config"],
|
||||
)
|
||||
sparsifier.add_data(name=name, data=data, **config)
|
||||
return sparsifier
|
||||
|
||||
def check_constructor(self, data_list, data_with_config, defaults, **kwargs):
|
||||
sparsifier = self._make_sparsifier(data_list, data_with_config, defaults=defaults, **kwargs)
|
||||
self.assertEqual(len(sparsifier.data_groups),
|
||||
len(data_list) + len(data_with_config),
|
||||
msg="Sparsifier data groups don't match the input "
|
||||
f"({len(sparsifier.data_groups)} vs. "
|
||||
f"{len(data_list) + len(data_with_config)}).")
|
||||
sparsifier = self._make_sparsifier(
|
||||
data_list, data_with_config, defaults=defaults, **kwargs
|
||||
)
|
||||
self.assertEqual(
|
||||
len(sparsifier.data_groups),
|
||||
len(data_list) + len(data_with_config),
|
||||
msg="Sparsifier data groups don't match the input "
|
||||
f"({len(sparsifier.data_groups)} vs. "
|
||||
f"{len(data_list) + len(data_with_config)}).",
|
||||
)
|
||||
|
||||
all_data = data_list + data_with_config
|
||||
|
||||
@ -87,7 +113,9 @@ class _BaseDataSparsiferTestCase(TestCase):
|
||||
self.assertEqual(sparsifier.data_groups[name], config)
|
||||
|
||||
def check_step(self, data_list, data_with_config, defaults, **kwargs):
|
||||
sparsifier = self._make_sparsifier(data_list, data_with_config, defaults=defaults, **kwargs)
|
||||
sparsifier = self._make_sparsifier(
|
||||
data_list, data_with_config, defaults=defaults, **kwargs
|
||||
)
|
||||
all_data = data_list + data_with_config
|
||||
|
||||
# Check data and mask before doing the step
|
||||
@ -114,11 +142,13 @@ class _BaseDataSparsiferTestCase(TestCase):
|
||||
self.assertEqualBroadcasting(sparsified_data[0], 0)
|
||||
self.assertEqual(original_data, data)
|
||||
self.assertEqualBroadcasting(mask[0], 0)
|
||||
assert 'step_count' in sparsifier.state[name]
|
||||
assert sparsifier.state[name]['step_count'] == 3
|
||||
assert "step_count" in sparsifier.state[name]
|
||||
assert sparsifier.state[name]["step_count"] == 3
|
||||
|
||||
def check_squash_mask(self, data_list, data_with_config, defaults, **kwargs):
|
||||
sparsifier = self._make_sparsifier(data_list, data_with_config, defaults=defaults, **kwargs)
|
||||
sparsifier = self._make_sparsifier(
|
||||
data_list, data_with_config, defaults=defaults, **kwargs
|
||||
)
|
||||
all_data = data_list + data_with_config
|
||||
for some_data in all_data:
|
||||
name, _, _ = self._get_name_data_config(some_data)
|
||||
@ -129,15 +159,21 @@ class _BaseDataSparsiferTestCase(TestCase):
|
||||
|
||||
for some_data in all_data:
|
||||
name, _, _ = self._get_name_data_config(some_data)
|
||||
assert not is_parametrized(sparsifier._container, name) # not parametrized anymore
|
||||
assert not is_parametrized(
|
||||
sparsifier._container, name
|
||||
) # not parametrized anymore
|
||||
with self.assertRaises(ValueError):
|
||||
sparsifier.get_data(name, return_original=True)
|
||||
|
||||
def check_add_data(self, data_list, data_with_config, defaults, **kwargs):
|
||||
sparsifier = self._make_sparsifier(data_list, data_with_config, defaults=defaults, **kwargs)
|
||||
sparsifier = self._make_sparsifier(
|
||||
data_list, data_with_config, defaults=defaults, **kwargs
|
||||
)
|
||||
all_data = data_list + data_with_config
|
||||
for some_data in all_data:
|
||||
name1, data1, config = self._get_name_data_config(some_data, defaults=defaults)
|
||||
name1, data1, config = self._get_name_data_config(
|
||||
some_data, defaults=defaults
|
||||
)
|
||||
data1 = sparsifier._extract_weight(data1)
|
||||
data1_old = copy.deepcopy(data1)
|
||||
assert torch.all(data1 == sparsifier.get_data(name=name1))
|
||||
@ -145,18 +181,28 @@ class _BaseDataSparsiferTestCase(TestCase):
|
||||
sparsifier.step()
|
||||
mask = sparsifier.get_mask(name1)
|
||||
|
||||
data2 = torch.randn(data1.shape) # add another data with the same shape as original data
|
||||
data2 = torch.randn(
|
||||
data1.shape
|
||||
) # add another data with the same shape as original data
|
||||
sparsifier.add_data(name=name1, data=data2)
|
||||
assert torch.all(data2 == sparsifier.get_data(name=name1))
|
||||
|
||||
assert torch.all(sparsifier.get_mask(name1) == mask) # mask should not change
|
||||
assert torch.all(
|
||||
sparsifier.get_mask(name1) == mask
|
||||
) # mask should not change
|
||||
assert torch.all(data1_old == data1)
|
||||
|
||||
assert sparsifier.data_groups[name1] == config # if replaced old_config should match new config
|
||||
assert (
|
||||
sparsifier.data_groups[name1] == config
|
||||
) # if replaced old_config should match new config
|
||||
|
||||
def check_state_dict(self, data_list, data_with_config, defaults, **kwargs):
|
||||
sparsifier1 = self._make_sparsifier(data_list, data_with_config, defaults=defaults, **kwargs)
|
||||
sparsifier2 = self._make_sparsifier(data_list=[data_list[0]], data_with_config=[], defaults=defaults, **kwargs)
|
||||
sparsifier1 = self._make_sparsifier(
|
||||
data_list, data_with_config, defaults=defaults, **kwargs
|
||||
)
|
||||
sparsifier2 = self._make_sparsifier(
|
||||
data_list=[data_list[0]], data_with_config=[], defaults=defaults, **kwargs
|
||||
)
|
||||
sparsifier1.step()
|
||||
|
||||
state_dict1 = sparsifier1.state_dict()
|
||||
@ -169,15 +215,17 @@ class _BaseDataSparsiferTestCase(TestCase):
|
||||
assert len(sparsifier1.state) == len(sparsifier2.state)
|
||||
assert len(sparsifier1.data_groups) == len(sparsifier2.data_groups)
|
||||
|
||||
state1 = state_dict1['state']
|
||||
state1 = state_dict1["state"]
|
||||
for name in state1.keys():
|
||||
# compare mask
|
||||
assert name in sparsifier2.state
|
||||
assert 'mask' in sparsifier2.state[name]
|
||||
assert 'mask' in sparsifier1.state[name]
|
||||
mask1, mask2 = state1[name]['mask'], sparsifier2.state[name]['mask']
|
||||
assert "mask" in sparsifier2.state[name]
|
||||
assert "mask" in sparsifier1.state[name]
|
||||
mask1, mask2 = state1[name]["mask"], sparsifier2.state[name]["mask"]
|
||||
assert mask1.is_sparse and not mask2.is_sparse
|
||||
assert torch.all(mask1.to_dense() == mask2) # mask1 is stored as sparse coo now
|
||||
assert torch.all(
|
||||
mask1.to_dense() == mask2
|
||||
) # mask1 is stored as sparse coo now
|
||||
|
||||
# compare data_groups
|
||||
dg1, dg2 = sparsifier1.data_groups, sparsifier2.data_groups
|
||||
@ -187,12 +235,14 @@ class _BaseDataSparsiferTestCase(TestCase):
|
||||
# compare container
|
||||
container1, container2 = sparsifier1._container, sparsifier2._container
|
||||
assert torch.all(getattr(container1, name) == getattr(container2, name))
|
||||
assert is_parametrized(container1, name) == is_parametrized(container2, name)
|
||||
assert is_parametrized(container1, name) == is_parametrized(
|
||||
container2, name
|
||||
)
|
||||
if is_parametrized(container1, name):
|
||||
param1 = getattr(container1.parametrizations, name)[0]
|
||||
param2 = getattr(container2.parametrizations, name)[0]
|
||||
assert hasattr(param1, 'mask')
|
||||
assert hasattr(param2, 'mask')
|
||||
assert hasattr(param1, "mask")
|
||||
assert hasattr(param2, "mask")
|
||||
self.assertEqual(param1.__dict__, param2.__dict__)
|
||||
|
||||
def check_memory_reference(self, data_list, data_with_config, defaults, **kwargs):
|
||||
@ -204,29 +254,35 @@ class _BaseDataSparsiferTestCase(TestCase):
|
||||
|
||||
This test modifies the data and asserts that data in the sparsifier is changed as well
|
||||
"""
|
||||
sparsifier = self._make_sparsifier(data_list, data_with_config, defaults=defaults, **kwargs)
|
||||
sparsifier = self._make_sparsifier(
|
||||
data_list, data_with_config, defaults=defaults, **kwargs
|
||||
)
|
||||
all_data = data_list + data_with_config
|
||||
for some_data in all_data:
|
||||
name, data, _ = self._get_name_data_config(some_data)
|
||||
weight = sparsifier._extract_weight(data)
|
||||
weight.data = weight + torch.randn(*weight.shape)
|
||||
contained_data = sparsifier.get_data(name=name)
|
||||
assert weight.data.storage().data_ptr() == contained_data.data.storage().data_ptr()
|
||||
assert (
|
||||
weight.data.storage().data_ptr()
|
||||
== contained_data.data.storage().data_ptr()
|
||||
)
|
||||
assert torch.all(contained_data == weight)
|
||||
|
||||
|
||||
class _NormDataSparsifierTestCase(_BaseDataSparsiferTestCase):
|
||||
r"""This helper test class takes in any supported type of and runs some tests.
|
||||
This inherits the TestBaseDataSparsifierRuner wherein some functions are
|
||||
over-ridden to take accomodate the specific sparsifier.
|
||||
TODO: Change the structure by creating a separate test case class for each
|
||||
member function
|
||||
This inherits the TestBaseDataSparsifierRuner wherein some functions are
|
||||
over-ridden to take accomodate the specific sparsifier.
|
||||
TODO: Change the structure by creating a separate test case class for each
|
||||
member function
|
||||
"""
|
||||
def run_all_checks(self, data_list, defaults, data_with_config, norm_type='L1'):
|
||||
assert norm_type in ['L1', 'L2']
|
||||
|
||||
def run_all_checks(self, data_list, defaults, data_with_config, norm_type="L1"):
|
||||
assert norm_type in ["L1", "L2"]
|
||||
kwargs = {
|
||||
'sparsifier_type': DataNormSparsifier,
|
||||
'sparsifier_kwargs': {'norm': norm_type}
|
||||
"sparsifier_type": DataNormSparsifier,
|
||||
"sparsifier_kwargs": {"norm": norm_type},
|
||||
}
|
||||
self.check_constructor(data_list, data_with_config, defaults, **kwargs)
|
||||
self.check_squash_mask(data_list, data_with_config, defaults, **kwargs)
|
||||
@ -234,24 +290,28 @@ class _NormDataSparsifierTestCase(_BaseDataSparsiferTestCase):
|
||||
self.check_state_dict(data_list, data_with_config, defaults, **kwargs)
|
||||
self.check_step(data_list, data_with_config, defaults, norm_type=norm_type)
|
||||
self.check_step_2_of_4(norm_type=norm_type)
|
||||
self.check_sparsity_level(data_list, data_with_config, defaults, norm_type=norm_type)
|
||||
self.check_sparsity_level(
|
||||
data_list, data_with_config, defaults, norm_type=norm_type
|
||||
)
|
||||
self.check_memory_reference(data_list, data_with_config, defaults, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _get_bounds_on_actual_sparsity(config, tensor_shape):
|
||||
r"""This function gets the bounds on actual sparsity.
|
||||
Note::
|
||||
Although we specify the sparsity_level parameter, this does not mean that
|
||||
the actual sparsity obtained after sparsification is the same as sparsity_level.
|
||||
The actual sparsity depends largely on the shape and the data itself.
|
||||
Note::
|
||||
Although we specify the sparsity_level parameter, this does not mean that
|
||||
the actual sparsity obtained after sparsification is the same as sparsity_level.
|
||||
The actual sparsity depends largely on the shape and the data itself.
|
||||
"""
|
||||
sparsity_level = config['sparsity_level']
|
||||
zeros_per_block = config['zeros_per_block']
|
||||
sparse_block_shape = config['sparse_block_shape']
|
||||
sparsity_level = config["sparsity_level"]
|
||||
zeros_per_block = config["zeros_per_block"]
|
||||
sparse_block_shape = config["sparse_block_shape"]
|
||||
|
||||
height, width = tensor_shape[-2], tensor_shape[-1]
|
||||
block_height, block_width = sparse_block_shape
|
||||
number_blocks = math.ceil(height / block_height) * math.ceil(width / block_width)
|
||||
number_blocks = math.ceil(height / block_height) * math.ceil(
|
||||
width / block_width
|
||||
)
|
||||
values_per_block = block_height * block_width
|
||||
|
||||
if zeros_per_block == 0:
|
||||
@ -260,17 +320,23 @@ class _NormDataSparsifierTestCase(_BaseDataSparsiferTestCase):
|
||||
# min value assumes zeros_per_block is 1
|
||||
min_values_sparsified = round(number_blocks * sparsity_level)
|
||||
# max value assumes actual zeros_per_block
|
||||
max_values_sparsified = min_values_sparsified * min(values_per_block, zeros_per_block)
|
||||
max_values_sparsified = min_values_sparsified * min(
|
||||
values_per_block, zeros_per_block
|
||||
)
|
||||
lower_bound = min_values_sparsified / (height * width)
|
||||
upper_bound = min(1.0, max_values_sparsified / (height * width))
|
||||
|
||||
lower_bound, upper_bound = round(lower_bound, 3), round(upper_bound, 3)
|
||||
return lower_bound, upper_bound
|
||||
|
||||
def check_step(self, data_list, data_with_config, defaults, norm_type='L1'):
|
||||
sparsifier = self._make_sparsifier(data_list, data_with_config, defaults,
|
||||
sparsifier_type=DataNormSparsifier,
|
||||
sparsifier_kwargs={'norm': norm_type})
|
||||
def check_step(self, data_list, data_with_config, defaults, norm_type="L1"):
|
||||
sparsifier = self._make_sparsifier(
|
||||
data_list,
|
||||
data_with_config,
|
||||
defaults,
|
||||
sparsifier_type=DataNormSparsifier,
|
||||
sparsifier_kwargs={"norm": norm_type},
|
||||
)
|
||||
all_data = data_list + data_with_config
|
||||
|
||||
# mask before step() should not be sparsified
|
||||
@ -289,29 +355,39 @@ class _NormDataSparsifierTestCase(_BaseDataSparsiferTestCase):
|
||||
mask = mask.to(torch.float)
|
||||
actual_sparsity = round(1 - mask.mean().item(), 3)
|
||||
assert actual_sparsity >= lb and actual_sparsity <= ub
|
||||
assert actual_sparsity > 0.0 # exact sparsity level cannot be achieved due to size of tensor
|
||||
assert (
|
||||
actual_sparsity > 0.0
|
||||
) # exact sparsity level cannot be achieved due to size of tensor
|
||||
|
||||
iters_before_collapse = 100
|
||||
|
||||
test_sparsifier = DataNormSparsifier(sparsity_level=0.5,
|
||||
sparse_block_shape=(1, 4),
|
||||
zeros_per_block=4,
|
||||
norm=norm_type)
|
||||
test_sparsifier = DataNormSparsifier(
|
||||
sparsity_level=0.5,
|
||||
sparse_block_shape=(1, 4),
|
||||
zeros_per_block=4,
|
||||
norm=norm_type,
|
||||
)
|
||||
|
||||
for _ in range(iters_before_collapse):
|
||||
new_data = torch.randn(20, 20)
|
||||
test_sparsifier.add_data(name='test_data', data=new_data)
|
||||
test_sparsifier.add_data(name="test_data", data=new_data)
|
||||
test_sparsifier.step()
|
||||
mask = test_sparsifier.get_mask(name='test_data')
|
||||
mask = test_sparsifier.get_mask(name="test_data")
|
||||
mask = mask.to(torch.float)
|
||||
assert (1.0 - mask.mean().item()) > 0 # some sparsity achieved
|
||||
|
||||
def check_step_2_of_4(self, norm_type):
|
||||
# overriding default config for test purposes
|
||||
default_config = {'sparsity_level': 1.0, 'zeros_per_block': 2, 'sparse_block_shape': (1, 4)}
|
||||
data_list = [('test_data', torch.randn(4, 4))]
|
||||
default_config = {
|
||||
"sparsity_level": 1.0,
|
||||
"zeros_per_block": 2,
|
||||
"sparse_block_shape": (1, 4),
|
||||
}
|
||||
data_list = [("test_data", torch.randn(4, 4))]
|
||||
|
||||
sparsifier = DataNormSparsifier(data_list=data_list, norm=norm_type, **default_config)
|
||||
sparsifier = DataNormSparsifier(
|
||||
data_list=data_list, norm=norm_type, **default_config
|
||||
)
|
||||
sparsifier.step()
|
||||
|
||||
for some_data in data_list:
|
||||
@ -321,29 +397,39 @@ class _NormDataSparsifierTestCase(_BaseDataSparsiferTestCase):
|
||||
self.assertAlmostEqual(1.0 - mask.mean().item(), 0.5, places=2)
|
||||
for row in mask:
|
||||
for idx in range(0, len(row), 4):
|
||||
block = row[idx:idx + 4]
|
||||
block = row[idx : idx + 4]
|
||||
block, _ = block.sort()
|
||||
assert (block[:2] == 0).all()
|
||||
assert (block[2:] != 0).all()
|
||||
|
||||
def check_sparsity_level(self, data_list, data_with_config, defaults, norm_type='L1'):
|
||||
def check_sparsity_level(
|
||||
self, data_list, data_with_config, defaults, norm_type="L1"
|
||||
):
|
||||
sparsity_levels = [-1.0, 0.0, 0.5, 1.0, 2.0]
|
||||
sparse_block_shapes = [(1, 1), (1, 4), (2, 2), (4, 1)]
|
||||
zeros_per_blocks = [0, 1, 2, 3, 4]
|
||||
sparsifier = DataNormSparsifier(data_list=data_list, norm=norm_type)
|
||||
|
||||
testcases = itertools.tee(itertools.product(sparsity_levels,
|
||||
sparse_block_shapes,
|
||||
zeros_per_blocks))
|
||||
testcases = itertools.tee(
|
||||
itertools.product(sparsity_levels, sparse_block_shapes, zeros_per_blocks)
|
||||
)
|
||||
|
||||
assert len(data_with_config) > 0 and 'name' in data_with_config[0] and 'data' in data_with_config[0]
|
||||
assert (
|
||||
len(data_with_config) > 0
|
||||
and "name" in data_with_config[0]
|
||||
and "data" in data_with_config[0]
|
||||
)
|
||||
# get some data
|
||||
name, data = data_with_config[0]['name'], data_with_config[0]['data']
|
||||
name, data = data_with_config[0]["name"], data_with_config[0]["data"]
|
||||
for idx, (sl, sbs, zpb) in enumerate(testcases[0]):
|
||||
new_name = f'{name}_{idx}'
|
||||
new_name = f"{name}_{idx}"
|
||||
if zpb > sbs[0] * sbs[1]:
|
||||
continue
|
||||
current_config = {'sparsity_level': sl, 'sparse_block_shape': sbs, 'zeros_per_block': zpb}
|
||||
current_config = {
|
||||
"sparsity_level": sl,
|
||||
"sparse_block_shape": sbs,
|
||||
"zeros_per_block": zpb,
|
||||
}
|
||||
sparsifier.add_data(name=new_name, data=data, **current_config)
|
||||
if zpb > sbs[0] * sbs[1]:
|
||||
continue
|
||||
@ -351,7 +437,7 @@ class _NormDataSparsifierTestCase(_BaseDataSparsiferTestCase):
|
||||
sparsifier.step()
|
||||
sparsifier.squash_mask()
|
||||
for idx, (sl, sbs, zpb) in enumerate(testcases[0]):
|
||||
new_name = f'{name}_{idx}'
|
||||
new_name = f"{name}_{idx}"
|
||||
sparsified_data = sparsifier.get_data(name=new_name, original=False)
|
||||
# sparse mask
|
||||
sparse_mask = (sparsified_data == 0).float()
|
||||
@ -372,130 +458,232 @@ class TestBaseDataSparsifier(_BaseDataSparsiferTestCase):
|
||||
|
||||
Once the above is done, create an instance of TestBaseDataSparsifierType and call all the run_tests()
|
||||
"""
|
||||
|
||||
def test_tensors(self):
|
||||
tensor1, tensor2, tensor3 = torch.randn(3, 3), torch.randn(4, 4), torch.randn(5, 5)
|
||||
tensor1, tensor2, tensor3 = (
|
||||
torch.randn(3, 3),
|
||||
torch.randn(4, 4),
|
||||
torch.randn(5, 5),
|
||||
)
|
||||
tensor4, tensor5 = torch.randn(1, 1), torch.randn(4, 4)
|
||||
data_list = [('tensor1', tensor1), ('tensor2', tensor2), ('tensor3', tensor3)]
|
||||
defaults = {'test': 3}
|
||||
data_list = [("tensor1", tensor1), ("tensor2", tensor2), ("tensor3", tensor3)]
|
||||
defaults = {"test": 3}
|
||||
|
||||
data_with_config = [
|
||||
{
|
||||
'name': 'tensor4', 'data': tensor4, 'config': {'test': 7}
|
||||
},
|
||||
{
|
||||
'name': 'tensor5', 'data': tensor5, 'config': {'test': 8}
|
||||
},
|
||||
{"name": "tensor4", "data": tensor4, "config": {"test": 7}},
|
||||
{"name": "tensor5", "data": tensor5, "config": {"test": 8}},
|
||||
]
|
||||
self.run_all_checks(data_list=data_list, defaults=defaults, data_with_config=data_with_config)
|
||||
self.run_all_checks(
|
||||
data_list=data_list, defaults=defaults, data_with_config=data_with_config
|
||||
)
|
||||
|
||||
def test_nn_parameters(self):
|
||||
param1, param2, param3 = nn.Parameter(torch.randn(3, 3)), nn.Parameter(torch.randn(4, 4)), nn.Parameter(torch.randn(5, 5))
|
||||
param4, param5 = nn.Parameter(torch.randn(1, 1)), nn.Parameter(torch.randn(4, 4))
|
||||
data_list = [('param1', param1), ('param2', param2), ('param3', param3)]
|
||||
defaults = {'test': 3}
|
||||
param1, param2, param3 = (
|
||||
nn.Parameter(torch.randn(3, 3)),
|
||||
nn.Parameter(torch.randn(4, 4)),
|
||||
nn.Parameter(torch.randn(5, 5)),
|
||||
)
|
||||
param4, param5 = nn.Parameter(torch.randn(1, 1)), nn.Parameter(
|
||||
torch.randn(4, 4)
|
||||
)
|
||||
data_list = [("param1", param1), ("param2", param2), ("param3", param3)]
|
||||
defaults = {"test": 3}
|
||||
|
||||
data_with_config = [
|
||||
{
|
||||
'name': 'param4', 'data': param4, 'config': {'test': 7}
|
||||
},
|
||||
{
|
||||
'name': 'param5', 'data': param5, 'config': {'test': 8}
|
||||
},
|
||||
{"name": "param4", "data": param4, "config": {"test": 7}},
|
||||
{"name": "param5", "data": param5, "config": {"test": 8}},
|
||||
]
|
||||
self.run_all_checks(data_list=data_list, defaults=defaults, data_with_config=data_with_config)
|
||||
self.run_all_checks(
|
||||
data_list=data_list, defaults=defaults, data_with_config=data_with_config
|
||||
)
|
||||
|
||||
def test_nn_embeddings(self):
|
||||
emb1, emb2, = nn.Embedding(10, 3), nn.Embedding(20, 3)
|
||||
(
|
||||
emb1,
|
||||
emb2,
|
||||
) = nn.Embedding(
|
||||
10, 3
|
||||
), nn.Embedding(20, 3)
|
||||
emb1_bag, emb2_bag = nn.EmbeddingBag(10, 3), nn.EmbeddingBag(20, 3)
|
||||
|
||||
emb3, emb3_bag = nn.Embedding(15, 3), nn.EmbeddingBag(20, 3)
|
||||
data_list = [('emb1', emb1), ('emb1_bag', emb1_bag), ('emb2', emb2), ('emb2_bag', emb2_bag)]
|
||||
defaults = {'test': 3}
|
||||
data_list = [
|
||||
("emb1", emb1),
|
||||
("emb1_bag", emb1_bag),
|
||||
("emb2", emb2),
|
||||
("emb2_bag", emb2_bag),
|
||||
]
|
||||
defaults = {"test": 3}
|
||||
|
||||
data_with_config = [
|
||||
{
|
||||
'name': 'emb3', 'data': emb3, 'config': {'test': 7}
|
||||
},
|
||||
{
|
||||
'name': 'emb3_bag', 'data': emb3_bag, 'config': {'test': 8}
|
||||
},
|
||||
{"name": "emb3", "data": emb3, "config": {"test": 7}},
|
||||
{"name": "emb3_bag", "data": emb3_bag, "config": {"test": 8}},
|
||||
]
|
||||
self.run_all_checks(data_list=data_list, defaults=defaults, data_with_config=data_with_config)
|
||||
self.run_all_checks(
|
||||
data_list=data_list, defaults=defaults, data_with_config=data_with_config
|
||||
)
|
||||
|
||||
|
||||
class TestNormDataSparsifiers(_NormDataSparsifierTestCase):
|
||||
"""To add unit tests to support new data types for the NormDataSparsifier, create the following
|
||||
data_list: List of tuples of name, data to be added to the constructor
|
||||
defaults: default config for the above data in data_list
|
||||
data_with_config: list of dictionaries defining name, data and config (look test_tensors())
|
||||
data_list: List of tuples of name, data to be added to the constructor
|
||||
defaults: default config for the above data in data_list
|
||||
data_with_config: list of dictionaries defining name, data and config (look test_tensors())
|
||||
|
||||
Once the above is done, create an instance of _NormDataSparsifierTestRunner and call run_tests()
|
||||
Once the above is done, create an instance of _NormDataSparsifierTestRunner and call run_tests()
|
||||
"""
|
||||
|
||||
def test_tensors(self):
|
||||
tensor1, tensor2, tensor3 = torch.randn(1, 10), torch.randn(4, 4), torch.randn(1, 5)
|
||||
tensor1, tensor2, tensor3 = (
|
||||
torch.randn(1, 10),
|
||||
torch.randn(4, 4),
|
||||
torch.randn(1, 5),
|
||||
)
|
||||
tensor4, tensor5 = torch.randn(1, 2), torch.randn(4, 4)
|
||||
data_list = [('tensor1', tensor1), ('tensor2', tensor2), ('tensor3', tensor3)]
|
||||
defaults = {'sparsity_level': 0.5, 'sparse_block_shape': (1, 4), 'zeros_per_block': 4}
|
||||
data_list = [("tensor1", tensor1), ("tensor2", tensor2), ("tensor3", tensor3)]
|
||||
defaults = {
|
||||
"sparsity_level": 0.5,
|
||||
"sparse_block_shape": (1, 4),
|
||||
"zeros_per_block": 4,
|
||||
}
|
||||
|
||||
data_with_config = [
|
||||
{
|
||||
'name': 'tensor4', 'data': tensor4,
|
||||
'config': {'sparsity_level': 0.7, 'sparse_block_shape': (2, 3), 'zeros_per_block': 6}
|
||||
"name": "tensor4",
|
||||
"data": tensor4,
|
||||
"config": {
|
||||
"sparsity_level": 0.7,
|
||||
"sparse_block_shape": (2, 3),
|
||||
"zeros_per_block": 6,
|
||||
},
|
||||
},
|
||||
{
|
||||
'name': 'tensor5', 'data': tensor5,
|
||||
'config': {'sparsity_level': 0.3, 'sparse_block_shape': (2, 3), 'zeros_per_block': 6}
|
||||
"name": "tensor5",
|
||||
"data": tensor5,
|
||||
"config": {
|
||||
"sparsity_level": 0.3,
|
||||
"sparse_block_shape": (2, 3),
|
||||
"zeros_per_block": 6,
|
||||
},
|
||||
},
|
||||
]
|
||||
self.run_all_checks(data_list=data_list, defaults=defaults,
|
||||
data_with_config=data_with_config, norm_type='L1')
|
||||
self.run_all_checks(data_list=data_list, defaults=defaults,
|
||||
data_with_config=data_with_config, norm_type='L2')
|
||||
self.run_all_checks(
|
||||
data_list=data_list,
|
||||
defaults=defaults,
|
||||
data_with_config=data_with_config,
|
||||
norm_type="L1",
|
||||
)
|
||||
self.run_all_checks(
|
||||
data_list=data_list,
|
||||
defaults=defaults,
|
||||
data_with_config=data_with_config,
|
||||
norm_type="L2",
|
||||
)
|
||||
|
||||
def test_nn_parameters(self):
|
||||
param1, param2, param3 = nn.Parameter(torch.randn(1, 8)), nn.Parameter(torch.randn(4, 4)), nn.Parameter(torch.randn(5, 5))
|
||||
param4, param5 = nn.Parameter(torch.randn(10, 10)), nn.Parameter(torch.randn(4, 4))
|
||||
data_list = [('param1', param1), ('param2', param2), ('param3', param3)]
|
||||
defaults = {'sparsity_level': 0.5, 'sparse_block_shape': (1, 4), 'zeros_per_block': 4}
|
||||
param1, param2, param3 = (
|
||||
nn.Parameter(torch.randn(1, 8)),
|
||||
nn.Parameter(torch.randn(4, 4)),
|
||||
nn.Parameter(torch.randn(5, 5)),
|
||||
)
|
||||
param4, param5 = nn.Parameter(torch.randn(10, 10)), nn.Parameter(
|
||||
torch.randn(4, 4)
|
||||
)
|
||||
data_list = [("param1", param1), ("param2", param2), ("param3", param3)]
|
||||
defaults = {
|
||||
"sparsity_level": 0.5,
|
||||
"sparse_block_shape": (1, 4),
|
||||
"zeros_per_block": 4,
|
||||
}
|
||||
|
||||
data_with_config = [
|
||||
{
|
||||
'name': 'param4', 'data': param4,
|
||||
'config': {'sparsity_level': 0.7, 'sparse_block_shape': (2, 3), 'zeros_per_block': 6}
|
||||
"name": "param4",
|
||||
"data": param4,
|
||||
"config": {
|
||||
"sparsity_level": 0.7,
|
||||
"sparse_block_shape": (2, 3),
|
||||
"zeros_per_block": 6,
|
||||
},
|
||||
},
|
||||
{
|
||||
'name': 'param5', 'data': param5,
|
||||
'config': {'sparsity_level': 0.3, 'sparse_block_shape': (2, 3), 'zeros_per_block': 6}
|
||||
"name": "param5",
|
||||
"data": param5,
|
||||
"config": {
|
||||
"sparsity_level": 0.3,
|
||||
"sparse_block_shape": (2, 3),
|
||||
"zeros_per_block": 6,
|
||||
},
|
||||
},
|
||||
]
|
||||
self.run_all_checks(data_list=data_list, defaults=defaults,
|
||||
data_with_config=data_with_config, norm_type='L1')
|
||||
self.run_all_checks(data_list=data_list, defaults=defaults,
|
||||
data_with_config=data_with_config, norm_type='L2')
|
||||
self.run_all_checks(
|
||||
data_list=data_list,
|
||||
defaults=defaults,
|
||||
data_with_config=data_with_config,
|
||||
norm_type="L1",
|
||||
)
|
||||
self.run_all_checks(
|
||||
data_list=data_list,
|
||||
defaults=defaults,
|
||||
data_with_config=data_with_config,
|
||||
norm_type="L2",
|
||||
)
|
||||
|
||||
def test_nn_embeddings(self):
|
||||
emb1, emb2, = nn.Embedding(10, 3), nn.Embedding(20, 3)
|
||||
(
|
||||
emb1,
|
||||
emb2,
|
||||
) = nn.Embedding(
|
||||
10, 3
|
||||
), nn.Embedding(20, 3)
|
||||
emb1_bag, emb2_bag = nn.EmbeddingBag(10, 3), nn.EmbeddingBag(20, 3)
|
||||
|
||||
emb3, emb3_bag = nn.Embedding(15, 3), nn.EmbeddingBag(20, 3)
|
||||
data_list = [('emb1', emb1), ('emb1_bag', emb1_bag), ('emb2', emb2), ('emb2_bag', emb2_bag)]
|
||||
defaults = {'sparsity_level': 0.5, 'sparse_block_shape': (1, 4), 'zeros_per_block': 4}
|
||||
data_list = [
|
||||
("emb1", emb1),
|
||||
("emb1_bag", emb1_bag),
|
||||
("emb2", emb2),
|
||||
("emb2_bag", emb2_bag),
|
||||
]
|
||||
defaults = {
|
||||
"sparsity_level": 0.5,
|
||||
"sparse_block_shape": (1, 4),
|
||||
"zeros_per_block": 4,
|
||||
}
|
||||
|
||||
data_with_config = [
|
||||
{
|
||||
'name': 'emb3', 'data': emb3,
|
||||
'config': {'sparsity_level': 0.7, 'sparse_block_shape': (2, 3), 'zeros_per_block': 6}
|
||||
"name": "emb3",
|
||||
"data": emb3,
|
||||
"config": {
|
||||
"sparsity_level": 0.7,
|
||||
"sparse_block_shape": (2, 3),
|
||||
"zeros_per_block": 6,
|
||||
},
|
||||
},
|
||||
{
|
||||
'name': 'emb3_bag', 'data': emb3_bag,
|
||||
'config': {'sparsity_level': 0.3, 'sparse_block_shape': (2, 3), 'zeros_per_block': 6}
|
||||
"name": "emb3_bag",
|
||||
"data": emb3_bag,
|
||||
"config": {
|
||||
"sparsity_level": 0.3,
|
||||
"sparse_block_shape": (2, 3),
|
||||
"zeros_per_block": 6,
|
||||
},
|
||||
},
|
||||
]
|
||||
self.run_all_checks(data_list=data_list, defaults=defaults,
|
||||
data_with_config=data_with_config, norm_type='L1')
|
||||
self.run_all_checks(
|
||||
data_list=data_list,
|
||||
defaults=defaults,
|
||||
data_with_config=data_with_config,
|
||||
norm_type="L1",
|
||||
)
|
||||
|
||||
self.run_all_checks(data_list=data_list, defaults=defaults,
|
||||
data_with_config=data_with_config, norm_type='L2')
|
||||
self.run_all_checks(
|
||||
data_list=data_list,
|
||||
defaults=defaults,
|
||||
data_with_config=data_with_config,
|
||||
norm_type="L2",
|
||||
)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
@ -522,16 +710,21 @@ class TestQuantizationUtils(TestCase):
|
||||
"""
|
||||
model = Model()
|
||||
|
||||
sparse_config = {'sparsity_level': 0.80, 'sparse_block_shape': (1, 1)}
|
||||
sparse_config = {"sparsity_level": 0.80, "sparse_block_shape": (1, 1)}
|
||||
select_embeddings = [model.embbag1, model.emb1]
|
||||
post_training_sparse_quantize(model,
|
||||
data_sparsifier_class=DataNormSparsifier,
|
||||
sparsify_first=True,
|
||||
select_embeddings=select_embeddings,
|
||||
**sparse_config)
|
||||
post_training_sparse_quantize(
|
||||
model,
|
||||
data_sparsifier_class=DataNormSparsifier,
|
||||
sparsify_first=True,
|
||||
select_embeddings=select_embeddings,
|
||||
**sparse_config,
|
||||
)
|
||||
|
||||
assert type(model.emb1) == torch.ao.nn.quantized.modules.embedding_ops.Embedding
|
||||
assert type(model.embbag1) == torch.ao.nn.quantized.modules.embedding_ops.EmbeddingBag
|
||||
assert (
|
||||
type(model.embbag1)
|
||||
== torch.ao.nn.quantized.modules.embedding_ops.EmbeddingBag
|
||||
)
|
||||
assert type(model.emb_seq[0] == nn.Embedding)
|
||||
assert type(model.emb_seq[1] == nn.EmbeddingBag)
|
||||
assert type(model.linear1) == nn.Linear
|
||||
@ -561,24 +754,34 @@ class TestQuantizationUtils(TestCase):
|
||||
"""
|
||||
model = Model()
|
||||
|
||||
sparse_config = {'sparsity_level': 0.8, 'sparse_block_shape': (1, 1)}
|
||||
post_training_sparse_quantize(model, DataNormSparsifier, sparsify_first=False, **sparse_config)
|
||||
sparse_config = {"sparsity_level": 0.8, "sparse_block_shape": (1, 1)}
|
||||
post_training_sparse_quantize(
|
||||
model, DataNormSparsifier, sparsify_first=False, **sparse_config
|
||||
)
|
||||
|
||||
assert type(model.emb1) == torch.ao.nn.quantized.modules.embedding_ops.Embedding
|
||||
assert type(model.embbag1) == torch.ao.nn.quantized.modules.embedding_ops.EmbeddingBag
|
||||
assert type(model.emb_seq[0] == torch.ao.nn.quantized.modules.embedding_ops.Embedding)
|
||||
assert type(model.emb_seq[1] == torch.ao.nn.quantized.modules.embedding_ops.EmbeddingBag)
|
||||
assert (
|
||||
type(model.embbag1)
|
||||
== torch.ao.nn.quantized.modules.embedding_ops.EmbeddingBag
|
||||
)
|
||||
assert type(
|
||||
model.emb_seq[0] == torch.ao.nn.quantized.modules.embedding_ops.Embedding
|
||||
)
|
||||
assert type(
|
||||
model.emb_seq[1] == torch.ao.nn.quantized.modules.embedding_ops.EmbeddingBag
|
||||
)
|
||||
assert type(model.linear1) == nn.Linear # not quantized
|
||||
assert type(model.linear2) == nn.Linear # not quantized
|
||||
|
||||
|
||||
dequant_emb1 = torch.dequantize(model.emb1.weight())
|
||||
dequant_embbag1 = torch.dequantize(model.embbag1.weight())
|
||||
dequant_emb_seq_0 = torch.dequantize(model.emb_seq[0].weight())
|
||||
dequant_emb_seq_1 = torch.dequantize(model.emb_seq[1].weight())
|
||||
|
||||
# higher threshold as quantization occurs before sparsity
|
||||
threshold = 1 # zero points seem to have higher magnitude with sparsity occuring after
|
||||
threshold = (
|
||||
1 # zero points seem to have higher magnitude with sparsity occuring after
|
||||
)
|
||||
|
||||
sl_emb1 = (torch.abs(dequant_emb1) < threshold).float().mean()
|
||||
sl_embbag1 = (torch.abs(dequant_embbag1) < threshold).float().mean()
|
||||
|
@ -1,32 +1,34 @@
|
||||
# Owner(s): ["module: unknown"]
|
||||
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
|
||||
import copy
|
||||
import numpy as np
|
||||
import io
|
||||
import logging
|
||||
from itertools import product
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.ao.quantization as tq
|
||||
|
||||
from torch import nn
|
||||
from torch.ao.pruning.sparsifier.utils import fqn_to_module
|
||||
|
||||
from torch.testing._internal.common_utils import TestCase, skipIfTorchDynamo
|
||||
from torch.testing._internal.common_quantized import (
|
||||
override_cpu_allocator_for_qnnpack,
|
||||
override_qengines,
|
||||
qengine_is_qnnpack,
|
||||
qengine_is_fbgemm,
|
||||
qengine_is_onednn,
|
||||
qengine_is_qnnpack,
|
||||
qengine_is_x86,
|
||||
)
|
||||
|
||||
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
|
||||
|
||||
# TODO: Once more test files are created, move the contents to a ao folder.
|
||||
|
||||
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO)
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
|
||||
)
|
||||
|
||||
|
||||
class TestQuantizedSparseKernels(TestCase):
|
||||
@skipIfTorchDynamo("TorchDynamo fails here for unknown reasons")
|
||||
@ -84,29 +86,42 @@ class TestQuantizedSparseKernels(TestCase):
|
||||
continue
|
||||
if use_channelwise:
|
||||
W_q = torch.quantize_per_channel(
|
||||
W_fp32, scales=W_scales, zero_points=W_zps, axis=0, dtype=torch.qint8
|
||||
W_fp32,
|
||||
scales=W_scales,
|
||||
zero_points=W_zps,
|
||||
axis=0,
|
||||
dtype=torch.qint8,
|
||||
)
|
||||
else:
|
||||
W_q = torch.quantize_per_tensor(
|
||||
W_fp32, scale=W_scales[0], zero_point=W_zps[0], dtype=torch.qint8
|
||||
W_fp32,
|
||||
scale=W_scales[0],
|
||||
zero_point=W_zps[0],
|
||||
dtype=torch.qint8,
|
||||
)
|
||||
|
||||
Y_scale = 1.1234
|
||||
Y_zp = 5
|
||||
W_prepack_dense = dense_prepack(W_q, float_bias)
|
||||
W_prepack_sparse = sparse_prepack(W_q, float_bias, row_block_size, col_block_size)
|
||||
W_prepack_sparse = sparse_prepack(
|
||||
W_q, float_bias, row_block_size, col_block_size
|
||||
)
|
||||
|
||||
if dynamic_mode:
|
||||
Y = sparse_qlinear_dynamic(X_fp32, W_prepack_sparse)
|
||||
Y_ref = dense_qlinear_dynamic(X_fp32, W_prepack_dense)
|
||||
|
||||
np.testing.assert_array_almost_equal(Y_ref.numpy(), Y.numpy(), decimal=decimal_val)
|
||||
np.testing.assert_array_almost_equal(
|
||||
Y_ref.numpy(), Y.numpy(), decimal=decimal_val
|
||||
)
|
||||
else:
|
||||
Y_q = sparse_qlinear(X_q, W_prepack_sparse, Y_scale, Y_zp)
|
||||
Y_q_ref = dense_qlinear(X_q, W_prepack_dense, Y_scale, Y_zp)
|
||||
|
||||
np.testing.assert_array_almost_equal(
|
||||
Y_q_ref.int_repr().numpy(), Y_q.int_repr().numpy(), decimal=decimal_val
|
||||
Y_q_ref.int_repr().numpy(),
|
||||
Y_q.int_repr().numpy(),
|
||||
decimal=decimal_val,
|
||||
)
|
||||
|
||||
|
||||
@ -235,6 +250,7 @@ def _sparse_layer_test_helper(
|
||||
Y_hat = sqmodel(X_q)
|
||||
test_class.assertEqual(Y_ref.dequantize(), Y_hat.dequantize())
|
||||
|
||||
|
||||
class SparseQuantizedModel(nn.Module):
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super().__init__()
|
||||
@ -243,6 +259,7 @@ class SparseQuantizedModel(nn.Module):
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
|
||||
class TestQuantizedSparseLayers(TestCase):
|
||||
@override_qengines
|
||||
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
|
||||
|
@ -3,22 +3,24 @@
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
from torch import nn
|
||||
from torch.ao.pruning.sparsifier import utils
|
||||
from torch.nn.utils import parametrize
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
|
||||
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO)
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
|
||||
)
|
||||
|
||||
|
||||
class ModelUnderTest(nn.Module):
|
||||
def __init__(self, bias=True):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(16, 16, bias=bias)
|
||||
self.seq = nn.Sequential(
|
||||
nn.Linear(16, 16, bias=bias),
|
||||
nn.Linear(16, 16, bias=bias)
|
||||
nn.Linear(16, 16, bias=bias), nn.Linear(16, 16, bias=bias)
|
||||
)
|
||||
|
||||
# Make sure the weights are not random
|
||||
@ -45,7 +47,7 @@ class TestFakeSparsity(TestCase):
|
||||
|
||||
mask = torch.zeros(16, 16)
|
||||
sparsity = utils.FakeSparsity(mask)
|
||||
parametrize.register_parametrization(model, 'weight', sparsity)
|
||||
parametrize.register_parametrization(model, "weight", sparsity)
|
||||
|
||||
x = torch.randn(3, 16)
|
||||
self.assertEqual(torch.zeros(3, 16), model(x))
|
||||
@ -53,95 +55,119 @@ class TestFakeSparsity(TestCase):
|
||||
def test_weights_parametrized(self):
|
||||
model = ModelUnderTest(bias=False)
|
||||
|
||||
assert not hasattr(model.linear, 'parametrizations')
|
||||
assert not hasattr(model.seq[0], 'parametrizations')
|
||||
assert not hasattr(model.seq[1], 'parametrizations')
|
||||
assert not hasattr(model.linear, "parametrizations")
|
||||
assert not hasattr(model.seq[0], "parametrizations")
|
||||
assert not hasattr(model.seq[1], "parametrizations")
|
||||
mask = torch.eye(16)
|
||||
parametrize.register_parametrization(model.linear, 'weight',
|
||||
utils.FakeSparsity(mask))
|
||||
parametrize.register_parametrization(
|
||||
model.linear, "weight", utils.FakeSparsity(mask)
|
||||
)
|
||||
mask = torch.eye(16)
|
||||
parametrize.register_parametrization(model.seq[0], 'weight',
|
||||
utils.FakeSparsity(mask))
|
||||
parametrize.register_parametrization(
|
||||
model.seq[0], "weight", utils.FakeSparsity(mask)
|
||||
)
|
||||
mask = torch.eye(16)
|
||||
parametrize.register_parametrization(model.seq[1], 'weight',
|
||||
utils.FakeSparsity(mask))
|
||||
parametrize.register_parametrization(
|
||||
model.seq[1], "weight", utils.FakeSparsity(mask)
|
||||
)
|
||||
|
||||
assert hasattr(model.linear, 'parametrizations')
|
||||
assert parametrize.is_parametrized(model.linear, 'weight')
|
||||
assert hasattr(model.seq[0], 'parametrizations')
|
||||
assert parametrize.is_parametrized(model.linear, 'weight')
|
||||
assert hasattr(model.seq[1], 'parametrizations')
|
||||
assert parametrize.is_parametrized(model.linear, 'weight')
|
||||
assert hasattr(model.linear, "parametrizations")
|
||||
assert parametrize.is_parametrized(model.linear, "weight")
|
||||
assert hasattr(model.seq[0], "parametrizations")
|
||||
assert parametrize.is_parametrized(model.linear, "weight")
|
||||
assert hasattr(model.seq[1], "parametrizations")
|
||||
assert parametrize.is_parametrized(model.linear, "weight")
|
||||
|
||||
def test_state_dict_preserved(self):
|
||||
model_save = ModelUnderTest(bias=False)
|
||||
|
||||
mask = torch.eye(16)
|
||||
parametrize.register_parametrization(model_save.linear, 'weight',
|
||||
utils.FakeSparsity(mask))
|
||||
parametrize.register_parametrization(
|
||||
model_save.linear, "weight", utils.FakeSparsity(mask)
|
||||
)
|
||||
mask = torch.eye(16)
|
||||
parametrize.register_parametrization(model_save.seq[0], 'weight',
|
||||
utils.FakeSparsity(mask))
|
||||
parametrize.register_parametrization(
|
||||
model_save.seq[0], "weight", utils.FakeSparsity(mask)
|
||||
)
|
||||
mask = torch.eye(16)
|
||||
parametrize.register_parametrization(model_save.seq[1], 'weight',
|
||||
utils.FakeSparsity(mask))
|
||||
parametrize.register_parametrization(
|
||||
model_save.seq[1], "weight", utils.FakeSparsity(mask)
|
||||
)
|
||||
state_dict = model_save.state_dict()
|
||||
|
||||
model_load = ModelUnderTest(bias=False)
|
||||
mask = torch.zeros(model_load.linear.weight.shape)
|
||||
parametrize.register_parametrization(model_load.linear, 'weight',
|
||||
utils.FakeSparsity(mask))
|
||||
parametrize.register_parametrization(
|
||||
model_load.linear, "weight", utils.FakeSparsity(mask)
|
||||
)
|
||||
mask = torch.zeros(model_load.seq[0].weight.shape)
|
||||
parametrize.register_parametrization(model_load.seq[0], 'weight',
|
||||
utils.FakeSparsity(mask))
|
||||
parametrize.register_parametrization(
|
||||
model_load.seq[0], "weight", utils.FakeSparsity(mask)
|
||||
)
|
||||
mask = torch.zeros(model_load.seq[1].weight.shape)
|
||||
parametrize.register_parametrization(model_load.seq[1], 'weight',
|
||||
utils.FakeSparsity(mask))
|
||||
parametrize.register_parametrization(
|
||||
model_load.seq[1], "weight", utils.FakeSparsity(mask)
|
||||
)
|
||||
# Keep this strict, as we are not loading the 'mask'
|
||||
model_load.load_state_dict(state_dict, strict=False)
|
||||
|
||||
# Check the parametrizations are preserved
|
||||
assert hasattr(model_load.linear, 'parametrizations')
|
||||
assert parametrize.is_parametrized(model_load.linear, 'weight')
|
||||
assert hasattr(model_load.seq[0], 'parametrizations')
|
||||
assert parametrize.is_parametrized(model_load.linear, 'weight')
|
||||
assert hasattr(model_load.seq[1], 'parametrizations')
|
||||
assert parametrize.is_parametrized(model_load.linear, 'weight')
|
||||
assert hasattr(model_load.linear, "parametrizations")
|
||||
assert parametrize.is_parametrized(model_load.linear, "weight")
|
||||
assert hasattr(model_load.seq[0], "parametrizations")
|
||||
assert parametrize.is_parametrized(model_load.linear, "weight")
|
||||
assert hasattr(model_load.seq[1], "parametrizations")
|
||||
assert parametrize.is_parametrized(model_load.linear, "weight")
|
||||
|
||||
# Check the weights are preserved
|
||||
self.assertEqual(model_save.linear.parametrizations['weight'].original,
|
||||
model_load.linear.parametrizations['weight'].original)
|
||||
self.assertEqual(model_save.seq[0].parametrizations['weight'].original,
|
||||
model_load.seq[0].parametrizations['weight'].original)
|
||||
self.assertEqual(model_save.seq[1].parametrizations['weight'].original,
|
||||
model_load.seq[1].parametrizations['weight'].original)
|
||||
self.assertEqual(
|
||||
model_save.linear.parametrizations["weight"].original,
|
||||
model_load.linear.parametrizations["weight"].original,
|
||||
)
|
||||
self.assertEqual(
|
||||
model_save.seq[0].parametrizations["weight"].original,
|
||||
model_load.seq[0].parametrizations["weight"].original,
|
||||
)
|
||||
self.assertEqual(
|
||||
model_save.seq[1].parametrizations["weight"].original,
|
||||
model_load.seq[1].parametrizations["weight"].original,
|
||||
)
|
||||
|
||||
# Check the masks are not preserved in the state_dict
|
||||
# We store the state_dicts in the sparsifier, not in the model itself.
|
||||
# TODO: Need to find a clean way of exporting the parametrized model
|
||||
self.assertNotEqual(model_save.linear.parametrizations['weight'][0].mask,
|
||||
model_load.linear.parametrizations['weight'][0].mask)
|
||||
self.assertNotEqual(model_save.seq[0].parametrizations['weight'][0].mask,
|
||||
model_load.seq[0].parametrizations['weight'][0].mask)
|
||||
self.assertNotEqual(model_save.seq[1].parametrizations['weight'][0].mask,
|
||||
model_load.seq[1].parametrizations['weight'][0].mask)
|
||||
self.assertNotEqual(
|
||||
model_save.linear.parametrizations["weight"][0].mask,
|
||||
model_load.linear.parametrizations["weight"][0].mask,
|
||||
)
|
||||
self.assertNotEqual(
|
||||
model_save.seq[0].parametrizations["weight"][0].mask,
|
||||
model_load.seq[0].parametrizations["weight"][0].mask,
|
||||
)
|
||||
self.assertNotEqual(
|
||||
model_save.seq[1].parametrizations["weight"][0].mask,
|
||||
model_load.seq[1].parametrizations["weight"][0].mask,
|
||||
)
|
||||
|
||||
def test_jit_trace(self):
|
||||
model = ModelUnderTest(bias=False)
|
||||
|
||||
mask = torch.eye(16)
|
||||
parametrize.register_parametrization(model.linear, 'weight',
|
||||
utils.FakeSparsity(mask))
|
||||
parametrize.register_parametrization(
|
||||
model.linear, "weight", utils.FakeSparsity(mask)
|
||||
)
|
||||
mask = torch.eye(16)
|
||||
parametrize.register_parametrization(model.seq[0], 'weight',
|
||||
utils.FakeSparsity(mask))
|
||||
parametrize.register_parametrization(
|
||||
model.seq[0], "weight", utils.FakeSparsity(mask)
|
||||
)
|
||||
mask = torch.eye(16)
|
||||
parametrize.register_parametrization(model.seq[1], 'weight',
|
||||
utils.FakeSparsity(mask))
|
||||
parametrize.register_parametrization(
|
||||
model.seq[1], "weight", utils.FakeSparsity(mask)
|
||||
)
|
||||
|
||||
# Tracing
|
||||
example_x = torch.ones(3, 16)
|
||||
model_trace = torch.jit.trace_module(model, {'forward': example_x})
|
||||
model_trace = torch.jit.trace_module(model, {"forward": example_x})
|
||||
|
||||
x = torch.randn(3, 16)
|
||||
y = model(x)
|
||||
|
@ -2,29 +2,30 @@
|
||||
# Owner(s): ["oncall: mobile"]
|
||||
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
from torch.ao.nn.sparse.quantized.dynamic.linear import Linear
|
||||
from torch.testing._internal.common_quantization import (
|
||||
skipIfNoFBGEMM,
|
||||
skipIfNoQNNPACK,
|
||||
)
|
||||
from torch.testing._internal.common_quantization import skipIfNoFBGEMM, skipIfNoQNNPACK
|
||||
from torch.testing._internal.common_quantized import (
|
||||
qengine_is_qnnpack,
|
||||
override_cpu_allocator_for_qnnpack,
|
||||
override_quantized_engine,
|
||||
override_cpu_allocator_for_qnnpack
|
||||
qengine_is_qnnpack,
|
||||
)
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
|
||||
|
||||
class TestQlinearPackedParams(TestCase):
|
||||
def qlinear_packed_params_test(self, allow_non_zero_zero_points=False):
|
||||
# copied from https://pytorch.org/docs/stable/sparse.html#csr-tensor-operations,
|
||||
# so row/col block indices match that example, but with blocks and
|
||||
# scaled rows
|
||||
weight_fp32 = torch.Tensor([
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 0, 0, 0, 0],
|
||||
[6, 6, 6, 6, 12, 12, 12, 12, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
])
|
||||
weight_fp32 = torch.Tensor(
|
||||
[
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 0, 0, 0, 0],
|
||||
[6, 6, 6, 6, 12, 12, 12, 12, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
]
|
||||
)
|
||||
|
||||
row_block_size = 1
|
||||
col_block_size = 4
|
||||
@ -43,18 +44,14 @@ class TestQlinearPackedParams(TestCase):
|
||||
wide_weight_fp32[1][0] = 8
|
||||
|
||||
per_tensor_small = (
|
||||
torch.quantize_per_tensor(
|
||||
weight_fp32,
|
||||
scales[0],
|
||||
zero_points[0],
|
||||
dtype
|
||||
),
|
||||
torch.quantize_per_tensor(weight_fp32, scales[0], zero_points[0], dtype),
|
||||
True,
|
||||
[0, 1, 3, 3],
|
||||
[2, 0, 1],
|
||||
[x + (1 if allow_non_zero_zero_points else 0) for x in [
|
||||
1, 1, 1, 1, 3, 3, 3, 3, 6, 6, 6, 6
|
||||
]],
|
||||
[
|
||||
x + (1 if allow_non_zero_zero_points else 0)
|
||||
for x in [1, 1, 1, 1, 3, 3, 3, 3, 6, 6, 6, 6]
|
||||
],
|
||||
)
|
||||
|
||||
per_channel_small = (
|
||||
@ -68,9 +65,10 @@ class TestQlinearPackedParams(TestCase):
|
||||
False,
|
||||
[0, 1, 3, 3],
|
||||
[2, 0, 1],
|
||||
[x + ([1, 2, 2][i // 4] if allow_non_zero_zero_points else 0) for (i, x) in enumerate([
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2
|
||||
])],
|
||||
[
|
||||
x + ([1, 2, 2][i // 4] if allow_non_zero_zero_points else 0)
|
||||
for (i, x) in enumerate([1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2])
|
||||
],
|
||||
)
|
||||
|
||||
per_tensor_large = (
|
||||
@ -83,14 +81,19 @@ class TestQlinearPackedParams(TestCase):
|
||||
True,
|
||||
[0, 2, 3, 3],
|
||||
[0, 1001, 0],
|
||||
[x + (1 if allow_non_zero_zero_points else 0) for x in [
|
||||
2, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0
|
||||
]],
|
||||
[
|
||||
x + (1 if allow_non_zero_zero_points else 0)
|
||||
for x in [2, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0]
|
||||
],
|
||||
)
|
||||
|
||||
for (weight, is_per_tensor_quantized, expected_row_block_indices, expected_col_block_indices, expected_weights) in [
|
||||
per_tensor_small, per_channel_small, per_tensor_large
|
||||
]:
|
||||
for (
|
||||
weight,
|
||||
is_per_tensor_quantized,
|
||||
expected_row_block_indices,
|
||||
expected_col_block_indices,
|
||||
expected_weights,
|
||||
) in [per_tensor_small, per_channel_small, per_tensor_large]:
|
||||
lin = Linear(
|
||||
out_features=weight.shape[0],
|
||||
in_features=weight.shape[1],
|
||||
@ -118,24 +121,36 @@ class TestQlinearPackedParams(TestCase):
|
||||
col_block_indices_,
|
||||
weights_,
|
||||
output_channels_,
|
||||
input_channels_
|
||||
input_channels_,
|
||||
) = serialized[0]
|
||||
|
||||
# Test Serialization
|
||||
self.assertEqual(bias_, bias)
|
||||
self.assertEqual(out_features_block_size_, row_block_size)
|
||||
self.assertEqual(in_features_block_size_, col_block_size)
|
||||
self.assertEqual(weight_scales_, [scales[0]] if is_per_tensor_quantized else scales)
|
||||
self.assertEqual(weight_zero_points_, [zero_points[0]] if is_per_tensor_quantized else zero_points)
|
||||
self.assertEqual(
|
||||
weight_scales_, [scales[0]] if is_per_tensor_quantized else scales
|
||||
)
|
||||
self.assertEqual(
|
||||
weight_zero_points_,
|
||||
[zero_points[0]] if is_per_tensor_quantized else zero_points,
|
||||
)
|
||||
self.assertEqual(quantization_scheme_, is_per_tensor_quantized)
|
||||
self.assertEqual(row_block_indices_, expected_row_block_indices)
|
||||
self.assertEqual(col_block_indices_, expected_col_block_indices)
|
||||
self.assertEqual(weights_.tolist(), [v + 128 for v in expected_weights]) # weights are serialized as +128
|
||||
self.assertEqual(
|
||||
weights_.tolist(), [v + 128 for v in expected_weights]
|
||||
) # weights are serialized as +128
|
||||
self.assertEqual(output_channels_, weight.shape[0])
|
||||
self.assertEqual(input_channels_, weight.shape[1])
|
||||
|
||||
# Test Unpacking
|
||||
(weights_, bias_, out_features_block_size_, in_features_block_size_) = lin._weight_bias()
|
||||
(
|
||||
weights_,
|
||||
bias_,
|
||||
out_features_block_size_,
|
||||
in_features_block_size_,
|
||||
) = lin._weight_bias()
|
||||
self.assertEqual(torch.dequantize(weights_), torch.dequantize(weight))
|
||||
self.assertEqual(bias_, bias)
|
||||
self.assertEqual(out_features_block_size_, row_block_size)
|
||||
@ -148,7 +163,9 @@ class TestQlinearPackedParams(TestCase):
|
||||
lin2 = torch.load(file_buff)
|
||||
self.assertEqual(lin._weight_bias(), lin2._weight_bias())
|
||||
# Serialize -> Deserialize -> Serialize should match Serialize
|
||||
self.assertEqual(serialized, lin2._packed_params._packed_params.__getstate__())
|
||||
self.assertEqual(
|
||||
serialized, lin2._packed_params._packed_params.__getstate__()
|
||||
)
|
||||
|
||||
# Test that op output is preserved by serialize -> deserialize
|
||||
if qengine_is_qnnpack():
|
||||
@ -157,29 +174,29 @@ class TestQlinearPackedParams(TestCase):
|
||||
y2 = lin2(x)
|
||||
self.assertEqual(y1, y2)
|
||||
|
||||
|
||||
@skipIfNoFBGEMM
|
||||
def test_qlinear_packed_params_fbgemm(self):
|
||||
torch.manual_seed(0)
|
||||
with override_quantized_engine('fbgemm'):
|
||||
with override_quantized_engine("fbgemm"):
|
||||
self.qlinear_packed_params_test(allow_non_zero_zero_points=False)
|
||||
|
||||
|
||||
@skipIfNoQNNPACK
|
||||
def test_qlinear_packed_params_qnnpack(self):
|
||||
torch.manual_seed(0)
|
||||
with override_quantized_engine('qnnpack'):
|
||||
with override_quantized_engine("qnnpack"):
|
||||
with override_cpu_allocator_for_qnnpack(qengine_is_qnnpack()):
|
||||
self.qlinear_packed_params_test(allow_non_zero_zero_points=True)
|
||||
|
||||
def test_qlinear_packed_params_fbgemm_qnnpack_cross_compatibility(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
weight_fp32 = torch.Tensor([
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 0, 0, 0, 0],
|
||||
[6, 6, 6, 6, 12, 12, 12, 12, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
])
|
||||
weight_fp32 = torch.Tensor(
|
||||
[
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 0, 0, 0, 0],
|
||||
[6, 6, 6, 6, 12, 12, 12, 12, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
]
|
||||
)
|
||||
|
||||
row_block_size = 1
|
||||
col_block_size = 4
|
||||
@ -229,17 +246,23 @@ class TestQlinearPackedParams(TestCase):
|
||||
def packed_params_data_with_int32_indices(data_as_state_and_weight_bias):
|
||||
(st, weight_bias) = data_as_state_and_weight_bias
|
||||
(s0, s1) = st
|
||||
s0_updated = tuple([
|
||||
# 7 and 8 are row and col block indices respectively
|
||||
v if (i != 7 and i != 8) else v.to(torch.int32) for (i, v) in enumerate(list(s0))
|
||||
])
|
||||
s0_updated = tuple(
|
||||
[
|
||||
# 7 and 8 are row and col block indices respectively
|
||||
v if (i != 7 and i != 8) else v.to(torch.int32)
|
||||
for (i, v) in enumerate(list(s0))
|
||||
]
|
||||
)
|
||||
return ((s0_updated, s1), weight_bias)
|
||||
|
||||
# Test Fbgemm -> Qnnpack
|
||||
with override_quantized_engine('fbgemm'):
|
||||
packed_params_data_1a, file_buff_1 = make_lin_get_state_weight_bias_and_save()
|
||||
with override_quantized_engine("fbgemm"):
|
||||
(
|
||||
packed_params_data_1a,
|
||||
file_buff_1,
|
||||
) = make_lin_get_state_weight_bias_and_save()
|
||||
|
||||
with override_quantized_engine('qnnpack'):
|
||||
with override_quantized_engine("qnnpack"):
|
||||
with override_cpu_allocator_for_qnnpack(qengine_is_qnnpack()):
|
||||
packed_params_data_1b = load_get_state_weight_bias(file_buff_1)
|
||||
|
||||
@ -249,11 +272,14 @@ class TestQlinearPackedParams(TestCase):
|
||||
)
|
||||
|
||||
# Test Qnnpack -> Fbgemm
|
||||
with override_quantized_engine('qnnpack'):
|
||||
with override_quantized_engine("qnnpack"):
|
||||
with override_cpu_allocator_for_qnnpack(qengine_is_qnnpack()):
|
||||
packed_params_data_2a, file_buff_2 = make_lin_get_state_weight_bias_and_save()
|
||||
(
|
||||
packed_params_data_2a,
|
||||
file_buff_2,
|
||||
) = make_lin_get_state_weight_bias_and_save()
|
||||
|
||||
with override_quantized_engine('fbgemm'):
|
||||
with override_quantized_engine("fbgemm"):
|
||||
packed_params_data_2b = load_get_state_weight_bias(file_buff_2)
|
||||
|
||||
self.assertEqual(
|
||||
|
@ -1,43 +1,38 @@
|
||||
# Owner(s): ["module: unknown"]
|
||||
|
||||
import warnings
|
||||
|
||||
from torch import nn
|
||||
|
||||
from torch.ao.pruning import WeightNormSparsifier
|
||||
from torch.ao.pruning import BaseScheduler, LambdaSL, CubicSL
|
||||
from torch.ao.pruning import BaseScheduler, CubicSL, LambdaSL, WeightNormSparsifier
|
||||
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
|
||||
import warnings
|
||||
|
||||
class ImplementedScheduler(BaseScheduler):
|
||||
def get_sl(self):
|
||||
if self.last_epoch > 0:
|
||||
return [group['sparsity_level'] * 0.5
|
||||
for group in self.sparsifier.groups]
|
||||
return [group["sparsity_level"] * 0.5 for group in self.sparsifier.groups]
|
||||
else:
|
||||
return list(self.base_sl)
|
||||
|
||||
|
||||
class TestScheduler(TestCase):
|
||||
def test_constructor(self):
|
||||
model = nn.Sequential(
|
||||
nn.Linear(16, 16)
|
||||
)
|
||||
model = nn.Sequential(nn.Linear(16, 16))
|
||||
sparsifier = WeightNormSparsifier()
|
||||
sparsifier.prepare(model, config=None)
|
||||
scheduler = ImplementedScheduler(sparsifier)
|
||||
|
||||
assert scheduler.sparsifier is sparsifier
|
||||
assert scheduler._step_count == 1
|
||||
assert scheduler.base_sl == [sparsifier.groups[0]['sparsity_level']]
|
||||
assert scheduler.base_sl == [sparsifier.groups[0]["sparsity_level"]]
|
||||
|
||||
def test_order_of_steps(self):
|
||||
"""Checks if the warning is thrown if the scheduler step is called
|
||||
before the sparsifier step"""
|
||||
|
||||
model = nn.Sequential(
|
||||
nn.Linear(16, 16)
|
||||
)
|
||||
model = nn.Sequential(nn.Linear(16, 16))
|
||||
sparsifier = WeightNormSparsifier()
|
||||
sparsifier.prepare(model, config=None)
|
||||
scheduler = ImplementedScheduler(sparsifier)
|
||||
@ -54,43 +49,41 @@ class TestScheduler(TestCase):
|
||||
# Make sure there is no warning related to the base_scheduler
|
||||
for warning in w:
|
||||
fname = warning.filename
|
||||
fname = '/'.join(fname.split('/')[-5:])
|
||||
assert fname != 'torch/ao/sparsity/scheduler/base_scheduler.py'
|
||||
fname = "/".join(fname.split("/")[-5:])
|
||||
assert fname != "torch/ao/sparsity/scheduler/base_scheduler.py"
|
||||
|
||||
def test_step(self):
|
||||
model = nn.Sequential(
|
||||
nn.Linear(16, 16)
|
||||
)
|
||||
model = nn.Sequential(nn.Linear(16, 16))
|
||||
sparsifier = WeightNormSparsifier()
|
||||
sparsifier.prepare(model, config=None)
|
||||
assert sparsifier.groups[0]['sparsity_level'] == 0.5
|
||||
assert sparsifier.groups[0]["sparsity_level"] == 0.5
|
||||
scheduler = ImplementedScheduler(sparsifier)
|
||||
assert sparsifier.groups[0]['sparsity_level'] == 0.5
|
||||
assert sparsifier.groups[0]["sparsity_level"] == 0.5
|
||||
|
||||
sparsifier.step()
|
||||
scheduler.step()
|
||||
assert sparsifier.groups[0]['sparsity_level'] == 0.25
|
||||
assert sparsifier.groups[0]["sparsity_level"] == 0.25
|
||||
|
||||
def test_lambda_scheduler(self):
|
||||
model = nn.Sequential(
|
||||
nn.Linear(16, 16)
|
||||
)
|
||||
model = nn.Sequential(nn.Linear(16, 16))
|
||||
sparsifier = WeightNormSparsifier()
|
||||
sparsifier.prepare(model, config=None)
|
||||
assert sparsifier.groups[0]['sparsity_level'] == 0.5
|
||||
assert sparsifier.groups[0]["sparsity_level"] == 0.5
|
||||
scheduler = LambdaSL(sparsifier, lambda epoch: epoch * 10)
|
||||
assert sparsifier.groups[0]['sparsity_level'] == 0.0 # Epoch 0
|
||||
assert sparsifier.groups[0]["sparsity_level"] == 0.0 # Epoch 0
|
||||
scheduler.step()
|
||||
assert sparsifier.groups[0]['sparsity_level'] == 5.0 # Epoch 1
|
||||
assert sparsifier.groups[0]["sparsity_level"] == 5.0 # Epoch 1
|
||||
|
||||
|
||||
class TestCubicScheduler(TestCase):
|
||||
def setUp(self):
|
||||
self.model_sparse_config = [
|
||||
{'tensor_fqn': '0.weight', 'sparsity_level': 0.8},
|
||||
{'tensor_fqn': '2.weight', 'sparsity_level': 0.4},
|
||||
{"tensor_fqn": "0.weight", "sparsity_level": 0.8},
|
||||
{"tensor_fqn": "2.weight", "sparsity_level": 0.4},
|
||||
]
|
||||
self.sorted_sparse_levels = [
|
||||
conf["sparsity_level"] for conf in self.model_sparse_config
|
||||
]
|
||||
self.sorted_sparse_levels = [conf['sparsity_level'] for conf in self.model_sparse_config]
|
||||
self.initial_sparsity = 0.1
|
||||
self.initial_step = 3
|
||||
|
||||
@ -107,8 +100,8 @@ class TestCubicScheduler(TestCase):
|
||||
sparsifier.prepare(model, config=self.model_sparse_config)
|
||||
|
||||
scheduler_args = {
|
||||
'init_sl': self.initial_sparsity,
|
||||
'init_t': self.initial_step,
|
||||
"init_sl": self.initial_sparsity,
|
||||
"init_t": self.initial_step,
|
||||
}
|
||||
scheduler_args.update(kwargs)
|
||||
|
||||
@ -118,25 +111,33 @@ class TestCubicScheduler(TestCase):
|
||||
@staticmethod
|
||||
def _get_sparsity_levels(sparsifier, precision=32):
|
||||
r"""Gets the current levels of sparsity in a sparsifier."""
|
||||
return [round(group['sparsity_level'], precision) for group in sparsifier.groups]
|
||||
return [
|
||||
round(group["sparsity_level"], precision) for group in sparsifier.groups
|
||||
]
|
||||
|
||||
def test_constructor(self):
|
||||
model = self._make_model()
|
||||
sparsifier, scheduler = self._make_scheduler(model=model, initially_zero=True)
|
||||
self.assertIs(
|
||||
scheduler.sparsifier, sparsifier,
|
||||
msg="Sparsifier is not properly attached")
|
||||
scheduler.sparsifier, sparsifier, msg="Sparsifier is not properly attached"
|
||||
)
|
||||
self.assertEqual(
|
||||
scheduler._step_count, 1,
|
||||
msg="Scheduler is initialized with incorrect step count")
|
||||
scheduler._step_count,
|
||||
1,
|
||||
msg="Scheduler is initialized with incorrect step count",
|
||||
)
|
||||
self.assertEqual(
|
||||
scheduler.base_sl, self.sorted_sparse_levels,
|
||||
msg="Scheduler did not store the target sparsity levels correctly")
|
||||
scheduler.base_sl,
|
||||
self.sorted_sparse_levels,
|
||||
msg="Scheduler did not store the target sparsity levels correctly",
|
||||
)
|
||||
|
||||
# Value before t_0 is 0
|
||||
self.assertEqual(
|
||||
self._get_sparsity_levels(sparsifier), scheduler._make_sure_a_list(0.0),
|
||||
msg="Sparsifier is not reset correctly after attaching to the Scheduler")
|
||||
self._get_sparsity_levels(sparsifier),
|
||||
scheduler._make_sure_a_list(0.0),
|
||||
msg="Sparsifier is not reset correctly after attaching to the Scheduler",
|
||||
)
|
||||
|
||||
# Value before t_0 is s_0
|
||||
model = self._make_model()
|
||||
@ -144,36 +145,50 @@ class TestCubicScheduler(TestCase):
|
||||
self.assertEqual(
|
||||
self._get_sparsity_levels(sparsifier),
|
||||
scheduler._make_sure_a_list(self.initial_sparsity),
|
||||
msg="Sparsifier is not reset correctly after attaching to the Scheduler")
|
||||
msg="Sparsifier is not reset correctly after attaching to the Scheduler",
|
||||
)
|
||||
|
||||
def test_step(self):
|
||||
# For n=5, dt=2, there will be totally 10 steps between s_0 and s_f, starting from t_0
|
||||
model = self._make_model()
|
||||
sparsifier, scheduler = self._make_scheduler(
|
||||
model=model, initially_zero=True, init_t=3, delta_t=2, total_t=5)
|
||||
model=model, initially_zero=True, init_t=3, delta_t=2, total_t=5
|
||||
)
|
||||
|
||||
scheduler.step()
|
||||
scheduler.step()
|
||||
self.assertEqual(scheduler._step_count, 3, msg="Scheduler step_count is expected to increment")
|
||||
self.assertEqual(
|
||||
scheduler._step_count,
|
||||
3,
|
||||
msg="Scheduler step_count is expected to increment",
|
||||
)
|
||||
# Value before t_0 is supposed to be 0
|
||||
self.assertEqual(
|
||||
self._get_sparsity_levels(sparsifier), scheduler._make_sure_a_list(0.0),
|
||||
msg="Scheduler step updating the sparsity level before t_0")
|
||||
self._get_sparsity_levels(sparsifier),
|
||||
scheduler._make_sure_a_list(0.0),
|
||||
msg="Scheduler step updating the sparsity level before t_0",
|
||||
)
|
||||
|
||||
scheduler.step() # Step = 3 => sparsity = initial_sparsity
|
||||
self.assertEqual(
|
||||
self._get_sparsity_levels(sparsifier), scheduler._make_sure_a_list(self.initial_sparsity),
|
||||
msg="Sparsifier is not reset to initial sparsity at the first step")
|
||||
self._get_sparsity_levels(sparsifier),
|
||||
scheduler._make_sure_a_list(self.initial_sparsity),
|
||||
msg="Sparsifier is not reset to initial sparsity at the first step",
|
||||
)
|
||||
|
||||
scheduler.step() # Step = 4 => sparsity ~ [0.3, 0.2]
|
||||
self.assertEqual(
|
||||
self._get_sparsity_levels(sparsifier, 1), [0.3, 0.2],
|
||||
msg="Sparsity level is not set correctly after the first step")
|
||||
self._get_sparsity_levels(sparsifier, 1),
|
||||
[0.3, 0.2],
|
||||
msg="Sparsity level is not set correctly after the first step",
|
||||
)
|
||||
|
||||
current_step = scheduler._step_count - scheduler.init_t[0] - 1
|
||||
more_steps_needed = scheduler.delta_t[0] * scheduler.total_t[0] - current_step
|
||||
for _ in range(more_steps_needed): # More steps needed to final sparsity level
|
||||
scheduler.step()
|
||||
self.assertEqual(
|
||||
self._get_sparsity_levels(sparsifier), self.sorted_sparse_levels,
|
||||
msg="Sparsity level is not reaching the target level afer delta_t * n steps ")
|
||||
self._get_sparsity_levels(sparsifier),
|
||||
self.sorted_sparse_levels,
|
||||
msg="Sparsity level is not reaching the target level afer delta_t * n steps ",
|
||||
)
|
||||
|
@ -6,13 +6,24 @@ import re
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.ao.pruning import BaseSparsifier, WeightNormSparsifier, FakeSparsity, NearlyDiagonalSparsifier
|
||||
from torch.ao.pruning import (
|
||||
BaseSparsifier,
|
||||
FakeSparsity,
|
||||
NearlyDiagonalSparsifier,
|
||||
WeightNormSparsifier,
|
||||
)
|
||||
from torch.nn.utils.parametrize import is_parametrized
|
||||
from torch.testing._internal.common_pruning import (
|
||||
ImplementedSparsifier,
|
||||
MockSparseLinear,
|
||||
SimpleLinear,
|
||||
)
|
||||
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
from torch.testing._internal.common_pruning import SimpleLinear, MockSparseLinear, ImplementedSparsifier
|
||||
|
||||
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO)
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
|
||||
)
|
||||
|
||||
|
||||
class TestBaseSparsifier(TestCase):
|
||||
@ -27,41 +38,44 @@ class TestBaseSparsifier(TestCase):
|
||||
sparsifier.step()
|
||||
# Can instantiate the model with configs
|
||||
sparsifier = ImplementedSparsifier(test=3)
|
||||
sparsifier.prepare(model, [{'tensor_fqn': 'linear1.weight'}])
|
||||
sparsifier.prepare(model, [{"tensor_fqn": "linear1.weight"}])
|
||||
assert len(sparsifier.groups) == 1
|
||||
assert sparsifier.groups[0]['tensor_fqn'] == 'linear1.weight'
|
||||
assert 'test' in sparsifier.groups[0]
|
||||
assert sparsifier.groups[0]['test'] == 3
|
||||
assert sparsifier.groups[0]["tensor_fqn"] == "linear1.weight"
|
||||
assert "test" in sparsifier.groups[0]
|
||||
assert sparsifier.groups[0]["test"] == 3
|
||||
|
||||
def test_prepare_config(self):
|
||||
model = SimpleLinear()
|
||||
sparsifier = ImplementedSparsifier(test=3)
|
||||
# Make sure there are no parametrizations before `prepare`
|
||||
assert not hasattr(model.seq[0], 'parametrizations')
|
||||
assert not hasattr(model.linear1, 'parametrizations')
|
||||
assert not hasattr(model.linear2, 'parametrizations')
|
||||
sparsifier.prepare(model, config=[
|
||||
{'tensor_fqn': 'seq.0.weight', 'test': 42},
|
||||
# No 'linear1' to make sure it will be skipped in the sparsification
|
||||
{'tensor_fqn': 'linear2.weight'}
|
||||
])
|
||||
assert not hasattr(model.seq[0], "parametrizations")
|
||||
assert not hasattr(model.linear1, "parametrizations")
|
||||
assert not hasattr(model.linear2, "parametrizations")
|
||||
sparsifier.prepare(
|
||||
model,
|
||||
config=[
|
||||
{"tensor_fqn": "seq.0.weight", "test": 42},
|
||||
# No 'linear1' to make sure it will be skipped in the sparsification
|
||||
{"tensor_fqn": "linear2.weight"},
|
||||
],
|
||||
)
|
||||
assert len(sparsifier.groups) == 2
|
||||
# Check if default argument is not assigned if explicit
|
||||
assert sparsifier.groups[0]['tensor_fqn'] == 'seq.0.weight'
|
||||
assert sparsifier.groups[0]['test'] == 42
|
||||
assert sparsifier.groups[0]["tensor_fqn"] == "seq.0.weight"
|
||||
assert sparsifier.groups[0]["test"] == 42
|
||||
# Check if FQN and module are pointing to the same location
|
||||
assert sparsifier.groups[1]['tensor_fqn'] == 'linear2.weight'
|
||||
assert sparsifier.groups[1]['module'] == model.linear2
|
||||
assert sparsifier.groups[1]["tensor_fqn"] == "linear2.weight"
|
||||
assert sparsifier.groups[1]["module"] == model.linear2
|
||||
# Check if parameterizations are attached
|
||||
assert hasattr(model.seq[0], 'parametrizations')
|
||||
assert not hasattr(model.linear1, 'parametrizations')
|
||||
assert hasattr(model.linear2, 'parametrizations')
|
||||
assert hasattr(model.seq[0], "parametrizations")
|
||||
assert not hasattr(model.linear1, "parametrizations")
|
||||
assert hasattr(model.linear2, "parametrizations")
|
||||
|
||||
def test_step(self):
|
||||
model = SimpleLinear()
|
||||
sparsifier = ImplementedSparsifier(test=3)
|
||||
sparsifier.enable_mask_update = True
|
||||
sparsifier.prepare(model, [{'tensor_fqn': 'linear1.weight'}])
|
||||
sparsifier.prepare(model, [{"tensor_fqn": "linear1.weight"}])
|
||||
sparsifier.step()
|
||||
assert torch.all(model.linear1.parametrizations.weight[0].mask[0] == 0)
|
||||
|
||||
@ -69,21 +83,21 @@ class TestBaseSparsifier(TestCase):
|
||||
step_count = 3
|
||||
model0 = SimpleLinear()
|
||||
sparsifier0 = ImplementedSparsifier(test=3)
|
||||
sparsifier0.prepare(model0, [{'tensor_fqn': 'linear1.weight'}])
|
||||
mask = model0.linear1.parametrizations['weight'][0].mask
|
||||
sparsifier0.prepare(model0, [{"tensor_fqn": "linear1.weight"}])
|
||||
mask = model0.linear1.parametrizations["weight"][0].mask
|
||||
mask.data = torch.arange(mask.shape[0] * mask.shape[1]).reshape(mask.shape)
|
||||
for step in range(step_count):
|
||||
sparsifier0.step()
|
||||
state_dict = sparsifier0.state_dict()
|
||||
|
||||
# Check the expected keys in the state_dict
|
||||
assert 'state' in state_dict
|
||||
assert 'step_count' in state_dict['state']['linear1.weight']
|
||||
assert state_dict['state']['linear1.weight']['step_count'] == 3
|
||||
assert 'groups' in state_dict
|
||||
assert 'test' in state_dict['groups'][0]
|
||||
assert 'tensor_fqn' in state_dict['groups'][0]
|
||||
assert state_dict['groups'][0]['tensor_fqn'] == 'linear1.weight'
|
||||
assert "state" in state_dict
|
||||
assert "step_count" in state_dict["state"]["linear1.weight"]
|
||||
assert state_dict["state"]["linear1.weight"]["step_count"] == 3
|
||||
assert "groups" in state_dict
|
||||
assert "test" in state_dict["groups"][0]
|
||||
assert "tensor_fqn" in state_dict["groups"][0]
|
||||
assert state_dict["groups"][0]["tensor_fqn"] == "linear1.weight"
|
||||
|
||||
# Check loading static_dict creates an equivalent model
|
||||
model1 = SimpleLinear()
|
||||
@ -94,11 +108,11 @@ class TestBaseSparsifier(TestCase):
|
||||
|
||||
# Make sure the masks are different in the beginning
|
||||
for mg in sparsifier0.groups:
|
||||
if mg['tensor_fqn'] == 'linear1.weight':
|
||||
mask0 = mg['module'].parametrizations.weight[0].mask
|
||||
if mg["tensor_fqn"] == "linear1.weight":
|
||||
mask0 = mg["module"].parametrizations.weight[0].mask
|
||||
for mg in sparsifier1.groups:
|
||||
if mg['tensor_fqn'] == 'linear1.weight':
|
||||
mask1 = mg['module'].parametrizations.weight[0].mask
|
||||
if mg["tensor_fqn"] == "linear1.weight":
|
||||
mask1 = mg["module"].parametrizations.weight[0].mask
|
||||
self.assertNotEqual(mask0, mask1)
|
||||
|
||||
sparsifier1.load_state_dict(state_dict)
|
||||
@ -113,12 +127,12 @@ class TestBaseSparsifier(TestCase):
|
||||
mg1 = sparsifier1.groups[idx]
|
||||
for key in mg0.keys():
|
||||
assert key in mg1
|
||||
if key == 'module':
|
||||
if key == "module":
|
||||
# We cannot compare modules as they are different
|
||||
param0 = mg0[key].parametrizations.weight[0]
|
||||
param1 = mg1[key].parametrizations.weight[0]
|
||||
assert hasattr(param0, 'mask')
|
||||
assert hasattr(param1, 'mask')
|
||||
assert hasattr(param0, "mask")
|
||||
assert hasattr(param1, "mask")
|
||||
self.assertEqual(param0.__dict__, param1.__dict__)
|
||||
else:
|
||||
assert mg0[key] == mg1[key]
|
||||
@ -126,80 +140,84 @@ class TestBaseSparsifier(TestCase):
|
||||
def test_convert(self):
|
||||
model = SimpleLinear()
|
||||
sparsifier = ImplementedSparsifier(test=3)
|
||||
sparsifier.prepare(model, [{'tensor_fqn': 'linear1.weight'}])
|
||||
new_model = sparsifier.convert(model, mapping={nn.Linear: MockSparseLinear}, inplace=False)
|
||||
sparsifier.prepare(model, [{"tensor_fqn": "linear1.weight"}])
|
||||
new_model = sparsifier.convert(
|
||||
model, mapping={nn.Linear: MockSparseLinear}, inplace=False
|
||||
)
|
||||
|
||||
assert isinstance(new_model.linear1, MockSparseLinear)
|
||||
assert isinstance(new_model.seq[0], nn.Linear)
|
||||
assert isinstance(new_model.linear2, nn.Linear)
|
||||
|
||||
|
||||
|
||||
def test_mask_squash(self):
|
||||
model = SimpleLinear()
|
||||
sparsifier = ImplementedSparsifier(test=3)
|
||||
sparsifier.prepare(model, [{'tensor_fqn': 'linear1.weight'}])
|
||||
assert hasattr(model.linear1.parametrizations.weight[0], 'mask')
|
||||
assert is_parametrized(model.linear1, 'weight')
|
||||
assert not is_parametrized(model.seq[0], 'weight')
|
||||
sparsifier.prepare(model, [{"tensor_fqn": "linear1.weight"}])
|
||||
assert hasattr(model.linear1.parametrizations.weight[0], "mask")
|
||||
assert is_parametrized(model.linear1, "weight")
|
||||
assert not is_parametrized(model.seq[0], "weight")
|
||||
|
||||
sparsifier.squash_mask()
|
||||
assert not is_parametrized(model.seq[0], 'weight')
|
||||
assert not is_parametrized(model.linear1, 'weight')
|
||||
assert not is_parametrized(model.seq[0], "weight")
|
||||
assert not is_parametrized(model.linear1, "weight")
|
||||
|
||||
def test_mask_squash_with_params1(self):
|
||||
model = SimpleLinear()
|
||||
sparsifier = ImplementedSparsifier(foo=3, bar=2, baz=1)
|
||||
sparsifier.prepare(model, [{'tensor_fqn': 'linear1.weight'}, {'tensor_fqn': 'seq.0.weight'}])
|
||||
sparsifier.prepare(
|
||||
model, [{"tensor_fqn": "linear1.weight"}, {"tensor_fqn": "seq.0.weight"}]
|
||||
)
|
||||
sparsifier.squash_mask(
|
||||
params_to_keep_per_layer={
|
||||
'linear1': ('foo', 'bar'),
|
||||
'seq.0': ('baz',)
|
||||
})
|
||||
assert not is_parametrized(model.seq[0], 'weight')
|
||||
assert not is_parametrized(model.linear1, 'weight')
|
||||
assert hasattr(model.seq[0], 'sparse_params')
|
||||
assert hasattr(model.linear1, 'sparse_params')
|
||||
assert model.seq[0].sparse_params.get('foo', None) is None
|
||||
assert model.seq[0].sparse_params.get('bar', None) is None
|
||||
assert model.seq[0].sparse_params.get('baz', None) == 1
|
||||
assert model.linear1.sparse_params.get('foo', None) == 3
|
||||
assert model.linear1.sparse_params.get('bar', None) == 2
|
||||
assert model.linear1.sparse_params.get('baz', None) is None
|
||||
params_to_keep_per_layer={"linear1": ("foo", "bar"), "seq.0": ("baz",)}
|
||||
)
|
||||
assert not is_parametrized(model.seq[0], "weight")
|
||||
assert not is_parametrized(model.linear1, "weight")
|
||||
assert hasattr(model.seq[0], "sparse_params")
|
||||
assert hasattr(model.linear1, "sparse_params")
|
||||
assert model.seq[0].sparse_params.get("foo", None) is None
|
||||
assert model.seq[0].sparse_params.get("bar", None) is None
|
||||
assert model.seq[0].sparse_params.get("baz", None) == 1
|
||||
assert model.linear1.sparse_params.get("foo", None) == 3
|
||||
assert model.linear1.sparse_params.get("bar", None) == 2
|
||||
assert model.linear1.sparse_params.get("baz", None) is None
|
||||
|
||||
def test_mask_squash_with_params2(self):
|
||||
model = SimpleLinear()
|
||||
sparsifier = ImplementedSparsifier(foo=3, bar=2, baz=1)
|
||||
sparsifier.prepare(model, [{'tensor_fqn': 'linear1.weight'}, {'tensor_fqn': 'seq.0.weight'}])
|
||||
sparsifier.squash_mask(params_to_keep=('foo', 'bar'))
|
||||
assert not is_parametrized(model.seq[0], 'weight')
|
||||
assert not is_parametrized(model.linear1, 'weight')
|
||||
assert hasattr(model.seq[0], 'sparse_params')
|
||||
assert hasattr(model.linear1, 'sparse_params')
|
||||
assert model.seq[0].sparse_params.get('foo', None) == 3
|
||||
assert model.seq[0].sparse_params.get('bar', None) == 2
|
||||
assert model.seq[0].sparse_params.get('baz', None) is None
|
||||
assert model.linear1.sparse_params.get('foo', None) == 3
|
||||
assert model.linear1.sparse_params.get('bar', None) == 2
|
||||
assert model.linear1.sparse_params.get('baz', None) is None
|
||||
sparsifier.prepare(
|
||||
model, [{"tensor_fqn": "linear1.weight"}, {"tensor_fqn": "seq.0.weight"}]
|
||||
)
|
||||
sparsifier.squash_mask(params_to_keep=("foo", "bar"))
|
||||
assert not is_parametrized(model.seq[0], "weight")
|
||||
assert not is_parametrized(model.linear1, "weight")
|
||||
assert hasattr(model.seq[0], "sparse_params")
|
||||
assert hasattr(model.linear1, "sparse_params")
|
||||
assert model.seq[0].sparse_params.get("foo", None) == 3
|
||||
assert model.seq[0].sparse_params.get("bar", None) == 2
|
||||
assert model.seq[0].sparse_params.get("baz", None) is None
|
||||
assert model.linear1.sparse_params.get("foo", None) == 3
|
||||
assert model.linear1.sparse_params.get("bar", None) == 2
|
||||
assert model.linear1.sparse_params.get("baz", None) is None
|
||||
|
||||
def test_mask_squash_with_params3(self):
|
||||
model = SimpleLinear()
|
||||
sparsifier = ImplementedSparsifier(foo=3, bar=2, baz=1)
|
||||
sparsifier.prepare(model, [{'tensor_fqn': 'linear1.weight'}, {'tensor_fqn': 'seq.0.weight'}])
|
||||
sparsifier.prepare(
|
||||
model, [{"tensor_fqn": "linear1.weight"}, {"tensor_fqn": "seq.0.weight"}]
|
||||
)
|
||||
sparsifier.squash_mask(
|
||||
params_to_keep=('foo', 'bar'),
|
||||
params_to_keep_per_layer={'seq.0': ('baz',)})
|
||||
assert not is_parametrized(model.seq[0], 'weight')
|
||||
assert not is_parametrized(model.linear1, 'weight')
|
||||
assert hasattr(model.seq[0], 'sparse_params')
|
||||
assert hasattr(model.linear1, 'sparse_params')
|
||||
assert model.seq[0].sparse_params.get('foo', None) == 3
|
||||
assert model.seq[0].sparse_params.get('bar', None) == 2
|
||||
assert model.seq[0].sparse_params.get('baz', None) == 1
|
||||
assert model.linear1.sparse_params.get('foo', None) == 3
|
||||
assert model.linear1.sparse_params.get('bar', None) == 2
|
||||
assert model.linear1.sparse_params.get('baz', None) is None
|
||||
params_to_keep=("foo", "bar"), params_to_keep_per_layer={"seq.0": ("baz",)}
|
||||
)
|
||||
assert not is_parametrized(model.seq[0], "weight")
|
||||
assert not is_parametrized(model.linear1, "weight")
|
||||
assert hasattr(model.seq[0], "sparse_params")
|
||||
assert hasattr(model.linear1, "sparse_params")
|
||||
assert model.seq[0].sparse_params.get("foo", None) == 3
|
||||
assert model.seq[0].sparse_params.get("bar", None) == 2
|
||||
assert model.seq[0].sparse_params.get("baz", None) == 1
|
||||
assert model.linear1.sparse_params.get("foo", None) == 3
|
||||
assert model.linear1.sparse_params.get("bar", None) == 2
|
||||
assert model.linear1.sparse_params.get("baz", None) is None
|
||||
|
||||
|
||||
class TestWeightNormSparsifier(TestCase):
|
||||
@ -208,25 +226,33 @@ class TestWeightNormSparsifier(TestCase):
|
||||
sparsifier = WeightNormSparsifier()
|
||||
sparsifier.prepare(model, config=None)
|
||||
for g in sparsifier.groups:
|
||||
assert isinstance(g['module'], nn.Linear)
|
||||
assert isinstance(g["module"], nn.Linear)
|
||||
# The groups are unordered
|
||||
assert g['module_fqn'] in ('seq.0', 'seq.1', 'seq.2', 'linear1', 'linear2')
|
||||
assert g["module_fqn"] in ("seq.0", "seq.1", "seq.2", "linear1", "linear2")
|
||||
|
||||
def test_step(self):
|
||||
model = SimpleLinear()
|
||||
sparsifier = WeightNormSparsifier(sparsity_level=0.5)
|
||||
sparsifier.prepare(model, config=[{'tensor_fqn': 'linear1.weight'}])
|
||||
sparsifier.prepare(model, config=[{"tensor_fqn": "linear1.weight"}])
|
||||
for g in sparsifier.groups:
|
||||
# Before step
|
||||
module = g['module']
|
||||
assert (1.0 - module.parametrizations['weight'][0].mask.mean()) == 0 # checking sparsity level is 0
|
||||
module = g["module"]
|
||||
assert (
|
||||
1.0 - module.parametrizations["weight"][0].mask.mean()
|
||||
) == 0 # checking sparsity level is 0
|
||||
sparsifier.enable_mask_update = True
|
||||
sparsifier.step()
|
||||
self.assertAlmostEqual(model.linear1.parametrizations['weight'][0].mask.mean().item(), 0.5, places=2)
|
||||
self.assertAlmostEqual(
|
||||
model.linear1.parametrizations["weight"][0].mask.mean().item(),
|
||||
0.5,
|
||||
places=2,
|
||||
)
|
||||
for g in sparsifier.groups:
|
||||
# After step
|
||||
module = g['module']
|
||||
assert (1.0 - module.parametrizations['weight'][0].mask.mean()) > 0 # checking sparsity level has increased
|
||||
module = g["module"]
|
||||
assert (
|
||||
1.0 - module.parametrizations["weight"][0].mask.mean()
|
||||
) > 0 # checking sparsity level has increased
|
||||
# Test if the mask collapses to all zeros if the weights are randomized
|
||||
iters_before_collapse = 1000
|
||||
for _ in range(iters_before_collapse):
|
||||
@ -234,25 +260,29 @@ class TestWeightNormSparsifier(TestCase):
|
||||
sparsifier.step()
|
||||
for g in sparsifier.groups:
|
||||
# After step
|
||||
module = g['module']
|
||||
assert (1.0 - module.parametrizations['weight'][0].mask.mean()) > 0 # checking sparsity level did not collapse
|
||||
module = g["module"]
|
||||
assert (
|
||||
1.0 - module.parametrizations["weight"][0].mask.mean()
|
||||
) > 0 # checking sparsity level did not collapse
|
||||
|
||||
def test_step_2_of_4(self):
|
||||
model = SimpleLinear()
|
||||
sparsifier = WeightNormSparsifier(sparsity_level=1.0,
|
||||
sparse_block_shape=(1, 4),
|
||||
zeros_per_block=2)
|
||||
sparsifier.prepare(model, config=[{'tensor_fqn': 'linear1.weight'}])
|
||||
sparsifier = WeightNormSparsifier(
|
||||
sparsity_level=1.0, sparse_block_shape=(1, 4), zeros_per_block=2
|
||||
)
|
||||
sparsifier.prepare(model, config=[{"tensor_fqn": "linear1.weight"}])
|
||||
sparsifier.step()
|
||||
# make sure the sparsity level is approximately 50%
|
||||
mask = model.linear1.parametrizations['weight'][0].mask.to(torch.float) # mean works on float only
|
||||
mask = model.linear1.parametrizations["weight"][0].mask.to(
|
||||
torch.float
|
||||
) # mean works on float only
|
||||
self.assertAlmostEqual(mask.mean().item(), 0.5, places=2)
|
||||
# Make sure each block has exactly 50% zeros
|
||||
module = sparsifier.groups[0]['module']
|
||||
mask = module.parametrizations['weight'][0].mask
|
||||
module = sparsifier.groups[0]["module"]
|
||||
mask = module.parametrizations["weight"][0].mask
|
||||
for row in mask:
|
||||
for idx in range(0, len(row), 4):
|
||||
block = row[idx:idx + 4]
|
||||
block = row[idx : idx + 4]
|
||||
block, _ = block.sort()
|
||||
assert (block[:2] == 0).all()
|
||||
assert (block[2:] != 0).all()
|
||||
@ -262,11 +292,11 @@ class TestWeightNormSparsifier(TestCase):
|
||||
sparsifier = WeightNormSparsifier()
|
||||
sparsifier.prepare(model, config=None)
|
||||
for g in sparsifier.groups:
|
||||
module = g['module']
|
||||
module = g["module"]
|
||||
# Check mask exists
|
||||
assert hasattr(module.parametrizations['weight'][0], 'mask')
|
||||
assert hasattr(module.parametrizations["weight"][0], "mask")
|
||||
# Check parametrization exists and is correct
|
||||
assert is_parametrized(module, 'weight')
|
||||
assert is_parametrized(module, "weight")
|
||||
assert type(module.parametrizations.weight[0]) == FakeSparsity
|
||||
|
||||
def test_mask_squash(self):
|
||||
@ -275,39 +305,39 @@ class TestWeightNormSparsifier(TestCase):
|
||||
sparsifier.prepare(model, config=None)
|
||||
sparsifier.squash_mask()
|
||||
for g in sparsifier.groups:
|
||||
module = g['module']
|
||||
assert not is_parametrized(module, 'weight')
|
||||
assert not hasattr(module, 'mask')
|
||||
module = g["module"]
|
||||
assert not is_parametrized(module, "weight")
|
||||
assert not hasattr(module, "mask")
|
||||
|
||||
def test_sparsity_levels(self):
|
||||
sparsity_levels = [-1.0, 0.0, 0.5, 1.0, 2.0]
|
||||
sparse_block_shapes = [(1, 1), (1, 4), (2, 2), (4, 1)]
|
||||
zeros_per_blocks = [0, 1, 2, 3, 4]
|
||||
|
||||
testcases = itertools.tee(itertools.product(sparsity_levels,
|
||||
sparse_block_shapes,
|
||||
zeros_per_blocks))
|
||||
testcases = itertools.tee(
|
||||
itertools.product(sparsity_levels, sparse_block_shapes, zeros_per_blocks)
|
||||
)
|
||||
# Create a config and model with all the testcases
|
||||
model = nn.Sequential()
|
||||
sparsifier = WeightNormSparsifier()
|
||||
|
||||
sparsity_per_layer_config = []
|
||||
p = re.compile(r'[-\.\s]')
|
||||
p = re.compile(r"[-\.\s]")
|
||||
for sl, sbs, zpb in testcases[0]:
|
||||
# Make sure the number of zeros is not > values in a block
|
||||
if zpb > sbs[0] * sbs[1]:
|
||||
continue
|
||||
layer_name = f'{sl}_{sbs}_{zpb}'
|
||||
layer_name = p.sub('_', layer_name)
|
||||
layer_name = f"{sl}_{sbs}_{zpb}"
|
||||
layer_name = p.sub("_", layer_name)
|
||||
|
||||
layer = nn.Linear(12, 12, bias=False)
|
||||
layer.weight = nn.Parameter(torch.ones(12, 12))
|
||||
model.add_module(layer_name, layer)
|
||||
config = {
|
||||
'tensor_fqn': layer_name + ".weight",
|
||||
'sparsity_level': sl,
|
||||
'sparse_block_shape': sbs,
|
||||
'zeros_per_block': zpb
|
||||
"tensor_fqn": layer_name + ".weight",
|
||||
"sparsity_level": sl,
|
||||
"sparse_block_shape": sbs,
|
||||
"zeros_per_block": zpb,
|
||||
}
|
||||
sparsity_per_layer_config.append(config)
|
||||
|
||||
@ -319,8 +349,8 @@ class TestWeightNormSparsifier(TestCase):
|
||||
for sl, sbs, zpb in testcases[1]:
|
||||
if zpb > sbs[0] * sbs[1]:
|
||||
continue
|
||||
layer_name = f'{sl}_{sbs}_{zpb}'
|
||||
layer_name = p.sub('_', layer_name)
|
||||
layer_name = f"{sl}_{sbs}_{zpb}"
|
||||
layer_name = p.sub("_", layer_name)
|
||||
layer = getattr(model, layer_name)
|
||||
|
||||
# Level of sparsity is achieved
|
||||
@ -340,30 +370,34 @@ class TestNearlyDiagonalSparsifier(TestCase):
|
||||
sparsifier = NearlyDiagonalSparsifier(nearliness=1)
|
||||
sparsifier.prepare(model, config=None)
|
||||
for g in sparsifier.groups:
|
||||
assert isinstance(g['module'], nn.Linear)
|
||||
assert isinstance(g["module"], nn.Linear)
|
||||
# The groups are unordered
|
||||
assert g['module_fqn'] in ('seq.0', 'seq.1', 'seq.2', 'linear1', 'linear2')
|
||||
assert g["module_fqn"] in ("seq.0", "seq.1", "seq.2", "linear1", "linear2")
|
||||
|
||||
def test_step(self):
|
||||
model = SimpleLinear()
|
||||
sparsifier = NearlyDiagonalSparsifier(nearliness=1)
|
||||
sparsifier.prepare(model, config=[{'tensor_fqn': 'linear1.weight'}])
|
||||
sparsifier.prepare(model, config=[{"tensor_fqn": "linear1.weight"}])
|
||||
|
||||
for g in sparsifier.groups:
|
||||
# Before step
|
||||
module = g['module']
|
||||
assert (1.0 - module.parametrizations['weight'][0].mask.mean()) == 0 # checking sparsity level is 0
|
||||
module = g["module"]
|
||||
assert (
|
||||
1.0 - module.parametrizations["weight"][0].mask.mean()
|
||||
) == 0 # checking sparsity level is 0
|
||||
|
||||
sparsifier.enable_mask_update = True
|
||||
sparsifier.step()
|
||||
mask = module.parametrizations['weight'][0].mask
|
||||
mask = module.parametrizations["weight"][0].mask
|
||||
height, width = mask.shape
|
||||
assert torch.all(mask == torch.eye(height, width))
|
||||
|
||||
for g in sparsifier.groups:
|
||||
# After step
|
||||
module = g['module']
|
||||
assert (1.0 - module.parametrizations['weight'][0].mask.mean()) > 0 # checking sparsity level has increased
|
||||
module = g["module"]
|
||||
assert (
|
||||
1.0 - module.parametrizations["weight"][0].mask.mean()
|
||||
) > 0 # checking sparsity level has increased
|
||||
|
||||
# Test if the mask collapses to all zeros if the weights are randomized
|
||||
iters_before_collapse = 1000
|
||||
@ -372,19 +406,21 @@ class TestNearlyDiagonalSparsifier(TestCase):
|
||||
sparsifier.step()
|
||||
for g in sparsifier.groups:
|
||||
# After step
|
||||
module = g['module']
|
||||
assert (1.0 - module.parametrizations['weight'][0].mask.mean()) > 0 # checking sparsity level did not collapse
|
||||
module = g["module"]
|
||||
assert (
|
||||
1.0 - module.parametrizations["weight"][0].mask.mean()
|
||||
) > 0 # checking sparsity level did not collapse
|
||||
|
||||
def test_prepare(self):
|
||||
model = SimpleLinear()
|
||||
sparsifier = NearlyDiagonalSparsifier(nearliness=1)
|
||||
sparsifier.prepare(model, config=None)
|
||||
for g in sparsifier.groups:
|
||||
module = g['module']
|
||||
module = g["module"]
|
||||
# Check mask exists
|
||||
assert hasattr(module.parametrizations['weight'][0], 'mask')
|
||||
assert hasattr(module.parametrizations["weight"][0], "mask")
|
||||
# Check parametrization exists and is correct
|
||||
assert is_parametrized(module, 'weight')
|
||||
assert is_parametrized(module, "weight")
|
||||
assert type(module.parametrizations.weight[0]) == FakeSparsity
|
||||
|
||||
def test_mask_squash(self):
|
||||
@ -394,36 +430,36 @@ class TestNearlyDiagonalSparsifier(TestCase):
|
||||
sparsifier.step()
|
||||
sparsifier.squash_mask()
|
||||
for g in sparsifier.groups:
|
||||
module = g['module']
|
||||
assert not is_parametrized(module, 'weight')
|
||||
assert not hasattr(module, 'mask')
|
||||
module = g["module"]
|
||||
assert not is_parametrized(module, "weight")
|
||||
assert not hasattr(module, "mask")
|
||||
weights = module.weight
|
||||
height, width = weights.shape
|
||||
assert torch.all(weights == torch.eye(height, width) * weights) # only diagonal to be present
|
||||
|
||||
assert torch.all(
|
||||
weights == torch.eye(height, width) * weights
|
||||
) # only diagonal to be present
|
||||
|
||||
def test_sparsity_levels(self):
|
||||
nearliness_levels = list(range(-1, 100))
|
||||
model = nn.Sequential()
|
||||
|
||||
p = re.compile(r'[-\.\s]')
|
||||
p = re.compile(r"[-\.\s]")
|
||||
for nearliness in nearliness_levels:
|
||||
sparsifier = NearlyDiagonalSparsifier(nearliness=1)
|
||||
layer_name = f'{nearliness}'
|
||||
layer_name = p.sub('_', layer_name)
|
||||
layer_name = f"{nearliness}"
|
||||
layer_name = p.sub("_", layer_name)
|
||||
|
||||
layer = nn.Linear(32, 32, bias=False)
|
||||
layer.weight = nn.Parameter(torch.ones(32, 32))
|
||||
width, height = layer.weight.shape
|
||||
model.add_module(layer_name, layer)
|
||||
config = {
|
||||
'tensor_fqn': layer_name + ".weight",
|
||||
'nearliness': nearliness
|
||||
}
|
||||
config = {"tensor_fqn": layer_name + ".weight", "nearliness": nearliness}
|
||||
|
||||
sparsifier.prepare(model, [config])
|
||||
# should raise a ValueError when nearliness arg is illegal
|
||||
if (nearliness > 0 and nearliness % 2 == 0) or (nearliness // 2 >= min(width, height)):
|
||||
if (nearliness > 0 and nearliness % 2 == 0) or (
|
||||
nearliness // 2 >= min(width, height)
|
||||
):
|
||||
with self.assertRaises(ValueError):
|
||||
sparsifier.step()
|
||||
else:
|
||||
|
@ -6,32 +6,32 @@ import random
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.ao.pruning._experimental.pruner import (
|
||||
SaliencyPruner,
|
||||
LSTMSaliencyPruner,
|
||||
BaseStructuredSparsifier,
|
||||
FakeStructuredSparsity,
|
||||
FPGMPruner
|
||||
FPGMPruner,
|
||||
LSTMSaliencyPruner,
|
||||
SaliencyPruner,
|
||||
)
|
||||
from torch.nn.utils import parametrize
|
||||
|
||||
from torch.testing._internal.common_utils import TestCase, skipIfTorchDynamo
|
||||
from torch.testing._internal.common_pruning import (
|
||||
SimpleLinear,
|
||||
LinearBias,
|
||||
LinearActivation,
|
||||
LinearActivationFunctional,
|
||||
SimpleConv2d,
|
||||
Conv2dBias,
|
||||
Conv2dActivation,
|
||||
Conv2dBias,
|
||||
Conv2dPadBias,
|
||||
Conv2dPool,
|
||||
Conv2dPoolFlatten,
|
||||
Conv2dPoolFlattenFunctional,
|
||||
LSTMLinearModel,
|
||||
LinearActivation,
|
||||
LinearActivationFunctional,
|
||||
LinearBias,
|
||||
LSTMLayerNormLinearModel,
|
||||
LSTMLinearModel,
|
||||
rows_are_subset,
|
||||
SimpleConv2d,
|
||||
SimpleLinear,
|
||||
)
|
||||
|
||||
from torch.testing._internal.common_utils import skipIfTorchDynamo, TestCase
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
|
||||
@ -73,6 +73,7 @@ class BottomHalfLSTMPruner(BaseStructuredSparsifier):
|
||||
new_mask = torch.cat(masks)
|
||||
mask.data = new_mask.data
|
||||
|
||||
|
||||
class TestSaliencyPruner(TestCase):
|
||||
def test_saliency_pruner_update_mask(self):
|
||||
"""Test that we prune out the row with the lowest saliency (first row)"""
|
||||
@ -103,14 +104,9 @@ class TestSaliencyPruner(TestCase):
|
||||
num_layers=1,
|
||||
)
|
||||
|
||||
manual_weights = torch.Tensor([[1, 1],
|
||||
[2, 2],
|
||||
[2, 2],
|
||||
[1, 1],
|
||||
[-1, -1],
|
||||
[-2, -2],
|
||||
[-2, -2],
|
||||
[-1, -1]])
|
||||
manual_weights = torch.Tensor(
|
||||
[[1, 1], [2, 2], [2, 2], [1, 1], [-1, -1], [-2, -2], [-2, -2], [-1, -1]]
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
model.lstm.weight_ih_l0 = nn.Parameter(manual_weights)
|
||||
@ -137,18 +133,12 @@ class TestSaliencyPruner(TestCase):
|
||||
pruned_model(lstm_input)
|
||||
|
||||
# make sure lowest saliency rows are pruned
|
||||
expected = torch.Tensor([[2, 2],
|
||||
[2, 2],
|
||||
[-2, -2],
|
||||
[-2, -2]])
|
||||
expected = torch.Tensor([[2, 2], [2, 2], [-2, -2], [-2, -2]])
|
||||
pruned = model.lstm.weight_ih_l0
|
||||
assert expected.shape == pruned.shape
|
||||
assert torch.isclose(expected, pruned, rtol=1e-05, atol=1e-07).all()
|
||||
|
||||
expected = torch.Tensor([[2],
|
||||
[2],
|
||||
[-2],
|
||||
[-2]])
|
||||
expected = torch.Tensor([[2], [2], [-2], [-2]])
|
||||
pruned = model.lstm.weight_hh_l0
|
||||
assert expected.shape == pruned.shape
|
||||
assert torch.isclose(expected, pruned, rtol=1e-05, atol=1e-07).all()
|
||||
@ -159,7 +149,6 @@ class TestSaliencyPruner(TestCase):
|
||||
assert torch.isclose(expected, pruned, rtol=1e-05, atol=1e-07).all()
|
||||
|
||||
|
||||
|
||||
class TestBaseStructuredSparsifier(TestCase):
|
||||
def _check_pruner_prepared(self, model, pruner, device):
|
||||
for config in pruner.groups:
|
||||
@ -916,15 +905,19 @@ class TestBaseStructuredSparsifier(TestCase):
|
||||
# linear columns correctly.
|
||||
assert out_expected.shape == out_pruned.shape
|
||||
|
||||
|
||||
class TestFPGMPruner(TestCase):
|
||||
"""
|
||||
Test case for the implementation of paper:
|
||||
`Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration <https://arxiv.org/abs/1811.00250>`_.
|
||||
"""
|
||||
|
||||
class SimpleConvFPGM(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv2d1 = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=3, padding=1, bias=False)
|
||||
self.conv2d1 = nn.Conv2d(
|
||||
in_channels=1, out_channels=3, kernel_size=3, padding=1, bias=False
|
||||
)
|
||||
# Manually set the filter weights for demonstration purposes
|
||||
"""
|
||||
Three filters' weight are manually set to values 3.0, 2.0, and 0.1.
|
||||
@ -933,13 +926,19 @@ class TestFPGMPruner(TestCase):
|
||||
"""
|
||||
weights = torch.tensor([3.0, 2.0, 0.1]) # Weight weights for each filter
|
||||
weights = weights[:, None, None, None] # broadcasting
|
||||
self.conv2d1.weight.data.copy_(torch.ones(self.conv2d1.weight.shape) * weights)
|
||||
self.conv2d1.weight.data.copy_(
|
||||
torch.ones(self.conv2d1.weight.shape) * weights
|
||||
)
|
||||
|
||||
# Second Convolutional Layer
|
||||
self.conv2d2 = nn.Conv2d(in_channels=3, out_channels=4, kernel_size=3, padding=1, bias=False)
|
||||
self.conv2d2 = nn.Conv2d(
|
||||
in_channels=3, out_channels=4, kernel_size=3, padding=1, bias=False
|
||||
)
|
||||
weights = torch.tensor([6.0, 7.0, 0.4, 0.5])
|
||||
weights = weights[:, None, None, None]
|
||||
self.conv2d2.weight.data.copy_(torch.ones(self.conv2d2.weight.shape) * weights)
|
||||
self.conv2d2.weight.data.copy_(
|
||||
torch.ones(self.conv2d2.weight.shape) * weights
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv2d1(x)
|
||||
@ -953,11 +952,43 @@ class TestFPGMPruner(TestCase):
|
||||
dist_conv1 = pruner._compute_distance(model.conv2d1.weight)
|
||||
|
||||
# compute the distance matrix using torch.cdist
|
||||
flattened_filters = torch.Tensor([
|
||||
[3.0000, 3.0000, 3.0000, 3.0000, 3.0000, 3.0000, 3.0000, 3.0000, 3.0000],
|
||||
[2.0000, 2.0000, 2.0000, 2.0000, 2.0000, 2.0000, 2.0000, 2.0000, 2.0000],
|
||||
[0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000]
|
||||
])
|
||||
flattened_filters = torch.Tensor(
|
||||
[
|
||||
[
|
||||
3.0000,
|
||||
3.0000,
|
||||
3.0000,
|
||||
3.0000,
|
||||
3.0000,
|
||||
3.0000,
|
||||
3.0000,
|
||||
3.0000,
|
||||
3.0000,
|
||||
],
|
||||
[
|
||||
2.0000,
|
||||
2.0000,
|
||||
2.0000,
|
||||
2.0000,
|
||||
2.0000,
|
||||
2.0000,
|
||||
2.0000,
|
||||
2.0000,
|
||||
2.0000,
|
||||
],
|
||||
[
|
||||
0.1000,
|
||||
0.1000,
|
||||
0.1000,
|
||||
0.1000,
|
||||
0.1000,
|
||||
0.1000,
|
||||
0.1000,
|
||||
0.1000,
|
||||
0.1000,
|
||||
],
|
||||
]
|
||||
)
|
||||
|
||||
"""
|
||||
Expected distance matrix should have the following values:
|
||||
@ -967,9 +998,13 @@ class TestFPGMPruner(TestCase):
|
||||
the distance should therefore be:
|
||||
[11.7000, 8.7000, 14.4000]
|
||||
"""
|
||||
expected_dist_matrix_conv1 = torch.cdist(flattened_filters, flattened_filters, p=2)
|
||||
expected_dist_matrix_conv1 = torch.cdist(
|
||||
flattened_filters, flattened_filters, p=2
|
||||
)
|
||||
expected_dist_conv1 = torch.sum(torch.abs(expected_dist_matrix_conv1), 1)
|
||||
assert torch.isclose(dist_conv1, expected_dist_conv1, rtol=1e-05, atol=1e-07).all()
|
||||
assert torch.isclose(
|
||||
dist_conv1, expected_dist_conv1, rtol=1e-05, atol=1e-07
|
||||
).all()
|
||||
|
||||
def _test_update_mask_on_single_layer(self, expected_conv1, device):
|
||||
"""Test that pruning is conducted based on the pair-wise distance measurement instead of absolute norm value"""
|
||||
@ -981,8 +1016,10 @@ class TestFPGMPruner(TestCase):
|
||||
pruner.prepare(model, config)
|
||||
pruner.enable_mask_update = True
|
||||
pruner.step()
|
||||
assert pruner.groups[0]["module"].parametrizations.weight[0].mask[-1].item() is not False, \
|
||||
"do not prune the least-norm filter"
|
||||
assert (
|
||||
pruner.groups[0]["module"].parametrizations.weight[0].mask[-1].item()
|
||||
is not False
|
||||
), "do not prune the least-norm filter"
|
||||
|
||||
# fusion step
|
||||
pruned_model = pruner.prune()
|
||||
@ -992,27 +1029,38 @@ class TestFPGMPruner(TestCase):
|
||||
expected_conv1 = expected_conv1.to(device)
|
||||
assert pruned_y.shape == (1, 4, 32, 32)
|
||||
assert pruned_model.conv2d1.weight.shape == expected_conv1.shape
|
||||
assert pruned_model.conv2d2.weight.shape == (4, 2, 3, 3), "conv2d2 should have input channel pruned"
|
||||
assert pruned_model.conv2d2.weight.shape == (
|
||||
4,
|
||||
2,
|
||||
3,
|
||||
3,
|
||||
), "conv2d2 should have input channel pruned"
|
||||
# assert value
|
||||
assert torch.isclose(pruned_model.conv2d1.weight, expected_conv1, rtol=1e-05, atol=1e-07).all()
|
||||
assert torch.isclose(
|
||||
pruned_model.conv2d1.weight, expected_conv1, rtol=1e-05, atol=1e-07
|
||||
).all()
|
||||
|
||||
def _test_update_mask_on_multiple_layer(self, expected_conv1, expected_conv2, device):
|
||||
def _test_update_mask_on_multiple_layer(
|
||||
self, expected_conv1, expected_conv2, device
|
||||
):
|
||||
# the second setting
|
||||
model = TestFPGMPruner.SimpleConvFPGM().to(device)
|
||||
x = torch.ones((1, 1, 32, 32), device=device)
|
||||
pruner = FPGMPruner(0.3)
|
||||
config = [
|
||||
{"tensor_fqn": "conv2d1.weight"},
|
||||
{"tensor_fqn": "conv2d2.weight", "sparsity_level": 0.5}
|
||||
{"tensor_fqn": "conv2d2.weight", "sparsity_level": 0.5},
|
||||
]
|
||||
pruner.prepare(model, config)
|
||||
pruner.enable_mask_update = True
|
||||
pruner.step()
|
||||
# Get the masks for the two least-norm filters
|
||||
mask1 = pruner.groups[0]['module'].parametrizations.weight[0].mask[-1]
|
||||
mask2 = pruner.groups[0]['module'].parametrizations.weight[0].mask[-2]
|
||||
mask1 = pruner.groups[0]["module"].parametrizations.weight[0].mask[-1]
|
||||
mask2 = pruner.groups[0]["module"].parametrizations.weight[0].mask[-2]
|
||||
# Check if either of the least-norm filters is not pruned
|
||||
assert mask1.item() is not False or mask2.item() is not False, "Do not prune all least-norm filters"
|
||||
assert (
|
||||
mask1.item() is not False or mask2.item() is not False
|
||||
), "Do not prune all least-norm filters"
|
||||
|
||||
# fusion step
|
||||
pruned_model = pruner.prune()
|
||||
@ -1024,8 +1072,12 @@ class TestFPGMPruner(TestCase):
|
||||
assert pruned_model.conv2d1.weight.shape == expected_conv1.shape
|
||||
assert pruned_model.conv2d2.weight.shape == expected_conv2.shape
|
||||
# assert values
|
||||
assert torch.isclose(pruned_model.conv2d1.weight, expected_conv1, rtol=1e-05, atol=1e-07).all()
|
||||
assert torch.isclose(pruned_model.conv2d2.weight, expected_conv2, rtol=1e-05, atol=1e-07).all()
|
||||
assert torch.isclose(
|
||||
pruned_model.conv2d1.weight, expected_conv1, rtol=1e-05, atol=1e-07
|
||||
).all()
|
||||
assert torch.isclose(
|
||||
pruned_model.conv2d2.weight, expected_conv2, rtol=1e-05, atol=1e-07
|
||||
).all()
|
||||
|
||||
def test_update_mask(self):
|
||||
weights = torch.tensor([3.0, 0.1])
|
||||
@ -1036,4 +1088,6 @@ class TestFPGMPruner(TestCase):
|
||||
|
||||
for device in DEVICES:
|
||||
self._test_update_mask_on_single_layer(expected_conv1, device)
|
||||
self._test_update_mask_on_multiple_layer(expected_conv1, expected_conv2, device)
|
||||
self._test_update_mask_on_multiple_layer(
|
||||
expected_conv1, expected_conv2, device
|
||||
)
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
import torch
|
||||
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests, gradcheck
|
||||
from torch.testing._internal.common_utils import gradcheck, run_tests, TestCase
|
||||
|
||||
|
||||
class TestAutogradComplex(TestCase):
|
||||
@ -71,7 +71,9 @@ class TestAutogradComplex(TestCase):
|
||||
# modified inplace
|
||||
res = x1.unbind(0)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "output of a function that returns multiple views"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "output of a function that returns multiple views"
|
||||
):
|
||||
res[0] += torch.rand(2, requires_grad=True)
|
||||
|
||||
x.requires_grad_(True)
|
||||
@ -80,7 +82,9 @@ class TestAutogradComplex(TestCase):
|
||||
# modified inplace
|
||||
res = x1.unbind(0)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "output of a function that returns multiple views"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "output of a function that returns multiple views"
|
||||
):
|
||||
res[0] += torch.rand(2, requires_grad=True)
|
||||
|
||||
def as_identity(self):
|
||||
@ -101,5 +105,5 @@ class TestAutogradComplex(TestCase):
|
||||
self.assertEqual(z.grad, z1.grad)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,11 +1,13 @@
|
||||
# Owner(s): ["module: intel"]
|
||||
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests, IS_LINUX
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from torch.testing._internal.common_utils import IS_LINUX, run_tests, TestCase
|
||||
|
||||
|
||||
@unittest.skipIf(not IS_LINUX, "Only works on linux")
|
||||
class TestTorchrun(TestCase):
|
||||
def setUp(self):
|
||||
@ -37,6 +39,7 @@ class TestTorchrun(TestCase):
|
||||
15,7,1,1
|
||||
"""
|
||||
from torch.backends.xeon.run_cpu import _CPUinfo
|
||||
|
||||
cpuinfo = _CPUinfo(lscpu_info)
|
||||
assert cpuinfo._physical_core_nums() == 8
|
||||
assert cpuinfo._logical_core_nums() == 16
|
||||
@ -45,21 +48,43 @@ class TestTorchrun(TestCase):
|
||||
assert cpuinfo.get_node_logical_cores(0) == [0, 1, 2, 3, 8, 9, 10, 11]
|
||||
assert cpuinfo.get_node_logical_cores(1) == [4, 5, 6, 7, 12, 13, 14, 15]
|
||||
assert cpuinfo.get_all_physical_cores() == [0, 1, 2, 3, 4, 5, 6, 7]
|
||||
assert cpuinfo.get_all_logical_cores() == [0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15]
|
||||
assert cpuinfo.get_all_logical_cores() == [
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
8,
|
||||
9,
|
||||
10,
|
||||
11,
|
||||
4,
|
||||
5,
|
||||
6,
|
||||
7,
|
||||
12,
|
||||
13,
|
||||
14,
|
||||
15,
|
||||
]
|
||||
assert cpuinfo.numa_aware_check([0, 1, 2, 3]) == [0]
|
||||
assert cpuinfo.numa_aware_check([4, 5, 6, 7]) == [1]
|
||||
assert cpuinfo.numa_aware_check([2, 3, 4, 5]) == [0, 1]
|
||||
|
||||
def test_multi_threads(self):
|
||||
num = 0
|
||||
with subprocess.Popen(f"python -m torch.backends.xeon.run_cpu --ninstances 4 --use-default-allocator \
|
||||
with subprocess.Popen(
|
||||
f"python -m torch.backends.xeon.run_cpu --ninstances 4 --use-default-allocator \
|
||||
--disable-iomp --disable-numactl --disable-taskset --log-path {self._test_dir} --no-python pwd",
|
||||
shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) as p:
|
||||
shell=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
) as p:
|
||||
for line in p.stdout.readlines():
|
||||
segs = str(line, "utf-8").strip().split("-")
|
||||
if segs[-1].strip() == "pwd":
|
||||
num += 1
|
||||
assert num == 4, "Failed to launch multiple instances for inference"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -6,19 +6,26 @@ import os
|
||||
import re
|
||||
import textwrap
|
||||
import timeit
|
||||
from typing import Any, List, Tuple
|
||||
import unittest
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.utils.benchmark as benchmark_utils
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests, IS_SANDCASTLE, IS_WINDOWS, slowTest, TEST_WITH_ASAN
|
||||
import expecttest
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.utils.benchmark as benchmark_utils
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_SANDCASTLE,
|
||||
IS_WINDOWS,
|
||||
run_tests,
|
||||
slowTest,
|
||||
TEST_WITH_ASAN,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
|
||||
CALLGRIND_ARTIFACTS: str = os.path.join(
|
||||
os.path.split(os.path.abspath(__file__))[0],
|
||||
"callgrind_artifacts.json"
|
||||
os.path.split(os.path.abspath(__file__))[0], "callgrind_artifacts.json"
|
||||
)
|
||||
|
||||
|
||||
@ -33,13 +40,13 @@ def generate_callgrind_artifacts() -> None:
|
||||
"""
|
||||
print("Regenerating callgrind artifact.")
|
||||
|
||||
stats_no_data = benchmark_utils.Timer(
|
||||
"y = torch.ones(())"
|
||||
).collect_callgrind(number=1000)
|
||||
stats_no_data = benchmark_utils.Timer("y = torch.ones(())").collect_callgrind(
|
||||
number=1000
|
||||
)
|
||||
|
||||
stats_with_data = benchmark_utils.Timer(
|
||||
"y = torch.ones((1,))"
|
||||
).collect_callgrind(number=1000)
|
||||
stats_with_data = benchmark_utils.Timer("y = torch.ones((1,))").collect_callgrind(
|
||||
number=1000
|
||||
)
|
||||
|
||||
user = os.getenv("USER")
|
||||
|
||||
@ -59,7 +66,9 @@ def generate_callgrind_artifacts() -> None:
|
||||
json.dump(artifacts, f, indent=4)
|
||||
|
||||
|
||||
def load_callgrind_artifacts() -> Tuple[benchmark_utils.CallgrindStats, benchmark_utils.CallgrindStats]:
|
||||
def load_callgrind_artifacts() -> (
|
||||
Tuple[benchmark_utils.CallgrindStats, benchmark_utils.CallgrindStats]
|
||||
):
|
||||
"""Hermetic artifact to unit test Callgrind wrapper.
|
||||
|
||||
In addition to collecting counts, this wrapper provides some facilities for
|
||||
@ -76,8 +85,7 @@ def load_callgrind_artifacts() -> Tuple[benchmark_utils.CallgrindStats, benchmar
|
||||
pattern = re.compile(r"^\s*([0-9]+)\s(.+)$")
|
||||
|
||||
def to_function_counts(
|
||||
count_strings: List[str],
|
||||
inclusive: bool
|
||||
count_strings: List[str], inclusive: bool
|
||||
) -> benchmark_utils.FunctionCounts:
|
||||
data: List[benchmark_utils.FunctionCount] = []
|
||||
for cs in count_strings:
|
||||
@ -89,8 +97,8 @@ def load_callgrind_artifacts() -> Tuple[benchmark_utils.CallgrindStats, benchmar
|
||||
data.append(benchmark_utils.FunctionCount(count=int(c), function=fn))
|
||||
|
||||
return benchmark_utils.FunctionCounts(
|
||||
tuple(sorted(data, reverse=True)),
|
||||
inclusive=inclusive)
|
||||
tuple(sorted(data, reverse=True)), inclusive=inclusive
|
||||
)
|
||||
|
||||
baseline_inclusive = to_function_counts(artifacts["baseline_inclusive"], True)
|
||||
baseline_exclusive = to_function_counts(artifacts["baseline_exclusive"], False)
|
||||
@ -101,8 +109,12 @@ def load_callgrind_artifacts() -> Tuple[benchmark_utils.CallgrindStats, benchmar
|
||||
built_with_debug_symbols=True,
|
||||
baseline_inclusive_stats=baseline_inclusive,
|
||||
baseline_exclusive_stats=baseline_exclusive,
|
||||
stmt_inclusive_stats=to_function_counts(artifacts["ones_no_data_inclusive"], True),
|
||||
stmt_exclusive_stats=to_function_counts(artifacts["ones_no_data_exclusive"], False),
|
||||
stmt_inclusive_stats=to_function_counts(
|
||||
artifacts["ones_no_data_inclusive"], True
|
||||
),
|
||||
stmt_exclusive_stats=to_function_counts(
|
||||
artifacts["ones_no_data_exclusive"], False
|
||||
),
|
||||
stmt_callgrind_out=None,
|
||||
)
|
||||
|
||||
@ -112,8 +124,12 @@ def load_callgrind_artifacts() -> Tuple[benchmark_utils.CallgrindStats, benchmar
|
||||
built_with_debug_symbols=True,
|
||||
baseline_inclusive_stats=baseline_inclusive,
|
||||
baseline_exclusive_stats=baseline_exclusive,
|
||||
stmt_inclusive_stats=to_function_counts(artifacts["ones_with_data_inclusive"], True),
|
||||
stmt_exclusive_stats=to_function_counts(artifacts["ones_with_data_exclusive"], False),
|
||||
stmt_inclusive_stats=to_function_counts(
|
||||
artifacts["ones_with_data_inclusive"], True
|
||||
),
|
||||
stmt_exclusive_stats=to_function_counts(
|
||||
artifacts["ones_with_data_exclusive"], False
|
||||
),
|
||||
stmt_callgrind_out=None,
|
||||
)
|
||||
|
||||
@ -127,14 +143,12 @@ class MyModule(torch.nn.Module):
|
||||
|
||||
class TestBenchmarkUtils(TestCase):
|
||||
def regularizeAndAssertExpectedInline(
|
||||
self, x: Any,
|
||||
expect: str,
|
||||
indent: int = 12
|
||||
self, x: Any, expect: str, indent: int = 12
|
||||
) -> None:
|
||||
x_str: str = re.sub(
|
||||
"object at 0x[0-9a-fA-F]+>",
|
||||
"object at 0xXXXXXXXXXXXX>",
|
||||
x if isinstance(x, str) else repr(x)
|
||||
x if isinstance(x, str) else repr(x),
|
||||
)
|
||||
if "\n" in x_str:
|
||||
# Indent makes the reference align at the call site.
|
||||
@ -157,15 +171,19 @@ class TestBenchmarkUtils(TestCase):
|
||||
median = timer.adaptive_autorange(threshold=0.5).median
|
||||
|
||||
# Test that multi-line statements work properly.
|
||||
median = benchmark_utils.Timer(
|
||||
stmt="""
|
||||
median = (
|
||||
benchmark_utils.Timer(
|
||||
stmt="""
|
||||
with torch.no_grad():
|
||||
y = x + 1""",
|
||||
setup="""
|
||||
setup="""
|
||||
x = torch.ones((1,), requires_grad=True)
|
||||
for _ in range(5):
|
||||
x = x + 1.0""",
|
||||
).timeit(5).median
|
||||
)
|
||||
.timeit(5)
|
||||
.median
|
||||
)
|
||||
self.assertIsInstance(sample, float)
|
||||
|
||||
@slowTest
|
||||
@ -173,7 +191,7 @@ class TestBenchmarkUtils(TestCase):
|
||||
@unittest.skipIf(True, "Failing on clang, see 74398")
|
||||
def test_timer_tiny_fast_snippet(self):
|
||||
timer = benchmark_utils.Timer(
|
||||
'auto x = 1;(void)x;',
|
||||
"auto x = 1;(void)x;",
|
||||
timer=timeit.default_timer,
|
||||
language=benchmark_utils.Language.CPP,
|
||||
)
|
||||
@ -222,16 +240,16 @@ class TestBenchmarkUtils(TestCase):
|
||||
return max(self._random_state.normal(mean, mean * noise_level), 5e-9)
|
||||
|
||||
def timeit(self, number):
|
||||
return sum([
|
||||
# First timer invocation
|
||||
self.sample(self._timer_cost, self._timer_noise_level),
|
||||
|
||||
# Stmt body
|
||||
self.sample(self._mean_cost * number, self._function_noise_level),
|
||||
|
||||
# Second timer invocation
|
||||
self.sample(self._timer_cost, self._timer_noise_level),
|
||||
])
|
||||
return sum(
|
||||
[
|
||||
# First timer invocation
|
||||
self.sample(self._timer_cost, self._timer_noise_level),
|
||||
# Stmt body
|
||||
self.sample(self._mean_cost * number, self._function_noise_level),
|
||||
# Second timer invocation
|
||||
self.sample(self._timer_cost, self._timer_noise_level),
|
||||
]
|
||||
)
|
||||
|
||||
def test_adaptive_timer(self):
|
||||
class MockTimer(benchmark_utils.Timer):
|
||||
@ -245,7 +263,6 @@ class TestBenchmarkUtils(TestCase):
|
||||
_function_costs = (
|
||||
self._MockTimer._function_costs[0],
|
||||
self._MockTimer._function_costs[1],
|
||||
|
||||
# GPU should be faster once there is enough work.
|
||||
("expensive_fn()", 5e-6),
|
||||
)
|
||||
@ -261,7 +278,7 @@ class TestBenchmarkUtils(TestCase):
|
||||
pass
|
||||
Median: 7.98 ns
|
||||
IQR: 0.52 ns (7.74 to 8.26)
|
||||
125 measurements, 10000000 runs per measurement, 1 thread"""
|
||||
125 measurements, 10000000 runs per measurement, 1 thread""",
|
||||
)
|
||||
|
||||
self.regularizeAndAssertExpectedInline(
|
||||
@ -271,7 +288,7 @@ class TestBenchmarkUtils(TestCase):
|
||||
pass
|
||||
Median: 7.86 ns
|
||||
IQR: 0.71 ns (7.63 to 8.34)
|
||||
6 measurements, 1000000 runs per measurement, 1 thread"""
|
||||
6 measurements, 1000000 runs per measurement, 1 thread""",
|
||||
)
|
||||
|
||||
# Check against strings so we can reuse expect infra.
|
||||
@ -287,7 +304,7 @@ class TestBenchmarkUtils(TestCase):
|
||||
cheap_fn()
|
||||
Median: 3.98 us
|
||||
IQR: 0.27 us (3.85 to 4.12)
|
||||
252 measurements, 10000 runs per measurement, 1 thread"""
|
||||
252 measurements, 10000 runs per measurement, 1 thread""",
|
||||
)
|
||||
|
||||
self.regularizeAndAssertExpectedInline(
|
||||
@ -297,7 +314,7 @@ class TestBenchmarkUtils(TestCase):
|
||||
cheap_fn()
|
||||
Median: 4.16 us
|
||||
IQR: 0.22 us (4.04 to 4.26)
|
||||
4 measurements, 1000 runs per measurement, 1 thread"""
|
||||
4 measurements, 1000 runs per measurement, 1 thread""",
|
||||
)
|
||||
|
||||
self.regularizeAndAssertExpectedInline(
|
||||
@ -307,7 +324,7 @@ class TestBenchmarkUtils(TestCase):
|
||||
expensive_fn()
|
||||
Median: 19.97 us
|
||||
IQR: 1.35 us (19.31 to 20.65)
|
||||
501 measurements, 1000 runs per measurement, 1 thread"""
|
||||
501 measurements, 1000 runs per measurement, 1 thread""",
|
||||
)
|
||||
|
||||
self.regularizeAndAssertExpectedInline(
|
||||
@ -317,7 +334,7 @@ class TestBenchmarkUtils(TestCase):
|
||||
expensive_fn()
|
||||
Median: 20.79 us
|
||||
IQR: 1.09 us (20.20 to 21.29)
|
||||
4 measurements, 1000 runs per measurement, 1 thread"""
|
||||
4 measurements, 1000 runs per measurement, 1 thread""",
|
||||
)
|
||||
|
||||
self.regularizeAndAssertExpectedInline(
|
||||
@ -327,7 +344,7 @@ class TestBenchmarkUtils(TestCase):
|
||||
pass
|
||||
Median: 7.92 ns
|
||||
IQR: 0.43 ns (7.75 to 8.17)
|
||||
13 measurements, 100000000 runs per measurement, 1 thread"""
|
||||
13 measurements, 100000000 runs per measurement, 1 thread""",
|
||||
)
|
||||
|
||||
self.regularizeAndAssertExpectedInline(
|
||||
@ -337,7 +354,7 @@ class TestBenchmarkUtils(TestCase):
|
||||
pass
|
||||
Median: 7.75 ns
|
||||
IQR: 0.57 ns (7.56 to 8.13)
|
||||
4 measurements, 10000000 runs per measurement, 1 thread"""
|
||||
4 measurements, 10000000 runs per measurement, 1 thread""",
|
||||
)
|
||||
|
||||
self.regularizeAndAssertExpectedInline(
|
||||
@ -347,7 +364,7 @@ class TestBenchmarkUtils(TestCase):
|
||||
cheap_fn()
|
||||
Median: 4.04 us
|
||||
IQR: 0.30 us (3.90 to 4.19)
|
||||
25 measurements, 100000 runs per measurement, 1 thread"""
|
||||
25 measurements, 100000 runs per measurement, 1 thread""",
|
||||
)
|
||||
|
||||
self.regularizeAndAssertExpectedInline(
|
||||
@ -357,7 +374,7 @@ class TestBenchmarkUtils(TestCase):
|
||||
cheap_fn()
|
||||
Median: 4.09 us
|
||||
IQR: 0.38 us (3.90 to 4.28)
|
||||
4 measurements, 100000 runs per measurement, 1 thread"""
|
||||
4 measurements, 100000 runs per measurement, 1 thread""",
|
||||
)
|
||||
|
||||
self.regularizeAndAssertExpectedInline(
|
||||
@ -367,7 +384,7 @@ class TestBenchmarkUtils(TestCase):
|
||||
expensive_fn()
|
||||
Median: 4.98 us
|
||||
IQR: 0.31 us (4.83 to 5.13)
|
||||
20 measurements, 100000 runs per measurement, 1 thread"""
|
||||
20 measurements, 100000 runs per measurement, 1 thread""",
|
||||
)
|
||||
|
||||
self.regularizeAndAssertExpectedInline(
|
||||
@ -377,7 +394,7 @@ class TestBenchmarkUtils(TestCase):
|
||||
expensive_fn()
|
||||
Median: 5.01 us
|
||||
IQR: 0.28 us (4.87 to 5.15)
|
||||
4 measurements, 10000 runs per measurement, 1 thread"""
|
||||
4 measurements, 10000 runs per measurement, 1 thread""",
|
||||
)
|
||||
|
||||
# Make sure __repr__ is reasonable for
|
||||
@ -398,7 +415,7 @@ class TestBenchmarkUtils(TestCase):
|
||||
|
||||
Median: 10.06 us
|
||||
IQR: 0.54 us (9.73 to 10.27)
|
||||
20 measurements, 1000 runs per measurement, 1 thread"""
|
||||
20 measurements, 1000 runs per measurement, 1 thread""",
|
||||
)
|
||||
|
||||
self.regularizeAndAssertExpectedInline(
|
||||
@ -411,7 +428,7 @@ class TestBenchmarkUtils(TestCase):
|
||||
|
||||
Median: 10.06 us
|
||||
IQR: 0.54 us (9.73 to 10.27)
|
||||
20 measurements, 1000 runs per measurement, 1 thread"""
|
||||
20 measurements, 1000 runs per measurement, 1 thread""",
|
||||
)
|
||||
|
||||
self.regularizeAndAssertExpectedInline(
|
||||
@ -425,7 +442,7 @@ class TestBenchmarkUtils(TestCase):
|
||||
x + 1 (no grad): scalar_add
|
||||
Median: 10.06 us
|
||||
IQR: 0.54 us (9.73 to 10.27)
|
||||
20 measurements, 1000 runs per measurement, 1 thread"""
|
||||
20 measurements, 1000 runs per measurement, 1 thread""",
|
||||
)
|
||||
|
||||
self.regularizeAndAssertExpectedInline(
|
||||
@ -443,7 +460,7 @@ class TestBenchmarkUtils(TestCase):
|
||||
setup: setup_fn()
|
||||
Median: 10.06 us
|
||||
IQR: 0.54 us (9.73 to 10.27)
|
||||
20 measurements, 1000 runs per measurement, 1 thread"""
|
||||
20 measurements, 1000 runs per measurement, 1 thread""",
|
||||
)
|
||||
|
||||
self.regularizeAndAssertExpectedInline(
|
||||
@ -471,7 +488,7 @@ class TestBenchmarkUtils(TestCase):
|
||||
|
||||
Median: 10.06 us
|
||||
IQR: 0.54 us (9.73 to 10.27)
|
||||
20 measurements, 1000 runs per measurement, 16 threads"""
|
||||
20 measurements, 1000 runs per measurement, 16 threads""",
|
||||
)
|
||||
|
||||
@slowTest
|
||||
@ -482,24 +499,23 @@ class TestBenchmarkUtils(TestCase):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"`collect_callgrind` requires that globals be wrapped "
|
||||
r"in `CopyIfCallgrind` so that serialization is explicit."
|
||||
r"in `CopyIfCallgrind` so that serialization is explicit.",
|
||||
):
|
||||
benchmark_utils.Timer(
|
||||
"pass",
|
||||
globals={"x": 1}
|
||||
).collect_callgrind(collect_baseline=False)
|
||||
benchmark_utils.Timer("pass", globals={"x": 1}).collect_callgrind(
|
||||
collect_baseline=False
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
# Subprocess raises AttributeError (from pickle),
|
||||
# _ValgrindWrapper re-raises as generic OSError.
|
||||
OSError, "AttributeError: Can't get attribute 'MyModule'"
|
||||
OSError,
|
||||
"AttributeError: Can't get attribute 'MyModule'",
|
||||
):
|
||||
benchmark_utils.Timer(
|
||||
"model(1)",
|
||||
globals={"model": benchmark_utils.CopyIfCallgrind(MyModule())}
|
||||
globals={"model": benchmark_utils.CopyIfCallgrind(MyModule())},
|
||||
).collect_callgrind(collect_baseline=False)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def add_one(x):
|
||||
return x + 1
|
||||
@ -516,9 +532,9 @@ class TestBenchmarkUtils(TestCase):
|
||||
import sys
|
||||
sys.path.append({repr(os.path.split(os.path.abspath(__file__))[0])})
|
||||
from test_benchmark_utils import MyModule
|
||||
"""
|
||||
)
|
||||
}
|
||||
""",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
stats = timer.collect_callgrind(number=1000)
|
||||
@ -538,13 +554,22 @@ class TestBenchmarkUtils(TestCase):
|
||||
assert isinstance(stats, tuple)
|
||||
|
||||
# Check that the repeats are at least somewhat repeatable. (within 10 instructions per iter)
|
||||
counts = collections.Counter([s.counts(denoise=True) // 10_000 * 10_000 for s in stats])
|
||||
self.assertGreater(max(counts.values()), 1, f"Every instruction count total was unique: {counts}")
|
||||
counts = collections.Counter(
|
||||
[s.counts(denoise=True) // 10_000 * 10_000 for s in stats]
|
||||
)
|
||||
self.assertGreater(
|
||||
max(counts.values()),
|
||||
1,
|
||||
f"Every instruction count total was unique: {counts}",
|
||||
)
|
||||
|
||||
from torch.utils.benchmark.utils.valgrind_wrapper.timer_interface import (
|
||||
wrapper_singleton,
|
||||
)
|
||||
|
||||
from torch.utils.benchmark.utils.valgrind_wrapper.timer_interface import wrapper_singleton
|
||||
self.assertIsNone(
|
||||
wrapper_singleton()._bindings_module,
|
||||
"JIT'd bindings are only for back testing."
|
||||
"JIT'd bindings are only for back testing.",
|
||||
)
|
||||
|
||||
@slowTest
|
||||
@ -558,28 +583,29 @@ class TestBenchmarkUtils(TestCase):
|
||||
timer=timeit.default_timer,
|
||||
language="c++",
|
||||
)
|
||||
stats = [
|
||||
timer.collect_callgrind()
|
||||
for _ in range(3)
|
||||
]
|
||||
stats = [timer.collect_callgrind() for _ in range(3)]
|
||||
counts = [s.counts() for s in stats]
|
||||
|
||||
self.assertGreater(
|
||||
min(counts), 0, "No stats were collected")
|
||||
self.assertGreater(min(counts), 0, "No stats were collected")
|
||||
self.assertEqual(
|
||||
min(counts), max(counts), "C++ Callgrind should be deterministic")
|
||||
min(counts), max(counts), "C++ Callgrind should be deterministic"
|
||||
)
|
||||
|
||||
for s in stats:
|
||||
self.assertEqual(
|
||||
s.counts(denoise=True), s.counts(denoise=False),
|
||||
"De-noising should not apply to C++.")
|
||||
s.counts(denoise=True),
|
||||
s.counts(denoise=False),
|
||||
"De-noising should not apply to C++.",
|
||||
)
|
||||
|
||||
stats = timer.collect_callgrind(number=1000, repeats=20)
|
||||
assert isinstance(stats, tuple)
|
||||
|
||||
# NB: Unlike the example above, there is no expectation that all
|
||||
# repeats will be identical.
|
||||
counts = collections.Counter([s.counts(denoise=True) // 10_000 * 10_000 for s in stats])
|
||||
counts = collections.Counter(
|
||||
[s.counts(denoise=True) // 10_000 * 10_000 for s in stats]
|
||||
)
|
||||
self.assertGreater(max(counts.values()), 1, repr(counts))
|
||||
|
||||
def test_manipulate_callgrind_stats(self):
|
||||
@ -587,7 +613,8 @@ class TestBenchmarkUtils(TestCase):
|
||||
|
||||
# Mock `torch.set_printoptions(linewidth=160)`
|
||||
wide_linewidth = benchmark_utils.FunctionCounts(
|
||||
stats_no_data.stats(inclusive=False)._data, False, _linewidth=160)
|
||||
stats_no_data.stats(inclusive=False)._data, False, _linewidth=160
|
||||
)
|
||||
|
||||
for l in repr(wide_linewidth).splitlines(keepends=False):
|
||||
self.assertLessEqual(len(l), 160)
|
||||
@ -595,10 +622,12 @@ class TestBenchmarkUtils(TestCase):
|
||||
self.assertEqual(
|
||||
# `delta` is just a convenience method.
|
||||
stats_with_data.delta(stats_no_data)._data,
|
||||
(stats_with_data.stats() - stats_no_data.stats())._data
|
||||
(stats_with_data.stats() - stats_no_data.stats())._data,
|
||||
)
|
||||
|
||||
deltas = stats_with_data.as_standardized().delta(stats_no_data.as_standardized())
|
||||
deltas = stats_with_data.as_standardized().delta(
|
||||
stats_no_data.as_standardized()
|
||||
)
|
||||
|
||||
def custom_transforms(fn: str):
|
||||
fn = re.sub(re.escape("/usr/include/c++/8/bits/"), "", fn)
|
||||
@ -703,7 +732,7 @@ class TestBenchmarkUtils(TestCase):
|
||||
2000 /usr/include/c++/8/bits/atomic_base.h:at::Tensor at::detail::make_tensor ... t_null_type<c10::StorageImpl> >&&, c10::DispatchKey&&, caffe2::TypeMeta&)
|
||||
2000 /usr/include/c++/8/array:at::Tensor& c10::Dispatcher::callWithDispatchKe ... , c10::Scalar)> const&, c10::DispatchKey, at::Tensor&, c10::Scalar) const
|
||||
|
||||
Total: 8869966""" # noqa: B950
|
||||
Total: 8869966""", # noqa: B950
|
||||
)
|
||||
|
||||
self.regularizeAndAssertExpectedInline(
|
||||
@ -821,10 +850,8 @@ class TestBenchmarkUtils(TestCase):
|
||||
costs = (
|
||||
# overhead_optimized_fn()
|
||||
(1e-6, 1e-9),
|
||||
|
||||
# compute_optimized_fn()
|
||||
(3e-6, 5e-10),
|
||||
|
||||
# special_case_fn() [square inputs only]
|
||||
(1e-6, 4e-10),
|
||||
)
|
||||
@ -840,8 +867,7 @@ class TestBenchmarkUtils(TestCase):
|
||||
# overhead_optimized_fn()
|
||||
class _MockTimer_0(self._MockTimer):
|
||||
_function_costs = tuple(
|
||||
(f"fn({i}, {j})", costs[0][0] + costs[0][1] * i * j)
|
||||
for i, j in sizes
|
||||
(f"fn({i}, {j})", costs[0][0] + costs[0][1] * i * j) for i, j in sizes
|
||||
)
|
||||
|
||||
class MockTimer_0(benchmark_utils.Timer):
|
||||
@ -850,8 +876,7 @@ class TestBenchmarkUtils(TestCase):
|
||||
# compute_optimized_fn()
|
||||
class _MockTimer_1(self._MockTimer):
|
||||
_function_costs = tuple(
|
||||
(f"fn({i}, {j})", costs[1][0] + costs[1][1] * i * j)
|
||||
for i, j in sizes
|
||||
(f"fn({i}, {j})", costs[1][0] + costs[1][1] * i * j) for i, j in sizes
|
||||
)
|
||||
|
||||
class MockTimer_1(benchmark_utils.Timer):
|
||||
@ -861,7 +886,8 @@ class TestBenchmarkUtils(TestCase):
|
||||
class _MockTimer_2(self._MockTimer):
|
||||
_function_costs = tuple(
|
||||
(f"fn({i}, {j})", costs[2][0] + costs[2][1] * i * j)
|
||||
for i, j in sizes if i == j
|
||||
for i, j in sizes
|
||||
if i == j
|
||||
)
|
||||
|
||||
class MockTimer_2(benchmark_utils.Timer):
|
||||
@ -913,7 +939,7 @@ class TestBenchmarkUtils(TestCase):
|
||||
compute_optimized | 3.1 | 4.0 | 11.2 | 2099.3 | 2099.3
|
||||
special_case (square) | 1.1 | | 7.5 | | 1674.7
|
||||
|
||||
Times are in microseconds (us)."""
|
||||
Times are in microseconds (us).""",
|
||||
)
|
||||
|
||||
compare.trim_significant_figures()
|
||||
@ -927,7 +953,7 @@ class TestBenchmarkUtils(TestCase):
|
||||
compute_optimized | 3 | 4.0 | 11 | 2100 | 2100
|
||||
special_case (square) | 1 | | 8 | | 1700
|
||||
|
||||
Times are in microseconds (us)."""
|
||||
Times are in microseconds (us).""",
|
||||
)
|
||||
|
||||
compare.colorize()
|
||||
@ -978,12 +1004,16 @@ class TestBenchmarkUtils(TestCase):
|
||||
self.assertEqual(columnwise_colored_actual, columnwise_colored_expected)
|
||||
self.assertEqual(rowwise_colored_actual, rowwise_colored_expected)
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS and os.getenv("VC_YEAR") == "2019", "Random seed only accepts int32")
|
||||
@unittest.skipIf(
|
||||
IS_WINDOWS and os.getenv("VC_YEAR") == "2019", "Random seed only accepts int32"
|
||||
)
|
||||
def test_fuzzer(self):
|
||||
fuzzer = benchmark_utils.Fuzzer(
|
||||
parameters=[
|
||||
benchmark_utils.FuzzedParameter(
|
||||
"n", minval=1, maxval=16, distribution="loguniform")],
|
||||
"n", minval=1, maxval=16, distribution="loguniform"
|
||||
)
|
||||
],
|
||||
tensors=[benchmark_utils.FuzzedTensor("x", size=("n",))],
|
||||
seed=0,
|
||||
)
|
||||
@ -995,9 +1025,8 @@ class TestBenchmarkUtils(TestCase):
|
||||
|
||||
for i, (tensors, _, _) in enumerate(fuzzer.take(2)):
|
||||
x = tensors["x"]
|
||||
self.assertEqual(
|
||||
x, torch.tensor(expected_results[i]), rtol=1e-3, atol=1e-3)
|
||||
self.assertEqual(x, torch.tensor(expected_results[i]), rtol=1e-3, atol=1e-3)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -1,14 +1,15 @@
|
||||
# Owner(s): ["module: unknown"]
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Required args. Raises error if they aren't passed.
|
||||
parser.add_argument('--foo', help='foo', required=True)
|
||||
parser.add_argument('--bar', help='bar', required=True)
|
||||
parser.add_argument("--foo", help="foo", required=True)
|
||||
parser.add_argument("--bar", help="bar", required=True)
|
||||
_ = parser.parse_args()
|
||||
|
||||
x = torch.ones((3, 3), requires_grad=True)
|
||||
|
@ -25,5 +25,5 @@ def main():
|
||||
optimizer.step()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -1,25 +1,24 @@
|
||||
from _pytest.junitxml import LogXML, _NodeReporter, bin_xml_escape
|
||||
from _pytest.terminal import _get_raw_skip_reason
|
||||
from _pytest.stash import StashKey
|
||||
from _pytest.reports import TestReport
|
||||
from _pytest.config.argparsing import Parser
|
||||
from _pytest.config import filename_arg
|
||||
from _pytest.config import Config
|
||||
from _pytest._code.code import ReprFileLocation
|
||||
from _pytest.python import Module
|
||||
from typing import Any, List, Union
|
||||
from typing import Optional
|
||||
from types import MethodType
|
||||
import xml.etree.ElementTree as ET
|
||||
import functools
|
||||
import pytest
|
||||
import sys
|
||||
import os
|
||||
import copy
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import xml.etree.ElementTree as ET
|
||||
from collections import defaultdict
|
||||
from pytest_shard_custom import PytestShardPlugin, pytest_addoptions as shard_addoptions
|
||||
from types import MethodType
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import pytest
|
||||
from _pytest._code.code import ReprFileLocation
|
||||
from _pytest.config import Config, filename_arg
|
||||
from _pytest.config.argparsing import Parser
|
||||
from _pytest.junitxml import _NodeReporter, bin_xml_escape, LogXML
|
||||
from _pytest.python import Module
|
||||
from _pytest.reports import TestReport
|
||||
from _pytest.stash import StashKey
|
||||
from _pytest.terminal import _get_raw_skip_reason
|
||||
from pytest_shard_custom import pytest_addoptions as shard_addoptions, PytestShardPlugin
|
||||
|
||||
# a lot of this file is copied from _pytest.junitxml and modified to get rerun info
|
||||
|
||||
@ -42,7 +41,7 @@ def pytest_addoption(parser: Parser) -> None:
|
||||
dest="stepcurrent",
|
||||
)
|
||||
|
||||
parser.addoption("--use-main-module", action='store_true')
|
||||
parser.addoption("--use-main-module", action="store_true")
|
||||
group = parser.getgroup("terminal reporting")
|
||||
group.addoption(
|
||||
"--junit-xml-reruns",
|
||||
@ -143,11 +142,14 @@ class _NodeReporterReruns(_NodeReporter):
|
||||
skipreason = skipreason[9:]
|
||||
details = f"{filename}:{lineno}: {skipreason}"
|
||||
|
||||
skipped = ET.Element("skipped", type="pytest.skip", message=bin_xml_escape(skipreason))
|
||||
skipped = ET.Element(
|
||||
"skipped", type="pytest.skip", message=bin_xml_escape(skipreason)
|
||||
)
|
||||
skipped.text = bin_xml_escape(details)
|
||||
self.append(skipped)
|
||||
self.write_captured_output(report)
|
||||
|
||||
|
||||
class LogXMLReruns(LogXML):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@ -223,7 +225,7 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config):
|
||||
def pytest_pycollect_makemodule(module_path, path, parent) -> Module:
|
||||
if parent.config.getoption("--use-main-module"):
|
||||
mod = Module.from_parent(parent, path=module_path)
|
||||
mod._getobj = MethodType(lambda x: sys.modules['__main__'], mod)
|
||||
mod._getobj = MethodType(lambda x: sys.modules["__main__"], mod)
|
||||
return mod
|
||||
|
||||
|
||||
@ -275,7 +277,10 @@ def pytest_collection_modifyitems(items: List[Any]) -> None:
|
||||
test_name = item.name
|
||||
test_class = item.parent.name
|
||||
|
||||
if test_class not in disabled_tests or test_name not in disabled_tests[test_class]:
|
||||
if (
|
||||
test_class not in disabled_tests
|
||||
or test_name not in disabled_tests[test_class]
|
||||
):
|
||||
continue
|
||||
|
||||
cpy = copy.copy(item)
|
||||
|
@ -1,10 +1,10 @@
|
||||
|
||||
import torch
|
||||
from torch._export import aot_compile
|
||||
from torch.export import Dim
|
||||
|
||||
torch.manual_seed(1337)
|
||||
|
||||
|
||||
class Net(torch.nn.Module):
|
||||
def __init__(self, device):
|
||||
super().__init__()
|
||||
@ -17,6 +17,7 @@ class Net(torch.nn.Module):
|
||||
w = w_relu + self.w_add
|
||||
return torch.matmul(x, w)
|
||||
|
||||
|
||||
class NetWithTensorConstants(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -26,9 +27,11 @@ class NetWithTensorConstants(torch.nn.Module):
|
||||
z = self.w * x * y
|
||||
return z[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17]]
|
||||
|
||||
|
||||
data = {}
|
||||
data_with_tensor_constants = {}
|
||||
|
||||
|
||||
# Basice AOTI model test generation.
|
||||
def generate_basic_tests():
|
||||
for device in ["cpu", "cuda"]:
|
||||
@ -49,18 +52,24 @@ def generate_basic_tests():
|
||||
model,
|
||||
(x,),
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
options={"aot_inductor.use_runtime_constant_folding": use_runtime_constant_folding})
|
||||
options={
|
||||
"aot_inductor.use_runtime_constant_folding": use_runtime_constant_folding
|
||||
},
|
||||
)
|
||||
|
||||
suffix = f"{device}"
|
||||
if use_runtime_constant_folding:
|
||||
suffix += "_use_runtime_constant_folding"
|
||||
data.update({
|
||||
f"model_so_path_{suffix}": model_so_path,
|
||||
f"inputs_{suffix}": [x],
|
||||
f"outputs_{suffix}": [ref_output],
|
||||
f"w_pre_{suffix}": model.w_pre,
|
||||
f"w_add_{suffix}": model.w_add,
|
||||
})
|
||||
data.update(
|
||||
{
|
||||
f"model_so_path_{suffix}": model_so_path,
|
||||
f"inputs_{suffix}": [x],
|
||||
f"outputs_{suffix}": [ref_output],
|
||||
f"w_pre_{suffix}": model.w_pre,
|
||||
f"w_add_{suffix}": model.w_add,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# AOTI model which will create additional tensors during autograd.
|
||||
def generate_test_with_additional_tensors():
|
||||
@ -72,20 +81,22 @@ def generate_test_with_additional_tensors():
|
||||
|
||||
torch._dynamo.reset()
|
||||
with torch.no_grad():
|
||||
model_so_path = aot_compile(
|
||||
model,
|
||||
(x, y))
|
||||
model_so_path = aot_compile(model, (x, y))
|
||||
|
||||
data_with_tensor_constants.update(
|
||||
{
|
||||
"model_so_path": model_so_path,
|
||||
"inputs": [x, y],
|
||||
"outputs": [ref_output],
|
||||
"w": model.w,
|
||||
}
|
||||
)
|
||||
|
||||
data_with_tensor_constants.update({
|
||||
"model_so_path": model_so_path,
|
||||
"inputs": [x, y],
|
||||
"outputs": [ref_output],
|
||||
"w": model.w,
|
||||
})
|
||||
|
||||
generate_basic_tests()
|
||||
generate_test_with_additional_tensors()
|
||||
|
||||
|
||||
# Use this to communicate tensors to the cpp code
|
||||
class Serializer(torch.nn.Module):
|
||||
def __init__(self, data):
|
||||
@ -93,5 +104,8 @@ class Serializer(torch.nn.Module):
|
||||
for key in data:
|
||||
setattr(self, key, data[key])
|
||||
|
||||
|
||||
torch.jit.script(Serializer(data)).save("data.pt")
|
||||
torch.jit.script(Serializer(data_with_tensor_constants)).save("data_with_tensor_constants.pt")
|
||||
torch.jit.script(Serializer(data_with_tensor_constants)).save(
|
||||
"data_with_tensor_constants.pt"
|
||||
)
|
||||
|
@ -1,6 +1,7 @@
|
||||
"""Script to generate baseline values from PyTorch initialization algorithms"""
|
||||
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
HEADER = """
|
||||
@ -19,13 +20,13 @@ INITIALIZERS = {
|
||||
"Xavier_Uniform": lambda w: torch.nn.init.xavier_uniform(w),
|
||||
"Xavier_Normal": lambda w: torch.nn.init.xavier_normal(w),
|
||||
"Kaiming_Normal": lambda w: torch.nn.init.kaiming_normal(w),
|
||||
"Kaiming_Uniform": lambda w: torch.nn.init.kaiming_uniform(w)
|
||||
"Kaiming_Uniform": lambda w: torch.nn.init.kaiming_uniform(w),
|
||||
}
|
||||
|
||||
|
||||
def emit(initializer_parameter_map):
|
||||
# Don't write generated with an @ in front, else this file is recognized as generated.
|
||||
print("// @{} from {}".format('generated', __file__))
|
||||
print("// @{} from {}".format("generated", __file__))
|
||||
print(HEADER)
|
||||
for initializer_name, weights in initializer_parameter_map.items():
|
||||
print(PARAMETERS.format(initializer_name))
|
||||
@ -63,10 +64,11 @@ def run(initializer):
|
||||
def main():
|
||||
initializer_parameter_map = {}
|
||||
for initializer in INITIALIZERS.keys():
|
||||
sys.stderr.write(f'Evaluating {initializer} ...\n')
|
||||
sys.stderr.write(f"Evaluating {initializer} ...\n")
|
||||
initializer_parameter_map[initializer] = run(initializer)
|
||||
|
||||
emit(initializer_parameter_map)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -21,27 +21,43 @@ FOOTER = "} // namespace expected_parameters"
|
||||
PARAMETERS = "inline std::vector<std::vector<torch::Tensor>> {}() {{"
|
||||
|
||||
OPTIMIZERS = {
|
||||
"LBFGS" : lambda p: torch.optim.LBFGS(p, 1.0),
|
||||
"LBFGS_with_line_search" : lambda p: torch.optim.LBFGS(p, 1.0, line_search_fn="strong_wolfe"),
|
||||
"LBFGS": lambda p: torch.optim.LBFGS(p, 1.0),
|
||||
"LBFGS_with_line_search": lambda p: torch.optim.LBFGS(
|
||||
p, 1.0, line_search_fn="strong_wolfe"
|
||||
),
|
||||
"Adam": lambda p: torch.optim.Adam(p, 1.0),
|
||||
"Adam_with_weight_decay": lambda p: torch.optim.Adam(p, 1.0, weight_decay=1e-2),
|
||||
"Adam_with_weight_decay_and_amsgrad": lambda p: torch.optim.Adam(p, 1.0, weight_decay=1e-6, amsgrad=True),
|
||||
"Adam_with_weight_decay_and_amsgrad": lambda p: torch.optim.Adam(
|
||||
p, 1.0, weight_decay=1e-6, amsgrad=True
|
||||
),
|
||||
"AdamW": lambda p: torch.optim.AdamW(p, 1.0),
|
||||
"AdamW_without_weight_decay": lambda p: torch.optim.AdamW(p, 1.0, weight_decay=0),
|
||||
"AdamW_with_amsgrad": lambda p: torch.optim.AdamW(p, 1.0, amsgrad=True),
|
||||
"Adagrad": lambda p: torch.optim.Adagrad(p, 1.0),
|
||||
"Adagrad_with_weight_decay": lambda p: torch.optim.Adagrad(p, 1.0, weight_decay=1e-2),
|
||||
"Adagrad_with_weight_decay_and_lr_decay": lambda p: torch.optim.Adagrad(p, 1.0, weight_decay=1e-6, lr_decay=1e-3),
|
||||
"Adagrad_with_weight_decay": lambda p: torch.optim.Adagrad(
|
||||
p, 1.0, weight_decay=1e-2
|
||||
),
|
||||
"Adagrad_with_weight_decay_and_lr_decay": lambda p: torch.optim.Adagrad(
|
||||
p, 1.0, weight_decay=1e-6, lr_decay=1e-3
|
||||
),
|
||||
"RMSprop": lambda p: torch.optim.RMSprop(p, 0.1),
|
||||
"RMSprop_with_weight_decay": lambda p: torch.optim.RMSprop(p, 0.1, weight_decay=1e-2),
|
||||
"RMSprop_with_weight_decay_and_centered": lambda p: torch.optim.RMSprop(p, 0.1, weight_decay=1e-6, centered=True),
|
||||
"RMSprop_with_weight_decay_and_centered_and_momentum":
|
||||
lambda p: torch.optim.RMSprop(p, 0.1, weight_decay=1e-6, centered=True, momentum=0.9),
|
||||
"RMSprop_with_weight_decay": lambda p: torch.optim.RMSprop(
|
||||
p, 0.1, weight_decay=1e-2
|
||||
),
|
||||
"RMSprop_with_weight_decay_and_centered": lambda p: torch.optim.RMSprop(
|
||||
p, 0.1, weight_decay=1e-6, centered=True
|
||||
),
|
||||
"RMSprop_with_weight_decay_and_centered_and_momentum": lambda p: torch.optim.RMSprop(
|
||||
p, 0.1, weight_decay=1e-6, centered=True, momentum=0.9
|
||||
),
|
||||
"SGD": lambda p: torch.optim.SGD(p, 0.1),
|
||||
"SGD_with_weight_decay": lambda p: torch.optim.SGD(p, 0.1, weight_decay=1e-2),
|
||||
"SGD_with_weight_decay_and_momentum": lambda p: torch.optim.SGD(p, 0.1, momentum=0.9, weight_decay=1e-2),
|
||||
"SGD_with_weight_decay_and_nesterov_momentum":
|
||||
lambda p: torch.optim.SGD(p, 0.1, momentum=0.9, weight_decay=1e-6, nesterov=True),
|
||||
"SGD_with_weight_decay_and_momentum": lambda p: torch.optim.SGD(
|
||||
p, 0.1, momentum=0.9, weight_decay=1e-2
|
||||
),
|
||||
"SGD_with_weight_decay_and_nesterov_momentum": lambda p: torch.optim.SGD(
|
||||
p, 0.1, momentum=0.9, weight_decay=1e-6, nesterov=True
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@ -75,11 +91,11 @@ def run(optimizer_name, iterations, sample_every):
|
||||
loss.backward()
|
||||
|
||||
def closure():
|
||||
return torch.tensor([10.])
|
||||
return torch.tensor([10.0])
|
||||
|
||||
optimizer.step(closure)
|
||||
|
||||
if i % sample_every == 0:
|
||||
|
||||
values.append(
|
||||
[p.clone().flatten().data.numpy() for p in model.parameters()]
|
||||
)
|
||||
@ -89,7 +105,7 @@ def run(optimizer_name, iterations, sample_every):
|
||||
|
||||
def emit(optimizer_parameter_map):
|
||||
# Don't write generated with an @ in front, else this file is recognized as generated.
|
||||
print("// @{} from {}".format('generated', __file__))
|
||||
print("// @{} from {}".format("generated", __file__))
|
||||
print(HEADER)
|
||||
for optimizer_name, parameters in optimizer_parameter_map.items():
|
||||
print(PARAMETERS.format(optimizer_name))
|
||||
@ -115,7 +131,7 @@ def main():
|
||||
|
||||
optimizer_parameter_map = {}
|
||||
for optimizer in OPTIMIZERS.keys():
|
||||
sys.stderr.write(f'Evaluating {optimizer} ...\n')
|
||||
sys.stderr.write(f"Evaluating {optimizer} ...\n")
|
||||
optimizer_parameter_map[optimizer] = run(
|
||||
optimizer, options.iterations, options.sample_every
|
||||
)
|
||||
|
@ -1,5 +1,6 @@
|
||||
import sys
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@ -21,7 +22,7 @@ class FileSetup:
|
||||
|
||||
|
||||
class EvalModeForLoadedModule(FileSetup):
|
||||
path = 'dropout_model.pt'
|
||||
path = "dropout_model.pt"
|
||||
|
||||
def setup(self):
|
||||
class Model(torch.jit.ScriptModule):
|
||||
@ -40,7 +41,7 @@ class EvalModeForLoadedModule(FileSetup):
|
||||
|
||||
|
||||
class SerializationInterop(FileSetup):
|
||||
path = 'ivalue.pt'
|
||||
path = "ivalue.pt"
|
||||
|
||||
def setup(self):
|
||||
ones = torch.ones(2, 2)
|
||||
@ -53,7 +54,7 @@ class SerializationInterop(FileSetup):
|
||||
|
||||
# See testTorchSaveError in test/cpp/jit/tests.h for usage
|
||||
class TorchSaveError(FileSetup):
|
||||
path = 'eager_value.pt'
|
||||
path = "eager_value.pt"
|
||||
|
||||
def setup(self):
|
||||
ones = torch.ones(2, 2)
|
||||
@ -63,8 +64,9 @@ class TorchSaveError(FileSetup):
|
||||
|
||||
torch.save(value, self.path, _use_new_zipfile_serialization=False)
|
||||
|
||||
|
||||
class TorchSaveJitStream_CUDA(FileSetup):
|
||||
path = 'saved_stream_model.pt'
|
||||
path = "saved_stream_model.pt"
|
||||
|
||||
def setup(self):
|
||||
if not torch.cuda.is_available():
|
||||
@ -77,7 +79,9 @@ class TorchSaveJitStream_CUDA(FileSetup):
|
||||
b = torch.rand(3, 4, device="cuda")
|
||||
|
||||
with torch.cuda.stream(s):
|
||||
is_stream_s = torch.cuda.current_stream(s.device_index()).id() == s.id()
|
||||
is_stream_s = (
|
||||
torch.cuda.current_stream(s.device_index()).id() == s.id()
|
||||
)
|
||||
c = torch.cat((a, b), 0).to("cuda")
|
||||
s.synchronize()
|
||||
return is_stream_s, a, b, c
|
||||
@ -93,9 +97,10 @@ tests = [
|
||||
EvalModeForLoadedModule(),
|
||||
SerializationInterop(),
|
||||
TorchSaveError(),
|
||||
TorchSaveJitStream_CUDA()
|
||||
TorchSaveJitStream_CUDA(),
|
||||
]
|
||||
|
||||
|
||||
def setup():
|
||||
for test in tests:
|
||||
test.setup()
|
||||
|
Reference in New Issue
Block a user