UFMT formatting on test/autograd test/ao test/cpp test/backends (#123369)

Partially addresses #123062

Ran lintrunner on
- test/_test_bazel.py
- test/ao
- test/autograd test/backends test/benchmark_uitls test/conftest.py test/bottleneck_test test/cpp

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123369
Approved by: https://github.com/huydhn
This commit is contained in:
Arun Pa
2024-04-05 18:51:38 +00:00
committed by PyTorch MergeBot
parent de7edeea25
commit f71e368969
23 changed files with 1914 additions and 1035 deletions

View File

@ -1014,33 +1014,6 @@ exclude_patterns = [
'test/_nvfuser/test_dynamo.py',
'test/_nvfuser/test_python_frontend.py',
'test/_nvfuser/test_torchscript.py',
'test/_test_bazel.py',
'test/ao/sparsity/test_activation_sparsifier.py',
'test/ao/sparsity/test_composability.py',
'test/ao/sparsity/test_data_scheduler.py',
'test/ao/sparsity/test_data_sparsifier.py',
'test/ao/sparsity/test_kernels.py',
'test/ao/sparsity/test_parametrization.py',
'test/ao/sparsity/test_qlinear_packed_params.py',
'test/ao/sparsity/test_scheduler.py',
'test/ao/sparsity/test_sparsifier.py',
'test/ao/sparsity/test_sparsity_utils.py',
'test/ao/sparsity/test_structured_sparsifier.py',
'test/autograd/test_complex.py',
'test/autograd/test_fallback.py',
'test/autograd/test_functional.py',
'test/backends/xeon/test_launch.py',
'test/benchmark_utils/test_benchmark_utils.py',
'test/bottleneck_test/test.py',
'test/bottleneck_test/test_args.py',
'test/bottleneck_test/test_cuda.py',
'test/conftest.py',
'test/cpp/__init__.py',
'test/cpp/aot_inductor/test.py',
'test/cpp/api/init_baseline.py',
'test/cpp/api/optim_baseline.py',
'test/cpp/jit/__init__.py',
'test/cpp/jit/tests_setup.py',
'test/cpp_api_parity/__init__.py',
'test/cpp_api_parity/functional_impl_check.py',
'test/cpp_api_parity/module_impl_check.py',

View File

@ -11,11 +11,14 @@ The name uses underscore `_test_bazel.py` to avoid globbing into other non-bazel
import torch
def test_sum() -> None:
assert torch.eq(torch.tensor([[1, 2, 3]]) + torch.tensor([[4, 5, 6]]), torch.tensor([[5, 7, 9]])).all()
assert torch.eq(
torch.tensor([[1, 2, 3]]) + torch.tensor([[4, 5, 6]]), torch.tensor([[5, 7, 9]])
).all()
def test_simple_compile_eager() -> None:
def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
a = torch.sin(x)
b = torch.cos(y)

View File

@ -1,16 +1,21 @@
# Owner(s): ["module: unknown"]
import copy
from torch.testing._internal.common_utils import TestCase, skipIfTorchDynamo
import logging
import torch
from torch.ao.pruning._experimental.activation_sparsifier.activation_sparsifier import ActivationSparsifier
import torch.nn as nn
import torch.nn.functional as F
from torch.ao.pruning.sparsifier.utils import module_to_fqn
from typing import List
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO)
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.ao.pruning._experimental.activation_sparsifier.activation_sparsifier import (
ActivationSparsifier,
)
from torch.ao.pruning.sparsifier.utils import module_to_fqn
from torch.testing._internal.common_utils import skipIfTorchDynamo, TestCase
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
)
class Model(nn.Module):
@ -45,7 +50,7 @@ class TestActivationSparsifier(TestCase):
in the activation sparsifier
"""
sparsifier_defaults = activation_sparsifier.defaults
combined_defaults = {**defaults, 'sparse_config': sparse_config}
combined_defaults = {**defaults, "sparse_config": sparse_config}
# more keys are populated in activation sparsifier (eventhough they may be None)
assert len(combined_defaults) <= len(activation_sparsifier.defaults)
@ -54,7 +59,9 @@ class TestActivationSparsifier(TestCase):
# all the keys in combined_defaults should be present in sparsifier defaults
assert config == combined_defaults.get(key, None)
def _check_register_layer(self, activation_sparsifier, defaults, sparse_config, layer_args_list):
def _check_register_layer(
self, activation_sparsifier, defaults, sparse_config, layer_args_list
):
"""Checks if layers in the model are correctly mapped to it's arguments.
Args:
@ -82,14 +89,14 @@ class TestActivationSparsifier(TestCase):
sparse_config_actual = copy.deepcopy(sparse_config)
sparse_config_actual.update(sparse_config_layer)
name = module_to_fqn(activation_sparsifier.model, layer_arg['layer'])
name = module_to_fqn(activation_sparsifier.model, layer_arg["layer"])
assert data_groups[name]['sparse_config'] == sparse_config_actual
assert data_groups[name]["sparse_config"] == sparse_config_actual
# assert the rest
other_config_actual = copy.deepcopy(defaults)
other_config_actual.update(layer_arg)
other_config_actual.pop('layer')
other_config_actual.pop("layer")
for key, value in other_config_actual.items():
assert key in data_groups[name]
@ -119,13 +126,15 @@ class TestActivationSparsifier(TestCase):
data_agg_actual = data_list[0]
model = activation_sparsifier.model
layer_name = module_to_fqn(model, model.conv1)
agg_fn = activation_sparsifier.data_groups[layer_name]['aggregate_fn']
agg_fn = activation_sparsifier.data_groups[layer_name]["aggregate_fn"]
for i in range(1, len(data_list)):
data_agg_actual = agg_fn(data_agg_actual, data_list[i])
assert 'data' in activation_sparsifier.data_groups[layer_name]
assert torch.all(activation_sparsifier.data_groups[layer_name]['data'] == data_agg_actual)
assert "data" in activation_sparsifier.data_groups[layer_name]
assert torch.all(
activation_sparsifier.data_groups[layer_name]["data"] == data_agg_actual
)
return data_agg_actual
@ -144,11 +153,11 @@ class TestActivationSparsifier(TestCase):
layer_name = module_to_fqn(model, model.conv1)
assert layer_name is not None
reduce_fn = activation_sparsifier.data_groups[layer_name]['reduce_fn']
reduce_fn = activation_sparsifier.data_groups[layer_name]["reduce_fn"]
data_reduce_actual = reduce_fn(data_agg_actual)
mask_fn = activation_sparsifier.data_groups[layer_name]['mask_fn']
sparse_config = activation_sparsifier.data_groups[layer_name]['sparse_config']
mask_fn = activation_sparsifier.data_groups[layer_name]["mask_fn"]
sparse_config = activation_sparsifier.data_groups[layer_name]["sparse_config"]
mask_actual = mask_fn(data_reduce_actual, **sparse_config)
mask_model = activation_sparsifier.get_mask(layer_name)
@ -156,8 +165,7 @@ class TestActivationSparsifier(TestCase):
assert torch.all(mask_model == mask_actual)
for config in activation_sparsifier.data_groups.values():
assert 'data' not in config
assert "data" not in config
def _check_squash_mask(self, activation_sparsifier, data):
"""Makes sure that squash_mask() works as usual. Specifically, checks
@ -172,11 +180,12 @@ class TestActivationSparsifier(TestCase):
data (torch tensor)
dummy batched data
"""
# create a forward hook for checking output == layer(input * mask)
def check_output(name):
mask = activation_sparsifier.get_mask(name)
features = activation_sparsifier.data_groups[name].get('features')
feature_dim = activation_sparsifier.data_groups[name].get('feature_dim')
features = activation_sparsifier.data_groups[name].get("features")
feature_dim = activation_sparsifier.data_groups[name].get("feature_dim")
def hook(module, input, output):
input_data = input[0]
@ -184,20 +193,28 @@ class TestActivationSparsifier(TestCase):
assert torch.all(mask * input_data == output)
else:
for feature_idx in range(0, len(features)):
feature = torch.Tensor([features[feature_idx]], device=input_data.device).long()
inp_data_feature = torch.index_select(input_data, feature_dim, feature)
out_data_feature = torch.index_select(output, feature_dim, feature)
feature = torch.Tensor(
[features[feature_idx]], device=input_data.device
).long()
inp_data_feature = torch.index_select(
input_data, feature_dim, feature
)
out_data_feature = torch.index_select(
output, feature_dim, feature
)
assert torch.all(
mask[feature_idx] * inp_data_feature == out_data_feature
)
assert torch.all(mask[feature_idx] * inp_data_feature == out_data_feature)
return hook
for name, config in activation_sparsifier.data_groups.items():
if 'identity' in name:
config['layer'].register_forward_hook(check_output(name))
if "identity" in name:
config["layer"].register_forward_hook(check_output(name))
activation_sparsifier.model(data)
def _check_state_dict(self, sparsifier1):
"""Checks if loading and restoring of state_dict() works as expected.
Basically, dumps the state of the sparsifier and loads it in the other sparsifier
@ -222,8 +239,8 @@ class TestActivationSparsifier(TestCase):
for name, state in sparsifier2.state.items():
assert name in sparsifier1.state
mask1 = sparsifier1.state[name]['mask']
mask2 = state['mask']
mask1 = sparsifier1.state[name]["mask"]
mask2 = state["mask"]
if mask1 is None:
assert mask2 is None
@ -237,8 +254,8 @@ class TestActivationSparsifier(TestCase):
assert torch.all(mask1 == mask2)
# make sure that the state dict is stored as torch sparse
for state in state_dict['state'].values():
mask = state['mask']
for state in state_dict["state"].values():
mask = state["mask"]
if mask is not None:
if isinstance(mask, List):
for idx in range(len(mask)):
@ -252,8 +269,16 @@ class TestActivationSparsifier(TestCase):
assert layer_name in dg2
# exclude hook and layer
config1 = {key: value for key, value in config.items() if key not in ['hook', 'layer']}
config2 = {key: value for key, value in dg2[layer_name].items() if key not in ['hook', 'layer']}
config1 = {
key: value
for key, value in config.items()
if key not in ["hook", "layer"]
}
config2 = {
key: value
for key, value in dg2[layer_name].items()
if key not in ["hook", "layer"]
}
assert config1 == config2
@ -263,6 +288,7 @@ class TestActivationSparsifier(TestCase):
till squash_mask().
The idea is to check that everything works as expected while in the workflow.
"""
# defining aggregate, reduce and mask functions
def agg_fn(x, y):
return x + y
@ -287,14 +313,9 @@ class TestActivationSparsifier(TestCase):
# Creating default function and sparse configs
# default sparse_config
sparse_config = {
'sparsity_level': 0.5
}
sparse_config = {"sparsity_level": 0.5}
defaults = {
'aggregate_fn': agg_fn,
'reduce_fn': reduce_fn
}
defaults = {"aggregate_fn": agg_fn, "reduce_fn": reduce_fn}
# simulate the workflow
# STEP 1: make data and activation sparsifier object
@ -306,35 +327,41 @@ class TestActivationSparsifier(TestCase):
# STEP 2: Register some layers
register_layer1_args = {
'layer': model.conv1,
'mask_fn': _vanilla_norm_sparsifier
"layer": model.conv1,
"mask_fn": _vanilla_norm_sparsifier,
}
sparse_config_layer1 = {'sparsity_level': 0.3}
sparse_config_layer1 = {"sparsity_level": 0.3}
register_layer2_args = {
'layer': model.linear1,
'features': [0, 10, 234],
'feature_dim': 1,
'mask_fn': _vanilla_norm_sparsifier
"layer": model.linear1,
"features": [0, 10, 234],
"feature_dim": 1,
"mask_fn": _vanilla_norm_sparsifier,
}
sparse_config_layer2 = {'sparsity_level': 0.1}
sparse_config_layer2 = {"sparsity_level": 0.1}
register_layer3_args = {
'layer': model.identity1,
'mask_fn': _vanilla_norm_sparsifier
"layer": model.identity1,
"mask_fn": _vanilla_norm_sparsifier,
}
sparse_config_layer3 = {'sparsity_level': 0.3}
sparse_config_layer3 = {"sparsity_level": 0.3}
register_layer4_args = {
'layer': model.identity2,
'features': [0, 10, 20],
'feature_dim': 1,
'mask_fn': _vanilla_norm_sparsifier
"layer": model.identity2,
"features": [0, 10, 20],
"feature_dim": 1,
"mask_fn": _vanilla_norm_sparsifier,
}
sparse_config_layer4 = {'sparsity_level': 0.1}
sparse_config_layer4 = {"sparsity_level": 0.1}
layer_args_list = [(register_layer1_args, sparse_config_layer1), (register_layer2_args, sparse_config_layer2)]
layer_args_list += [(register_layer3_args, sparse_config_layer3), (register_layer4_args, sparse_config_layer4)]
layer_args_list = [
(register_layer1_args, sparse_config_layer1),
(register_layer2_args, sparse_config_layer2),
]
layer_args_list += [
(register_layer3_args, sparse_config_layer3),
(register_layer4_args, sparse_config_layer4),
]
# Registering..
for layer_args in layer_args_list:
@ -342,7 +369,9 @@ class TestActivationSparsifier(TestCase):
activation_sparsifier.register_layer(**layer_arg, **sparse_config_layer)
# check if things are registered correctly
self._check_register_layer(activation_sparsifier, defaults, sparse_config, layer_args_list)
self._check_register_layer(
activation_sparsifier, defaults, sparse_config, layer_args_list
)
# check state_dict after registering and before model forward
self._check_state_dict(activation_sparsifier)

View File

