mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145200 Approved by: https://github.com/bobrenjc93
386 lines
13 KiB
Python
386 lines
13 KiB
Python
# Owner(s): ["module: unknown"]
|
|
|
|
from typing import Any
|
|
from torch.ao.pruning import BaseSparsifier
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import nn
|
|
|
|
class ImplementedSparsifier(BaseSparsifier):
|
|
def __init__(self, **kwargs: dict[str, Any]) -> None:
|
|
super().__init__(defaults=kwargs)
|
|
|
|
def update_mask(self, module: nn.Module, tensor_name: str, **kwargs: dict[str, Any]) -> None:
|
|
module.parametrizations.weight[0].mask[0] = 0 # type: ignore[index, union-attr]
|
|
linear_state = self.state['linear1.weight']
|
|
linear_state['step_count'] = linear_state.get('step_count', 0) + 1
|
|
|
|
|
|
class MockSparseLinear(nn.Linear):
|
|
"""
|
|
This class is a MockSparseLinear class to check convert functionality.
|
|
It is the same as a normal Linear layer, except with a different type, as
|
|
well as an additional from_dense method.
|
|
"""
|
|
@classmethod
|
|
def from_dense(cls, mod: nn.Linear) -> 'MockSparseLinear':
|
|
"""
|
|
"""
|
|
linear = cls(mod.in_features,
|
|
mod.out_features)
|
|
return linear
|
|
|
|
|
|
def rows_are_subset(subset_tensor: torch.Tensor, superset_tensor: torch.Tensor) -> bool:
|
|
"""
|
|
Checks to see if all rows in subset tensor are present in the superset tensor
|
|
"""
|
|
i = 0
|
|
for row in subset_tensor:
|
|
while i < len(superset_tensor):
|
|
if not torch.equal(row, superset_tensor[i]):
|
|
i += 1
|
|
else:
|
|
break
|
|
else:
|
|
return False
|
|
return True
|
|
|
|
|
|
class SimpleLinear(nn.Module):
|
|
r"""Model with only Linear layers without biases, some wrapped in a Sequential,
|
|
some following the Sequential. Used to test basic pruned Linear-Linear fusion."""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.seq = nn.Sequential(
|
|
nn.Linear(7, 5, bias=False),
|
|
nn.Linear(5, 6, bias=False),
|
|
nn.Linear(6, 4, bias=False),
|
|
)
|
|
self.linear1 = nn.Linear(4, 4, bias=False)
|
|
self.linear2 = nn.Linear(4, 10, bias=False)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x = self.seq(x)
|
|
x = self.linear1(x)
|
|
x = self.linear2(x)
|
|
return x
|
|
|
|
|
|
class LinearBias(nn.Module):
|
|
r"""Model with only Linear layers, alternating layers with biases,
|
|
wrapped in a Sequential. Used to test pruned Linear-Bias-Linear fusion."""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.seq = nn.Sequential(
|
|
nn.Linear(7, 5, bias=True),
|
|
nn.Linear(5, 6, bias=False),
|
|
nn.Linear(6, 3, bias=True),
|
|
nn.Linear(3, 3, bias=True),
|
|
nn.Linear(3, 10, bias=False),
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x = self.seq(x)
|
|
return x
|
|
|
|
|
|
class LinearActivation(nn.Module):
|
|
r"""Model with only Linear layers, some with bias, some in a Sequential and some following.
|
|
Activation functions modules in between each Linear in the Sequential, and each outside layer.
|
|
Used to test pruned Linear(Bias)-Activation-Linear fusion."""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.seq = nn.Sequential(
|
|
nn.Linear(7, 5, bias=True),
|
|
nn.ReLU(),
|
|
nn.Linear(5, 6, bias=False),
|
|
nn.Tanh(),
|
|
nn.Linear(6, 4, bias=True),
|
|
)
|
|
self.linear1 = nn.Linear(4, 3, bias=True)
|
|
self.act1 = nn.ReLU()
|
|
self.linear2 = nn.Linear(3, 10, bias=False)
|
|
self.act2 = nn.Tanh()
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x = self.seq(x)
|
|
x = self.linear1(x)
|
|
x = self.act1(x)
|
|
x = self.linear2(x)
|
|
x = self.act2(x)
|
|
return x
|
|
|
|
|
|
class LinearActivationFunctional(nn.Module):
|
|
r"""Model with only Linear layers, some with bias, some in a Sequential and some following.
|
|
Activation functions modules in between each Linear in the Sequential, and functional
|
|
activationals are called in between each outside layer.
|
|
Used to test pruned Linear(Bias)-Activation-Linear fusion."""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.seq = nn.Sequential(
|
|
nn.Linear(7, 5, bias=True),
|
|
nn.ReLU(),
|
|
nn.Linear(5, 6, bias=False),
|
|
nn.ReLU(),
|
|
nn.Linear(6, 4, bias=True),
|
|
)
|
|
self.linear1 = nn.Linear(4, 3, bias=True)
|
|
self.linear2 = nn.Linear(3, 8, bias=False)
|
|
self.linear3 = nn.Linear(8, 10, bias=False)
|
|
self.act1 = nn.ReLU()
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x = self.seq(x)
|
|
x = self.linear1(x)
|
|
x = F.relu(x)
|
|
x = self.linear2(x)
|
|
x = F.relu(x)
|
|
x = self.linear3(x)
|
|
x = F.relu(x)
|
|
return x
|
|
|
|
|
|
class SimpleConv2d(nn.Module):
|
|
r"""Model with only Conv2d layers, all without bias, some in a Sequential and some following.
|
|
Used to test pruned Conv2d-Conv2d fusion."""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.seq = nn.Sequential(
|
|
nn.Conv2d(1, 32, 3, 1, bias=False),
|
|
nn.Conv2d(32, 64, 3, 1, bias=False),
|
|
)
|
|
self.conv2d1 = nn.Conv2d(64, 48, 3, 1, bias=False)
|
|
self.conv2d2 = nn.Conv2d(48, 52, 3, 1, bias=False)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x = self.seq(x)
|
|
x = self.conv2d1(x)
|
|
x = self.conv2d2(x)
|
|
return x
|
|
|
|
|
|
class Conv2dBias(nn.Module):
|
|
r"""Model with only Conv2d layers, some with bias, some in a Sequential and some outside.
|
|
Used to test pruned Conv2d-Bias-Conv2d fusion."""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.seq = nn.Sequential(
|
|
nn.Conv2d(1, 32, 3, 1, bias=True),
|
|
nn.Conv2d(32, 32, 3, 1, bias=True),
|
|
nn.Conv2d(32, 64, 3, 1, bias=False),
|
|
)
|
|
self.conv2d1 = nn.Conv2d(64, 48, 3, 1, bias=True)
|
|
self.conv2d2 = nn.Conv2d(48, 52, 3, 1, bias=False)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x = self.seq(x)
|
|
x = self.conv2d1(x)
|
|
x = self.conv2d2(x)
|
|
return x
|
|
|
|
|
|
class Conv2dActivation(nn.Module):
|
|
r"""Model with only Conv2d layers, some with bias, some in a Sequential and some following.
|
|
Activation function modules in between each Sequential layer, functional activations called
|
|
in-between each outside layer.
|
|
Used to test pruned Conv2d-Bias-Activation-Conv2d fusion."""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.seq = nn.Sequential(
|
|
nn.Conv2d(1, 32, 3, 1, bias=True),
|
|
nn.ReLU(),
|
|
nn.Conv2d(32, 64, 3, 1, bias=True),
|
|
nn.Tanh(),
|
|
nn.Conv2d(64, 64, 3, 1, bias=False),
|
|
nn.ReLU(),
|
|
)
|
|
self.conv2d1 = nn.Conv2d(64, 48, 3, 1, bias=False)
|
|
self.conv2d2 = nn.Conv2d(48, 52, 3, 1, bias=True)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x = self.seq(x)
|
|
x = self.conv2d1(x)
|
|
x = F.relu(x)
|
|
x = self.conv2d2(x)
|
|
x = F.hardtanh(x)
|
|
return x
|
|
|
|
|
|
class Conv2dPadBias(nn.Module):
|
|
r"""Model with only Conv2d layers, all with bias and some with padding > 0,
|
|
some in a Sequential and some following. Activation function modules in between each layer.
|
|
Used to test that bias is propagated correctly in the special case of
|
|
pruned Conv2d-Bias-(Activation)Conv2d fusion, when the second Conv2d layer has padding > 0."""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.seq = nn.Sequential(
|
|
nn.Conv2d(1, 32, 3, 1, padding=1, bias=True),
|
|
nn.ReLU(),
|
|
nn.Conv2d(32, 32, 3, 1, bias=False),
|
|
nn.ReLU(),
|
|
nn.Conv2d(32, 32, 3, 1, padding=1, bias=True),
|
|
nn.ReLU(),
|
|
nn.Conv2d(32, 32, 3, 1, padding=1, bias=True),
|
|
nn.ReLU(),
|
|
nn.Conv2d(32, 64, 3, 1, bias=True),
|
|
nn.Tanh(),
|
|
)
|
|
self.conv2d1 = nn.Conv2d(64, 48, 3, 1, padding=1, bias=True)
|
|
self.act1 = nn.ReLU()
|
|
self.conv2d2 = nn.Conv2d(48, 52, 3, 1, padding=1, bias=True)
|
|
self.act2 = nn.Tanh()
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x = self.seq(x)
|
|
x = self.conv2d1(x)
|
|
x = self.act1(x)
|
|
x = self.conv2d2(x)
|
|
x = self.act2(x)
|
|
return x
|
|
|
|
|
|
class Conv2dPool(nn.Module):
|
|
r"""Model with only Conv2d layers, all with bias, some in a Sequential and some following.
|
|
Activation function modules in between each layer, Pool2d modules in between each layer.
|
|
Used to test pruned Conv2d-Pool2d-Conv2d fusion."""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.seq = nn.Sequential(
|
|
nn.Conv2d(1, 32, kernel_size=3, padding=1, bias=True),
|
|
nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
|
|
nn.ReLU(),
|
|
nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=True),
|
|
nn.Tanh(),
|
|
nn.AvgPool2d(kernel_size=2, stride=2, padding=1),
|
|
)
|
|
self.conv2d1 = nn.Conv2d(64, 48, kernel_size=3, padding=1, bias=True)
|
|
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, padding=1)
|
|
self.af1 = nn.ReLU()
|
|
self.conv2d2 = nn.Conv2d(48, 52, kernel_size=3, padding=1, bias=True)
|
|
self.conv2d3 = nn.Conv2d(52, 52, kernel_size=3, padding=1, bias=True)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x = self.seq(x)
|
|
x = self.conv2d1(x)
|
|
x = self.maxpool(x)
|
|
x = self.af1(x)
|
|
x = self.conv2d2(x)
|
|
x = F.avg_pool2d(x, kernel_size=2, stride=2, padding=1)
|
|
x = F.relu(x)
|
|
x = self.conv2d3(x)
|
|
return x
|
|
|
|
|
|
class Conv2dPoolFlattenFunctional(nn.Module):
|
|
r"""Model with Conv2d layers, all with bias, some in a Sequential and some following, and then a Pool2d
|
|
and a functional Flatten followed by a Linear layer.
|
|
Activation functions and Pool2ds in between each layer also.
|
|
Used to test pruned Conv2d-Pool2d-Flatten-Linear fusion."""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.seq = nn.Sequential(
|
|
nn.Conv2d(1, 3, kernel_size=3, padding=1, bias=True),
|
|
nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
|
|
nn.ReLU(),
|
|
nn.Conv2d(3, 5, kernel_size=3, padding=1, bias=True),
|
|
nn.Tanh(),
|
|
nn.AvgPool2d(kernel_size=2, stride=2, padding=1),
|
|
)
|
|
self.conv2d1 = nn.Conv2d(5, 7, kernel_size=3, padding=1, bias=True)
|
|
self.af1 = nn.ReLU()
|
|
self.conv2d2 = nn.Conv2d(7, 11, kernel_size=3, padding=1, bias=True)
|
|
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
|
|
self.fc = nn.Linear(11, 13, bias=True)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x = self.seq(x)
|
|
x = self.conv2d1(x)
|
|
x = F.max_pool2d(x, kernel_size=2, stride=2, padding=1)
|
|
x = self.af1(x)
|
|
x = self.conv2d2(x)
|
|
x = self.avg_pool(x)
|
|
x = torch.flatten(x, 1) # test functional flatten
|
|
x = self.fc(x)
|
|
return x
|
|
|
|
|
|
class Conv2dPoolFlatten(nn.Module):
|
|
r"""Model with Conv2d layers, all with bias, some in a Sequential and some following, and then a Pool2d
|
|
and a Flatten module followed by a Linear layer.
|
|
Activation functions and Pool2ds in between each layer also.
|
|
Used to test pruned Conv2d-Pool2d-Flatten-Linear fusion."""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.seq = nn.Sequential(
|
|
nn.Conv2d(1, 3, kernel_size=3, padding=1, bias=True),
|
|
nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
|
|
nn.ReLU(),
|
|
nn.Conv2d(3, 5, kernel_size=3, padding=1, bias=True),
|
|
nn.Tanh(),
|
|
nn.AvgPool2d(kernel_size=2, stride=2, padding=1),
|
|
)
|
|
self.conv2d1 = nn.Conv2d(5, 7, kernel_size=3, padding=1, bias=True)
|
|
self.af1 = nn.ReLU()
|
|
self.conv2d2 = nn.Conv2d(7, 11, kernel_size=3, padding=1, bias=True)
|
|
self.avg_pool = nn.AdaptiveAvgPool2d((2, 2))
|
|
self.flatten = nn.Flatten()
|
|
self.fc = nn.Linear(44, 13, bias=True)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x = self.seq(x)
|
|
x = self.conv2d1(x)
|
|
x = F.max_pool2d(x, kernel_size=2, stride=2, padding=1)
|
|
x = self.af1(x)
|
|
x = self.conv2d2(x)
|
|
x = self.avg_pool(x)
|
|
x = self.flatten(x)
|
|
x = self.fc(x)
|
|
return x
|
|
|
|
|
|
class LSTMLinearModel(nn.Module):
|
|
"""Container module with an encoder, a recurrent module, and a linear."""
|
|
|
|
def __init__(
|
|
self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int
|
|
) -> None:
|
|
super().__init__()
|
|
self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers)
|
|
self.linear = nn.Linear(hidden_dim, output_dim)
|
|
|
|
def forward(self, input: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
output, _hidden = self.lstm(input)
|
|
decoded = self.linear(output)
|
|
return decoded, output
|
|
|
|
|
|
class LSTMLayerNormLinearModel(nn.Module):
|
|
"""Container module with an LSTM, a LayerNorm, and a linear."""
|
|
|
|
def __init__(
|
|
self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int
|
|
) -> None:
|
|
super().__init__()
|
|
self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers)
|
|
self.norm = nn.LayerNorm(hidden_dim)
|
|
self.linear = nn.Linear(hidden_dim, output_dim)
|
|
|
|
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
x, state = self.lstm(x)
|
|
x = self.norm(x)
|
|
x = self.linear(x)
|
|
return x, state
|