@ -7,9 +7,14 @@ import torch
import torch.ao.quantization as tq
from torch import nn
from torch.ao import pruning
from torch.testing._internal.common_utils import TestCase
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx, convert_to_reference_fx, prepare_qat_fx
from torch.ao.pruning import fqn_to_module
from torch.ao.quantization.quantize_fx import (
convert_fx,
convert_to_reference_fx,
prepare_fx,
prepare_qat_fx,
)
from torch.testing._internal.common_utils import TestCase
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
@ -21,6 +26,7 @@ sparse_defaults = {
"zeros_per_block": 4,
}
def _get_model_and_sparsifier_and_sparse_config(qconfig=None):
model = nn.Sequential(
nn.Linear(4, 4), # 0
@ -40,7 +46,7 @@ def _get_model_and_sparsifier_and_sparse_config(qconfig=None):
sparse_config = [
{
"tensor_fqn": '5.weight',
"tensor_fqn": "5.weight",
"sparsity_level": 0.7,
"sparse_block_shape": (1, 4),
"zeros_per_block": 4,
@ -49,15 +55,18 @@ def _get_model_and_sparsifier_and_sparse_config(qconfig=None):
]
return model, sparsifier, sparse_config
def _squash_mask_calibrate_and_convert(model, sparsifier, input):
sparsifier.step()
sparsifier.squash_mask()
model(input)
tq.convert(model, inplace=True)
def _calculate_sparsity(tensor):
return ((tensor == 0).sum() / tensor.numel()).item()
# This series of tests are to check the composability goals for sparsity and quantization. Namely
# that performing quantization and sparsity model manipulations in various orderings
# does not cause problems
@ -70,7 +79,9 @@ class TestComposability(TestCase):
mod,
sparsifier,
sparse_config,
) = _get_model_and_sparsifier_and_sparse_config(tq.get_default_qconfig("fbgemm"))
) = _get_model_and_sparsifier_and_sparse_config(
tq.get_default_qconfig("fbgemm")
)
tq.prepare(mod, inplace=True)
sparsifier.prepare(mod, config=sparse_config)
@ -81,9 +92,7 @@ class TestComposability(TestCase):
# check that correct observers were inserted
self.assertTrue(hasattr(mod[5], "activation_post_process"))
_squash_mask_calibrate_and_convert(
mod, sparsifier, torch.randn(1, 4, 4, 4)
)
_squash_mask_calibrate_and_convert(mod, sparsifier, torch.randn(1, 4, 4, 4))
# check that final module is the expected quantized module and that the model runs
self.assertTrue(isinstance(mod[5], torch.ao.nn.quantized.Linear))
@ -99,7 +108,9 @@ class TestComposability(TestCase):
mod,
sparsifier,
sparse_config,
) = _get_model_and_sparsifier_and_sparse_config(tq.get_default_qconfig("fbgemm"))
) = _get_model_and_sparsifier_and_sparse_config(
tq.get_default_qconfig("fbgemm")
)
sparsifier.prepare(mod, config=sparse_config)
tq.prepare(mod, inplace=True)
@ -113,9 +124,7 @@ class TestComposability(TestCase):
# occurred successfully
self.assertTrue(hasattr(mod[5], "activation_post_process"))
_squash_mask_calibrate_and_convert(
mod, sparsifier, torch.randn(1, 4, 4, 4)
)
_squash_mask_calibrate_and_convert(mod, sparsifier, torch.randn(1, 4, 4, 4))
# check that final module is the expected quantized module and that the model runs
self.assertTrue(isinstance(mod[5], torch.ao.nn.quantized.Linear))
@ -130,7 +139,9 @@ class TestComposability(TestCase):
mod,
sparsifier,
sparse_config,
) = _get_model_and_sparsifier_and_sparse_config(tq.get_default_qconfig("fbgemm"))
) = _get_model_and_sparsifier_and_sparse_config(
tq.get_default_qconfig("fbgemm")
)
sparsifier.prepare(mod, config=sparse_config)
tq.prepare(mod, inplace=True)
@ -168,7 +179,9 @@ class TestComposability(TestCase):
mod,
sparsifier,
sparse_config,
) = _get_model_and_sparsifier_and_sparse_config(tq.get_default_qconfig("fbgemm"))
) = _get_model_and_sparsifier_and_sparse_config(
tq.get_default_qconfig("fbgemm")
)
sparsifier.prepare(mod, config=sparse_config)
tq.fuse_modules(mod, [["5", "6"]], inplace=True)
mod[5].qconfig = tq.get_default_qconfig("fbgemm")
@ -182,9 +195,7 @@ class TestComposability(TestCase):
# check that correct observers were inserted and that matching
# occurred successfully
self.assertTrue(hasattr(mod[5], "activation_post_process"))
_squash_mask_calibrate_and_convert(
mod, sparsifier, torch.randn(1, 4, 4, 4)
)
_squash_mask_calibrate_and_convert(mod, sparsifier, torch.randn(1, 4, 4, 4))
# check that final module is the expected quantized module and that the model runs
self.assertTrue(isinstance(mod[5], torch.ao.nn.intrinsic.quantized.LinearReLU))
@ -197,7 +208,9 @@ class TestComposability(TestCase):
mod,
sparsifier,
_,
) = _get_model_and_sparsifier_and_sparse_config(tq.get_default_qconfig("fbgemm"))
) = _get_model_and_sparsifier_and_sparse_config(
tq.get_default_qconfig("fbgemm")
)
tq.fuse_modules(mod, [["5", "6"]], inplace=True)
# its absolutely broken by fusion but will still work if you put the correct fqn in
@ -261,9 +274,7 @@ class TestComposability(TestCase):
# occurred successfully
self.assertTrue(hasattr(mod[5], "activation_post_process"))
self.assertTrue(isinstance(mod[5], torch.ao.nn.qat.Linear))
_squash_mask_calibrate_and_convert(
mod, sparsifier, torch.randn(1, 4, 4, 4)
)
_squash_mask_calibrate_and_convert(mod, sparsifier, torch.randn(1, 4, 4, 4))
# check that final module is the expected quantized module and that the model runs
self.assertTrue(isinstance(mod[5], torch.ao.nn.quantized.Linear))
self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4]))
@ -301,9 +312,7 @@ class TestComposability(TestCase):
self.assertTrue(hasattr(mod[5], "activation_post_process"))
self.assertTrue(isinstance(mod[5], torch.ao.nn.qat.Linear))
_squash_mask_calibrate_and_convert(
mod, sparsifier, torch.randn(1, 4, 4, 4)
)
_squash_mask_calibrate_and_convert(mod, sparsifier, torch.randn(1, 4, 4, 4))
# check that final module is the expected quantized module and that the model runs
self.assertTrue(isinstance(mod[5], torch.ao.nn.quantized.Linear))
@ -313,6 +322,7 @@ class TestComposability(TestCase):
cur_sparsity = _calculate_sparsity(mod[5]._weight_bias()[0])
self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
def _module_has_activation_post_process(model, fqn_of_module):
for node in model.graph.nodes:
# look for an observer whose arg is the target module
@ -321,10 +331,12 @@ def _module_has_activation_post_process(model, fqn_of_module):
return True
return False
class TestFxComposability(TestCase):
r"""This series of tests checks that various steps of the quantization and sparsity flow
compose cleanly despite variation in sequencing.
"""
def test_q_prep_fx_before_s_prep(self):
r"""
This test checks that the ordering of prepare_fx -> sparse prepare -> convert_fx
@ -340,10 +352,11 @@ class TestFxComposability(TestCase):
example = torch.randn(1, 4, 4, 4)
qconfig = tq.get_default_qconfig("fbgemm")
qconfig_mapping = tq.QConfigMapping() \
.set_module_name("4", qconfig) \
qconfig_mapping = (
tq.QConfigMapping()
.set_module_name("4", qconfig)
.set_module_name("5", qconfig)
)
mod = prepare_fx(mod, qconfig_mapping, (example,))
@ -374,7 +387,11 @@ class TestFxComposability(TestCase):
mod = convert_fx(mod)
# check that final module is the expected quantized module and that the model runs
self.assertTrue(isinstance(fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.quantized.LinearReLU))
self.assertTrue(
isinstance(
fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.quantized.LinearReLU
)
)
self.assertEqual(mod(example).shape, torch.Size([1, 4, 4, 4]))
# check that module was actually sparsified
@ -399,9 +416,11 @@ class TestFxComposability(TestCase):
example = torch.randn(1, 4, 4, 4)
qconfig = tq.get_default_qconfig("fbgemm")
qconfig_mapping = tq.QConfigMapping() \
.set_module_name("4", qconfig) \
qconfig_mapping = (
tq.QConfigMapping()
.set_module_name("4", qconfig)
.set_module_name("5", qconfig)
)
mod = prepare_fx(mod, qconfig_mapping, (example,))
@ -432,9 +451,15 @@ class TestFxComposability(TestCase):
mod = convert_to_reference_fx(mod)
# check that final module is the expected quantized module and that the model runs
self.assertTrue(isinstance(fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.LinearReLU))
self.assertTrue(
isinstance(fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.LinearReLU)
)
self.assertEqual(mod(example).shape, torch.Size([1, 4, 4, 4]))
self.assertTrue(isinstance(fqn_to_module(mod, "5.0"), torch.ao.nn.quantized.reference.Linear))
self.assertTrue(
isinstance(
fqn_to_module(mod, "5.0"), torch.ao.nn.quantized.reference.Linear
)
)
# check that module was actually sparsified
cur_sparsity = _calculate_sparsity(fqn_to_module(mod, "5.0.weight"))
@ -459,9 +484,11 @@ class TestFxComposability(TestCase):
example = torch.randn(1, 4, 4, 4)
qconfig = tq.get_default_qconfig("fbgemm")
qconfig_mapping = tq.QConfigMapping() \
.set_module_name("4", qconfig) \
qconfig_mapping = (
tq.QConfigMapping()
.set_module_name("4", qconfig)
.set_module_name("5", qconfig)
)
mod = prepare_fx(mod, qconfig_mapping, (example,))
# check that correct modules had parametrizations added and
@ -478,7 +505,11 @@ class TestFxComposability(TestCase):
mod = convert_fx(mod)
# check that final module is the expected quantized module and that the model runs
self.assertTrue(isinstance(fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.quantized.LinearReLU))
self.assertTrue(
isinstance(
fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.quantized.LinearReLU
)
)
self.assertEqual(mod(example).shape, torch.Size([1, 4, 4, 4]))
# check that module was actually sparsified
@ -504,16 +535,20 @@ class TestFxComposability(TestCase):
example = torch.randn(1, 4, 4, 4)
qconfig = tq.get_default_qat_qconfig("fbgemm")
qconfig_mapping = tq.QConfigMapping() \
.set_module_name("4", qconfig) \
qconfig_mapping = (
tq.QConfigMapping()
.set_module_name("4", qconfig)
.set_module_name("5", qconfig)
)
mod = prepare_qat_fx(mod, qconfig_mapping, (example,))
# check that correct modules had parametrizations added and
# that none were lost during prepare
self.assertTrue(hasattr(fqn_to_module(mod, "0.0"), "parametrizations"))
self.assertTrue(hasattr(fqn_to_module(mod, "5"), "parametrizations"))
self.assertTrue(isinstance(fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.qat.LinearReLU))
self.assertTrue(
isinstance(fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.qat.LinearReLU)
)
# check that correct observers were inserted and that matching
# occurred successfully
@ -524,7 +559,11 @@ class TestFxComposability(TestCase):
mod = convert_fx(mod)
# check that final module is the expected quantized module and that the model runs
self.assertTrue(isinstance(fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.quantized.LinearReLU))
self.assertTrue(
isinstance(
fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.quantized.LinearReLU
)
)
self.assertEqual(mod(example).shape, torch.Size([1, 4, 4, 4]))
# check that module was actually sparsified
@ -550,9 +589,11 @@ class TestFxComposability(TestCase):
example = torch.randn(1, 4, 4, 4)
qconfig = tq.get_default_qconfig("fbgemm")
qconfig_mapping = tq.QConfigMapping() \
.set_module_name("4", qconfig) \
qconfig_mapping = (
tq.QConfigMapping()
.set_module_name("4", qconfig)
.set_module_name("5", qconfig)
)
mod = prepare_fx(mod, qconfig_mapping, (example,))
# check that correct modules had parametrizations added and
@ -569,9 +610,15 @@ class TestFxComposability(TestCase):
mod = convert_to_reference_fx(mod)
# check that final module is the expected quantized module and that the model runs
self.assertTrue(isinstance(fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.LinearReLU))
self.assertTrue(
isinstance(fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.LinearReLU)
)
self.assertEqual(mod(example).shape, torch.Size([1, 4, 4, 4]))
self.assertTrue(isinstance(fqn_to_module(mod, "5.0"), torch.ao.nn.quantized.reference.Linear))
self.assertTrue(
isinstance(
fqn_to_module(mod, "5.0"), torch.ao.nn.quantized.reference.Linear
)
)
# check that module was actually sparsified
cur_sparsity = _calculate_sparsity(fqn_to_module(mod, "5.0.weight"))

View File

@ -1,17 +1,20 @@
# Owner(s): ["module: unknown"]
import copy
import logging
import warnings
from torch.testing._internal.common_utils import TestCase
from torch import nn
import torch
from typing import Tuple
import copy
from torch.ao.pruning._experimental.data_sparsifier import DataNormSparsifier
import torch
from torch import nn
from torch.ao.pruning._experimental.data_scheduler import BaseDataScheduler
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO)
from torch.ao.pruning._experimental.data_sparsifier import DataNormSparsifier
from torch.testing._internal.common_utils import TestCase
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
)
class ImplementedDataScheduler(BaseDataScheduler):
@ -20,27 +23,32 @@ class ImplementedDataScheduler(BaseDataScheduler):
def get_schedule_param(self):
if self.last_epoch > 0:
return {name: config['sparsity_level'] * 0.5
for name, config in self.data_sparsifier.data_groups.items()}
return {
name: config["sparsity_level"] * 0.5
for name, config in self.data_sparsifier.data_groups.items()
}
else:
return self.base_param
class TestBaseDataScheduler(TestCase):
def _get_data(self):
tensor1, param1, emb1 = torch.randn(5, 5), nn.Parameter(torch.randn(10, 10)), nn.Embedding(50, 5)
data_list = [
('tensor1', tensor1), ('param1', param1), ('emb1', emb1)
]
tensor1, param1, emb1 = (
torch.randn(5, 5),
nn.Parameter(torch.randn(10, 10)),
nn.Embedding(50, 5),
)
data_list = [("tensor1", tensor1), ("param1", param1), ("emb1", emb1)]
defaults = {
'sparsity_level': 0.7,
'sparse_block_shape': (1, 4),
'zeros_per_block': 2
"sparsity_level": 0.7,
"sparse_block_shape": (1, 4),
"zeros_per_block": 2,
}
data_with_config = [
{
'name': 'tensor2', 'data': torch.randn(4, 4),
'config': {'sparsity_level': 0.3}
"name": "tensor2",
"data": torch.randn(4, 4),
"config": {"sparsity_level": 0.3},
}
]
return data_list, data_with_config, defaults
@ -48,7 +56,11 @@ class TestBaseDataScheduler(TestCase):
def _get_sparsifier(self, data_list, data_with_config, defaults):
sparsifier = DataNormSparsifier(data_list, **defaults)
for data_config_dict in data_with_config:
name, data, config = data_config_dict['name'], data_config_dict['data'], data_config_dict['config']
name, data, config = (
data_config_dict["name"],
data_config_dict["data"],
data_config_dict["config"],
)
sparsifier.add_data(name=name, data=data, **config)
return sparsifier
@ -57,7 +69,7 @@ class TestBaseDataScheduler(TestCase):
return scheduler
def _get_schedule_param(self):
return 'sparsity_level'
return "sparsity_level"
def _get_name_data_config(self, some_data, defaults):
config = copy.deepcopy(defaults)
@ -66,7 +78,11 @@ class TestBaseDataScheduler(TestCase):
name, data = some_data
else:
# dealing with data_with_config
name, data, new_config = some_data['name'], some_data['data'], some_data['config']
name, data, new_config = (
some_data["name"],
some_data["data"],
some_data["config"],
)
config.update(new_config)
return name, data, config
@ -102,8 +118,11 @@ class TestBaseDataScheduler(TestCase):
# Make sure there is no warning related to the base_data_scheduler
for warning in w:
fname = warning.filename
fname = '/'.join(fname.split('/')[-5:])
assert fname != 'torch/ao/sparsity/experimental/scheduler/data_scheduler/base_data_scheduler.py'
fname = "/".join(fname.split("/")[-5:])
assert (
fname
!= "torch/ao/sparsity/experimental/scheduler/data_scheduler/base_data_scheduler.py"
)
def test_step(self):
data_list, data_with_config, defaults = self._get_data()
@ -115,14 +134,19 @@ class TestBaseDataScheduler(TestCase):
for some_data in all_data:
name, _, config = self._get_name_data_config(some_data, defaults)
assert sparsifier.data_groups[name][schedule_param] == config[schedule_param]
assert (
sparsifier.data_groups[name][schedule_param] == config[schedule_param]
)
sparsifier.step()
scheduler.step()
for some_data in all_data:
name, _, config = self._get_name_data_config(some_data, defaults)
assert sparsifier.data_groups[name][schedule_param] == config[schedule_param] * 0.5
assert (
sparsifier.data_groups[name][schedule_param]
== config[schedule_param] * 0.5
)
# checking step count
step_cnt = 5
@ -130,7 +154,9 @@ class TestBaseDataScheduler(TestCase):
sparsifier.step()
scheduler.step()
assert scheduler._step_count == step_cnt + 2 # step_cnt + step above + 1 step in constructor
assert (
scheduler._step_count == step_cnt + 2
) # step_cnt + step above + 1 step in constructor
def test_state_dict(self):
data_list, data_with_config, defaults = self._get_data()

View File

@ -1,20 +1,28 @@
# Owner(s): ["module: unknown"]
import copy
import itertools
import logging
import math
from typing import Tuple
import torch
from torch import nn
from torch.ao.pruning._experimental.data_sparsifier import (
BaseDataSparsifier,
DataNormSparsifier,
)
from torch.ao.pruning._experimental.data_sparsifier.quantization_utils import (
post_training_sparse_quantize,
)
from torch.nn.utils.parametrize import is_parametrized
from torch.testing._internal.common_utils import TestCase
from typing import Tuple
from torch import nn
import itertools
import math
import copy
from torch.ao.pruning._experimental.data_sparsifier import BaseDataSparsifier, DataNormSparsifier
from torch.ao.pruning._experimental.data_sparsifier.quantization_utils import post_training_sparse_quantize
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO)
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
)
class ImplementedSparsifier(BaseDataSparsifier):
@ -25,17 +33,18 @@ class ImplementedSparsifier(BaseDataSparsifier):
mask = self.get_mask(name)
mask[0] = 0
linear_state = self.state[name]
linear_state['step_count'] = linear_state.get('step_count', 0) + 1
linear_state["step_count"] = linear_state.get("step_count", 0) + 1
class _BaseDataSparsiferTestCase(TestCase):
r"""This helper test class takes in any supported type of and runs some tests.
The user is required to pass in the data that needs to sparsified and the
runner will run some tests that needs to be passed in order for the data
type to be supported.
TODO: Change the structure by creating a separate test case class for each
member function
The user is required to pass in the data that needs to sparsified and the
runner will run some tests that needs to be passed in order for the data
type to be supported.
TODO: Change the structure by creating a separate test case class for each
member function
"""
def run_all_checks(self, data_list, data_with_config, defaults):
self.check_constructor(data_list, data_with_config, defaults)
self.check_squash_mask(data_list, data_with_config, defaults)
@ -52,32 +61,49 @@ class _BaseDataSparsiferTestCase(TestCase):
config = defaults
else:
# dealing with data_with_config
name, data, config = some_data['name'], some_data['data'], some_data['config']
name, data, config = (
some_data["name"],
some_data["data"],
some_data["config"],
)
return name, data, config
@staticmethod
def _make_sparsifier(data_list, data_with_config, defaults,
sparsifier_type=None, sparsifier_kwargs=None):
def _make_sparsifier(
data_list,
data_with_config,
defaults,
sparsifier_type=None,
sparsifier_kwargs=None,
):
if sparsifier_type is None:
sparsifier = ImplementedSparsifier(data_list=data_list, **defaults)
else:
kwargs = copy.deepcopy(defaults)
kwargs.update(sparsifier_kwargs)
kwargs['data_list'] = data_list
kwargs["data_list"] = data_list
sparsifier = sparsifier_type(**kwargs)
assert len(sparsifier.data_groups) == len(data_list)
for data_config_dict in data_with_config:
name, data, config = data_config_dict['name'], data_config_dict['data'], data_config_dict['config']
name, data, config = (
data_config_dict["name"],
data_config_dict["data"],
data_config_dict["config"],
)
sparsifier.add_data(name=name, data=data, **config)
return sparsifier
def check_constructor(self, data_list, data_with_config, defaults, **kwargs):
sparsifier = self._make_sparsifier(data_list, data_with_config, defaults=defaults, **kwargs)
self.assertEqual(len(sparsifier.data_groups),
len(data_list) + len(data_with_config),
msg="Sparsifier data groups don't match the input "
f"({len(sparsifier.data_groups)} vs. "
f"{len(data_list) + len(data_with_config)}).")
sparsifier = self._make_sparsifier(
data_list, data_with_config, defaults=defaults, **kwargs
)
self.assertEqual(
len(sparsifier.data_groups),
len(data_list) + len(data_with_config),
msg="Sparsifier data groups don't match the input "
f"({len(sparsifier.data_groups)} vs. "
f"{len(data_list) + len(data_with_config)}).",
)
all_data = data_list + data_with_config
@ -87,7 +113,9 @@ class _BaseDataSparsiferTestCase(TestCase):
self.assertEqual(sparsifier.data_groups[name], config)
def check_step(self, data_list, data_with_config, defaults, **kwargs):
sparsifier = self._make_sparsifier(data_list, data_with_config, defaults=defaults, **kwargs)
sparsifier = self._make_sparsifier(
data_list, data_with_config, defaults=defaults, **kwargs
)
all_data = data_list + data_with_config
# Check data and mask before doing the step
@ -114,11 +142,13 @@ class _BaseDataSparsiferTestCase(TestCase):
self.assertEqualBroadcasting(sparsified_data[0], 0)
self.assertEqual(original_data, data)
self.assertEqualBroadcasting(mask[0], 0)
assert 'step_count' in sparsifier.state[name]
assert sparsifier.state[name]['step_count'] == 3
assert "step_count" in sparsifier.state[name]
assert sparsifier.state[name]["step_count"] == 3
def check_squash_mask(self, data_list, data_with_config, defaults, **kwargs):
sparsifier = self._make_sparsifier(data_list, data_with_config, defaults=defaults, **kwargs)
sparsifier = self._make_sparsifier(
data_list, data_with_config, defaults=defaults, **kwargs
)
all_data = data_list + data_with_config
for some_data in all_data:
name, _, _ = self._get_name_data_config(some_data)
@ -129,15 +159,21 @@ class _BaseDataSparsiferTestCase(TestCase):
for some_data in all_data:
name, _, _ = self._get_name_data_config(some_data)
assert not is_parametrized(sparsifier._container, name) # not parametrized anymore
assert not is_parametrized(
sparsifier._container, name
) # not parametrized anymore
with self.assertRaises(ValueError):
sparsifier.get_data(name, return_original=True)
def check_add_data(self, data_list, data_with_config, defaults, **kwargs):
sparsifier = self._make_sparsifier(data_list, data_with_config, defaults=defaults, **kwargs)
sparsifier = self._make_sparsifier(
data_list, data_with_config, defaults=defaults, **kwargs
)
all_data = data_list + data_with_config
for some_data in all_data:
name1, data1, config = self._get_name_data_config(some_data, defaults=defaults)
name1, data1, config = self._get_name_data_config(
some_data, defaults=defaults
)
data1 = sparsifier._extract_weight(data1)
data1_old = copy.deepcopy(data1)
assert torch.all(data1 == sparsifier.get_data(name=name1))
@ -145,18 +181,28 @@ class _BaseDataSparsiferTestCase(TestCase):
sparsifier.step()
mask = sparsifier.get_mask(name1)
data2 = torch.randn(data1.shape) # add another data with the same shape as original data
data2 = torch.randn(
data1.shape
) # add another data with the same shape as original data
sparsifier.add_data(name=name1, data=data2)
assert torch.all(data2 == sparsifier.get_data(name=name1))
assert torch.all(sparsifier.get_mask(name1) == mask) # mask should not change
assert torch.all(
sparsifier.get_mask(name1) == mask
) # mask should not change
assert torch.all(data1_old == data1)
assert sparsifier.data_groups[name1] == config # if replaced old_config should match new config
assert (
sparsifier.data_groups[name1] == config
) # if replaced old_config should match new config
def check_state_dict(self, data_list, data_with_config, defaults, **kwargs):
sparsifier1 = self._make_sparsifier(data_list, data_with_config, defaults=defaults, **kwargs)
sparsifier2 = self._make_sparsifier(data_list=[data_list[0]], data_with_config=[], defaults=defaults, **kwargs)
sparsifier1 = self._make_sparsifier(
data_list, data_with_config, defaults=defaults, **kwargs
)
sparsifier2 = self._make_sparsifier(
data_list=[data_list[0]], data_with_config=[], defaults=defaults, **kwargs
)
sparsifier1.step()
state_dict1 = sparsifier1.state_dict()
@ -169,15 +215,17 @@ class _BaseDataSparsiferTestCase(TestCase):
assert len(sparsifier1.state) == len(sparsifier2.state)
assert len(sparsifier1.data_groups) == len(sparsifier2.data_groups)
state1 = state_dict1['state']
state1 = state_dict1["state"]
for name in state1.keys():
# compare mask
assert name in sparsifier2.state
assert 'mask' in sparsifier2.state[name]
assert 'mask' in sparsifier1.state[name]
mask1, mask2 = state1[name]['mask'], sparsifier2.state[name]['mask']
assert "mask" in sparsifier2.state[name]
assert "mask" in sparsifier1.state[name]
mask1, mask2 = state1[name]["mask"], sparsifier2.state[name]["mask"]
assert mask1.is_sparse and not mask2.is_sparse
assert torch.all(mask1.to_dense() == mask2) # mask1 is stored as sparse coo now
assert torch.all(
mask1.to_dense() == mask2
) # mask1 is stored as sparse coo now
# compare data_groups
dg1, dg2 = sparsifier1.data_groups, sparsifier2.data_groups
@ -187,12 +235,14 @@ class _BaseDataSparsiferTestCase(TestCase):
# compare container
container1, container2 = sparsifier1._container, sparsifier2._container
assert torch.all(getattr(container1, name) == getattr(container2, name))
assert is_parametrized(container1, name) == is_parametrized(container2, name)
assert is_parametrized(container1, name) == is_parametrized(
container2, name
)
if is_parametrized(container1, name):
param1 = getattr(container1.parametrizations, name)[0]
param2 = getattr(container2.parametrizations, name)[0]
assert hasattr(param1, 'mask')
assert hasattr(param2, 'mask')
assert hasattr(param1, "mask")
assert hasattr(param2, "mask")
self.assertEqual(param1.__dict__, param2.__dict__)
def check_memory_reference(self, data_list, data_with_config, defaults, **kwargs):
@ -204,29 +254,35 @@ class _BaseDataSparsiferTestCase(TestCase):
This test modifies the data and asserts that data in the sparsifier is changed as well
"""
sparsifier = self._make_sparsifier(data_list, data_with_config, defaults=defaults, **kwargs)
sparsifier = self._make_sparsifier(
data_list, data_with_config, defaults=defaults, **kwargs
)
all_data = data_list + data_with_config
for some_data in all_data:
name, data, _ = self._get_name_data_config(some_data)
weight = sparsifier._extract_weight(data)
weight.data = weight + torch.randn(*weight.shape)
contained_data = sparsifier.get_data(name=name)
assert weight.data.storage().data_ptr() == contained_data.data.storage().data_ptr()
assert (
weight.data.storage().data_ptr()
== contained_data.data.storage().data_ptr()
)
assert torch.all(contained_data == weight)
class _NormDataSparsifierTestCase(_BaseDataSparsiferTestCase):
r"""This helper test class takes in any supported type of and runs some tests.
This inherits the TestBaseDataSparsifierRuner wherein some functions are
over-ridden to take accomodate the specific sparsifier.
TODO: Change the structure by creating a separate test case class for each
member function
This inherits the TestBaseDataSparsifierRuner wherein some functions are
over-ridden to take accomodate the specific sparsifier.
TODO: Change the structure by creating a separate test case class for each
member function
"""
def run_all_checks(self, data_list, defaults, data_with_config, norm_type='L1'):
assert norm_type in ['L1', 'L2']
def run_all_checks(self, data_list, defaults, data_with_config, norm_type="L1"):
assert norm_type in ["L1", "L2"]
kwargs = {
'sparsifier_type': DataNormSparsifier,
'sparsifier_kwargs': {'norm': norm_type}
"sparsifier_type": DataNormSparsifier,
"sparsifier_kwargs": {"norm": norm_type},
}
self.check_constructor(data_list, data_with_config, defaults, **kwargs)
self.check_squash_mask(data_list, data_with_config, defaults, **kwargs)
@ -234,24 +290,28 @@ class _NormDataSparsifierTestCase(_BaseDataSparsiferTestCase):
self.check_state_dict(data_list, data_with_config, defaults, **kwargs)
self.check_step(data_list, data_with_config, defaults, norm_type=norm_type)
self.check_step_2_of_4(norm_type=norm_type)
self.check_sparsity_level(data_list, data_with_config, defaults, norm_type=norm_type)
self.check_sparsity_level(
data_list, data_with_config, defaults, norm_type=norm_type
)
self.check_memory_reference(data_list, data_with_config, defaults, **kwargs)
@staticmethod
def _get_bounds_on_actual_sparsity(config, tensor_shape):
r"""This function gets the bounds on actual sparsity.
Note::
Although we specify the sparsity_level parameter, this does not mean that
the actual sparsity obtained after sparsification is the same as sparsity_level.
The actual sparsity depends largely on the shape and the data itself.
Note::
Although we specify the sparsity_level parameter, this does not mean that
the actual sparsity obtained after sparsification is the same as sparsity_level.
The actual sparsity depends largely on the shape and the data itself.
"""
sparsity_level = config['sparsity_level']
zeros_per_block = config['zeros_per_block']
sparse_block_shape = config['sparse_block_shape']
sparsity_level = config["sparsity_level"]
zeros_per_block = config["zeros_per_block"]
sparse_block_shape = config["sparse_block_shape"]
height, width = tensor_shape[-2], tensor_shape[-1]
block_height, block_width = sparse_block_shape
number_blocks = math.ceil(height / block_height) * math.ceil(width / block_width)
number_blocks = math.ceil(height / block_height) * math.ceil(
width / block_width
)
values_per_block = block_height * block_width
if zeros_per_block == 0:
@ -260,17 +320,23 @@ class _NormDataSparsifierTestCase(_BaseDataSparsiferTestCase):
# min value assumes zeros_per_block is 1
min_values_sparsified = round(number_blocks * sparsity_level)
# max value assumes actual zeros_per_block
max_values_sparsified = min_values_sparsified * min(values_per_block, zeros_per_block)
max_values_sparsified = min_values_sparsified * min(
values_per_block, zeros_per_block
)
lower_bound = min_values_sparsified / (height * width)
upper_bound = min(1.0, max_values_sparsified / (height * width))
lower_bound, upper_bound = round(lower_bound, 3), round(upper_bound, 3)
return lower_bound, upper_bound
def check_step(self, data_list, data_with_config, defaults, norm_type='L1'):
sparsifier = self._make_sparsifier(data_list, data_with_config, defaults,
sparsifier_type=DataNormSparsifier,
sparsifier_kwargs={'norm': norm_type})
def check_step(self, data_list, data_with_config, defaults, norm_type="L1"):
sparsifier = self._make_sparsifier(
data_list,
data_with_config,
defaults,
sparsifier_type=DataNormSparsifier,
sparsifier_kwargs={"norm": norm_type},
)
all_data = data_list + data_with_config
# mask before step() should not be sparsified
@ -289,29 +355,39 @@ class _NormDataSparsifierTestCase(_BaseDataSparsiferTestCase):
mask = mask.to(torch.float)
actual_sparsity = round(1 - mask.mean().item(), 3)
assert actual_sparsity >= lb and actual_sparsity <= ub
assert actual_sparsity > 0.0 # exact sparsity level cannot be achieved due to size of tensor
assert (
actual_sparsity > 0.0
) # exact sparsity level cannot be achieved due to size of tensor
iters_before_collapse = 100
test_sparsifier = DataNormSparsifier(sparsity_level=0.5,
sparse_block_shape=(1, 4),
zeros_per_block=4,
norm=norm_type)
test_sparsifier = DataNormSparsifier(
sparsity_level=0.5,
sparse_block_shape=(1, 4),
zeros_per_block=4,
norm=norm_type,
)
for _ in range(iters_before_collapse):
new_data = torch.randn(20, 20)
test_sparsifier.add_data(name='test_data', data=new_data)
test_sparsifier.add_data(name="test_data", data=new_data)
test_sparsifier.step()
mask = test_sparsifier.get_mask(name='test_data')
mask = test_sparsifier.get_mask(name="test_data")
mask = mask.to(torch.float)
assert (1.0 - mask.mean().item()) > 0 # some sparsity achieved
def check_step_2_of_4(self, norm_type):
# overriding default config for test purposes
default_config = {'sparsity_level': 1.0, 'zeros_per_block': 2, 'sparse_block_shape': (1, 4)}
data_list = [('test_data', torch.randn(4, 4))]
default_config = {
"sparsity_level": 1.0,
"zeros_per_block": 2,
"sparse_block_shape": (1, 4),
}
data_list = [("test_data", torch.randn(4, 4))]
sparsifier = DataNormSparsifier(data_list=data_list, norm=norm_type, **default_config)
sparsifier = DataNormSparsifier(
data_list=data_list, norm=norm_type, **default_config
)
sparsifier.step()
for some_data in data_list:
@ -321,29 +397,39 @@ class _NormDataSparsifierTestCase(_BaseDataSparsiferTestCase):
self.assertAlmostEqual(1.0 - mask.mean().item(), 0.5, places=2)
for row in mask:
for idx in range(0, len(row), 4):
block = row[idx:idx + 4]
block = row[idx : idx + 4]
block, _ = block.sort()
assert (block[:2] == 0).all()
assert (block[2:] != 0).all()
def check_sparsity_level(self, data_list, data_with_config, defaults, norm_type='L1'):
def check_sparsity_level(
self, data_list, data_with_config, defaults, norm_type="L1"
):
sparsity_levels = [-1.0, 0.0, 0.5, 1.0, 2.0]
sparse_block_shapes = [(1, 1), (1, 4), (2, 2), (4, 1)]
zeros_per_blocks = [0, 1, 2, 3, 4]
sparsifier = DataNormSparsifier(data_list=data_list, norm=norm_type)
testcases = itertools.tee(itertools.product(sparsity_levels,
sparse_block_shapes,
zeros_per_blocks))
testcases = itertools.tee(
itertools.product(sparsity_levels, sparse_block_shapes, zeros_per_blocks)
)
assert len(data_with_config) > 0 and 'name' in data_with_config[0] and 'data' in data_with_config[0]
assert (
len(data_with_config) > 0
and "name" in data_with_config[0]
and "data" in data_with_config[0]
)
# get some data
name, data = data_with_config[0]['name'], data_with_config[0]['data']
name, data = data_with_config[0]["name"], data_with_config[0]["data"]
for idx, (sl, sbs, zpb) in enumerate(testcases[0]):
new_name = f'{name}_{idx}'
new_name = f"{name}_{idx}"
if zpb > sbs[0] * sbs[1]:
continue
current_config = {'sparsity_level': sl, 'sparse_block_shape': sbs, 'zeros_per_block': zpb}
current_config = {
"sparsity_level": sl,
"sparse_block_shape": sbs,
"zeros_per_block": zpb,
}
sparsifier.add_data(name=new_name, data=data, **current_config)
if zpb > sbs[0] * sbs[1]:
continue
@ -351,7 +437,7 @@ class _NormDataSparsifierTestCase(_BaseDataSparsiferTestCase):
sparsifier.step()
sparsifier.squash_mask()
for idx, (sl, sbs, zpb) in enumerate(testcases[0]):
new_name = f'{name}_{idx}'
new_name = f"{name}_{idx}"
sparsified_data = sparsifier.get_data(name=new_name, original=False)
# sparse mask
sparse_mask = (sparsified_data == 0).float()
@ -372,130 +458,232 @@ class TestBaseDataSparsifier(_BaseDataSparsiferTestCase):
Once the above is done, create an instance of TestBaseDataSparsifierType and call all the run_tests()
"""
def test_tensors(self):
tensor1, tensor2, tensor3 = torch.randn(3, 3), torch.randn(4, 4), torch.randn(5, 5)
tensor1, tensor2, tensor3 = (
torch.randn(3, 3),
torch.randn(4, 4),
torch.randn(5, 5),
)
tensor4, tensor5 = torch.randn(1, 1), torch.randn(4, 4)
data_list = [('tensor1', tensor1), ('tensor2', tensor2), ('tensor3', tensor3)]
defaults = {'test': 3}
data_list = [("tensor1", tensor1), ("tensor2", tensor2), ("tensor3", tensor3)]
defaults = {"test": 3}
data_with_config = [
{
'name': 'tensor4', 'data': tensor4, 'config': {'test': 7}
},
{
'name': 'tensor5', 'data': tensor5, 'config': {'test': 8}
},
{"name": "tensor4", "data": tensor4, "config": {"test": 7}},
{"name": "tensor5", "data": tensor5, "config": {"test": 8}},
]
self.run_all_checks(data_list=data_list, defaults=defaults, data_with_config=data_with_config)
self.run_all_checks(
data_list=data_list, defaults=defaults, data_with_config=data_with_config
)
def test_nn_parameters(self):
param1, param2, param3 = nn.Parameter(torch.randn(3, 3)), nn.Parameter(torch.randn(4, 4)), nn.Parameter(torch.randn(5, 5))
param4, param5 = nn.Parameter(torch.randn(1, 1)), nn.Parameter(torch.randn(4, 4))
data_list = [('param1', param1), ('param2', param2), ('param3', param3)]
defaults = {'test': 3}
param1, param2, param3 = (
nn.Parameter(torch.randn(3, 3)),
nn.Parameter(torch.randn(4, 4)),
nn.Parameter(torch.randn(5, 5)),
)
param4, param5 = nn.Parameter(torch.randn(1, 1)), nn.Parameter(
torch.randn(4, 4)
)
data_list = [("param1", param1), ("param2", param2), ("param3", param3)]
defaults = {"test": 3}
data_with_config = [
{
'name': 'param4', 'data': param4, 'config': {'test': 7}
},
{
'name': 'param5', 'data': param5, 'config': {'test': 8}
},
{"name": "param4", "data": param4, "config": {"test": 7}},
{"name": "param5", "data": param5, "config": {"test": 8}},
]
self.run_all_checks(data_list=data_list, defaults=defaults, data_with_config=data_with_config)
self.run_all_checks(
data_list=data_list, defaults=defaults, data_with_config=data_with_config
)
def test_nn_embeddings(self):
emb1, emb2, = nn.Embedding(10, 3), nn.Embedding(20, 3)
(
emb1,
emb2,
) = nn.Embedding(
10, 3
), nn.Embedding(20, 3)
emb1_bag, emb2_bag = nn.EmbeddingBag(10, 3), nn.EmbeddingBag(20, 3)
emb3, emb3_bag = nn.Embedding(15, 3), nn.EmbeddingBag(20, 3)
data_list = [('emb1', emb1), ('emb1_bag', emb1_bag), ('emb2', emb2), ('emb2_bag', emb2_bag)]
defaults = {'test': 3}
data_list = [
("emb1", emb1),
("emb1_bag", emb1_bag),
("emb2", emb2),
("emb2_bag", emb2_bag),
]
defaults = {"test": 3}
data_with_config = [
{
'name': 'emb3', 'data': emb3, 'config': {'test': 7}
},
{
'name': 'emb3_bag', 'data': emb3_bag, 'config': {'test': 8}
},
{"name": "emb3", "data": emb3, "config": {"test": 7}},
{"name": "emb3_bag", "data": emb3_bag, "config": {"test": 8}},
]
self.run_all_checks(data_list=data_list, defaults=defaults, data_with_config=data_with_config)
self.run_all_checks(
data_list=data_list, defaults=defaults, data_with_config=data_with_config
)
class TestNormDataSparsifiers(_NormDataSparsifierTestCase):
"""To add unit tests to support new data types for the NormDataSparsifier, create the following
data_list: List of tuples of name, data to be added to the constructor
defaults: default config for the above data in data_list
data_with_config: list of dictionaries defining name, data and config (look test_tensors())
data_list: List of tuples of name, data to be added to the constructor
defaults: default config for the above data in data_list
data_with_config: list of dictionaries defining name, data and config (look test_tensors())
Once the above is done, create an instance of _NormDataSparsifierTestRunner and call run_tests()
Once the above is done, create an instance of _NormDataSparsifierTestRunner and call run_tests()
"""
def test_tensors(self):
tensor1, tensor2, tensor3 = torch.randn(1, 10), torch.randn(4, 4), torch.randn(1, 5)
tensor1, tensor2, tensor3 = (
torch.randn(1, 10),
torch.randn(4, 4),
torch.randn(1, 5),
)
tensor4, tensor5 = torch.randn(1, 2), torch.randn(4, 4)
data_list = [('tensor1', tensor1), ('tensor2', tensor2), ('tensor3', tensor3)]
defaults = {'sparsity_level': 0.5, 'sparse_block_shape': (1, 4), 'zeros_per_block': 4}
data_list = [("tensor1", tensor1), ("tensor2", tensor2), ("tensor3", tensor3)]
defaults = {
"sparsity_level": 0.5,
"sparse_block_shape": (1, 4),
"zeros_per_block": 4,
}
data_with_config = [
{
'name': 'tensor4', 'data': tensor4,
'config': {'sparsity_level': 0.7, 'sparse_block_shape': (2, 3), 'zeros_per_block': 6}
"name": "tensor4",
"data": tensor4,
"config": {
"sparsity_level": 0.7,
"sparse_block_shape": (2, 3),
"zeros_per_block": 6,
},
},
{
'name': 'tensor5', 'data': tensor5,
'config': {'sparsity_level': 0.3, 'sparse_block_shape': (2, 3), 'zeros_per_block': 6}
"name": "tensor5",
"data": tensor5,
"config": {
"sparsity_level": 0.3,
"sparse_block_shape": (2, 3),
"zeros_per_block": 6,
},
},
]
self.run_all_checks(data_list=data_list, defaults=defaults,
data_with_config=data_with_config, norm_type='L1')
self.run_all_checks(data_list=data_list, defaults=defaults,
data_with_config=data_with_config, norm_type='L2')
self.run_all_checks(
data_list=data_list,
defaults=defaults,
data_with_config=data_with_config,
norm_type="L1",
)
self.run_all_checks(
data_list=data_list,
defaults=defaults,
data_with_config=data_with_config,
norm_type="L2",
)
def test_nn_parameters(self):
param1, param2, param3 = nn.Parameter(torch.randn(1, 8)), nn.Parameter(torch.randn(4, 4)), nn.Parameter(torch.randn(5, 5))
param4, param5 = nn.Parameter(torch.randn(10, 10)), nn.Parameter(torch.randn(4, 4))
data_list = [('param1', param1), ('param2', param2), ('param3', param3)]
defaults = {'sparsity_level': 0.5, 'sparse_block_shape': (1, 4), 'zeros_per_block': 4}
param1, param2, param3 = (
nn.Parameter(torch.randn(1, 8)),
nn.Parameter(torch.randn(4, 4)),
nn.Parameter(torch.randn(5, 5)),
)
param4, param5 = nn.Parameter(torch.randn(10, 10)), nn.Parameter(
torch.randn(4, 4)
)
data_list = [("param1", param1), ("param2", param2), ("param3", param3)]
defaults = {
"sparsity_level": 0.5,
"sparse_block_shape": (1, 4),
"zeros_per_block": 4,
}
data_with_config = [
{
'name': 'param4', 'data': param4,
'config': {'sparsity_level': 0.7, 'sparse_block_shape': (2, 3), 'zeros_per_block': 6}
"name": "param4",
"data": param4,
"config": {
"sparsity_level": 0.7,
"sparse_block_shape": (2, 3),
"zeros_per_block": 6,
},
},
{
'name': 'param5', 'data': param5,
'config': {'sparsity_level': 0.3, 'sparse_block_shape': (2, 3), 'zeros_per_block': 6}
"name": "param5",
"data": param5,
"config": {
"sparsity_level": 0.3,
"sparse_block_shape": (2, 3),
"zeros_per_block": 6,
},
},
]
self.run_all_checks(data_list=data_list, defaults=defaults,
data_with_config=data_with_config, norm_type='L1')
self.run_all_checks(data_list=data_list, defaults=defaults,
data_with_config=data_with_config, norm_type='L2')
self.run_all_checks(
data_list=data_list,
defaults=defaults,
data_with_config=data_with_config,
norm_type="L1",
)
self.run_all_checks(
data_list=data_list,
defaults=defaults,
data_with_config=data_with_config,
norm_type="L2",
)
def test_nn_embeddings(self):
emb1, emb2, = nn.Embedding(10, 3), nn.Embedding(20, 3)
(
emb1,
emb2,
) = nn.Embedding(
10, 3
), nn.Embedding(20, 3)
emb1_bag, emb2_bag = nn.EmbeddingBag(10, 3), nn.EmbeddingBag(20, 3)
emb3, emb3_bag = nn.Embedding(15, 3), nn.EmbeddingBag(20, 3)
data_list = [('emb1', emb1), ('emb1_bag', emb1_bag), ('emb2', emb2), ('emb2_bag', emb2_bag)]
defaults = {'sparsity_level': 0.5, 'sparse_block_shape': (1, 4), 'zeros_per_block': 4}
data_list = [
("emb1", emb1),
("emb1_bag", emb1_bag),
("emb2", emb2),
("emb2_bag", emb2_bag),
]
defaults = {
"sparsity_level": 0.5,
"sparse_block_shape": (1, 4),
"zeros_per_block": 4,
}
data_with_config = [
{
'name': 'emb3', 'data': emb3,
'config': {'sparsity_level': 0.7, 'sparse_block_shape': (2, 3), 'zeros_per_block': 6}
"name": "emb3",
"data": emb3,
"config": {
"sparsity_level": 0.7,
"sparse_block_shape": (2, 3),
"zeros_per_block": 6,
},
},
{
'name': 'emb3_bag', 'data': emb3_bag,
'config': {'sparsity_level': 0.3, 'sparse_block_shape': (2, 3), 'zeros_per_block': 6}
"name": "emb3_bag",
"data": emb3_bag,
"config": {
"sparsity_level": 0.3,
"sparse_block_shape": (2, 3),
"zeros_per_block": 6,
},
},
]
self.run_all_checks(data_list=data_list, defaults=defaults,
data_with_config=data_with_config, norm_type='L1')
self.run_all_checks(
data_list=data_list,
defaults=defaults,
data_with_config=data_with_config,
norm_type="L1",
)
self.run_all_checks(data_list=data_list, defaults=defaults,
data_with_config=data_with_config, norm_type='L2')
self.run_all_checks(
data_list=data_list,
defaults=defaults,
data_with_config=data_with_config,
norm_type="L2",
)
class Model(nn.Module):
@ -522,16 +710,21 @@ class TestQuantizationUtils(TestCase):
"""
model = Model()
sparse_config = {'sparsity_level': 0.80, 'sparse_block_shape': (1, 1)}
sparse_config = {"sparsity_level": 0.80, "sparse_block_shape": (1, 1)}
select_embeddings = [model.embbag1, model.emb1]
post_training_sparse_quantize(model,
data_sparsifier_class=DataNormSparsifier,
sparsify_first=True,
select_embeddings=select_embeddings,
**sparse_config)
post_training_sparse_quantize(
model,
data_sparsifier_class=DataNormSparsifier,
sparsify_first=True,
select_embeddings=select_embeddings,
**sparse_config,
)
assert type(model.emb1) == torch.ao.nn.quantized.modules.embedding_ops.Embedding
assert type(model.embbag1) == torch.ao.nn.quantized.modules.embedding_ops.EmbeddingBag
assert (
type(model.embbag1)
== torch.ao.nn.quantized.modules.embedding_ops.EmbeddingBag
)
assert type(model.emb_seq[0] == nn.Embedding)
assert type(model.emb_seq[1] == nn.EmbeddingBag)
assert type(model.linear1) == nn.Linear
@ -561,24 +754,34 @@ class TestQuantizationUtils(TestCase):
"""
model = Model()
sparse_config = {'sparsity_level': 0.8, 'sparse_block_shape': (1, 1)}
post_training_sparse_quantize(model, DataNormSparsifier, sparsify_first=False, **sparse_config)
sparse_config = {"sparsity_level": 0.8, "sparse_block_shape": (1, 1)}
post_training_sparse_quantize(
model, DataNormSparsifier, sparsify_first=False, **sparse_config
)
assert type(model.emb1) == torch.ao.nn.quantized.modules.embedding_ops.Embedding
assert type(model.embbag1) == torch.ao.nn.quantized.modules.embedding_ops.EmbeddingBag
assert type(model.emb_seq[0] == torch.ao.nn.quantized.modules.embedding_ops.Embedding)
assert type(model.emb_seq[1] == torch.ao.nn.quantized.modules.embedding_ops.EmbeddingBag)
assert (
type(model.embbag1)
== torch.ao.nn.quantized.modules.embedding_ops.EmbeddingBag
)
assert type(
model.emb_seq[0] == torch.ao.nn.quantized.modules.embedding_ops.Embedding
)
assert type(
model.emb_seq[1] == torch.ao.nn.quantized.modules.embedding_ops.EmbeddingBag
)
assert type(model.linear1) == nn.Linear # not quantized
assert type(model.linear2) == nn.Linear # not quantized
dequant_emb1 = torch.dequantize(model.emb1.weight())
dequant_embbag1 = torch.dequantize(model.embbag1.weight())
dequant_emb_seq_0 = torch.dequantize(model.emb_seq[0].weight())
dequant_emb_seq_1 = torch.dequantize(model.emb_seq[1].weight())
# higher threshold as quantization occurs before sparsity
threshold = 1 # zero points seem to have higher magnitude with sparsity occuring after
threshold = (
1 # zero points seem to have higher magnitude with sparsity occuring after
)
sl_emb1 = (torch.abs(dequant_emb1) < threshold).float().mean()
sl_embbag1 = (torch.abs(dequant_embbag1) < threshold).float().mean()

View File

@ -1,32 +1,34 @@
# Owner(s): ["module: unknown"]
from torch.testing._internal.common_utils import run_tests
import copy
import numpy as np
import io
import logging
from itertools import product
import numpy as np
import torch
import torch.ao.quantization as tq
from torch import nn
from torch.ao.pruning.sparsifier.utils import fqn_to_module
from torch.testing._internal.common_utils import TestCase, skipIfTorchDynamo
from torch.testing._internal.common_quantized import (
override_cpu_allocator_for_qnnpack,
override_qengines,
qengine_is_qnnpack,
qengine_is_fbgemm,
qengine_is_onednn,
qengine_is_qnnpack,
qengine_is_x86,
)
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
# TODO: Once more test files are created, move the contents to a ao folder.
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO)
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
)
class TestQuantizedSparseKernels(TestCase):
@skipIfTorchDynamo("TorchDynamo fails here for unknown reasons")
@ -84,29 +86,42 @@ class TestQuantizedSparseKernels(TestCase):
continue
if use_channelwise:
W_q = torch.quantize_per_channel(
W_fp32, scales=W_scales, zero_points=W_zps, axis=0, dtype=torch.qint8
W_fp32,
scales=W_scales,
zero_points=W_zps,
axis=0,
dtype=torch.qint8,
)
else:
W_q = torch.quantize_per_tensor(
W_fp32, scale=W_scales[0], zero_point=W_zps[0], dtype=torch.qint8
W_fp32,
scale=W_scales[0],
zero_point=W_zps[0],
dtype=torch.qint8,
)
Y_scale = 1.1234
Y_zp = 5
W_prepack_dense = dense_prepack(W_q, float_bias)
W_prepack_sparse = sparse_prepack(W_q, float_bias, row_block_size, col_block_size)
W_prepack_sparse = sparse_prepack(
W_q, float_bias, row_block_size, col_block_size
)
if dynamic_mode:
Y = sparse_qlinear_dynamic(X_fp32, W_prepack_sparse)
Y_ref = dense_qlinear_dynamic(X_fp32, W_prepack_dense)
np.testing.assert_array_almost_equal(Y_ref.numpy(), Y.numpy(), decimal=decimal_val)
np.testing.assert_array_almost_equal(
Y_ref.numpy(), Y.numpy(), decimal=decimal_val
)
else:
Y_q = sparse_qlinear(X_q, W_prepack_sparse, Y_scale, Y_zp)
Y_q_ref = dense_qlinear(X_q, W_prepack_dense, Y_scale, Y_zp)
np.testing.assert_array_almost_equal(
Y_q_ref.int_repr().numpy(), Y_q.int_repr().numpy(), decimal=decimal_val
Y_q_ref.int_repr().numpy(),
Y_q.int_repr().numpy(),
decimal=decimal_val,
)
@ -235,6 +250,7 @@ def _sparse_layer_test_helper(
Y_hat = sqmodel(X_q)
test_class.assertEqual(Y_ref.dequantize(), Y_hat.dequantize())
class SparseQuantizedModel(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
@ -243,6 +259,7 @@ class SparseQuantizedModel(nn.Module):
def forward(self, x):
return self.linear(x)
class TestQuantizedSparseLayers(TestCase):
@override_qengines
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")

View File

@ -3,22 +3,24 @@
import logging
import torch
from torch import nn
from torch.ao.pruning.sparsifier import utils
from torch.nn.utils import parametrize
import torch
from torch.testing._internal.common_utils import TestCase
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO)
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
)
class ModelUnderTest(nn.Module):
def __init__(self, bias=True):
super().__init__()
self.linear = nn.Linear(16, 16, bias=bias)
self.seq = nn.Sequential(
nn.Linear(16, 16, bias=bias),
nn.Linear(16, 16, bias=bias)
nn.Linear(16, 16, bias=bias), nn.Linear(16, 16, bias=bias)
)
# Make sure the weights are not random
@ -45,7 +47,7 @@ class TestFakeSparsity(TestCase):
mask = torch.zeros(16, 16)
sparsity = utils.FakeSparsity(mask)
parametrize.register_parametrization(model, 'weight', sparsity)
parametrize.register_parametrization(model, "weight", sparsity)
x = torch.randn(3, 16)
self.assertEqual(torch.zeros(3, 16), model(x))
@ -53,95 +55,119 @@ class TestFakeSparsity(TestCase):
def test_weights_parametrized(self):
model = ModelUnderTest(bias=False)
assert not hasattr(model.linear, 'parametrizations')
assert not hasattr(model.seq[0], 'parametrizations')
assert not hasattr(model.seq[1], 'parametrizations')
assert not hasattr(model.linear, "parametrizations")
assert not hasattr(model.seq[0], "parametrizations")
assert not hasattr(model.seq[1], "parametrizations")
mask = torch.eye(16)
parametrize.register_parametrization(model.linear, 'weight',
utils.FakeSparsity(mask))
parametrize.register_parametrization(
model.linear, "weight", utils.FakeSparsity(mask)
)
mask = torch.eye(16)
parametrize.register_parametrization(model.seq[0], 'weight',
utils.FakeSparsity(mask))
parametrize.register_parametrization(
model.seq[0], "weight", utils.FakeSparsity(mask)
)
mask = torch.eye(16)
parametrize.register_parametrization(model.seq[1], 'weight',
utils.FakeSparsity(mask))
parametrize.register_parametrization(
model.seq[1], "weight", utils.FakeSparsity(mask)
)
assert hasattr(model.linear, 'parametrizations')
assert parametrize.is_parametrized(model.linear, 'weight')
assert hasattr(model.seq[0], 'parametrizations')
assert parametrize.is_parametrized(model.linear, 'weight')
assert hasattr(model.seq[1], 'parametrizations')
assert parametrize.is_parametrized(model.linear, 'weight')
assert hasattr(model.linear, "parametrizations")
assert parametrize.is_parametrized(model.linear, "weight")
assert hasattr(model.seq[0], "parametrizations")
assert parametrize.is_parametrized(model.linear, "weight")
assert hasattr(model.seq[1], "parametrizations")
assert parametrize.is_parametrized(model.linear, "weight")
def test_state_dict_preserved(self):
model_save = ModelUnderTest(bias=False)
mask = torch.eye(16)
parametrize.register_parametrization(model_save.linear, 'weight',
utils.FakeSparsity(mask))
parametrize.register_parametrization(
model_save.linear, "weight", utils.FakeSparsity(mask)
)
mask = torch.eye(16)
parametrize.register_parametrization(model_save.seq[0], 'weight',
utils.FakeSparsity(mask))
parametrize.register_parametrization(
model_save.seq[0], "weight", utils.FakeSparsity(mask)
)
mask = torch.eye(16)
parametrize.register_parametrization(model_save.seq[1], 'weight',
utils.FakeSparsity(mask))
parametrize.register_parametrization(
model_save.seq[1], "weight", utils.FakeSparsity(mask)
)
state_dict = model_save.state_dict()
model_load = ModelUnderTest(bias=False)
mask = torch.zeros(model_load.linear.weight.shape)
parametrize.register_parametrization(model_load.linear, 'weight',
utils.FakeSparsity(mask))
parametrize.register_parametrization(
model_load.linear, "weight", utils.FakeSparsity(mask)
)
mask = torch.zeros(model_load.seq[0].weight.shape)
parametrize.register_parametrization(model_load.seq[0], 'weight',
utils.FakeSparsity(mask))
parametrize.register_parametrization(
model_load.seq[0], "weight", utils.FakeSparsity(mask)
)
mask = torch.zeros(model_load.seq[1].weight.shape)
parametrize.register_parametrization(model_load.seq[1], 'weight',
utils.FakeSparsity(mask))
parametrize.register_parametrization(
model_load.seq[1], "weight", utils.FakeSparsity(mask)
)
# Keep this strict, as we are not loading the 'mask'
model_load.load_state_dict(state_dict, strict=False)
# Check the parametrizations are preserved
assert hasattr(model_load.linear, 'parametrizations')
assert parametrize.is_parametrized(model_load.linear, 'weight')
assert hasattr(model_load.seq[0], 'parametrizations')
assert parametrize.is_parametrized(model_load.linear, 'weight')
assert hasattr(model_load.seq[1], 'parametrizations')
assert parametrize.is_parametrized(model_load.linear, 'weight')
assert hasattr(model_load.linear, "parametrizations")
assert parametrize.is_parametrized(model_load.linear, "weight")
assert hasattr(model_load.seq[0], "parametrizations")
assert parametrize.is_parametrized(model_load.linear, "weight")
assert hasattr(model_load.seq[1], "parametrizations")
assert parametrize.is_parametrized(model_load.linear, "weight")
# Check the weights are preserved
self.assertEqual(model_save.linear.parametrizations['weight'].original,
model_load.linear.parametrizations['weight'].original)
self.assertEqual(model_save.seq[0].parametrizations['weight'].original,
model_load.seq[0].parametrizations['weight'].original)
self.assertEqual(model_save.seq[1].parametrizations['weight'].original,
model_load.seq[1].parametrizations['weight'].original)
self.assertEqual(
model_save.linear.parametrizations["weight"].original,
model_load.linear.parametrizations["weight"].original,
)
self.assertEqual(
model_save.seq[0].parametrizations["weight"].original,
model_load.seq[0].parametrizations["weight"].original,
)
self.assertEqual(
model_save.seq[1].parametrizations["weight"].original,
model_load.seq[1].parametrizations["weight"].original,
)
# Check the masks are not preserved in the state_dict
# We store the state_dicts in the sparsifier, not in the model itself.
# TODO: Need to find a clean way of exporting the parametrized model
self.assertNotEqual(model_save.linear.parametrizations['weight'][0].mask,
model_load.linear.parametrizations['weight'][0].mask)
self.assertNotEqual(model_save.seq[0].parametrizations['weight'][0].mask,
model_load.seq[0].parametrizations['weight'][0].mask)
self.assertNotEqual(model_save.seq[1].parametrizations['weight'][0].mask,
model_load.seq[1].parametrizations['weight'][0].mask)
self.assertNotEqual(
model_save.linear.parametrizations["weight"][0].mask,
model_load.linear.parametrizations["weight"][0].mask,
)
self.assertNotEqual(
model_save.seq[0].parametrizations["weight"][0].mask,
model_load.seq[0].parametrizations["weight"][0].mask,
)
self.assertNotEqual(
model_save.seq[1].parametrizations["weight"][0].mask,
model_load.seq[1].parametrizations["weight"][0].mask,
)
def test_jit_trace(self):
model = ModelUnderTest(bias=False)
mask = torch.eye(16)
parametrize.register_parametrization(model.linear, 'weight',
utils.FakeSparsity(mask))
parametrize.register_parametrization(
model.linear, "weight", utils.FakeSparsity(mask)
)
mask = torch.eye(16)
parametrize.register_parametrization(model.seq[0], 'weight',
utils.FakeSparsity(mask))
parametrize.register_parametrization(
model.seq[0], "weight", utils.FakeSparsity(mask)
)
mask = torch.eye(16)
parametrize.register_parametrization(model.seq[1], 'weight',
utils.FakeSparsity(mask))
parametrize.register_parametrization(
model.seq[1], "weight", utils.FakeSparsity(mask)
)
# Tracing
example_x = torch.ones(3, 16)
model_trace = torch.jit.trace_module(model, {'forward': example_x})
model_trace = torch.jit.trace_module(model, {"forward": example_x})
x = torch.randn(3, 16)
y = model(x)

View File

@ -2,29 +2,30 @@
# Owner(s): ["oncall: mobile"]
import tempfile
import torch
from torch.ao.nn.sparse.quantized.dynamic.linear import Linear
from torch.testing._internal.common_quantization import (
skipIfNoFBGEMM,
skipIfNoQNNPACK,
)
from torch.testing._internal.common_quantization import skipIfNoFBGEMM, skipIfNoQNNPACK
from torch.testing._internal.common_quantized import (
qengine_is_qnnpack,
override_cpu_allocator_for_qnnpack,
override_quantized_engine,
override_cpu_allocator_for_qnnpack
qengine_is_qnnpack,
)
from torch.testing._internal.common_utils import TestCase
class TestQlinearPackedParams(TestCase):
def qlinear_packed_params_test(self, allow_non_zero_zero_points=False):
# copied from https://pytorch.org/docs/stable/sparse.html#csr-tensor-operations,
# so row/col block indices match that example, but with blocks and
# scaled rows
weight_fp32 = torch.Tensor([
[0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 0, 0, 0, 0],
[6, 6, 6, 6, 12, 12, 12, 12, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
])
weight_fp32 = torch.Tensor(
[
[0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 0, 0, 0, 0],
[6, 6, 6, 6, 12, 12, 12, 12, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
]
)
row_block_size = 1
col_block_size = 4
@ -43,18 +44,14 @@ class TestQlinearPackedParams(TestCase):
wide_weight_fp32[1][0] = 8
per_tensor_small = (
torch.quantize_per_tensor(
weight_fp32,
scales[0],
zero_points[0],
dtype
),
torch.quantize_per_tensor(weight_fp32, scales[0], zero_points[0], dtype),
True,
[0, 1, 3, 3],
[2, 0, 1],
[x + (1 if allow_non_zero_zero_points else 0) for x in [
1, 1, 1, 1, 3, 3, 3, 3, 6, 6, 6, 6
]],
[
x + (1 if allow_non_zero_zero_points else 0)
for x in [1, 1, 1, 1, 3, 3, 3, 3, 6, 6, 6, 6]
],
)
per_channel_small = (
@ -68,9 +65,10 @@ class TestQlinearPackedParams(TestCase):
False,
[0, 1, 3, 3],
[2, 0, 1],
[x + ([1, 2, 2][i // 4] if allow_non_zero_zero_points else 0) for (i, x) in enumerate([
1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2
])],
[
x + ([1, 2, 2][i // 4] if allow_non_zero_zero_points else 0)
for (i, x) in enumerate([1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2])
],
)
per_tensor_large = (
@ -83,14 +81,19 @@ class TestQlinearPackedParams(TestCase):
True,
[0, 2, 3, 3],
[0, 1001, 0],
[x + (1 if allow_non_zero_zero_points else 0) for x in [
2, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0
]],
[
x + (1 if allow_non_zero_zero_points else 0)
for x in [2, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0]
],
)
for (weight, is_per_tensor_quantized, expected_row_block_indices, expected_col_block_indices, expected_weights) in [
per_tensor_small, per_channel_small, per_tensor_large
]:
for (
weight,
is_per_tensor_quantized,
expected_row_block_indices,
expected_col_block_indices,
expected_weights,
) in [per_tensor_small, per_channel_small, per_tensor_large]:
lin = Linear(
out_features=weight.shape[0],
in_features=weight.shape[1],
@ -118,24 +121,36 @@ class TestQlinearPackedParams(TestCase):
col_block_indices_,
weights_,
output_channels_,
input_channels_
input_channels_,
) = serialized[0]
# Test Serialization
self.assertEqual(bias_, bias)
self.assertEqual(out_features_block_size_, row_block_size)
self.assertEqual(in_features_block_size_, col_block_size)
self.assertEqual(weight_scales_, [scales[0]] if is_per_tensor_quantized else scales)
self.assertEqual(weight_zero_points_, [zero_points[0]] if is_per_tensor_quantized else zero_points)
self.assertEqual(
weight_scales_, [scales[0]] if is_per_tensor_quantized else scales
)
self.assertEqual(
weight_zero_points_,
[zero_points[0]] if is_per_tensor_quantized else zero_points,
)
self.assertEqual(quantization_scheme_, is_per_tensor_quantized)
self.assertEqual(row_block_indices_, expected_row_block_indices)
self.assertEqual(col_block_indices_, expected_col_block_indices)
self.assertEqual(weights_.tolist(), [v + 128 for v in expected_weights]) # weights are serialized as +128
self.assertEqual(
weights_.tolist(), [v + 128 for v in expected_weights]
) # weights are serialized as +128
self.assertEqual(output_channels_, weight.shape[0])
self.assertEqual(input_channels_, weight.shape[1])
# Test Unpacking
(weights_, bias_, out_features_block_size_, in_features_block_size_) = lin._weight_bias()
(
weights_,
bias_,
out_features_block_size_,
in_features_block_size_,
) = lin._weight_bias()
self.assertEqual(torch.dequantize(weights_), torch.dequantize(weight))
self.assertEqual(bias_, bias)
self.assertEqual(out_features_block_size_, row_block_size)
@ -148,7 +163,9 @@ class TestQlinearPackedParams(TestCase):
lin2 = torch.load(file_buff)
self.assertEqual(lin._weight_bias(), lin2._weight_bias())
# Serialize -> Deserialize -> Serialize should match Serialize
self.assertEqual(serialized, lin2._packed_params._packed_params.__getstate__())
self.assertEqual(
serialized, lin2._packed_params._packed_params.__getstate__()
)
# Test that op output is preserved by serialize -> deserialize
if qengine_is_qnnpack():
@ -157,29 +174,29 @@ class TestQlinearPackedParams(TestCase):
y2 = lin2(x)
self.assertEqual(y1, y2)
@skipIfNoFBGEMM
def test_qlinear_packed_params_fbgemm(self):
torch.manual_seed(0)
with override_quantized_engine('fbgemm'):
with override_quantized_engine("fbgemm"):
self.qlinear_packed_params_test(allow_non_zero_zero_points=False)
@skipIfNoQNNPACK
def test_qlinear_packed_params_qnnpack(self):
torch.manual_seed(0)
with override_quantized_engine('qnnpack'):
with override_quantized_engine("qnnpack"):
with override_cpu_allocator_for_qnnpack(qengine_is_qnnpack()):
self.qlinear_packed_params_test(allow_non_zero_zero_points=True)
def test_qlinear_packed_params_fbgemm_qnnpack_cross_compatibility(self):
torch.manual_seed(0)
weight_fp32 = torch.Tensor([
[0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 0, 0, 0, 0],
[6, 6, 6, 6, 12, 12, 12, 12, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
])
weight_fp32 = torch.Tensor(
[
[0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 0, 0, 0, 0],
[6, 6, 6, 6, 12, 12, 12, 12, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
]
)
row_block_size = 1
col_block_size = 4
@ -229,17 +246,23 @@ class TestQlinearPackedParams(TestCase):
def packed_params_data_with_int32_indices(data_as_state_and_weight_bias):
(st, weight_bias) = data_as_state_and_weight_bias
(s0, s1) = st
s0_updated = tuple([
# 7 and 8 are row and col block indices respectively
v if (i != 7 and i != 8) else v.to(torch.int32) for (i, v) in enumerate(list(s0))
])
s0_updated = tuple(
[
# 7 and 8 are row and col block indices respectively
v if (i != 7 and i != 8) else v.to(torch.int32)
for (i, v) in enumerate(list(s0))
]
)
return ((s0_updated, s1), weight_bias)
# Test Fbgemm -> Qnnpack
with override_quantized_engine('fbgemm'):
packed_params_data_1a, file_buff_1 = make_lin_get_state_weight_bias_and_save()
with override_quantized_engine("fbgemm"):
(
packed_params_data_1a,
file_buff_1,
) = make_lin_get_state_weight_bias_and_save()
with override_quantized_engine('qnnpack'):
with override_quantized_engine("qnnpack"):
with override_cpu_allocator_for_qnnpack(qengine_is_qnnpack()):
packed_params_data_1b = load_get_state_weight_bias(file_buff_1)
@ -249,11 +272,14 @@ class TestQlinearPackedParams(TestCase):
)
# Test Qnnpack -> Fbgemm
with override_quantized_engine('qnnpack'):
with override_quantized_engine("qnnpack"):
with override_cpu_allocator_for_qnnpack(qengine_is_qnnpack()):
packed_params_data_2a, file_buff_2 = make_lin_get_state_weight_bias_and_save()
(
packed_params_data_2a,
file_buff_2,
) = make_lin_get_state_weight_bias_and_save()
with override_quantized_engine('fbgemm'):
with override_quantized_engine("fbgemm"):
packed_params_data_2b = load_get_state_weight_bias(file_buff_2)
self.assertEqual(

View File

@ -1,43 +1,38 @@
# Owner(s): ["module: unknown"]
import warnings
from torch import nn
from torch.ao.pruning import WeightNormSparsifier
from torch.ao.pruning import BaseScheduler, LambdaSL, CubicSL
from torch.ao.pruning import BaseScheduler, CubicSL, LambdaSL, WeightNormSparsifier
from torch.testing._internal.common_utils import TestCase
import warnings
class ImplementedScheduler(BaseScheduler):
def get_sl(self):
if self.last_epoch > 0:
return [group['sparsity_level'] * 0.5
for group in self.sparsifier.groups]
return [group["sparsity_level"] * 0.5 for group in self.sparsifier.groups]
else:
return list(self.base_sl)
class TestScheduler(TestCase):
def test_constructor(self):
model = nn.Sequential(
nn.Linear(16, 16)
)
model = nn.Sequential(nn.Linear(16, 16))
sparsifier = WeightNormSparsifier()
sparsifier.prepare(model, config=None)
scheduler = ImplementedScheduler(sparsifier)
assert scheduler.sparsifier is sparsifier
assert scheduler._step_count == 1
assert scheduler.base_sl == [sparsifier.groups[0]['sparsity_level']]
assert scheduler.base_sl == [sparsifier.groups[0]["sparsity_level"]]
def test_order_of_steps(self):
"""Checks if the warning is thrown if the scheduler step is called
before the sparsifier step"""
model = nn.Sequential(
nn.Linear(16, 16)
)
model = nn.Sequential(nn.Linear(16, 16))
sparsifier = WeightNormSparsifier()
sparsifier.prepare(model, config=None)
scheduler = ImplementedScheduler(sparsifier)
@ -54,43 +49,41 @@ class TestScheduler(TestCase):
# Make sure there is no warning related to the base_scheduler
for warning in w:
fname = warning.filename
fname = '/'.join(fname.split('/')[-5:])
assert fname != 'torch/ao/sparsity/scheduler/base_scheduler.py'
fname = "/".join(fname.split("/")[-5:])
assert fname != "torch/ao/sparsity/scheduler/base_scheduler.py"
def test_step(self):
model = nn.Sequential(
nn.Linear(16, 16)
)
model = nn.Sequential(nn.Linear(16, 16))
sparsifier = WeightNormSparsifier()
sparsifier.prepare(model, config=None)
assert sparsifier.groups[0]['sparsity_level'] == 0.5
assert sparsifier.groups[0]["sparsity_level"] == 0.5
scheduler = ImplementedScheduler(sparsifier)
assert sparsifier.groups[0]['sparsity_level'] == 0.5
assert sparsifier.groups[0]["sparsity_level"] == 0.5
sparsifier.step()
scheduler.step()
assert sparsifier.groups[0]['sparsity_level'] == 0.25
assert sparsifier.groups[0]["sparsity_level"] == 0.25
def test_lambda_scheduler(self):
model = nn.Sequential(
nn.Linear(16, 16)
)
model = nn.Sequential(nn.Linear(16, 16))
sparsifier = WeightNormSparsifier()
sparsifier.prepare(model, config=None)
assert sparsifier.groups[0]['sparsity_level'] == 0.5
assert sparsifier.groups[0]["sparsity_level"] == 0.5
scheduler = LambdaSL(sparsifier, lambda epoch: epoch * 10)
assert sparsifier.groups[0]['sparsity_level'] == 0.0 # Epoch 0
assert sparsifier.groups[0]["sparsity_level"] == 0.0 # Epoch 0
scheduler.step()
assert sparsifier.groups[0]['sparsity_level'] == 5.0 # Epoch 1
assert sparsifier.groups[0]["sparsity_level"] == 5.0 # Epoch 1
class TestCubicScheduler(TestCase):
def setUp(self):
self.model_sparse_config = [
{'tensor_fqn': '0.weight', 'sparsity_level': 0.8},
{'tensor_fqn': '2.weight', 'sparsity_level': 0.4},
{"tensor_fqn": "0.weight", "sparsity_level": 0.8},
{"tensor_fqn": "2.weight", "sparsity_level": 0.4},
]
self.sorted_sparse_levels = [
conf["sparsity_level"] for conf in self.model_sparse_config
]
self.sorted_sparse_levels = [conf['sparsity_level'] for conf in self.model_sparse_config]
self.initial_sparsity = 0.1
self.initial_step = 3
@ -107,8 +100,8 @@ class TestCubicScheduler(TestCase):
sparsifier.prepare(model, config=self.model_sparse_config)
scheduler_args = {
'init_sl': self.initial_sparsity,
'init_t': self.initial_step,
"init_sl": self.initial_sparsity,
"init_t": self.initial_step,
}
scheduler_args.update(kwargs)
@ -118,25 +111,33 @@ class TestCubicScheduler(TestCase):
@staticmethod
def _get_sparsity_levels(sparsifier, precision=32):
r"""Gets the current levels of sparsity in a sparsifier."""
return [round(group['sparsity_level'], precision) for group in sparsifier.groups]
return [
round(group["sparsity_level"], precision) for group in sparsifier.groups
]
def test_constructor(self):
model = self._make_model()
sparsifier, scheduler = self._make_scheduler(model=model, initially_zero=True)
self.assertIs(
scheduler.sparsifier, sparsifier,
msg="Sparsifier is not properly attached")
scheduler.sparsifier, sparsifier, msg="Sparsifier is not properly attached"
)
self.assertEqual(
scheduler._step_count, 1,
msg="Scheduler is initialized with incorrect step count")
scheduler._step_count,
1,
msg="Scheduler is initialized with incorrect step count",
)
self.assertEqual(
scheduler.base_sl, self.sorted_sparse_levels,
msg="Scheduler did not store the target sparsity levels correctly")
scheduler.base_sl,
self.sorted_sparse_levels,
msg="Scheduler did not store the target sparsity levels correctly",
)
# Value before t_0 is 0
self.assertEqual(
self._get_sparsity_levels(sparsifier), scheduler._make_sure_a_list(0.0),
msg="Sparsifier is not reset correctly after attaching to the Scheduler")
self._get_sparsity_levels(sparsifier),
scheduler._make_sure_a_list(0.0),
msg="Sparsifier is not reset correctly after attaching to the Scheduler",
)
# Value before t_0 is s_0
model = self._make_model()
@ -144,36 +145,50 @@ class TestCubicScheduler(TestCase):
self.assertEqual(
self._get_sparsity_levels(sparsifier),
scheduler._make_sure_a_list(self.initial_sparsity),
msg="Sparsifier is not reset correctly after attaching to the Scheduler")
msg="Sparsifier is not reset correctly after attaching to the Scheduler",
)
def test_step(self):
# For n=5, dt=2, there will be totally 10 steps between s_0 and s_f, starting from t_0
model = self._make_model()
sparsifier, scheduler = self._make_scheduler(
model=model, initially_zero=True, init_t=3, delta_t=2, total_t=5)
model=model, initially_zero=True, init_t=3, delta_t=2, total_t=5
)
scheduler.step()
scheduler.step()
self.assertEqual(scheduler._step_count, 3, msg="Scheduler step_count is expected to increment")
self.assertEqual(
scheduler._step_count,
3,
msg="Scheduler step_count is expected to increment",
)
# Value before t_0 is supposed to be 0
self.assertEqual(
self._get_sparsity_levels(sparsifier), scheduler._make_sure_a_list(0.0),
msg="Scheduler step updating the sparsity level before t_0")
self._get_sparsity_levels(sparsifier),
scheduler._make_sure_a_list(0.0),
msg="Scheduler step updating the sparsity level before t_0",
)
scheduler.step() # Step = 3 => sparsity = initial_sparsity
self.assertEqual(
self._get_sparsity_levels(sparsifier), scheduler._make_sure_a_list(self.initial_sparsity),
msg="Sparsifier is not reset to initial sparsity at the first step")
self._get_sparsity_levels(sparsifier),
scheduler._make_sure_a_list(self.initial_sparsity),
msg="Sparsifier is not reset to initial sparsity at the first step",
)
scheduler.step() # Step = 4 => sparsity ~ [0.3, 0.2]
self.assertEqual(
self._get_sparsity_levels(sparsifier, 1), [0.3, 0.2],
msg="Sparsity level is not set correctly after the first step")
self._get_sparsity_levels(sparsifier, 1),
[0.3, 0.2],
msg="Sparsity level is not set correctly after the first step",
)
current_step = scheduler._step_count - scheduler.init_t[0] - 1
more_steps_needed = scheduler.delta_t[0] * scheduler.total_t[0] - current_step
for _ in range(more_steps_needed): # More steps needed to final sparsity level
scheduler.step()
self.assertEqual(
self._get_sparsity_levels(sparsifier), self.sorted_sparse_levels,
msg="Sparsity level is not reaching the target level afer delta_t * n steps ")
self._get_sparsity_levels(sparsifier),
self.sorted_sparse_levels,
msg="Sparsity level is not reaching the target level afer delta_t * n steps ",
)

View File

@ -6,13 +6,24 @@ import re
import torch
from torch import nn
from torch.ao.pruning import BaseSparsifier, WeightNormSparsifier, FakeSparsity, NearlyDiagonalSparsifier
from torch.ao.pruning import (
BaseSparsifier,
FakeSparsity,
NearlyDiagonalSparsifier,
WeightNormSparsifier,
)
from torch.nn.utils.parametrize import is_parametrized
from torch.testing._internal.common_pruning import (
ImplementedSparsifier,
MockSparseLinear,
SimpleLinear,
)
from torch.testing._internal.common_utils import TestCase
from torch.testing._internal.common_pruning import SimpleLinear, MockSparseLinear, ImplementedSparsifier
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO)
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
)
class TestBaseSparsifier(TestCase):
@ -27,41 +38,44 @@ class TestBaseSparsifier(TestCase):
sparsifier.step()
# Can instantiate the model with configs
sparsifier = ImplementedSparsifier(test=3)
sparsifier.prepare(model, [{'tensor_fqn': 'linear1.weight'}])
sparsifier.prepare(model, [{"tensor_fqn": "linear1.weight"}])
assert len(sparsifier.groups) == 1
assert sparsifier.groups[0]['tensor_fqn'] == 'linear1.weight'
assert 'test' in sparsifier.groups[0]
assert sparsifier.groups[0]['test'] == 3
assert sparsifier.groups[0]["tensor_fqn"] == "linear1.weight"
assert "test" in sparsifier.groups[0]
assert sparsifier.groups[0]["test"] == 3
def test_prepare_config(self):
model = SimpleLinear()
sparsifier = ImplementedSparsifier(test=3)
# Make sure there are no parametrizations before `prepare`
assert not hasattr(model.seq[0], 'parametrizations')
assert not hasattr(model.linear1, 'parametrizations')
assert not hasattr(model.linear2, 'parametrizations')
sparsifier.prepare(model, config=[
{'tensor_fqn': 'seq.0.weight', 'test': 42},
# No 'linear1' to make sure it will be skipped in the sparsification
{'tensor_fqn': 'linear2.weight'}
])
assert not hasattr(model.seq[0], "parametrizations")
assert not hasattr(model.linear1, "parametrizations")
assert not hasattr(model.linear2, "parametrizations")
sparsifier.prepare(
model,
config=[
{"tensor_fqn": "seq.0.weight", "test": 42},
# No 'linear1' to make sure it will be skipped in the sparsification
{"tensor_fqn": "linear2.weight"},
],
)
assert len(sparsifier.groups) == 2
# Check if default argument is not assigned if explicit
assert sparsifier.groups[0]['tensor_fqn'] == 'seq.0.weight'
assert sparsifier.groups[0]['test'] == 42
assert sparsifier.groups[0]["tensor_fqn"] == "seq.0.weight"
assert sparsifier.groups[0]["test"] == 42
# Check if FQN and module are pointing to the same location
assert sparsifier.groups[1]['tensor_fqn'] == 'linear2.weight'
assert sparsifier.groups[1]['module'] == model.linear2
assert sparsifier.groups[1]["tensor_fqn"] == "linear2.weight"
assert sparsifier.groups[1]["module"] == model.linear2
# Check if parameterizations are attached
assert hasattr(model.seq[0], 'parametrizations')
assert not hasattr(model.linear1, 'parametrizations')
assert hasattr(model.linear2, 'parametrizations')
assert hasattr(model.seq[0], "parametrizations")
assert not hasattr(model.linear1, "parametrizations")
assert hasattr(model.linear2, "parametrizations")
def test_step(self):
model = SimpleLinear()
sparsifier = ImplementedSparsifier(test=3)
sparsifier.enable_mask_update = True
sparsifier.prepare(model, [{'tensor_fqn': 'linear1.weight'}])
sparsifier.prepare(model, [{"tensor_fqn": "linear1.weight"}])
sparsifier.step()
assert torch.all(model.linear1.parametrizations.weight[0].mask[0] == 0)
@ -69,21 +83,21 @@ class TestBaseSparsifier(TestCase):
step_count = 3
model0 = SimpleLinear()
sparsifier0 = ImplementedSparsifier(test=3)
sparsifier0.prepare(model0, [{'tensor_fqn': 'linear1.weight'}])
mask = model0.linear1.parametrizations['weight'][0].mask
sparsifier0.prepare(model0, [{"tensor_fqn": "linear1.weight"}])
mask = model0.linear1.parametrizations["weight"][0].mask
mask.data = torch.arange(mask.shape[0] * mask.shape[1]).reshape(mask.shape)
for step in range(step_count):
sparsifier0.step()
state_dict = sparsifier0.state_dict()
# Check the expected keys in the state_dict
assert 'state' in state_dict
assert 'step_count' in state_dict['state']['linear1.weight']
assert state_dict['state']['linear1.weight']['step_count'] == 3
assert 'groups' in state_dict
assert 'test' in state_dict['groups'][0]
assert 'tensor_fqn' in state_dict['groups'][0]
assert state_dict['groups'][0]['tensor_fqn'] == 'linear1.weight'
assert "state" in state_dict
assert "step_count" in state_dict["state"]["linear1.weight"]
assert state_dict["state"]["linear1.weight"]["step_count"] == 3
assert "groups" in state_dict
assert "test" in state_dict["groups"][0]
assert "tensor_fqn" in state_dict["groups"][0]
assert state_dict["groups"][0]["tensor_fqn"] == "linear1.weight"
# Check loading static_dict creates an equivalent model
model1 = SimpleLinear()
@ -94,11 +108,11 @@ class TestBaseSparsifier(TestCase):
# Make sure the masks are different in the beginning
for mg in sparsifier0.groups:
if mg['tensor_fqn'] == 'linear1.weight':
mask0 = mg['module'].parametrizations.weight[0].mask
if mg["tensor_fqn"] == "linear1.weight":
mask0 = mg["module"].parametrizations.weight[0].mask
for mg in sparsifier1.groups:
if mg['tensor_fqn'] == 'linear1.weight':
mask1 = mg['module'].parametrizations.weight[0].mask
if mg["tensor_fqn"] == "linear1.weight":
mask1 = mg["module"].parametrizations.weight[0].mask
self.assertNotEqual(mask0, mask1)
sparsifier1.load_state_dict(state_dict)
@ -113,12 +127,12 @@ class TestBaseSparsifier(TestCase):
mg1 = sparsifier1.groups[idx]
for key in mg0.keys():
assert key in mg1
if key == 'module':
if key == "module":
# We cannot compare modules as they are different
param0 = mg0[key].parametrizations.weight[0]
param1 = mg1[key].parametrizations.weight[0]
assert hasattr(param0, 'mask')
assert hasattr(param1, 'mask')
assert hasattr(param0, "mask")
assert hasattr(param1, "mask")
self.assertEqual(param0.__dict__, param1.__dict__)
else:
assert mg0[key] == mg1[key]
@ -126,80 +140,84 @@ class TestBaseSparsifier(TestCase):
def test_convert(self):
model = SimpleLinear()
sparsifier = ImplementedSparsifier(test=3)
sparsifier.prepare(model, [{'tensor_fqn': 'linear1.weight'}])
new_model = sparsifier.convert(model, mapping={nn.Linear: MockSparseLinear}, inplace=False)
sparsifier.prepare(model, [{"tensor_fqn": "linear1.weight"}])
new_model = sparsifier.convert(
model, mapping={nn.Linear: MockSparseLinear}, inplace=False
)
assert isinstance(new_model.linear1, MockSparseLinear)
assert isinstance(new_model.seq[0], nn.Linear)
assert isinstance(new_model.linear2, nn.Linear)
def test_mask_squash(self):
model = SimpleLinear()
sparsifier = ImplementedSparsifier(test=3)
sparsifier.prepare(model, [{'tensor_fqn': 'linear1.weight'}])
assert hasattr(model.linear1.parametrizations.weight[0], 'mask')
assert is_parametrized(model.linear1, 'weight')
assert not is_parametrized(model.seq[0], 'weight')
sparsifier.prepare(model, [{"tensor_fqn": "linear1.weight"}])
assert hasattr(model.linear1.parametrizations.weight[0], "mask")
assert is_parametrized(model.linear1, "weight")
assert not is_parametrized(model.seq[0], "weight")
sparsifier.squash_mask()
assert not is_parametrized(model.seq[0], 'weight')
assert not is_parametrized(model.linear1, 'weight')
assert not is_parametrized(model.seq[0], "weight")
assert not is_parametrized(model.linear1, "weight")
def test_mask_squash_with_params1(self):
model = SimpleLinear()
sparsifier = ImplementedSparsifier(foo=3, bar=2, baz=1)
sparsifier.prepare(model, [{'tensor_fqn': 'linear1.weight'}, {'tensor_fqn': 'seq.0.weight'}])
sparsifier.prepare(
model, [{"tensor_fqn": "linear1.weight"}, {"tensor_fqn": "seq.0.weight"}]
)
sparsifier.squash_mask(
params_to_keep_per_layer={
'linear1': ('foo', 'bar'),
'seq.0': ('baz',)
})
assert not is_parametrized(model.seq[0], 'weight')
assert not is_parametrized(model.linear1, 'weight')
assert hasattr(model.seq[0], 'sparse_params')
assert hasattr(model.linear1, 'sparse_params')
assert model.seq[0].sparse_params.get('foo', None) is None
assert model.seq[0].sparse_params.get('bar', None) is None
assert model.seq[0].sparse_params.get('baz', None) == 1
assert model.linear1.sparse_params.get('foo', None) == 3
assert model.linear1.sparse_params.get('bar', None) == 2
assert model.linear1.sparse_params.get('baz', None) is None
params_to_keep_per_layer={"linear1": ("foo", "bar"), "seq.0": ("baz",)}
)
assert not is_parametrized(model.seq[0], "weight")
assert not is_parametrized(model.linear1, "weight")
assert hasattr(model.seq[0], "sparse_params")
assert hasattr(model.linear1, "sparse_params")
assert model.seq[0].sparse_params.get("foo", None) is None
assert model.seq[0].sparse_params.get("bar", None) is None
assert model.seq[0].sparse_params.get("baz", None) == 1
assert model.linear1.sparse_params.get("foo", None) == 3
assert model.linear1.sparse_params.get("bar", None) == 2
assert model.linear1.sparse_params.get("baz", None) is None
def test_mask_squash_with_params2(self):
model = SimpleLinear()
sparsifier = ImplementedSparsifier(foo=3, bar=2, baz=1)
sparsifier.prepare(model, [{'tensor_fqn': 'linear1.weight'}, {'tensor_fqn': 'seq.0.weight'}])
sparsifier.squash_mask(params_to_keep=('foo', 'bar'))
assert not is_parametrized(model.seq[0], 'weight')
assert not is_parametrized(model.linear1, 'weight')
assert hasattr(model.seq[0], 'sparse_params')
assert hasattr(model.linear1, 'sparse_params')
assert model.seq[0].sparse_params.get('foo', None) == 3
assert model.seq[0].sparse_params.get('bar', None) == 2
assert model.seq[0].sparse_params.get('baz', None) is None
assert model.linear1.sparse_params.get('foo', None) == 3
assert model.linear1.sparse_params.get('bar', None) == 2
assert model.linear1.sparse_params.get('baz', None) is None
sparsifier.prepare(
model, [{"tensor_fqn": "linear1.weight"}, {"tensor_fqn": "seq.0.weight"}]
)
sparsifier.squash_mask(params_to_keep=("foo", "bar"))
assert not is_parametrized(model.seq[0], "weight")
assert not is_parametrized(model.linear1, "weight")
assert hasattr(model.seq[0], "sparse_params")
assert hasattr(model.linear1, "sparse_params")
assert model.seq[0].sparse_params.get("foo", None) == 3
assert model.seq[0].sparse_params.get("bar", None) == 2
assert model.seq[0].sparse_params.get("baz", None) is None
assert model.linear1.sparse_params.get("foo", None) == 3
assert model.linear1.sparse_params.get("bar", None) == 2
assert model.linear1.sparse_params.get("baz", None) is None
def test_mask_squash_with_params3(self):
model = SimpleLinear()
sparsifier = ImplementedSparsifier(foo=3, bar=2, baz=1)
sparsifier.prepare(model, [{'tensor_fqn': 'linear1.weight'}, {'tensor_fqn': 'seq.0.weight'}])
sparsifier.prepare(
model, [{"tensor_fqn": "linear1.weight"}, {"tensor_fqn": "seq.0.weight"}]
)
sparsifier.squash_mask(
params_to_keep=('foo', 'bar'),
params_to_keep_per_layer={'seq.0': ('baz',)})
assert not is_parametrized(model.seq[0], 'weight')
assert not is_parametrized(model.linear1, 'weight')
assert hasattr(model.seq[0], 'sparse_params')
assert hasattr(model.linear1, 'sparse_params')
assert model.seq[0].sparse_params.get('foo', None) == 3
assert model.seq[0].sparse_params.get('bar', None) == 2
assert model.seq[0].sparse_params.get('baz', None) == 1
assert model.linear1.sparse_params.get('foo', None) == 3
assert model.linear1.sparse_params.get('bar', None) == 2
assert model.linear1.sparse_params.get('baz', None) is None
params_to_keep=("foo", "bar"), params_to_keep_per_layer={"seq.0": ("baz",)}
)
assert not is_parametrized(model.seq[0], "weight")
assert not is_parametrized(model.linear1, "weight")
assert hasattr(model.seq[0], "sparse_params")
assert hasattr(model.linear1, "sparse_params")
assert model.seq[0].sparse_params.get("foo", None) == 3
assert model.seq[0].sparse_params.get("bar", None) == 2
assert model.seq[0].sparse_params.get("baz", None) == 1
assert model.linear1.sparse_params.get("foo", None) == 3
assert model.linear1.sparse_params.get("bar", None) == 2
assert model.linear1.sparse_params.get("baz", None) is None
class TestWeightNormSparsifier(TestCase):
@ -208,25 +226,33 @@ class TestWeightNormSparsifier(TestCase):
sparsifier = WeightNormSparsifier()
sparsifier.prepare(model, config=None)
for g in sparsifier.groups:
assert isinstance(g['module'], nn.Linear)
assert isinstance(g["module"], nn.Linear)
# The groups are unordered
assert g['module_fqn'] in ('seq.0', 'seq.1', 'seq.2', 'linear1', 'linear2')
assert g["module_fqn"] in ("seq.0", "seq.1", "seq.2", "linear1", "linear2")
def test_step(self):
model = SimpleLinear()
sparsifier = WeightNormSparsifier(sparsity_level=0.5)
sparsifier.prepare(model, config=[{'tensor_fqn': 'linear1.weight'}])
sparsifier.prepare(model, config=[{"tensor_fqn": "linear1.weight"}])
for g in sparsifier.groups:
# Before step
module = g['module']
assert (1.0 - module.parametrizations['weight'][0].mask.mean()) == 0 # checking sparsity level is 0
module = g["module"]
assert (
1.0 - module.parametrizations["weight"][0].mask.mean()
) == 0 # checking sparsity level is 0
sparsifier.enable_mask_update = True
sparsifier.step()
self.assertAlmostEqual(model.linear1.parametrizations['weight'][0].mask.mean().item(), 0.5, places=2)
self.assertAlmostEqual(
model.linear1.parametrizations["weight"][0].mask.mean().item(),
0.5,
places=2,
)
for g in sparsifier.groups:
# After step
module = g['module']
assert (1.0 - module.parametrizations['weight'][0].mask.mean()) > 0 # checking sparsity level has increased
module = g["module"]
assert (
1.0 - module.parametrizations["weight"][0].mask.mean()
) > 0 # checking sparsity level has increased
# Test if the mask collapses to all zeros if the weights are randomized
iters_before_collapse = 1000
for _ in range(iters_before_collapse):
@ -234,25 +260,29 @@ class TestWeightNormSparsifier(TestCase):
sparsifier.step()
for g in sparsifier.groups:
# After step
module = g['module']
assert (1.0 - module.parametrizations['weight'][0].mask.mean()) > 0 # checking sparsity level did not collapse
module = g["module"]
assert (
1.0 - module.parametrizations["weight"][0].mask.mean()
) > 0 # checking sparsity level did not collapse
def test_step_2_of_4(self):
model = SimpleLinear()
sparsifier = WeightNormSparsifier(sparsity_level=1.0,
sparse_block_shape=(1, 4),
zeros_per_block=2)
sparsifier.prepare(model, config=[{'tensor_fqn': 'linear1.weight'}])
sparsifier = WeightNormSparsifier(
sparsity_level=1.0, sparse_block_shape=(1, 4), zeros_per_block=2
)
sparsifier.prepare(model, config=[{"tensor_fqn": "linear1.weight"}])
sparsifier.step()
# make sure the sparsity level is approximately 50%
mask = model.linear1.parametrizations['weight'][0].mask.to(torch.float) # mean works on float only
mask = model.linear1.parametrizations["weight"][0].mask.to(
torch.float
) # mean works on float only
self.assertAlmostEqual(mask.mean().item(), 0.5, places=2)
# Make sure each block has exactly 50% zeros
module = sparsifier.groups[0]['module']
mask = module.parametrizations['weight'][0].mask
module = sparsifier.groups[0]["module"]
mask = module.parametrizations["weight"][0].mask
for row in mask:
for idx in range(0, len(row), 4):
block = row[idx:idx + 4]
block = row[idx : idx + 4]
block, _ = block.sort()
assert (block[:2] == 0).all()
assert (block[2:] != 0).all()
@ -262,11 +292,11 @@ class TestWeightNormSparsifier(TestCase):
sparsifier = WeightNormSparsifier()
sparsifier.prepare(model, config=None)
for g in sparsifier.groups:
module = g['module']
module = g["module"]
# Check mask exists
assert hasattr(module.parametrizations['weight'][0], 'mask')
assert hasattr(module.parametrizations["weight"][0], "mask")
# Check parametrization exists and is correct
assert is_parametrized(module, 'weight')
assert is_parametrized(module, "weight")
assert type(module.parametrizations.weight[0]) == FakeSparsity
def test_mask_squash(self):
@ -275,39 +305,39 @@ class TestWeightNormSparsifier(TestCase):
sparsifier.prepare(model, config=None)
sparsifier.squash_mask()
for g in sparsifier.groups:
module = g['module']
assert not is_parametrized(module, 'weight')
assert not hasattr(module, 'mask')
module = g["module"]
assert not is_parametrized(module, "weight")
assert not hasattr(module, "mask")
def test_sparsity_levels(self):
sparsity_levels = [-1.0, 0.0, 0.5, 1.0, 2.0]
sparse_block_shapes = [(1, 1), (1, 4), (2, 2), (4, 1)]
zeros_per_blocks = [0, 1, 2, 3, 4]
testcases = itertools.tee(itertools.product(sparsity_levels,
sparse_block_shapes,
zeros_per_blocks))
testcases = itertools.tee(
itertools.product(sparsity_levels, sparse_block_shapes, zeros_per_blocks)
)
# Create a config and model with all the testcases
model = nn.Sequential()
sparsifier = WeightNormSparsifier()
sparsity_per_layer_config = []
p = re.compile(r'[-\.\s]')
p = re.compile(r"[-\.\s]")
for sl, sbs, zpb in testcases[0]:
# Make sure the number of zeros is not > values in a block
if zpb > sbs[0] * sbs[1]:
continue
layer_name = f'{sl}_{sbs}_{zpb}'
layer_name = p.sub('_', layer_name)
layer_name = f"{sl}_{sbs}_{zpb}"
layer_name = p.sub("_", layer_name)
layer = nn.Linear(12, 12, bias=False)
layer.weight = nn.Parameter(torch.ones(12, 12))
model.add_module(layer_name, layer)
config = {
'tensor_fqn': layer_name + ".weight",
'sparsity_level': sl,
'sparse_block_shape': sbs,
'zeros_per_block': zpb
"tensor_fqn": layer_name + ".weight",
"sparsity_level": sl,
"sparse_block_shape": sbs,
"zeros_per_block": zpb,
}
sparsity_per_layer_config.append(config)
@ -319,8 +349,8 @@ class TestWeightNormSparsifier(TestCase):
for sl, sbs, zpb in testcases[1]:
if zpb > sbs[0] * sbs[1]:
continue
layer_name = f'{sl}_{sbs}_{zpb}'
layer_name = p.sub('_', layer_name)
layer_name = f"{sl}_{sbs}_{zpb}"
layer_name = p.sub("_", layer_name)
layer = getattr(model, layer_name)
# Level of sparsity is achieved
@ -340,30 +370,34 @@ class TestNearlyDiagonalSparsifier(TestCase):
sparsifier = NearlyDiagonalSparsifier(nearliness=1)
sparsifier.prepare(model, config=None)
for g in sparsifier.groups:
assert isinstance(g['module'], nn.Linear)
assert isinstance(g["module"], nn.Linear)
# The groups are unordered
assert g['module_fqn'] in ('seq.0', 'seq.1', 'seq.2', 'linear1', 'linear2')
assert g["module_fqn"] in ("seq.0", "seq.1", "seq.2", "linear1", "linear2")
def test_step(self):
model = SimpleLinear()
sparsifier = NearlyDiagonalSparsifier(nearliness=1)
sparsifier.prepare(model, config=[{'tensor_fqn': 'linear1.weight'}])
sparsifier.prepare(model, config=[{"tensor_fqn": "linear1.weight"}])
for g in sparsifier.groups:
# Before step
module = g['module']
assert (1.0 - module.parametrizations['weight'][0].mask.mean()) == 0 # checking sparsity level is 0
module = g["module"]
assert (
1.0 - module.parametrizations["weight"][0].mask.mean()
) == 0 # checking sparsity level is 0
sparsifier.enable_mask_update = True
sparsifier.step()
mask = module.parametrizations['weight'][0].mask
mask = module.parametrizations["weight"][0].mask
height, width = mask.shape
assert torch.all(mask == torch.eye(height, width))
for g in sparsifier.groups:
# After step
module = g['module']
assert (1.0 - module.parametrizations['weight'][0].mask.mean()) > 0 # checking sparsity level has increased
module = g["module"]
assert (
1.0 - module.parametrizations["weight"][0].mask.mean()
) > 0 # checking sparsity level has increased
# Test if the mask collapses to all zeros if the weights are randomized
iters_before_collapse = 1000
@ -372,19 +406,21 @@ class TestNearlyDiagonalSparsifier(TestCase):
sparsifier.step()
for g in sparsifier.groups:
# After step
module = g['module']
assert (1.0 - module.parametrizations['weight'][0].mask.mean()) > 0 # checking sparsity level did not collapse
module = g["module"]
assert (
1.0 - module.parametrizations["weight"][0].mask.mean()
) > 0 # checking sparsity level did not collapse
def test_prepare(self):
model = SimpleLinear()
sparsifier = NearlyDiagonalSparsifier(nearliness=1)
sparsifier.prepare(model, config=None)
for g in sparsifier.groups:
module = g['module']
module = g["module"]
# Check mask exists
assert hasattr(module.parametrizations['weight'][0], 'mask')
assert hasattr(module.parametrizations["weight"][0], "mask")
# Check parametrization exists and is correct
assert is_parametrized(module, 'weight')
assert is_parametrized(module, "weight")
assert type(module.parametrizations.weight[0]) == FakeSparsity
def test_mask_squash(self):
@ -394,36 +430,36 @@ class TestNearlyDiagonalSparsifier(TestCase):
sparsifier.step()
sparsifier.squash_mask()
for g in sparsifier.groups:
module = g['module']
assert not is_parametrized(module, 'weight')
assert not hasattr(module, 'mask')
module = g["module"]
assert not is_parametrized(module, "weight")
assert not hasattr(module, "mask")
weights = module.weight
height, width = weights.shape
assert torch.all(weights == torch.eye(height, width) * weights) # only diagonal to be present
assert torch.all(
weights == torch.eye(height, width) * weights
) # only diagonal to be present
def test_sparsity_levels(self):
nearliness_levels = list(range(-1, 100))
model = nn.Sequential()
p = re.compile(r'[-\.\s]')
p = re.compile(r"[-\.\s]")
for nearliness in nearliness_levels:
sparsifier = NearlyDiagonalSparsifier(nearliness=1)
layer_name = f'{nearliness}'
layer_name = p.sub('_', layer_name)
layer_name = f"{nearliness}"
layer_name = p.sub("_", layer_name)
layer = nn.Linear(32, 32, bias=False)
layer.weight = nn.Parameter(torch.ones(32, 32))
width, height = layer.weight.shape
model.add_module(layer_name, layer)
config = {
'tensor_fqn': layer_name + ".weight",
'nearliness': nearliness
}
config = {"tensor_fqn": layer_name + ".weight", "nearliness": nearliness}
sparsifier.prepare(model, [config])
# should raise a ValueError when nearliness arg is illegal
if (nearliness > 0 and nearliness % 2 == 0) or (nearliness // 2 >= min(width, height)):
if (nearliness > 0 and nearliness % 2 == 0) or (
nearliness // 2 >= min(width, height)
):
with self.assertRaises(ValueError):
sparsifier.step()
else:

View File

@ -6,32 +6,32 @@ import random
import torch
from torch import nn
from torch.ao.pruning._experimental.pruner import (
SaliencyPruner,
LSTMSaliencyPruner,
BaseStructuredSparsifier,
FakeStructuredSparsity,
FPGMPruner
FPGMPruner,
LSTMSaliencyPruner,
SaliencyPruner,
)
from torch.nn.utils import parametrize
from torch.testing._internal.common_utils import TestCase, skipIfTorchDynamo
from torch.testing._internal.common_pruning import (
SimpleLinear,
LinearBias,
LinearActivation,
LinearActivationFunctional,
SimpleConv2d,
Conv2dBias,
Conv2dActivation,
Conv2dBias,
Conv2dPadBias,
Conv2dPool,
Conv2dPoolFlatten,
Conv2dPoolFlattenFunctional,
LSTMLinearModel,
LinearActivation,
LinearActivationFunctional,
LinearBias,
LSTMLayerNormLinearModel,
LSTMLinearModel,
rows_are_subset,
SimpleConv2d,
SimpleLinear,
)
from torch.testing._internal.common_utils import skipIfTorchDynamo, TestCase
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
@ -73,6 +73,7 @@ class BottomHalfLSTMPruner(BaseStructuredSparsifier):
new_mask = torch.cat(masks)
mask.data = new_mask.data
class TestSaliencyPruner(TestCase):
def test_saliency_pruner_update_mask(self):
"""Test that we prune out the row with the lowest saliency (first row)"""
@ -103,14 +104,9 @@ class TestSaliencyPruner(TestCase):
num_layers=1,
)
manual_weights = torch.Tensor([[1, 1],
[2, 2],
[2, 2],
[1, 1],
[-1, -1],
[-2, -2],
[-2, -2],
[-1, -1]])
manual_weights = torch.Tensor(
[[1, 1], [2, 2], [2, 2], [1, 1], [-1, -1], [-2, -2], [-2, -2], [-1, -1]]
)
with torch.no_grad():
model.lstm.weight_ih_l0 = nn.Parameter(manual_weights)
@ -137,18 +133,12 @@ class TestSaliencyPruner(TestCase):
pruned_model(lstm_input)
# make sure lowest saliency rows are pruned
expected = torch.Tensor([[2, 2],
[2, 2],
[-2, -2],
[-2, -2]])
expected = torch.Tensor([[2, 2], [2, 2], [-2, -2], [-2, -2]])
pruned = model.lstm.weight_ih_l0
assert expected.shape == pruned.shape
assert torch.isclose(expected, pruned, rtol=1e-05, atol=1e-07).all()
expected = torch.Tensor([[2],
[2],
[-2],
[-2]])
expected = torch.Tensor([[2], [2], [-2], [-2]])
pruned = model.lstm.weight_hh_l0
assert expected.shape == pruned.shape
assert torch.isclose(expected, pruned, rtol=1e-05, atol=1e-07).all()
@ -159,7 +149,6 @@ class TestSaliencyPruner(TestCase):
assert torch.isclose(expected, pruned, rtol=1e-05, atol=1e-07).all()
class TestBaseStructuredSparsifier(TestCase):
def _check_pruner_prepared(self, model, pruner, device):
for config in pruner.groups:
@ -916,15 +905,19 @@ class TestBaseStructuredSparsifier(TestCase):
# linear columns correctly.
assert out_expected.shape == out_pruned.shape
class TestFPGMPruner(TestCase):
"""
Test case for the implementation of paper:
`Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration <https://arxiv.org/abs/1811.00250>`_.
"""
class SimpleConvFPGM(nn.Module):
def __init__(self):
super().__init__()
self.conv2d1 = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=3, padding=1, bias=False)
self.conv2d1 = nn.Conv2d(
in_channels=1, out_channels=3, kernel_size=3, padding=1, bias=False
)
# Manually set the filter weights for demonstration purposes
"""
Three filters' weight are manually set to values 3.0, 2.0, and 0.1.
@ -933,13 +926,19 @@ class TestFPGMPruner(TestCase):
"""
weights = torch.tensor([3.0, 2.0, 0.1]) # Weight weights for each filter
weights = weights[:, None, None, None] # broadcasting
self.conv2d1.weight.data.copy_(torch.ones(self.conv2d1.weight.shape) * weights)
self.conv2d1.weight.data.copy_(
torch.ones(self.conv2d1.weight.shape) * weights
)
# Second Convolutional Layer
self.conv2d2 = nn.Conv2d(in_channels=3, out_channels=4, kernel_size=3, padding=1, bias=False)
self.conv2d2 = nn.Conv2d(
in_channels=3, out_channels=4, kernel_size=3, padding=1, bias=False
)
weights = torch.tensor([6.0, 7.0, 0.4, 0.5])
weights = weights[:, None, None, None]
self.conv2d2.weight.data.copy_(torch.ones(self.conv2d2.weight.shape) * weights)
self.conv2d2.weight.data.copy_(
torch.ones(self.conv2d2.weight.shape) * weights
)
def forward(self, x):
x = self.conv2d1(x)
@ -953,11 +952,43 @@ class TestFPGMPruner(TestCase):
dist_conv1 = pruner._compute_distance(model.conv2d1.weight)
# compute the distance matrix using torch.cdist
flattened_filters = torch.Tensor([
[3.0000, 3.0000, 3.0000, 3.0000, 3.0000, 3.0000, 3.0000, 3.0000, 3.0000],
[2.0000, 2.0000, 2.0000, 2.0000, 2.0000, 2.0000, 2.0000, 2.0000, 2.0000],
[0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000]
])
flattened_filters = torch.Tensor(
[
[
3.0000,
3.0000,
3.0000,
3.0000,
3.0000,
3.0000,
3.0000,
3.0000,
3.0000,
],
[
2.0000,
2.0000,
2.0000,
2.0000,
2.0000,
2.0000,
2.0000,
2.0000,
2.0000,
],
[
0.1000,
0.1000,
0.1000,
0.1000,
0.1000,
0.1000,
0.1000,
0.1000,
0.1000,
],
]
)
"""
Expected distance matrix should have the following values:
@ -967,9 +998,13 @@ class TestFPGMPruner(TestCase):
the distance should therefore be:
[11.7000, 8.7000, 14.4000]
"""
expected_dist_matrix_conv1 = torch.cdist(flattened_filters, flattened_filters, p=2)
expected_dist_matrix_conv1 = torch.cdist(
flattened_filters, flattened_filters, p=2
)
expected_dist_conv1 = torch.sum(torch.abs(expected_dist_matrix_conv1), 1)
assert torch.isclose(dist_conv1, expected_dist_conv1, rtol=1e-05, atol=1e-07).all()
assert torch.isclose(
dist_conv1, expected_dist_conv1, rtol=1e-05, atol=1e-07
).all()
def _test_update_mask_on_single_layer(self, expected_conv1, device):
"""Test that pruning is conducted based on the pair-wise distance measurement instead of absolute norm value"""
@ -981,8 +1016,10 @@ class TestFPGMPruner(TestCase):
pruner.prepare(model, config)
pruner.enable_mask_update = True
pruner.step()
assert pruner.groups[0]["module"].parametrizations.weight[0].mask[-1].item() is not False, \
"do not prune the least-norm filter"
assert (
pruner.groups[0]["module"].parametrizations.weight[0].mask[-1].item()
is not False
), "do not prune the least-norm filter"
# fusion step
pruned_model = pruner.prune()
@ -992,27 +1029,38 @@ class TestFPGMPruner(TestCase):
expected_conv1 = expected_conv1.to(device)
assert pruned_y.shape == (1, 4, 32, 32)
assert pruned_model.conv2d1.weight.shape == expected_conv1.shape
assert pruned_model.conv2d2.weight.shape == (4, 2, 3, 3), "conv2d2 should have input channel pruned"
assert pruned_model.conv2d2.weight.shape == (
4,
2,
3,
3,
), "conv2d2 should have input channel pruned"
# assert value
assert torch.isclose(pruned_model.conv2d1.weight, expected_conv1, rtol=1e-05, atol=1e-07).all()
assert torch.isclose(
pruned_model.conv2d1.weight, expected_conv1, rtol=1e-05, atol=1e-07
).all()
def _test_update_mask_on_multiple_layer(self, expected_conv1, expected_conv2, device):
def _test_update_mask_on_multiple_layer(
self, expected_conv1, expected_conv2, device
):
# the second setting
model = TestFPGMPruner.SimpleConvFPGM().to(device)
x = torch.ones((1, 1, 32, 32), device=device)
pruner = FPGMPruner(0.3)
config = [
{"tensor_fqn": "conv2d1.weight"},
{"tensor_fqn": "conv2d2.weight", "sparsity_level": 0.5}
{"tensor_fqn": "conv2d2.weight", "sparsity_level": 0.5},
]
pruner.prepare(model, config)
pruner.enable_mask_update = True
pruner.step()
# Get the masks for the two least-norm filters
mask1 = pruner.groups[0]['module'].parametrizations.weight[0].mask[-1]
mask2 = pruner.groups[0]['module'].parametrizations.weight[0].mask[-2]
mask1 = pruner.groups[0]["module"].parametrizations.weight[0].mask[-1]
mask2 = pruner.groups[0]["module"].parametrizations.weight[0].mask[-2]
# Check if either of the least-norm filters is not pruned
assert mask1.item() is not False or mask2.item() is not False, "Do not prune all least-norm filters"
assert (
mask1.item() is not False or mask2.item() is not False
), "Do not prune all least-norm filters"
# fusion step
pruned_model = pruner.prune()
@ -1024,8 +1072,12 @@ class TestFPGMPruner(TestCase):
assert pruned_model.conv2d1.weight.shape == expected_conv1.shape
assert pruned_model.conv2d2.weight.shape == expected_conv2.shape
# assert values
assert torch.isclose(pruned_model.conv2d1.weight, expected_conv1, rtol=1e-05, atol=1e-07).all()
assert torch.isclose(pruned_model.conv2d2.weight, expected_conv2, rtol=1e-05, atol=1e-07).all()
assert torch.isclose(
pruned_model.conv2d1.weight, expected_conv1, rtol=1e-05, atol=1e-07
).all()
assert torch.isclose(
pruned_model.conv2d2.weight, expected_conv2, rtol=1e-05, atol=1e-07
).all()
def test_update_mask(self):
weights = torch.tensor([3.0, 0.1])
@ -1036,4 +1088,6 @@ class TestFPGMPruner(TestCase):
for device in DEVICES:
self._test_update_mask_on_single_layer(expected_conv1, device)
self._test_update_mask_on_multiple_layer(expected_conv1, expected_conv2, device)
self._test_update_mask_on_multiple_layer(
expected_conv1, expected_conv2, device
)

View File

@ -2,7 +2,7 @@
import torch
from torch.testing._internal.common_utils import TestCase, run_tests, gradcheck
from torch.testing._internal.common_utils import gradcheck, run_tests, TestCase
class TestAutogradComplex(TestCase):
@ -71,7 +71,9 @@ class TestAutogradComplex(TestCase):
# modified inplace
res = x1.unbind(0)
with self.assertRaisesRegex(RuntimeError, "output of a function that returns multiple views"):
with self.assertRaisesRegex(
RuntimeError, "output of a function that returns multiple views"
):
res[0] += torch.rand(2, requires_grad=True)
x.requires_grad_(True)
@ -80,7 +82,9 @@ class TestAutogradComplex(TestCase):
# modified inplace
res = x1.unbind(0)
with self.assertRaisesRegex(RuntimeError, "output of a function that returns multiple views"):
with self.assertRaisesRegex(
RuntimeError, "output of a function that returns multiple views"
):
res[0] += torch.rand(2, requires_grad=True)
def as_identity(self):
@ -101,5 +105,5 @@ class TestAutogradComplex(TestCase):
self.assertEqual(z.grad, z1.grad)
if __name__ == '__main__':
if __name__ == "__main__":
run_tests()

File diff suppressed because it is too large Load Diff

View File

@ -1,11 +1,13 @@
# Owner(s): ["module: intel"]
from torch.testing._internal.common_utils import TestCase, run_tests, IS_LINUX
import shutil
import subprocess
import tempfile
import unittest
from torch.testing._internal.common_utils import IS_LINUX, run_tests, TestCase
@unittest.skipIf(not IS_LINUX, "Only works on linux")
class TestTorchrun(TestCase):
def setUp(self):
@ -37,6 +39,7 @@ class TestTorchrun(TestCase):
15,7,1,1
"""
from torch.backends.xeon.run_cpu import _CPUinfo
cpuinfo = _CPUinfo(lscpu_info)
assert cpuinfo._physical_core_nums() == 8
assert cpuinfo._logical_core_nums() == 16
@ -45,21 +48,43 @@ class TestTorchrun(TestCase):
assert cpuinfo.get_node_logical_cores(0) == [0, 1, 2, 3, 8, 9, 10, 11]
assert cpuinfo.get_node_logical_cores(1) == [4, 5, 6, 7, 12, 13, 14, 15]
assert cpuinfo.get_all_physical_cores() == [0, 1, 2, 3, 4, 5, 6, 7]
assert cpuinfo.get_all_logical_cores() == [0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15]
assert cpuinfo.get_all_logical_cores() == [
0,
1,
2,
3,
8,
9,
10,
11,
4,
5,
6,
7,
12,
13,
14,
15,
]
assert cpuinfo.numa_aware_check([0, 1, 2, 3]) == [0]
assert cpuinfo.numa_aware_check([4, 5, 6, 7]) == [1]
assert cpuinfo.numa_aware_check([2, 3, 4, 5]) == [0, 1]
def test_multi_threads(self):
num = 0
with subprocess.Popen(f"python -m torch.backends.xeon.run_cpu --ninstances 4 --use-default-allocator \
with subprocess.Popen(
f"python -m torch.backends.xeon.run_cpu --ninstances 4 --use-default-allocator \
--disable-iomp --disable-numactl --disable-taskset --log-path {self._test_dir} --no-python pwd",
shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) as p:
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
) as p:
for line in p.stdout.readlines():
segs = str(line, "utf-8").strip().split("-")
if segs[-1].strip() == "pwd":
num += 1
assert num == 4, "Failed to launch multiple instances for inference"
if __name__ == "__main__":
run_tests()

View File

@ -6,19 +6,26 @@ import os
import re
import textwrap
import timeit
from typing import Any, List, Tuple
import unittest
from typing import Any, List, Tuple
import torch
import torch.utils.benchmark as benchmark_utils
from torch.testing._internal.common_utils import TestCase, run_tests, IS_SANDCASTLE, IS_WINDOWS, slowTest, TEST_WITH_ASAN
import expecttest
import numpy as np
import torch
import torch.utils.benchmark as benchmark_utils
from torch.testing._internal.common_utils import (
IS_SANDCASTLE,
IS_WINDOWS,
run_tests,
slowTest,
TEST_WITH_ASAN,
TestCase,
)
CALLGRIND_ARTIFACTS: str = os.path.join(
os.path.split(os.path.abspath(__file__))[0],
"callgrind_artifacts.json"
os.path.split(os.path.abspath(__file__))[0], "callgrind_artifacts.json"
)
@ -33,13 +40,13 @@ def generate_callgrind_artifacts() -> None:
"""
print("Regenerating callgrind artifact.")
stats_no_data = benchmark_utils.Timer(
"y = torch.ones(())"
).collect_callgrind(number=1000)
stats_no_data = benchmark_utils.Timer("y = torch.ones(())").collect_callgrind(
number=1000
)
stats_with_data = benchmark_utils.Timer(
"y = torch.ones((1,))"
).collect_callgrind(number=1000)
stats_with_data = benchmark_utils.Timer("y = torch.ones((1,))").collect_callgrind(
number=1000
)
user = os.getenv("USER")
@ -59,7 +66,9 @@ def generate_callgrind_artifacts() -> None:
json.dump(artifacts, f, indent=4)
def load_callgrind_artifacts() -> Tuple[benchmark_utils.CallgrindStats, benchmark_utils.CallgrindStats]:
def load_callgrind_artifacts() -> (
Tuple[benchmark_utils.CallgrindStats, benchmark_utils.CallgrindStats]
):
"""Hermetic artifact to unit test Callgrind wrapper.
In addition to collecting counts, this wrapper provides some facilities for
@ -76,8 +85,7 @@ def load_callgrind_artifacts() -> Tuple[benchmark_utils.CallgrindStats, benchmar
pattern = re.compile(r"^\s*([0-9]+)\s(.+)$")
def to_function_counts(
count_strings: List[str],
inclusive: bool
count_strings: List[str], inclusive: bool
) -> benchmark_utils.FunctionCounts:
data: List[benchmark_utils.FunctionCount] = []
for cs in count_strings:
@ -89,8 +97,8 @@ def load_callgrind_artifacts() -> Tuple[benchmark_utils.CallgrindStats, benchmar
data.append(benchmark_utils.FunctionCount(count=int(c), function=fn))
return benchmark_utils.FunctionCounts(
tuple(sorted(data, reverse=True)),
inclusive=inclusive)
tuple(sorted(data, reverse=True)), inclusive=inclusive
)
baseline_inclusive = to_function_counts(artifacts["baseline_inclusive"], True)
baseline_exclusive = to_function_counts(artifacts["baseline_exclusive"], False)
@ -101,8 +109,12 @@ def load_callgrind_artifacts() -> Tuple[benchmark_utils.CallgrindStats, benchmar
built_with_debug_symbols=True,
baseline_inclusive_stats=baseline_inclusive,
baseline_exclusive_stats=baseline_exclusive,
stmt_inclusive_stats=to_function_counts(artifacts["ones_no_data_inclusive"], True),
stmt_exclusive_stats=to_function_counts(artifacts["ones_no_data_exclusive"], False),
stmt_inclusive_stats=to_function_counts(
artifacts["ones_no_data_inclusive"], True
),
stmt_exclusive_stats=to_function_counts(
artifacts["ones_no_data_exclusive"], False
),
stmt_callgrind_out=None,
)
@ -112,8 +124,12 @@ def load_callgrind_artifacts() -> Tuple[benchmark_utils.CallgrindStats, benchmar
built_with_debug_symbols=True,
baseline_inclusive_stats=baseline_inclusive,
baseline_exclusive_stats=baseline_exclusive,
stmt_inclusive_stats=to_function_counts(artifacts["ones_with_data_inclusive"], True),
stmt_exclusive_stats=to_function_counts(artifacts["ones_with_data_exclusive"], False),
stmt_inclusive_stats=to_function_counts(
artifacts["ones_with_data_inclusive"], True
),
stmt_exclusive_stats=to_function_counts(
artifacts["ones_with_data_exclusive"], False
),
stmt_callgrind_out=None,
)
@ -127,14 +143,12 @@ class MyModule(torch.nn.Module):
class TestBenchmarkUtils(TestCase):
def regularizeAndAssertExpectedInline(
self, x: Any,
expect: str,
indent: int = 12
self, x: Any, expect: str, indent: int = 12
) -> None:
x_str: str = re.sub(
"object at 0x[0-9a-fA-F]+>",
"object at 0xXXXXXXXXXXXX>",
x if isinstance(x, str) else repr(x)
x if isinstance(x, str) else repr(x),
)
if "\n" in x_str:
# Indent makes the reference align at the call site.
@ -157,15 +171,19 @@ class TestBenchmarkUtils(TestCase):
median = timer.adaptive_autorange(threshold=0.5).median
# Test that multi-line statements work properly.
median = benchmark_utils.Timer(
stmt="""
median = (
benchmark_utils.Timer(
stmt="""
with torch.no_grad():
y = x + 1""",
setup="""
setup="""
x = torch.ones((1,), requires_grad=True)
for _ in range(5):
x = x + 1.0""",
).timeit(5).median
)
.timeit(5)
.median
)
self.assertIsInstance(sample, float)
@slowTest
@ -173,7 +191,7 @@ class TestBenchmarkUtils(TestCase):
@unittest.skipIf(True, "Failing on clang, see 74398")
def test_timer_tiny_fast_snippet(self):
timer = benchmark_utils.Timer(
'auto x = 1;(void)x;',
"auto x = 1;(void)x;",
timer=timeit.default_timer,
language=benchmark_utils.Language.CPP,
)
@ -222,16 +240,16 @@ class TestBenchmarkUtils(TestCase):
return max(self._random_state.normal(mean, mean * noise_level), 5e-9)
def timeit(self, number):
return sum([
# First timer invocation
self.sample(self._timer_cost, self._timer_noise_level),
# Stmt body
self.sample(self._mean_cost * number, self._function_noise_level),
# Second timer invocation
self.sample(self._timer_cost, self._timer_noise_level),
])
return sum(
[
# First timer invocation
self.sample(self._timer_cost, self._timer_noise_level),
# Stmt body
self.sample(self._mean_cost * number, self._function_noise_level),
# Second timer invocation
self.sample(self._timer_cost, self._timer_noise_level),
]
)
def test_adaptive_timer(self):
class MockTimer(benchmark_utils.Timer):
@ -245,7 +263,6 @@ class TestBenchmarkUtils(TestCase):
_function_costs = (
self._MockTimer._function_costs[0],
self._MockTimer._function_costs[1],
# GPU should be faster once there is enough work.
("expensive_fn()", 5e-6),
)
@ -261,7 +278,7 @@ class TestBenchmarkUtils(TestCase):
pass
Median: 7.98 ns
IQR: 0.52 ns (7.74 to 8.26)
125 measurements, 10000000 runs per measurement, 1 thread"""
125 measurements, 10000000 runs per measurement, 1 thread""",
)
self.regularizeAndAssertExpectedInline(
@ -271,7 +288,7 @@ class TestBenchmarkUtils(TestCase):
pass
Median: 7.86 ns
IQR: 0.71 ns (7.63 to 8.34)
6 measurements, 1000000 runs per measurement, 1 thread"""
6 measurements, 1000000 runs per measurement, 1 thread""",
)
# Check against strings so we can reuse expect infra.
@ -287,7 +304,7 @@ class TestBenchmarkUtils(TestCase):
cheap_fn()
Median: 3.98 us
IQR: 0.27 us (3.85 to 4.12)
252 measurements, 10000 runs per measurement, 1 thread"""
252 measurements, 10000 runs per measurement, 1 thread""",
)
self.regularizeAndAssertExpectedInline(
@ -297,7 +314,7 @@ class TestBenchmarkUtils(TestCase):
cheap_fn()
Median: 4.16 us
IQR: 0.22 us (4.04 to 4.26)
4 measurements, 1000 runs per measurement, 1 thread"""
4 measurements, 1000 runs per measurement, 1 thread""",
)
self.regularizeAndAssertExpectedInline(
@ -307,7 +324,7 @@ class TestBenchmarkUtils(TestCase):
expensive_fn()
Median: 19.97 us
IQR: 1.35 us (19.31 to 20.65)
501 measurements, 1000 runs per measurement, 1 thread"""
501 measurements, 1000 runs per measurement, 1 thread""",
)
self.regularizeAndAssertExpectedInline(
@ -317,7 +334,7 @@ class TestBenchmarkUtils(TestCase):
expensive_fn()
Median: 20.79 us
IQR: 1.09 us (20.20 to 21.29)
4 measurements, 1000 runs per measurement, 1 thread"""
4 measurements, 1000 runs per measurement, 1 thread""",
)
self.regularizeAndAssertExpectedInline(
@ -327,7 +344,7 @@ class TestBenchmarkUtils(TestCase):
pass
Median: 7.92 ns
IQR: 0.43 ns (7.75 to 8.17)
13 measurements, 100000000 runs per measurement, 1 thread"""
13 measurements, 100000000 runs per measurement, 1 thread""",
)
self.regularizeAndAssertExpectedInline(
@ -337,7 +354,7 @@ class TestBenchmarkUtils(TestCase):
pass
Median: 7.75 ns
IQR: 0.57 ns (7.56 to 8.13)
4 measurements, 10000000 runs per measurement, 1 thread"""
4 measurements, 10000000 runs per measurement, 1 thread""",
)
self.regularizeAndAssertExpectedInline(
@ -347,7 +364,7 @@ class TestBenchmarkUtils(TestCase):
cheap_fn()
Median: 4.04 us
IQR: 0.30 us (3.90 to 4.19)
25 measurements, 100000 runs per measurement, 1 thread"""
25 measurements, 100000 runs per measurement, 1 thread""",
)
self.regularizeAndAssertExpectedInline(
@ -357,7 +374,7 @@ class TestBenchmarkUtils(TestCase):
cheap_fn()
Median: 4.09 us
IQR: 0.38 us (3.90 to 4.28)
4 measurements, 100000 runs per measurement, 1 thread"""
4 measurements, 100000 runs per measurement, 1 thread""",
)
self.regularizeAndAssertExpectedInline(
@ -367,7 +384,7 @@ class TestBenchmarkUtils(TestCase):
expensive_fn()
Median: 4.98 us
IQR: 0.31 us (4.83 to 5.13)
20 measurements, 100000 runs per measurement, 1 thread"""
20 measurements, 100000 runs per measurement, 1 thread""",
)
self.regularizeAndAssertExpectedInline(
@ -377,7 +394,7 @@ class TestBenchmarkUtils(TestCase):
expensive_fn()
Median: 5.01 us
IQR: 0.28 us (4.87 to 5.15)
4 measurements, 10000 runs per measurement, 1 thread"""
4 measurements, 10000 runs per measurement, 1 thread""",
)
# Make sure __repr__ is reasonable for
@ -398,7 +415,7 @@ class TestBenchmarkUtils(TestCase):
Median: 10.06 us
IQR: 0.54 us (9.73 to 10.27)
20 measurements, 1000 runs per measurement, 1 thread"""
20 measurements, 1000 runs per measurement, 1 thread""",
)
self.regularizeAndAssertExpectedInline(
@ -411,7 +428,7 @@ class TestBenchmarkUtils(TestCase):
Median: 10.06 us
IQR: 0.54 us (9.73 to 10.27)
20 measurements, 1000 runs per measurement, 1 thread"""
20 measurements, 1000 runs per measurement, 1 thread""",
)
self.regularizeAndAssertExpectedInline(
@ -425,7 +442,7 @@ class TestBenchmarkUtils(TestCase):
x + 1 (no grad): scalar_add
Median: 10.06 us
IQR: 0.54 us (9.73 to 10.27)
20 measurements, 1000 runs per measurement, 1 thread"""
20 measurements, 1000 runs per measurement, 1 thread""",
)
self.regularizeAndAssertExpectedInline(
@ -443,7 +460,7 @@ class TestBenchmarkUtils(TestCase):
setup: setup_fn()
Median: 10.06 us
IQR: 0.54 us (9.73 to 10.27)
20 measurements, 1000 runs per measurement, 1 thread"""
20 measurements, 1000 runs per measurement, 1 thread""",
)
self.regularizeAndAssertExpectedInline(
@ -471,7 +488,7 @@ class TestBenchmarkUtils(TestCase):
Median: 10.06 us
IQR: 0.54 us (9.73 to 10.27)
20 measurements, 1000 runs per measurement, 16 threads"""
20 measurements, 1000 runs per measurement, 16 threads""",
)
@slowTest
@ -482,24 +499,23 @@ class TestBenchmarkUtils(TestCase):
with self.assertRaisesRegex(
ValueError,
r"`collect_callgrind` requires that globals be wrapped "
r"in `CopyIfCallgrind` so that serialization is explicit."
r"in `CopyIfCallgrind` so that serialization is explicit.",
):
benchmark_utils.Timer(
"pass",
globals={"x": 1}
).collect_callgrind(collect_baseline=False)
benchmark_utils.Timer("pass", globals={"x": 1}).collect_callgrind(
collect_baseline=False
)
with self.assertRaisesRegex(
# Subprocess raises AttributeError (from pickle),
# _ValgrindWrapper re-raises as generic OSError.
OSError, "AttributeError: Can't get attribute 'MyModule'"
OSError,
"AttributeError: Can't get attribute 'MyModule'",
):
benchmark_utils.Timer(
"model(1)",
globals={"model": benchmark_utils.CopyIfCallgrind(MyModule())}
globals={"model": benchmark_utils.CopyIfCallgrind(MyModule())},
).collect_callgrind(collect_baseline=False)
@torch.jit.script
def add_one(x):
return x + 1
@ -516,9 +532,9 @@ class TestBenchmarkUtils(TestCase):
import sys
sys.path.append({repr(os.path.split(os.path.abspath(__file__))[0])})
from test_benchmark_utils import MyModule
"""
)
}
""",
),
},
)
stats = timer.collect_callgrind(number=1000)
@ -538,13 +554,22 @@ class TestBenchmarkUtils(TestCase):
assert isinstance(stats, tuple)
# Check that the repeats are at least somewhat repeatable. (within 10 instructions per iter)
counts = collections.Counter([s.counts(denoise=True) // 10_000 * 10_000 for s in stats])
self.assertGreater(max(counts.values()), 1, f"Every instruction count total was unique: {counts}")
counts = collections.Counter(
[s.counts(denoise=True) // 10_000 * 10_000 for s in stats]
)
self.assertGreater(
max(counts.values()),
1,
f"Every instruction count total was unique: {counts}",
)
from torch.utils.benchmark.utils.valgrind_wrapper.timer_interface import (
wrapper_singleton,
)
from torch.utils.benchmark.utils.valgrind_wrapper.timer_interface import wrapper_singleton
self.assertIsNone(
wrapper_singleton()._bindings_module,
"JIT'd bindings are only for back testing."
"JIT'd bindings are only for back testing.",
)
@slowTest
@ -558,28 +583,29 @@ class TestBenchmarkUtils(TestCase):
timer=timeit.default_timer,
language="c++",
)
stats = [
timer.collect_callgrind()
for _ in range(3)
]
stats = [timer.collect_callgrind() for _ in range(3)]
counts = [s.counts() for s in stats]
self.assertGreater(
min(counts), 0, "No stats were collected")
self.assertGreater(min(counts), 0, "No stats were collected")
self.assertEqual(
min(counts), max(counts), "C++ Callgrind should be deterministic")
min(counts), max(counts), "C++ Callgrind should be deterministic"
)
for s in stats:
self.assertEqual(
s.counts(denoise=True), s.counts(denoise=False),
"De-noising should not apply to C++.")
s.counts(denoise=True),
s.counts(denoise=False),
"De-noising should not apply to C++.",
)
stats = timer.collect_callgrind(number=1000, repeats=20)
assert isinstance(stats, tuple)
# NB: Unlike the example above, there is no expectation that all
# repeats will be identical.
counts = collections.Counter([s.counts(denoise=True) // 10_000 * 10_000 for s in stats])
counts = collections.Counter(
[s.counts(denoise=True) // 10_000 * 10_000 for s in stats]
)
self.assertGreater(max(counts.values()), 1, repr(counts))
def test_manipulate_callgrind_stats(self):
@ -587,7 +613,8 @@ class TestBenchmarkUtils(TestCase):
# Mock `torch.set_printoptions(linewidth=160)`
wide_linewidth = benchmark_utils.FunctionCounts(
stats_no_data.stats(inclusive=False)._data, False, _linewidth=160)
stats_no_data.stats(inclusive=False)._data, False, _linewidth=160
)
for l in repr(wide_linewidth).splitlines(keepends=False):
self.assertLessEqual(len(l), 160)
@ -595,10 +622,12 @@ class TestBenchmarkUtils(TestCase):
self.assertEqual(
# `delta` is just a convenience method.
stats_with_data.delta(stats_no_data)._data,
(stats_with_data.stats() - stats_no_data.stats())._data
(stats_with_data.stats() - stats_no_data.stats())._data,
)
deltas = stats_with_data.as_standardized().delta(stats_no_data.as_standardized())
deltas = stats_with_data.as_standardized().delta(
stats_no_data.as_standardized()
)
def custom_transforms(fn: str):
fn = re.sub(re.escape("/usr/include/c++/8/bits/"), "", fn)
@ -703,7 +732,7 @@ class TestBenchmarkUtils(TestCase):
2000 /usr/include/c++/8/bits/atomic_base.h:at::Tensor at::detail::make_tensor ... t_null_type<c10::StorageImpl> >&&, c10::DispatchKey&&, caffe2::TypeMeta&)
2000 /usr/include/c++/8/array:at::Tensor& c10::Dispatcher::callWithDispatchKe ... , c10::Scalar)> const&, c10::DispatchKey, at::Tensor&, c10::Scalar) const
Total: 8869966""" # noqa: B950
Total: 8869966""", # noqa: B950
)
self.regularizeAndAssertExpectedInline(
@ -821,10 +850,8 @@ class TestBenchmarkUtils(TestCase):
costs = (
# overhead_optimized_fn()
(1e-6, 1e-9),
# compute_optimized_fn()
(3e-6, 5e-10),
# special_case_fn() [square inputs only]
(1e-6, 4e-10),
)
@ -840,8 +867,7 @@ class TestBenchmarkUtils(TestCase):
# overhead_optimized_fn()
class _MockTimer_0(self._MockTimer):
_function_costs = tuple(
(f"fn({i}, {j})", costs[0][0] + costs[0][1] * i * j)
for i, j in sizes
(f"fn({i}, {j})", costs[0][0] + costs[0][1] * i * j) for i, j in sizes
)
class MockTimer_0(benchmark_utils.Timer):
@ -850,8 +876,7 @@ class TestBenchmarkUtils(TestCase):
# compute_optimized_fn()
class _MockTimer_1(self._MockTimer):
_function_costs = tuple(
(f"fn({i}, {j})", costs[1][0] + costs[1][1] * i * j)
for i, j in sizes
(f"fn({i}, {j})", costs[1][0] + costs[1][1] * i * j) for i, j in sizes
)
class MockTimer_1(benchmark_utils.Timer):
@ -861,7 +886,8 @@ class TestBenchmarkUtils(TestCase):
class _MockTimer_2(self._MockTimer):
_function_costs = tuple(
(f"fn({i}, {j})", costs[2][0] + costs[2][1] * i * j)
for i, j in sizes if i == j
for i, j in sizes
if i == j
)
class MockTimer_2(benchmark_utils.Timer):
@ -913,7 +939,7 @@ class TestBenchmarkUtils(TestCase):
compute_optimized | 3.1 | 4.0 | 11.2 | 2099.3 | 2099.3
special_case (square) | 1.1 | | 7.5 | | 1674.7
Times are in microseconds (us)."""
Times are in microseconds (us).""",
)
compare.trim_significant_figures()
@ -927,7 +953,7 @@ class TestBenchmarkUtils(TestCase):
compute_optimized | 3 | 4.0 | 11 | 2100 | 2100
special_case (square) | 1 | | 8 | | 1700
Times are in microseconds (us)."""
Times are in microseconds (us).""",
)
compare.colorize()
@ -978,12 +1004,16 @@ class TestBenchmarkUtils(TestCase):
self.assertEqual(columnwise_colored_actual, columnwise_colored_expected)
self.assertEqual(rowwise_colored_actual, rowwise_colored_expected)
@unittest.skipIf(IS_WINDOWS and os.getenv("VC_YEAR") == "2019", "Random seed only accepts int32")
@unittest.skipIf(
IS_WINDOWS and os.getenv("VC_YEAR") == "2019", "Random seed only accepts int32"
)
def test_fuzzer(self):
fuzzer = benchmark_utils.Fuzzer(
parameters=[
benchmark_utils.FuzzedParameter(
"n", minval=1, maxval=16, distribution="loguniform")],
"n", minval=1, maxval=16, distribution="loguniform"
)
],
tensors=[benchmark_utils.FuzzedTensor("x", size=("n",))],
seed=0,
)
@ -995,9 +1025,8 @@ class TestBenchmarkUtils(TestCase):
for i, (tensors, _, _) in enumerate(fuzzer.take(2)):
x = tensors["x"]
self.assertEqual(
x, torch.tensor(expected_results[i]), rtol=1e-3, atol=1e-3)
self.assertEqual(x, torch.tensor(expected_results[i]), rtol=1e-3, atol=1e-3)
if __name__ == '__main__':
if __name__ == "__main__":
run_tests()

View File

@ -1,14 +1,15 @@
# Owner(s): ["module: unknown"]
import argparse
import torch
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required args. Raises error if they aren't passed.
parser.add_argument('--foo', help='foo', required=True)
parser.add_argument('--bar', help='bar', required=True)
parser.add_argument("--foo", help="foo", required=True)
parser.add_argument("--bar", help="bar", required=True)
_ = parser.parse_args()
x = torch.ones((3, 3), requires_grad=True)

View File

@ -25,5 +25,5 @@ def main():
optimizer.step()
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -1,25 +1,24 @@
from _pytest.junitxml import LogXML, _NodeReporter, bin_xml_escape
from _pytest.terminal import _get_raw_skip_reason
from _pytest.stash import StashKey
from _pytest.reports import TestReport
from _pytest.config.argparsing import Parser
from _pytest.config import filename_arg
from _pytest.config import Config
from _pytest._code.code import ReprFileLocation
from _pytest.python import Module
from typing import Any, List, Union
from typing import Optional
from types import MethodType
import xml.etree.ElementTree as ET
import functools
import pytest
import sys
import os
import copy
import functools
import json
import os
import re
import sys
import xml.etree.ElementTree as ET
from collections import defaultdict
from pytest_shard_custom import PytestShardPlugin, pytest_addoptions as shard_addoptions
from types import MethodType
from typing import Any, List, Optional, Union
import pytest
from _pytest._code.code import ReprFileLocation
from _pytest.config import Config, filename_arg
from _pytest.config.argparsing import Parser
from _pytest.junitxml import _NodeReporter, bin_xml_escape, LogXML
from _pytest.python import Module
from _pytest.reports import TestReport
from _pytest.stash import StashKey
from _pytest.terminal import _get_raw_skip_reason
from pytest_shard_custom import pytest_addoptions as shard_addoptions, PytestShardPlugin
# a lot of this file is copied from _pytest.junitxml and modified to get rerun info
@ -42,7 +41,7 @@ def pytest_addoption(parser: Parser) -> None:
dest="stepcurrent",
)
parser.addoption("--use-main-module", action='store_true')
parser.addoption("--use-main-module", action="store_true")
group = parser.getgroup("terminal reporting")
group.addoption(
"--junit-xml-reruns",
@ -143,11 +142,14 @@ class _NodeReporterReruns(_NodeReporter):
skipreason = skipreason[9:]
details = f"{filename}:{lineno}: {skipreason}"
skipped = ET.Element("skipped", type="pytest.skip", message=bin_xml_escape(skipreason))
skipped = ET.Element(
"skipped", type="pytest.skip", message=bin_xml_escape(skipreason)
)
skipped.text = bin_xml_escape(details)
self.append(skipped)
self.write_captured_output(report)
class LogXMLReruns(LogXML):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@ -223,7 +225,7 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config):
def pytest_pycollect_makemodule(module_path, path, parent) -> Module:
if parent.config.getoption("--use-main-module"):
mod = Module.from_parent(parent, path=module_path)
mod._getobj = MethodType(lambda x: sys.modules['__main__'], mod)
mod._getobj = MethodType(lambda x: sys.modules["__main__"], mod)
return mod
@ -275,7 +277,10 @@ def pytest_collection_modifyitems(items: List[Any]) -> None:
test_name = item.name
test_class = item.parent.name
if test_class not in disabled_tests or test_name not in disabled_tests[test_class]:
if (
test_class not in disabled_tests
or test_name not in disabled_tests[test_class]
):
continue
cpy = copy.copy(item)

View File

@ -1,10 +1,10 @@
import torch
from torch._export import aot_compile
from torch.export import Dim
torch.manual_seed(1337)
class Net(torch.nn.Module):
def __init__(self, device):
super().__init__()
@ -17,6 +17,7 @@ class Net(torch.nn.Module):
w = w_relu + self.w_add
return torch.matmul(x, w)
class NetWithTensorConstants(torch.nn.Module):
def __init__(self):
super().__init__()
@ -26,9 +27,11 @@ class NetWithTensorConstants(torch.nn.Module):
z = self.w * x * y
return z[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17]]
data = {}
data_with_tensor_constants = {}
# Basice AOTI model test generation.
def generate_basic_tests():
for device in ["cpu", "cuda"]:
@ -49,18 +52,24 @@ def generate_basic_tests():
model,
(x,),
dynamic_shapes=dynamic_shapes,
options={"aot_inductor.use_runtime_constant_folding": use_runtime_constant_folding})
options={
"aot_inductor.use_runtime_constant_folding": use_runtime_constant_folding
},
)
suffix = f"{device}"
if use_runtime_constant_folding:
suffix += "_use_runtime_constant_folding"
data.update({
f"model_so_path_{suffix}": model_so_path,
f"inputs_{suffix}": [x],
f"outputs_{suffix}": [ref_output],
f"w_pre_{suffix}": model.w_pre,
f"w_add_{suffix}": model.w_add,
})
data.update(
{
f"model_so_path_{suffix}": model_so_path,
f"inputs_{suffix}": [x],
f"outputs_{suffix}": [ref_output],
f"w_pre_{suffix}": model.w_pre,
f"w_add_{suffix}": model.w_add,
}
)
# AOTI model which will create additional tensors during autograd.
def generate_test_with_additional_tensors():
@ -72,20 +81,22 @@ def generate_test_with_additional_tensors():
torch._dynamo.reset()
with torch.no_grad():
model_so_path = aot_compile(
model,
(x, y))
model_so_path = aot_compile(model, (x, y))
data_with_tensor_constants.update(
{
"model_so_path": model_so_path,
"inputs": [x, y],
"outputs": [ref_output],
"w": model.w,
}
)
data_with_tensor_constants.update({
"model_so_path": model_so_path,
"inputs": [x, y],
"outputs": [ref_output],
"w": model.w,
})
generate_basic_tests()
generate_test_with_additional_tensors()
# Use this to communicate tensors to the cpp code
class Serializer(torch.nn.Module):
def __init__(self, data):
@ -93,5 +104,8 @@ class Serializer(torch.nn.Module):
for key in data:
setattr(self, key, data[key])
torch.jit.script(Serializer(data)).save("data.pt")
torch.jit.script(Serializer(data_with_tensor_constants)).save("data_with_tensor_constants.pt")
torch.jit.script(Serializer(data_with_tensor_constants)).save(
"data_with_tensor_constants.pt"
)

View File

@ -1,6 +1,7 @@
"""Script to generate baseline values from PyTorch initialization algorithms"""
import sys
import torch
HEADER = """
@ -19,13 +20,13 @@ INITIALIZERS = {
"Xavier_Uniform": lambda w: torch.nn.init.xavier_uniform(w),
"Xavier_Normal": lambda w: torch.nn.init.xavier_normal(w),
"Kaiming_Normal": lambda w: torch.nn.init.kaiming_normal(w),
"Kaiming_Uniform": lambda w: torch.nn.init.kaiming_uniform(w)
"Kaiming_Uniform": lambda w: torch.nn.init.kaiming_uniform(w),
}
def emit(initializer_parameter_map):
# Don't write generated with an @ in front, else this file is recognized as generated.
print("// @{} from {}".format('generated', __file__))
print("// @{} from {}".format("generated", __file__))
print(HEADER)
for initializer_name, weights in initializer_parameter_map.items():
print(PARAMETERS.format(initializer_name))
@ -63,10 +64,11 @@ def run(initializer):
def main():
initializer_parameter_map = {}
for initializer in INITIALIZERS.keys():
sys.stderr.write(f'Evaluating {initializer} ...\n')
sys.stderr.write(f"Evaluating {initializer} ...\n")
initializer_parameter_map[initializer] = run(initializer)
emit(initializer_parameter_map)
if __name__ == "__main__":
main()

View File

@ -21,27 +21,43 @@ FOOTER = "} // namespace expected_parameters"
PARAMETERS = "inline std::vector<std::vector<torch::Tensor>> {}() {{"
OPTIMIZERS = {
"LBFGS" : lambda p: torch.optim.LBFGS(p, 1.0),
"LBFGS_with_line_search" : lambda p: torch.optim.LBFGS(p, 1.0, line_search_fn="strong_wolfe"),
"LBFGS": lambda p: torch.optim.LBFGS(p, 1.0),
"LBFGS_with_line_search": lambda p: torch.optim.LBFGS(
p, 1.0, line_search_fn="strong_wolfe"
),
"Adam": lambda p: torch.optim.Adam(p, 1.0),
"Adam_with_weight_decay": lambda p: torch.optim.Adam(p, 1.0, weight_decay=1e-2),
"Adam_with_weight_decay_and_amsgrad": lambda p: torch.optim.Adam(p, 1.0, weight_decay=1e-6, amsgrad=True),
"Adam_with_weight_decay_and_amsgrad": lambda p: torch.optim.Adam(
p, 1.0, weight_decay=1e-6, amsgrad=True
),
"AdamW": lambda p: torch.optim.AdamW(p, 1.0),
"AdamW_without_weight_decay": lambda p: torch.optim.AdamW(p, 1.0, weight_decay=0),
"AdamW_with_amsgrad": lambda p: torch.optim.AdamW(p, 1.0, amsgrad=True),
"Adagrad": lambda p: torch.optim.Adagrad(p, 1.0),
"Adagrad_with_weight_decay": lambda p: torch.optim.Adagrad(p, 1.0, weight_decay=1e-2),
"Adagrad_with_weight_decay_and_lr_decay": lambda p: torch.optim.Adagrad(p, 1.0, weight_decay=1e-6, lr_decay=1e-3),
"Adagrad_with_weight_decay": lambda p: torch.optim.Adagrad(
p, 1.0, weight_decay=1e-2
),
"Adagrad_with_weight_decay_and_lr_decay": lambda p: torch.optim.Adagrad(
p, 1.0, weight_decay=1e-6, lr_decay=1e-3
),
"RMSprop": lambda p: torch.optim.RMSprop(p, 0.1),
"RMSprop_with_weight_decay": lambda p: torch.optim.RMSprop(p, 0.1, weight_decay=1e-2),
"RMSprop_with_weight_decay_and_centered": lambda p: torch.optim.RMSprop(p, 0.1, weight_decay=1e-6, centered=True),
"RMSprop_with_weight_decay_and_centered_and_momentum":
lambda p: torch.optim.RMSprop(p, 0.1, weight_decay=1e-6, centered=True, momentum=0.9),
"RMSprop_with_weight_decay": lambda p: torch.optim.RMSprop(
p, 0.1, weight_decay=1e-2
),
"RMSprop_with_weight_decay_and_centered": lambda p: torch.optim.RMSprop(
p, 0.1, weight_decay=1e-6, centered=True
),
"RMSprop_with_weight_decay_and_centered_and_momentum": lambda p: torch.optim.RMSprop(
p, 0.1, weight_decay=1e-6, centered=True, momentum=0.9
),
"SGD": lambda p: torch.optim.SGD(p, 0.1),
"SGD_with_weight_decay": lambda p: torch.optim.SGD(p, 0.1, weight_decay=1e-2),
"SGD_with_weight_decay_and_momentum": lambda p: torch.optim.SGD(p, 0.1, momentum=0.9, weight_decay=1e-2),
"SGD_with_weight_decay_and_nesterov_momentum":
lambda p: torch.optim.SGD(p, 0.1, momentum=0.9, weight_decay=1e-6, nesterov=True),
"SGD_with_weight_decay_and_momentum": lambda p: torch.optim.SGD(
p, 0.1, momentum=0.9, weight_decay=1e-2
),
"SGD_with_weight_decay_and_nesterov_momentum": lambda p: torch.optim.SGD(
p, 0.1, momentum=0.9, weight_decay=1e-6, nesterov=True
),
}
@ -75,11 +91,11 @@ def run(optimizer_name, iterations, sample_every):
loss.backward()
def closure():
return torch.tensor([10.])
return torch.tensor([10.0])
optimizer.step(closure)
if i % sample_every == 0:
values.append(
[p.clone().flatten().data.numpy() for p in model.parameters()]
)
@ -89,7 +105,7 @@ def run(optimizer_name, iterations, sample_every):
def emit(optimizer_parameter_map):
# Don't write generated with an @ in front, else this file is recognized as generated.
print("// @{} from {}".format('generated', __file__))
print("// @{} from {}".format("generated", __file__))
print(HEADER)
for optimizer_name, parameters in optimizer_parameter_map.items():
print(PARAMETERS.format(optimizer_name))
@ -115,7 +131,7 @@ def main():
optimizer_parameter_map = {}
for optimizer in OPTIMIZERS.keys():
sys.stderr.write(f'Evaluating {optimizer} ...\n')
sys.stderr.write(f"Evaluating {optimizer} ...\n")
optimizer_parameter_map[optimizer] = run(
optimizer, options.iterations, options.sample_every
)

View File

@ -1,5 +1,6 @@
import sys
import os
import sys
import torch
@ -21,7 +22,7 @@ class FileSetup:
class EvalModeForLoadedModule(FileSetup):
path = 'dropout_model.pt'
path = "dropout_model.pt"
def setup(self):
class Model(torch.jit.ScriptModule):
@ -40,7 +41,7 @@ class EvalModeForLoadedModule(FileSetup):
class SerializationInterop(FileSetup):
path = 'ivalue.pt'
path = "ivalue.pt"
def setup(self):
ones = torch.ones(2, 2)
@ -53,7 +54,7 @@ class SerializationInterop(FileSetup):
# See testTorchSaveError in test/cpp/jit/tests.h for usage
class TorchSaveError(FileSetup):
path = 'eager_value.pt'
path = "eager_value.pt"
def setup(self):
ones = torch.ones(2, 2)
@ -63,8 +64,9 @@ class TorchSaveError(FileSetup):
torch.save(value, self.path, _use_new_zipfile_serialization=False)
class TorchSaveJitStream_CUDA(FileSetup):
path = 'saved_stream_model.pt'
path = "saved_stream_model.pt"
def setup(self):
if not torch.cuda.is_available():
@ -77,7 +79,9 @@ class TorchSaveJitStream_CUDA(FileSetup):
b = torch.rand(3, 4, device="cuda")
with torch.cuda.stream(s):
is_stream_s = torch.cuda.current_stream(s.device_index()).id() == s.id()
is_stream_s = (
torch.cuda.current_stream(s.device_index()).id() == s.id()
)
c = torch.cat((a, b), 0).to("cuda")
s.synchronize()
return is_stream_s, a, b, c
@ -93,9 +97,10 @@ tests = [
EvalModeForLoadedModule(),
SerializationInterop(),
TorchSaveError(),
TorchSaveJitStream_CUDA()
TorchSaveJitStream_CUDA(),
]
def setup():
for test in tests:
test.setup()