mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This reverts commit f786fbdebdd24d3a6807e3b9fbf055836db4ad60. Forward fixes Pull Request resolved: https://github.com/pytorch/pytorch/pull/110964 Approved by: https://github.com/ezyang, https://github.com/anijain2305
2279 lines
69 KiB
Python
2279 lines
69 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
|
|
import collections
|
|
import itertools
|
|
import traceback
|
|
import types
|
|
import unittest
|
|
from copy import deepcopy
|
|
from functools import partial
|
|
from typing import Tuple
|
|
from unittest.mock import patch
|
|
|
|
import torch
|
|
|
|
import torch._dynamo.test_case
|
|
import torch._dynamo.testing
|
|
import torch.nn.functional as F
|
|
from torch._dynamo.eval_frame import unsupported
|
|
from torch._dynamo.mutation_guard import GenerationTracker
|
|
from torch._dynamo.testing import expectedFailureDynamic, same
|
|
from torch.nn.modules.lazy import LazyModuleMixin
|
|
from torch.nn.parameter import Parameter, UninitializedParameter
|
|
|
|
try:
|
|
from . import test_functions
|
|
except ImportError:
|
|
import test_functions
|
|
|
|
|
|
class BasicModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear1 = torch.nn.Linear(10, 10)
|
|
self.scale = torch.randn(1, 10)
|
|
|
|
def forward(self, x):
|
|
return F.relu(self.linear1(x)) * self.scale
|
|
|
|
|
|
class FnMember(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear1 = torch.nn.Linear(10, 10)
|
|
self.activation = F.relu
|
|
|
|
def forward(self, x):
|
|
x = self.linear1(x)
|
|
if self.activation:
|
|
x = self.activation(x)
|
|
return x
|
|
|
|
|
|
class FnMemberCmp(torch.nn.Module):
|
|
def __init__(self, activation):
|
|
super().__init__()
|
|
self.linear1 = torch.nn.Linear(10, 10)
|
|
self.activation = activation
|
|
|
|
def forward(self, x):
|
|
x = self.linear1(x)
|
|
if self.activation is not None:
|
|
x = self.activation(x)
|
|
if self.activation is None:
|
|
x = torch.sigmoid(x)
|
|
return x
|
|
|
|
|
|
class SubmoduleExample(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.layer1 = BasicModule()
|
|
self.layer2 = BasicModule()
|
|
self.scale = torch.randn(1, 10)
|
|
|
|
def forward(self, x):
|
|
x = self.layer1(x)
|
|
x = self.layer2(x)
|
|
return x * self.scale
|
|
|
|
|
|
class IsTrainingCheck(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear1 = torch.nn.Linear(10, 10)
|
|
self.linear2 = torch.nn.Linear(10, 10)
|
|
self.train(True)
|
|
|
|
def forward(self, x):
|
|
if self.training:
|
|
mod = self.linear1
|
|
else:
|
|
mod = self.linear2
|
|
return F.relu(mod(x))
|
|
|
|
|
|
class IsEvalCheck(IsTrainingCheck):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.train(False)
|
|
|
|
|
|
class ModuleMethodCall(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.layer1 = BasicModule()
|
|
self.layer2 = BasicModule()
|
|
self.scale = torch.randn(1, 10)
|
|
|
|
def call_and_scale(self, mod, x):
|
|
x = mod(x)
|
|
return x * self.scale
|
|
|
|
def forward(self, x):
|
|
x1 = self.call_and_scale(self.layer1, x)
|
|
x2 = self.call_and_scale(self.layer2, x)
|
|
return x1 + x2
|
|
|
|
|
|
class UnsupportedMethodCall(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.layer1 = BasicModule()
|
|
self.scale = torch.randn(1, 10)
|
|
|
|
def call_and_scale(self, mod, x):
|
|
x = mod(x)
|
|
x = x * self.scale
|
|
return unsupported(x, x)
|
|
|
|
def forward(self, x):
|
|
x1 = self.call_and_scale(self.layer1, x)
|
|
return x + x1
|
|
|
|
|
|
class UnsupportedModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.layer1 = BasicModule()
|
|
self.scale = torch.randn(1, 10)
|
|
|
|
def forward(self, x):
|
|
x = self.layer1(x) * self.scale
|
|
return unsupported(x, x)
|
|
|
|
|
|
class UnsupportedModuleCall(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.mod = UnsupportedModule()
|
|
|
|
def forward(self, x):
|
|
return 1 + self.mod(x * 1.5)
|
|
|
|
|
|
class ModuleWithStaticForward(torch.nn.Module):
|
|
@staticmethod
|
|
def forward(x):
|
|
return x * torch.sigmoid(x)
|
|
|
|
|
|
class ModuleCallModuleWithStaticForward(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.mod = ModuleWithStaticForward()
|
|
|
|
def forward(self, x):
|
|
return self.mod(x)
|
|
|
|
|
|
class ModuleStaticMethodCall(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.layer1 = BasicModule()
|
|
self.layer2 = BasicModule()
|
|
self.scale = torch.randn(1, 10)
|
|
|
|
@staticmethod
|
|
def call_and_scale(scale, mod, x):
|
|
x = mod(x)
|
|
return x * scale
|
|
|
|
def forward(self, x):
|
|
x1 = self.call_and_scale(self.scale, self.layer1, x)
|
|
x2 = self.call_and_scale(self.scale, self.layer2, x)
|
|
return x1 + x2
|
|
|
|
|
|
class ModuleClassMethodCall(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.layer1 = BasicModule()
|
|
self.layer2 = BasicModule()
|
|
self.scale = torch.randn(1, 10)
|
|
|
|
@classmethod
|
|
def call_and_scale(cls, scale, mod, x):
|
|
x = mod(x)
|
|
return x * scale
|
|
|
|
def forward(self, x):
|
|
x1 = self.call_and_scale(self.scale, self.layer1, x)
|
|
x2 = self.call_and_scale(self.scale, self.layer2, x)
|
|
return x1 + x2
|
|
|
|
|
|
class ModuleProperty(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.scale = torch.randn(1, 10)
|
|
|
|
@property
|
|
def scale_alias(self):
|
|
return self.scale
|
|
|
|
def forward(self, x):
|
|
return x * self.scale_alias
|
|
|
|
|
|
class ConstLoop(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear1 = torch.nn.Linear(10, 10)
|
|
self.count = 3
|
|
|
|
def forward(self, x):
|
|
for i in range(self.count):
|
|
x = torch.sigmoid(self.linear1(x))
|
|
return x
|
|
|
|
|
|
class ViaModuleCall(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear1 = torch.nn.Linear(10, 10)
|
|
|
|
def forward(self, x):
|
|
return test_functions.constant3(torch.sigmoid(self.linear1(x)), x)
|
|
|
|
|
|
class IsNoneLayer(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.layer1 = torch.nn.Linear(10, 10)
|
|
self.layer2 = None
|
|
self.train(True)
|
|
|
|
def forward(self, x):
|
|
if self.layer1 is not None:
|
|
x = self.layer1(x)
|
|
if self.layer2 is not None:
|
|
x = self.layer2(x)
|
|
return x
|
|
|
|
|
|
class LayerList(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.layers = [
|
|
torch.nn.Linear(10, 10),
|
|
torch.nn.ReLU(),
|
|
torch.nn.Linear(10, 10),
|
|
]
|
|
|
|
def forward(self, x):
|
|
for layer in self.layers:
|
|
x = layer(x)
|
|
return x
|
|
|
|
|
|
class ModuleList(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.layers = torch.nn.ModuleList(
|
|
[
|
|
torch.nn.Linear(10, 10),
|
|
torch.nn.ReLU(),
|
|
torch.nn.Linear(10, 10),
|
|
torch.nn.ReLU(),
|
|
]
|
|
)
|
|
|
|
def forward(self, x):
|
|
for i in range(len(self.layers)):
|
|
x = self.layers[i](x)
|
|
|
|
for layer in self.layers:
|
|
x = layer(x)
|
|
|
|
for layer, val in zip(self.layers, (x, x, x, x)):
|
|
x = layer(x) + val
|
|
|
|
for layer, val in zip(self.layers, (1, 2, 3, 4)):
|
|
x = layer(x) + val
|
|
|
|
for idx, layer in enumerate(self.layers):
|
|
x = layer(x) * idx
|
|
|
|
for idx, layer in enumerate(self.layers[::-1]):
|
|
x = layer(x) * idx
|
|
|
|
return x
|
|
|
|
|
|
class CustomGetItemModuleList(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.layers = torch.nn.ModuleList(
|
|
[
|
|
torch.nn.Linear(10, 10),
|
|
torch.nn.ReLU(),
|
|
torch.nn.Linear(10, 10),
|
|
torch.nn.ReLU(),
|
|
]
|
|
)
|
|
|
|
def __getitem__(self, idx: int):
|
|
return self.layers[idx]
|
|
|
|
def __len__(self) -> int:
|
|
return len(self.layers)
|
|
|
|
def forward(self, x):
|
|
for i in range(len(self)):
|
|
x = self[i](x)
|
|
|
|
return x
|
|
|
|
|
|
class ModuleDict(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.layers = torch.nn.ModuleDict(
|
|
{
|
|
"0": torch.nn.Linear(10, 10),
|
|
}
|
|
)
|
|
|
|
def forward(self, x):
|
|
# TODO(future PR): handle more logic
|
|
x = self.layers["0"](x)
|
|
return x
|
|
|
|
|
|
class ParameterDict(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.layers = torch.nn.ParameterDict(
|
|
{
|
|
"0": torch.nn.Parameter(torch.randn(10, 10)),
|
|
}
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = self.layers["0"].mm(x)
|
|
return x
|
|
|
|
|
|
class CustomGetItemParameterDict(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.layers = torch.nn.ParameterDict(
|
|
{
|
|
"0": torch.nn.Parameter(torch.randn(10, 10)),
|
|
}
|
|
)
|
|
|
|
def __getitem__(self, key: str) -> torch.nn.Module:
|
|
return self.layers[key]
|
|
|
|
def forward(self, x):
|
|
x = self["0"].mm(x)
|
|
return x
|
|
|
|
|
|
class CustomGetItemModuleDict(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.layers = torch.nn.ModuleDict(
|
|
{
|
|
"0": torch.nn.Linear(10, 10),
|
|
}
|
|
)
|
|
|
|
def __getitem__(self, key: str) -> torch.nn.Module:
|
|
return self.layers[key]
|
|
|
|
def forward(self, x):
|
|
x = self["0"](x)
|
|
return x
|
|
|
|
|
|
class TensorList(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.layers = (
|
|
torch.randn((1, 10)),
|
|
torch.randn((10, 1)),
|
|
torch.randn((1, 10)),
|
|
torch.randn((10, 1)),
|
|
)
|
|
|
|
def forward(self, x):
|
|
for layer in self.layers:
|
|
x = x * layer
|
|
return x
|
|
|
|
|
|
class Children(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.l1 = torch.nn.Linear(10, 10)
|
|
self.l2 = torch.nn.ReLU()
|
|
self.l3 = torch.nn.Linear(10, 10)
|
|
self.l4 = torch.nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
for block in self.children():
|
|
x = block(x)
|
|
return x
|
|
|
|
|
|
class NamedChildren(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.l1 = torch.nn.Linear(10, 10)
|
|
self.l2 = torch.nn.ReLU()
|
|
self.l3 = torch.nn.Linear(10, 10)
|
|
self.l4 = torch.nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
for _, block in self.named_children():
|
|
x = block(x)
|
|
return x
|
|
|
|
|
|
class IntArg(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.layer1 = torch.nn.Linear(10, 10)
|
|
|
|
def forward(self, x, offset=1):
|
|
x = F.relu(self.layer1(x)) + offset
|
|
return x
|
|
|
|
|
|
class Seq(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.layers = torch.nn.Sequential(
|
|
torch.nn.Linear(10, 10),
|
|
torch.nn.ReLU(),
|
|
torch.nn.Linear(10, 10),
|
|
torch.nn.ReLU(),
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.layers(x)
|
|
|
|
|
|
class Cfg:
|
|
def __init__(self):
|
|
self.val = 0.5
|
|
self.count = 3
|
|
|
|
|
|
class CfgModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.cfg = Cfg()
|
|
self.layer = torch.nn.Linear(10, 10)
|
|
|
|
def forward(self, x):
|
|
for i in range(self.cfg.count):
|
|
x = self.layer(x + self.cfg.val)
|
|
return x
|
|
|
|
|
|
class StringMember(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear1 = torch.nn.Linear(10, 10)
|
|
self.mode = "some_string"
|
|
|
|
def forward(self, x):
|
|
if self.mode == "some_string":
|
|
return F.relu(self.linear1(x))
|
|
|
|
|
|
class _Block(torch.nn.Module):
|
|
def forward(self, x):
|
|
return 1.5 * torch.cat(x, 1)
|
|
|
|
|
|
class _DenseBlock(torch.nn.ModuleDict):
|
|
_version = 2
|
|
|
|
def __init__(
|
|
self,
|
|
num_layers: int = 3,
|
|
) -> None:
|
|
super().__init__()
|
|
for i in range(num_layers):
|
|
self.add_module("denselayer%d" % (i + 1), _Block())
|
|
|
|
def forward(self, init_features):
|
|
features = [init_features]
|
|
for layer in self.values():
|
|
new_features = layer(features)
|
|
features.append(new_features)
|
|
return torch.cat(features, 1)
|
|
|
|
|
|
class DenseNetBlocks(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.layers = _DenseBlock()
|
|
|
|
def forward(self, x):
|
|
return self.layers(x)
|
|
|
|
|
|
class MaterializedModule(torch.nn.Module):
|
|
"""Once the below lazy module is initialized with its first input,
|
|
it is transformed into this module."""
|
|
|
|
param: Parameter
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.register_parameter("param", None)
|
|
|
|
def forward(self, x):
|
|
return x
|
|
|
|
|
|
class LazyModule(LazyModuleMixin, MaterializedModule):
|
|
param: UninitializedParameter
|
|
cls_to_become = MaterializedModule
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.param = UninitializedParameter()
|
|
|
|
def initialize_parameters(self, x):
|
|
# force graph break to ensure this was not inlined
|
|
torch._dynamo.graph_break()
|
|
self.param.materialize(x.shape)
|
|
|
|
|
|
class LazyMLP(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.fc1 = torch.nn.LazyLinear(10)
|
|
self.relu1 = torch.nn.ReLU()
|
|
self.fc2 = torch.nn.LazyLinear(1)
|
|
self.relu2 = torch.nn.ReLU()
|
|
|
|
def forward(self, input):
|
|
x = self.relu1(self.fc1(input))
|
|
y = self.relu2(self.fc2(x))
|
|
return y
|
|
|
|
|
|
class LazyLayerWithListInput(LazyModuleMixin, torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def initialize_parameters(self, input):
|
|
with torch.no_grad():
|
|
self._param = torch.nn.Parameter(torch.empty(input[0].shape).fill_(0.5))
|
|
|
|
def forward(self, input):
|
|
x = 0
|
|
for i in range(len(input)):
|
|
x = x + input[i]
|
|
return x
|
|
|
|
|
|
class LazyModuleWithListInput(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.layer = LazyLayerWithListInput()
|
|
|
|
def forward(self, input):
|
|
return self.layer(input[:-1])
|
|
|
|
|
|
class LazyModuleWithLazySubmodule(LazyModuleMixin, torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def initialize_parameters(self, input):
|
|
with torch.no_grad():
|
|
self.layer = LazyLayerWithListInput()
|
|
|
|
def forward(self, x):
|
|
return self.layer(x)
|
|
|
|
|
|
class LazyParentModule(LazyModuleMixin, torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def impl(self, x):
|
|
return x.cos() + self._val
|
|
|
|
|
|
class LazyChildModuleNoClsToBecome(LazyParentModule):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
return super().impl(x.sin())
|
|
|
|
def initialize_parameters(self, input):
|
|
self._val = torch.nn.Parameter(torch.ones(2, 2))
|
|
|
|
|
|
def requires_grad1(module: torch.nn.Module, recurse: bool = False) -> bool:
|
|
requires_grad = any(p.requires_grad for p in module.parameters(recurse))
|
|
return requires_grad
|
|
|
|
|
|
def requires_grad2(module: torch.nn.Module, recurse: bool = False) -> bool:
|
|
requires_grad = any(p.requires_grad for p in module.parameters(recurse))
|
|
return requires_grad
|
|
|
|
|
|
class ParametersModule1(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear1 = torch.nn.Linear(10, 10)
|
|
self.scale = torch.nn.Parameter(torch.randn(1, 10))
|
|
|
|
def forward(self, x):
|
|
if not requires_grad1(self):
|
|
return F.relu(self.linear1(x)) * self.scale
|
|
else:
|
|
return x + 1
|
|
|
|
|
|
class ParametersModule2(ParametersModule1):
|
|
def forward(self, x):
|
|
if not requires_grad2(self):
|
|
return F.relu(self.linear1(x)) * self.scale
|
|
else:
|
|
return x + 1
|
|
|
|
|
|
class ParametersModule3(ParametersModule1):
|
|
def forward(self, x):
|
|
ones = torch.ones(10, dtype=next(self.parameters()).dtype)
|
|
return F.relu(self.linear1(x)) * self.scale + ones
|
|
|
|
|
|
class SuperModule(BasicModule):
|
|
def forward(self, x):
|
|
x = super().forward(x)
|
|
return x + 10.0
|
|
|
|
|
|
class SuperModule2(BasicModule):
|
|
def forward(self, x):
|
|
return BasicModule.forward(self, x)
|
|
|
|
|
|
class ComplicatedSuperParent(torch.nn.Module):
|
|
@classmethod
|
|
def custom_add(cls, x):
|
|
x = x + x
|
|
return x
|
|
|
|
|
|
class SuperChildCallsClassMethod(ComplicatedSuperParent):
|
|
@classmethod
|
|
def child_func(cls, x):
|
|
x = super().custom_add(x)
|
|
return x
|
|
|
|
def forward(self, x):
|
|
x = self.child_func(x)
|
|
return x
|
|
|
|
|
|
class HasAttrModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.scale = torch.nn.Parameter(torch.randn(1, 10))
|
|
|
|
def forward(self, x):
|
|
x = F.relu(x)
|
|
if hasattr(self, "scale"):
|
|
x *= self.scale
|
|
if hasattr(self, "scale2"):
|
|
x *= self.scale2
|
|
return x
|
|
|
|
|
|
class EnumValues(torch.nn.ModuleDict):
|
|
def __init__(
|
|
self,
|
|
num_layers: int = 3,
|
|
) -> None:
|
|
super().__init__()
|
|
for i in range(num_layers):
|
|
self.add_module("denselayer%d" % (i + 1), _Block())
|
|
|
|
def forward(self, init_features):
|
|
features = [init_features]
|
|
for idx, layer in enumerate(self.values()):
|
|
new_features = layer(features)
|
|
features.append(new_features)
|
|
return torch.cat(features, 1)
|
|
|
|
|
|
class AccessByKeys(torch.nn.ModuleDict):
|
|
def __init__(
|
|
self,
|
|
num_layers: int = 3,
|
|
) -> None:
|
|
super().__init__()
|
|
for i in range(num_layers):
|
|
self.add_module("denselayer%d" % (i + 1), _Block())
|
|
|
|
def forward(self, init_features):
|
|
features = [init_features]
|
|
for k in self.keys():
|
|
new_features = self[k](features)
|
|
features.append(new_features)
|
|
return torch.cat(features, 1)
|
|
|
|
|
|
class CallForwardDirectly(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.layer1 = BasicModule()
|
|
self.layer2 = torch.nn.Linear(10, 10)
|
|
|
|
def forward(self, x):
|
|
x = self.layer1.forward(x)
|
|
x = self.layer2.forward(x)
|
|
return x
|
|
|
|
|
|
class ConvCallForwardDirectly(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.layer = torch.nn.Conv2d(3, 64, 3, 1, 1, bias=False)
|
|
|
|
def forward(self, x):
|
|
return self.layer.forward(x)
|
|
|
|
|
|
class ConvTransposeCallForwardDirectly(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.layer = torch.nn.ConvTranspose2d(4, 4, 4)
|
|
|
|
def forward(self, x):
|
|
return self.layer.forward(x)
|
|
|
|
|
|
class ConvCallSuperForwardDirectly(torch.nn.Conv1d):
|
|
def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
|
|
super().__init__(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=kernel_size,
|
|
**kwargs,
|
|
)
|
|
|
|
def forward(self, inputs, mask=None):
|
|
outputs = super().forward(inputs)
|
|
return outputs
|
|
|
|
|
|
class ConvTransposeCallSuperForwardDirectly(torch.nn.ConvTranspose2d):
|
|
def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
|
|
super().__init__(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=kernel_size,
|
|
**kwargs,
|
|
)
|
|
|
|
def forward(self, x):
|
|
if x.numel() > 0:
|
|
return super().forward(x)
|
|
output_shape = [
|
|
((i - 1) * d - 2 * p + (di * (k - 1) + 1) + op)
|
|
for i, p, di, k, d, op in zip(
|
|
x.shape[-2:],
|
|
self.padding,
|
|
self.dilation,
|
|
self.kernel_size,
|
|
self.stride,
|
|
self.output_padding,
|
|
)
|
|
]
|
|
output_shape = [x.shape[0], self.bias.shape[0]] + output_shape
|
|
return _NewEmptyTensorOp.apply(x, output_shape)
|
|
|
|
|
|
class ModuleNameString(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear1 = torch.nn.Linear(10, 10)
|
|
|
|
def forward(self, x):
|
|
if self.__class__.__name__ == "ABC":
|
|
return 10
|
|
if self.linear1.__class__.__name__ == "Linear":
|
|
return F.relu(self.linear1(x) + 10)
|
|
return 11
|
|
|
|
|
|
class SelfMutatingModule(torch.nn.Module):
|
|
def __init__(self, layer):
|
|
super().__init__()
|
|
self.layer = layer
|
|
self.counter = 0
|
|
|
|
def forward(self, x):
|
|
result = self.layer(x) + self.counter
|
|
self.counter += 1
|
|
return F.relu(result)
|
|
|
|
|
|
class ModuleAttributePrecedenceBase(torch.nn.Module):
|
|
def linear(self, x):
|
|
return x * 2.0
|
|
|
|
|
|
class ModuleAttributePrecedence(ModuleAttributePrecedenceBase):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.activation = torch.nn.ReLU()
|
|
self.linear = torch.nn.Linear(10, 10)
|
|
self.initializer = torch.ones([10, 10])
|
|
self.scale = 0.5
|
|
|
|
def activation(self, x):
|
|
return x * 1.2
|
|
|
|
def initializer(self):
|
|
return torch.zeros([10, 10])
|
|
|
|
def scale(self):
|
|
return 2.0
|
|
|
|
def forward(self, x):
|
|
# object attribute takes precedence unless it's a nn.Module
|
|
return self.activation(self.linear(self.initializer + x)) * self.scale
|
|
|
|
|
|
class ModuleForwardHasGraphBreak(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.layer1 = BasicModule()
|
|
self.layer2 = BasicModule()
|
|
self.layer3 = torch.nn.Sequential(BasicModule(), BasicModule())
|
|
self.layer4 = torch.nn.ModuleList(
|
|
[
|
|
torch.nn.Linear(10, 10),
|
|
torch.nn.ReLU(),
|
|
torch.nn.Linear(10, 10),
|
|
torch.nn.ReLU(),
|
|
]
|
|
)
|
|
self.layer5 = torch.nn.ModuleDict(
|
|
{
|
|
"0": torch.nn.Linear(10, 10),
|
|
}
|
|
)
|
|
self.scale = torch.randn(1, 10)
|
|
|
|
def forward(self, x):
|
|
"""
|
|
This is used to test if the results of functions like `named_parameters`
|
|
can be reconstructed correctly after graph break.
|
|
|
|
https://github.com/pytorch/torchdynamo/issues/1931
|
|
"""
|
|
x = self.layer1(x)
|
|
params1 = dict(self.named_parameters())
|
|
params2 = list(self.parameters())
|
|
buffers1 = dict(self.named_buffers())
|
|
buffers2 = list(self.buffers())
|
|
modules1 = dict(self.named_modules())
|
|
modules2 = list(self.modules())
|
|
torch._dynamo.graph_break()
|
|
y = modules2
|
|
y = modules1
|
|
y = buffers2
|
|
y = buffers1
|
|
y = params2
|
|
y = params1
|
|
x = (
|
|
self.layer2(x)
|
|
+ y["layer3.1.linear1.weight"]
|
|
+ y["layer4.2.weight"]
|
|
+ y["layer5.0.weight"]
|
|
)
|
|
return x * self.scale
|
|
|
|
|
|
class ModuleGuardNameIsValid(torch.nn.ModuleDict):
|
|
# Guard names should be valid python identifier as we use eval() to get
|
|
# corresponding guard value. Some guard names come from source(module path)
|
|
# where special symbols are valid. But they are not valid python identifier,
|
|
# we should identify these pattern and rewrite them with getattr.
|
|
def __init__(self):
|
|
super().__init__()
|
|
for i in range(2):
|
|
self.add_module("l@yer-%d" % (i + 1), BasicModule())
|
|
|
|
def forward(self, x):
|
|
for layer in self.values():
|
|
x = layer(x)
|
|
return x
|
|
|
|
|
|
class SequentialWithDuplicatedModule(torch.nn.Module):
|
|
# Sequential module(self.layer) contains three duplicated ReLU module.
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.relu = torch.nn.ReLU()
|
|
self.layer = torch.nn.Sequential(
|
|
torch.nn.Linear(10, 20),
|
|
self.relu,
|
|
torch.nn.Linear(20, 20),
|
|
self.relu,
|
|
torch.nn.Linear(20, 10),
|
|
self.relu,
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.layer(x)
|
|
|
|
|
|
class SequentialWithDuplicatedModule2(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.relu = torch.nn.ReLU()
|
|
self.layer = torch.nn.Sequential(
|
|
collections.OrderedDict(
|
|
[
|
|
("linear1", torch.nn.Linear(10, 20)),
|
|
("relu1", self.relu),
|
|
("linear2", torch.nn.Linear(20, 20)),
|
|
("relu2", self.relu),
|
|
("linear3", torch.nn.Linear(20, 10)),
|
|
("relu3", self.relu),
|
|
]
|
|
)
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.layer(x)
|
|
|
|
|
|
class ModuleComparison(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.layer0 = torch.nn.Linear(10, 10)
|
|
self.layer1 = torch.nn.Linear(10, 10)
|
|
self.layer2 = torch.nn.Linear(10, 10)
|
|
|
|
@property
|
|
def encoder_layers(self):
|
|
return [self.layer0, self.layer1, self.layer2]
|
|
|
|
def forward(self, x):
|
|
for layer in self.encoder_layers:
|
|
output = layer(x)
|
|
if layer is None or layer == self.layer0:
|
|
output = F.relu6(output)
|
|
else:
|
|
output = F.relu(output)
|
|
return output
|
|
|
|
|
|
class ModulePatch1(torch.nn.Module):
|
|
pass
|
|
|
|
|
|
class ModulePatch2(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x - 1
|
|
|
|
|
|
class UnspecNonInlinableModule(torch.nn.Module):
|
|
torchdynamo_force_dynamic = True # forced to be a UnspecializedNNModule
|
|
|
|
def forward(self, x):
|
|
if x.sum() > 0:
|
|
return x + 1
|
|
else:
|
|
return x - 1
|
|
|
|
|
|
class UnspecNonInlinableToplevelModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.m = UnspecNonInlinableModule()
|
|
|
|
def forward(self, x):
|
|
return self.m(x)
|
|
|
|
|
|
def make_test(fn, expected_ops=None):
|
|
def test_fn(self):
|
|
return torch._dynamo.testing.standard_test(
|
|
self, fn=fn, nargs=1, expected_ops=expected_ops
|
|
)
|
|
|
|
fn.eval()
|
|
return test_fn
|
|
|
|
|
|
class NNModuleTests(torch._dynamo.test_case.TestCase):
|
|
test_seq = make_test(Seq())
|
|
test_basicmodule1 = make_test(BasicModule())
|
|
test_basicmodule2 = make_test(BasicModule())
|
|
test_submodules1 = make_test(SubmoduleExample())
|
|
test_submodules2 = make_test(SubmoduleExample())
|
|
test_modulemethod1 = make_test(ModuleMethodCall())
|
|
test_modulemethod2 = make_test(ModuleMethodCall())
|
|
test_module_call_module_with_static_forward = make_test(
|
|
ModuleCallModuleWithStaticForward()
|
|
)
|
|
test_module_static_method = make_test(ModuleStaticMethodCall())
|
|
test_fnmember = make_test(FnMember())
|
|
test_fnmembercmp1 = make_test(FnMemberCmp(F.relu))
|
|
test_fnmembercmp2 = make_test(FnMemberCmp(None))
|
|
test_constloop = make_test(ConstLoop())
|
|
test_istraining1 = make_test(IsTrainingCheck())
|
|
test_istraining2 = make_test(IsTrainingCheck())
|
|
test_iseval1 = make_test(IsEvalCheck())
|
|
test_iseval2 = make_test(IsEvalCheck())
|
|
test_viamodulecall = make_test(ViaModuleCall())
|
|
test_isnonelayer = make_test(IsNoneLayer())
|
|
test_layerlist = make_test(LayerList())
|
|
test_tensorlist = make_test(TensorList())
|
|
test_intarg = make_test(IntArg())
|
|
test_cfgmod = make_test(CfgModule())
|
|
test_stringmember = make_test(StringMember())
|
|
test_modulelist = make_test(ModuleList())
|
|
test_modulelist = make_test(CustomGetItemModuleList())
|
|
test_moduledict = make_test(ModuleDict())
|
|
test_moduledict = make_test(CustomGetItemModuleDict())
|
|
test_parameterdict = make_test(ParameterDict())
|
|
test_parameterdict = make_test(CustomGetItemParameterDict())
|
|
test_super1 = make_test(SuperModule())
|
|
test_super2 = make_test(SuperModule2())
|
|
test_super_class_method = make_test(SuperChildCallsClassMethod())
|
|
test_children = make_test(Children())
|
|
test_named_children = make_test(NamedChildren())
|
|
test_densenet = make_test(DenseNetBlocks())
|
|
test_parameters1 = make_test(ParametersModule1())
|
|
test_parameters2 = make_test(ParametersModule2())
|
|
test_parameters3 = make_test(ParametersModule3(), expected_ops=5)
|
|
test_hasattr = make_test(HasAttrModule())
|
|
test_enumvalues = make_test(EnumValues())
|
|
test_access_by_keys = make_test(AccessByKeys())
|
|
test_module_class_method = make_test(ModuleClassMethodCall())
|
|
test_module_property = make_test(ModuleProperty())
|
|
test_forward_directly = make_test(CallForwardDirectly())
|
|
test_module_name_string = make_test(ModuleNameString())
|
|
test_module_attribute_precedence = make_test(ModuleAttributePrecedence())
|
|
test_module_guard_name_is_valid = make_test(ModuleGuardNameIsValid())
|
|
test_sequential_with_duplicated_module = make_test(SequentialWithDuplicatedModule())
|
|
test_sequential_with_duplicated_module2 = make_test(
|
|
SequentialWithDuplicatedModule2()
|
|
)
|
|
test_module_comparison = make_test(ModuleComparison())
|
|
|
|
def test_module_forward_has_graph_break(self):
|
|
m = ModuleForwardHasGraphBreak()
|
|
x = torch.rand([10, 10])
|
|
ref = m(x)
|
|
opt_m = torch._dynamo.optimize("eager")(m)
|
|
res = opt_m(x)
|
|
self.assertTrue(torch.allclose(ref, res))
|
|
|
|
def test_unsupportedmethod(self):
|
|
m = UnsupportedMethodCall()
|
|
i = torch.randn(10)
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_m = torch._dynamo.optimize(cnt)(m)
|
|
r = opt_m(i)
|
|
self.assertTrue(torch._dynamo.testing.same(r, m(i)))
|
|
self.assertEqual(cnt.op_count, 5)
|
|
|
|
def test_unsupportedmodule(self):
|
|
m = UnsupportedModuleCall()
|
|
i = torch.randn(10)
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_m = torch._dynamo.optimize(cnt)(m)
|
|
r = opt_m(i)
|
|
self.assertTrue(torch._dynamo.testing.same(r, m(i)))
|
|
self.assertEqual(cnt.op_count, 6)
|
|
|
|
def test_self_mutating1(self):
|
|
m1 = torch.nn.Linear(10, 10)
|
|
m2 = SelfMutatingModule(m1)
|
|
m3 = SelfMutatingModule(m1)
|
|
m4 = SelfMutatingModule(m1)
|
|
i = torch.randn(10)
|
|
out2 = [m2(i), m2(i), m2(i)]
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_m3 = torch._dynamo.optimize_assert(cnt)(m3)
|
|
opt_m4 = torch._dynamo.optimize_assert(cnt)(m4)
|
|
out3 = [opt_m3(i), opt_m3(i), opt_m3(i)]
|
|
out4 = [opt_m4(i), opt_m4(i), opt_m4(i)]
|
|
self.assertTrue(torch._dynamo.testing.same(out2, out3))
|
|
self.assertTrue(torch._dynamo.testing.same(out2, out4))
|
|
self.assertEqual(cnt.frame_count, 3)
|
|
|
|
@patch.object(torch._dynamo.config, "raise_on_ctx_manager_usage", False)
|
|
def test_generation_tag(self):
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
# guarantee that we have installed
|
|
# the generation tagging function
|
|
with torch._dynamo.optimize_assert(cnt):
|
|
pass
|
|
|
|
m1 = torch.nn.Linear(10, 10)
|
|
prev_generation = GenerationTracker.get_generation_value(m1)
|
|
cur_generation = prev_generation + 1
|
|
|
|
with torch._dynamo.optimize_assert(cnt):
|
|
m2 = torch.nn.Linear(10, 10)
|
|
|
|
self.assertEqual(GenerationTracker.get_generation_value(m1), prev_generation)
|
|
self.assertEqual(GenerationTracker.get_generation_value(m2), cur_generation)
|
|
# check that newly constructed instances
|
|
# also have the same generation (even if copied from an old instance)
|
|
m3 = deepcopy(m1)
|
|
self.assertEqual(GenerationTracker.get_generation_value(m3), cur_generation)
|
|
|
|
def test_simple_torch_function(self):
|
|
def foo(x):
|
|
# function call, twice to test wrapping
|
|
x = F.sigmoid(x)
|
|
x = F.sigmoid(x)
|
|
# method call, twice to test wrapping
|
|
x = x.sigmoid()
|
|
x = x.sigmoid()
|
|
return x
|
|
|
|
class TensorProxy(torch.Tensor):
|
|
@classmethod
|
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
|
return super().__torch_function__(func, types, args, kwargs)
|
|
|
|
torch._dynamo.config.traceable_tensor_subclasses.add(TensorProxy)
|
|
|
|
try:
|
|
x = torch.randn(1).as_subclass(TensorProxy)
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
out1 = foo(x)
|
|
opt_foo = torch._dynamo.optimize(cnt, nopython=True)(foo)
|
|
out2 = opt_foo(x)
|
|
|
|
self.assertEqual(cnt.op_count, 4)
|
|
self.assertTrue(torch._dynamo.testing.same(out1, out2))
|
|
|
|
finally:
|
|
torch._dynamo.config.traceable_tensor_subclasses.remove(TensorProxy)
|
|
|
|
def test_torch_function_with_closure(self):
|
|
def run():
|
|
counter = 0
|
|
|
|
def foo(x):
|
|
# function call, twice to test wrapping
|
|
x = F.sigmoid(x)
|
|
x = F.sigmoid(x)
|
|
# method call, twice to test wrapping
|
|
x = x.sigmoid()
|
|
x = x.sigmoid()
|
|
return x
|
|
|
|
class TensorProxy(torch.Tensor):
|
|
@classmethod
|
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
|
nonlocal counter
|
|
# for now, only support reads from closure cells
|
|
# TODO(future PR): support writes as well
|
|
counter + 1
|
|
return super().__torch_function__(func, types, args, kwargs)
|
|
|
|
torch._dynamo.config.traceable_tensor_subclasses.add(TensorProxy)
|
|
|
|
try:
|
|
x = torch.randn(1).as_subclass(TensorProxy)
|
|
x = torch.randn(1)
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
out1 = foo(x)
|
|
opt_foo = torch._dynamo.optimize(cnt, nopython=True)(foo)
|
|
out2 = opt_foo(x)
|
|
|
|
self.assertEqual(cnt.op_count, 4)
|
|
self.assertTrue(torch._dynamo.testing.same(out1, out2))
|
|
finally:
|
|
torch._dynamo.config.traceable_tensor_subclasses.remove(TensorProxy)
|
|
|
|
run()
|
|
|
|
@patch.object(torch._dynamo.config, "raise_on_ctx_manager_usage", False)
|
|
def test_nn_moduledict_contains(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self, module_dict):
|
|
super().__init__()
|
|
self.module_dict = module_dict
|
|
|
|
def forward(self, x):
|
|
if "foo" in self.module_dict:
|
|
x = torch.mul(x, 1.0)
|
|
x = torch.add(x, 1.0)
|
|
return x
|
|
|
|
module_dict = torch.nn.ModuleDict({"foo": torch.nn.Conv2d(1, 1, 1)})
|
|
m = M(module_dict)
|
|
data = torch.randn(1)
|
|
out1 = m(data)
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_m = torch._dynamo.optimize(cnt, nopython=True)(m)
|
|
out2 = opt_m(data)
|
|
self.assertEqual(cnt.op_count, 2)
|
|
self.assertTrue(torch._dynamo.testing.same(out1, out2))
|
|
|
|
module_dict = torch.nn.ModuleDict({"bar": torch.nn.Conv2d(1, 1, 1)})
|
|
m = M(module_dict)
|
|
data = torch.randn(1)
|
|
out1 = m(data)
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
torch._dynamo.reset()
|
|
opt_m = torch._dynamo.optimize(cnt, nopython=True)(m)
|
|
out2 = opt_m(data)
|
|
|
|
self.assertEqual(cnt.op_count, 1)
|
|
self.assertTrue(torch._dynamo.testing.same(out1, out2))
|
|
|
|
module_dict = torch.nn.ModuleDict({"cat": torch.nn.Conv2d(1, 1, 1)})
|
|
pre = m(data)
|
|
cnt.clear()
|
|
|
|
with torch._dynamo.optimize(cnt, nopython=False):
|
|
opt_pre = m(data)
|
|
m = M(module_dict)
|
|
data = torch.randn(1)
|
|
out1 = m(data)
|
|
|
|
out_post = m(data)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(cnt.op_count, 1)
|
|
self.assertTrue(torch._dynamo.testing.same(pre, opt_pre))
|
|
self.assertTrue(torch._dynamo.testing.same(out1, out_post))
|
|
|
|
# RuntimeError: SymIntArrayRef expected to contain only concrete integers
|
|
@expectedFailureDynamic
|
|
def test_lazy_module1(self):
|
|
input_shape = (16, 3, 6, 7, 8)
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
module = LazyModule()
|
|
|
|
def test_static_module():
|
|
input = torch.ones(*input_shape)
|
|
module(input)
|
|
|
|
# test no graph break
|
|
opt_test_static_module = torch._dynamo.optimize(cnt, nopython=True)(
|
|
test_static_module
|
|
)
|
|
opt_test_static_module()
|
|
|
|
self.assertTrue(
|
|
isinstance(module, MaterializedModule),
|
|
"Module should be transformed to an instance of MaterializedModule.",
|
|
)
|
|
self.assertEqual(module.param.shape, input_shape)
|
|
|
|
# test when mapped to UnspecializedNNModule
|
|
module = LazyModule()
|
|
|
|
def test_unspecialized():
|
|
nonlocal module
|
|
module = LazyModule()
|
|
input = torch.ones(*input_shape)
|
|
module(input)
|
|
|
|
opt_test_unspecialized = torch._dynamo.optimize(cnt)(test_unspecialized)
|
|
opt_test_unspecialized()
|
|
|
|
self.assertTrue(
|
|
isinstance(module, MaterializedModule),
|
|
"Module should be transformed to an instance of MaterializedModule.",
|
|
)
|
|
self.assertEqual(module.param.shape, input_shape)
|
|
|
|
# test with a static module in torch.*
|
|
module = torch.nn.modules.LazyBatchNorm3d(
|
|
affine=False, track_running_stats=False
|
|
)
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
torch._dynamo.reset()
|
|
|
|
def test_torch_static():
|
|
input = torch.ones(*input_shape)
|
|
return module(input) # fully materialized
|
|
|
|
# test no graph break
|
|
opt_test_torch_static = torch._dynamo.optimize(cnt, nopython=True)(
|
|
test_torch_static
|
|
)
|
|
opt_test_torch_static()
|
|
out = opt_test_torch_static()
|
|
|
|
self.assertTrue(same(out, module(torch.ones(*input_shape))))
|
|
|
|
self.assertTrue(
|
|
isinstance(module, torch.nn.modules.batchnorm.BatchNorm3d),
|
|
"Module should be transformed to an instance of BatchNorm3d.",
|
|
)
|
|
self.assertEqual(cnt.frame_count, 1, "No guards should have triggered.")
|
|
|
|
# RuntimeError: SymIntArrayRef expected to contain only concrete integers
|
|
@expectedFailureDynamic
|
|
def test_lazy_module2(self):
|
|
# Test FX graph 'call_module' works well if argument is lazy module
|
|
m = LazyMLP()
|
|
x = torch.rand([10, 10])
|
|
opt_m = torch._dynamo.optimize("eager", nopython=True)(m)
|
|
# We should run compile mode firstly, otherwise the module
|
|
# would be initialized when running eager mode.
|
|
res = opt_m(x)
|
|
ref = m(x)
|
|
self.assertTrue(torch.allclose(ref, res))
|
|
|
|
# RuntimeError: SymIntArrayRef expected to contain only concrete integers
|
|
@expectedFailureDynamic
|
|
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
|
|
def test_lazy_module3(self):
|
|
m = LazyMLP()
|
|
x = torch.rand([10, 10])
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_m = torch._dynamo.optimize(cnt, nopython=True)(m)
|
|
# first iteration
|
|
res = opt_m(x)
|
|
ref = m(x)
|
|
self.assertTrue(torch.allclose(ref, res))
|
|
# move to cuda and second iteration
|
|
m = m.to("cuda")
|
|
x = x.to("cuda")
|
|
res = opt_m(x)
|
|
ref = m(x)
|
|
self.assertTrue(torch.allclose(ref, res))
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
|
|
# RuntimeError: SymIntArrayRef expected to contain only concrete integers
|
|
@expectedFailureDynamic
|
|
def test_lazy_module4(self):
|
|
m = LazyMLP()
|
|
x = torch.rand([10, 10])
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_m = torch._dynamo.optimize(cnt, nopython=True)(m)
|
|
# first iteration
|
|
res = opt_m(x)
|
|
ref = m(x)
|
|
self.assertTrue(torch.allclose(ref, res))
|
|
# input shape changed and second iteration
|
|
x = torch.rand([20, 20])
|
|
try:
|
|
opt_m(x)
|
|
except RuntimeError:
|
|
self.assertIn("must have same reduction dim", traceback.format_exc())
|
|
|
|
# RuntimeError: SymIntArrayRef expected to contain only concrete integers
|
|
@expectedFailureDynamic
|
|
def test_lazy_module5(self):
|
|
# Test lazy module works well with list/tuple input
|
|
m = LazyModuleWithListInput()
|
|
x = [torch.rand([5, 5])] * 3 + [None]
|
|
opt_m = torch._dynamo.optimize("eager", nopython=True)(m)
|
|
res = opt_m(x)
|
|
ref = m(x)
|
|
self.assertTrue(torch.allclose(ref, res))
|
|
|
|
# RuntimeError: SymIntArrayRef expected to contain only concrete integers
|
|
@expectedFailureDynamic
|
|
def test_lazy_module6(self):
|
|
# Test new lazy submodule in lazy module's initialize_parameters
|
|
m = LazyModuleWithLazySubmodule()
|
|
x = [torch.rand([5, 5])] * 3
|
|
opt_m = torch._dynamo.optimize("eager", nopython=True)(m)
|
|
res = opt_m(x)
|
|
ref = m(x)
|
|
self.assertTrue(torch.allclose(ref, res))
|
|
|
|
def test_lazy_module_no_cls_to_become(self):
|
|
# make sure super() works in the case where cls_to_become is None
|
|
m = LazyChildModuleNoClsToBecome()
|
|
x = torch.rand(2, 2)
|
|
opt_m = torch._dynamo.optimize("eager", nopython=True)(m)
|
|
res = opt_m(x)
|
|
ref = m(x)
|
|
self.assertTrue(torch.allclose(ref, res))
|
|
|
|
def test_call_fn_with_non_const_inputs_safe(self):
|
|
class ModuleSpecialFwd(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(
|
|
in_channels=3, out_channels=20, kernel_size=(5, 5)
|
|
)
|
|
|
|
def _conv_forward(self, x):
|
|
return self.conv._conv_forward(x, self.conv.weight, self.conv.bias)
|
|
|
|
def forward(self, x):
|
|
return self._conv_forward(x)
|
|
|
|
mod = ModuleSpecialFwd()
|
|
rx = torch.randn([3, 10, 10])
|
|
real = mod(rx)
|
|
graph, _ = torch._dynamo.export(mod)(rx)
|
|
self.assertTrue(torch._dynamo.testing.same(real, graph(rx)))
|
|
|
|
def test_conv_call_forward_directly(self):
|
|
m = ConvCallForwardDirectly()
|
|
x = torch.rand([4, 3, 9, 9])
|
|
ref = m(x)
|
|
opt_m = torch.compile(backend="eager", fullgraph=True)(m)
|
|
res = opt_m(x)
|
|
self.assertTrue(torch.allclose(ref, res))
|
|
|
|
def test_conv_transpose_call_forward_directly(self):
|
|
m = ConvTransposeCallForwardDirectly()
|
|
x = torch.rand([4, 4, 4, 4])
|
|
ref = m(x)
|
|
opt_m = torch.compile(backend="eager", fullgraph=True)(m)
|
|
res = opt_m(x)
|
|
self.assertTrue(torch.allclose(ref, res))
|
|
|
|
def test_conv_call_super_forward_directly(self):
|
|
x = torch.randn(4, 4)
|
|
m = ConvCallSuperForwardDirectly(4, 4, 4)
|
|
ref = m(x)
|
|
opt_m = torch.compile(backend="eager", fullgraph=True)(m)
|
|
res = opt_m(x)
|
|
self.assertTrue(torch.allclose(ref, res))
|
|
|
|
def test_conv_transpose_call_super_forward_directly(self):
|
|
x = torch.randn(4, 4, 4)
|
|
m = ConvTransposeCallSuperForwardDirectly(4, 4, 4)
|
|
ref = m(x)
|
|
opt_m = torch.compile(backend="eager", fullgraph=True)(m)
|
|
res = opt_m(x)
|
|
self.assertTrue(torch.allclose(ref, res))
|
|
|
|
|
|
class MockModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.relu = torch.nn.ReLU()
|
|
self.linear = torch.nn.Linear(10, 10)
|
|
self.register_buffer("buf0", torch.randn(10, 10))
|
|
|
|
def forward(self, x):
|
|
return self.relu(self.linear(x) + self.buf0)
|
|
|
|
|
|
class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
|
|
def test_nn_module(self):
|
|
mod = MockModule()
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_mod = torch._dynamo.optimize(cnt)(mod)
|
|
self.assertIsInstance(opt_mod, torch._dynamo.OptimizedModule)
|
|
|
|
x = torch.randn(10, 10)
|
|
self.assertTrue(torch._dynamo.testing.same(mod(x), opt_mod(x)))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
def test_to(self):
|
|
mod = MockModule()
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_mod = torch._dynamo.optimize(cnt)(mod)
|
|
x = torch.randn(10, 10)
|
|
self.assertTrue(torch._dynamo.testing.same(mod(x), opt_mod(x)))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
# Ensure that there is no recompilation
|
|
opt_mod(x)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
opt_mod = opt_mod.to(device="cpu").to(dtype=torch.float64)
|
|
self.assertIsInstance(opt_mod, torch._dynamo.OptimizedModule)
|
|
x = torch.randn(10, 10).to(dtype=torch.float64)
|
|
opt_mod(x)
|
|
# Ensure that there is a recompilation
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
|
|
# Ensure that there is no recompilation
|
|
opt_mod(x)
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
|
|
torch._dynamo.reset()
|
|
opt_mod(x)
|
|
self.assertEqual(cnt.frame_count, 3)
|
|
|
|
def test_attr(self):
|
|
class MockModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(10, 10)
|
|
self.register_buffer("buf0", torch.randn(10, 10))
|
|
|
|
def forward(self, x):
|
|
return self.r(torch.sin(x)) + self.buf0
|
|
|
|
mod = MockModule()
|
|
opt_mod = torch._dynamo.optimize("eager")(mod)
|
|
|
|
# Check parameteres and buffers
|
|
for p1, p2 in zip(mod.parameters(), opt_mod.parameters()):
|
|
self.assertTrue(id(p1) == id(p2))
|
|
for b1, b2 in zip(mod.buffers(), opt_mod.buffers()):
|
|
self.assertTrue(id(b1) == id(b2))
|
|
|
|
def get_parameter_dtype(mod: torch.nn.Module):
|
|
parameters_and_buffers = itertools.chain(mod.parameters(), mod.buffers())
|
|
return next(parameters_and_buffers).dtype
|
|
|
|
opt_mod = torch._dynamo.optimize("eager")(get_parameter_dtype)
|
|
out_dtype = opt_mod(mod)
|
|
self.assertEqual(out_dtype, torch.float32)
|
|
|
|
def test_dir(self):
|
|
class MockModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(10, 10)
|
|
self.register_buffer("buf0", torch.randn(10, 10))
|
|
self.register_parameter(
|
|
name="param0", param=torch.nn.Parameter(torch.randn(10, 10))
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.r(torch.sin(x)) + self.buf0
|
|
|
|
mod = MockModule()
|
|
mod_keys = dir(mod)
|
|
opt_mod = torch._dynamo.optimize("eager")(mod)
|
|
opt_mod_keys = dir(opt_mod)
|
|
|
|
# Check user-defined attributes, parameters and buffers
|
|
self.assertIn("linear", opt_mod_keys)
|
|
self.assertIn("buf0", opt_mod_keys)
|
|
self.assertIn("param0", opt_mod_keys)
|
|
|
|
# Check all attributes, parameters and buffers
|
|
self.assertTrue(len(set(mod_keys).difference(opt_mod_keys)) == 0)
|
|
|
|
def test_no_recompile_on_nn_guarded_modules(self):
|
|
size = (10, 10)
|
|
cache_size_limit = 1
|
|
num_submodules = 4
|
|
cnts = torch._dynamo.testing.CompileCounterWithBackend("eager")
|
|
|
|
class SubModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(*size)
|
|
|
|
def forward(self, x):
|
|
a = torch.sin(torch.cos(x))
|
|
return self.linear(a)
|
|
|
|
class MockModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.mods = [SubModule() for _ in range(num_submodules)]
|
|
self.mods = [torch.compile(mod, backend=cnts) for mod in self.mods]
|
|
|
|
def forward(self, x):
|
|
for mod in self.mods:
|
|
x = mod(x)
|
|
return x
|
|
|
|
mod = MockModule()
|
|
# Each submod is compiled separately and has a different nn module
|
|
# guard. Ensure that recompilation logic is handle correctly.
|
|
with unittest.mock.patch(
|
|
"torch._dynamo.config.error_on_recompile", True
|
|
), unittest.mock.patch(
|
|
"torch._dynamo.config.cache_size_limit",
|
|
cache_size_limit,
|
|
):
|
|
x = torch.randn(*size)
|
|
mod(x)
|
|
self.assertEqual(cnts.frame_count, num_submodules)
|
|
|
|
def test_cache_size_limit_on_guarded_nn_modules(self):
|
|
cache_size_limit = 2
|
|
num_submodules = 4
|
|
cnts = torch._dynamo.testing.CompileCounterWithBackend("eager")
|
|
|
|
class SubModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.relu = torch.nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
a = torch.sin(torch.cos(x))
|
|
return self.relu(a)
|
|
|
|
class MockModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.mods = [SubModule() for _ in range(num_submodules)]
|
|
self.mods = [torch.compile(mod, backend=cnts) for mod in self.mods]
|
|
|
|
def forward(self, x):
|
|
for mod in self.mods:
|
|
x = mod(x)
|
|
return x
|
|
|
|
mod = MockModule()
|
|
# For the third iteration, we would reach the cache size limit, and
|
|
# therefore the total number of expected frame count is 2 *
|
|
# num_submodules.
|
|
with unittest.mock.patch(
|
|
"torch._dynamo.config.cache_size_limit",
|
|
cache_size_limit,
|
|
):
|
|
for size in [
|
|
(4,),
|
|
(4, 4),
|
|
(4, 4, 4),
|
|
]:
|
|
x = torch.randn(size)
|
|
mod(x)
|
|
self.assertEqual(cnts.frame_count, 2 * num_submodules)
|
|
|
|
def test_recursion(self):
|
|
mod = MockModule()
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_mod = torch._dynamo.optimize(cnt)(mod)
|
|
|
|
for _ in range(5):
|
|
opt_mod = torch._dynamo.optimize(cnt)(opt_mod)
|
|
opt_mod(torch.randn(10, 10))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
def test_composition(self):
|
|
class InnerModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.relu = torch.nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
return self.relu(torch.sin(x))
|
|
|
|
opt_inner_mod = InnerModule()
|
|
|
|
class OuterModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.mod = opt_inner_mod
|
|
|
|
def forward(self, x):
|
|
return self.mod(torch.cos(x))
|
|
|
|
outer_mod = OuterModule()
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_outer_mod = torch._dynamo.optimize(cnt)(outer_mod)
|
|
|
|
x = torch.randn(4)
|
|
self.assertIsInstance(opt_outer_mod, torch._dynamo.OptimizedModule)
|
|
self.assertTrue(torch._dynamo.testing.same(outer_mod(x), opt_outer_mod(x)))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
def test_composition_with_opt_mod(self):
|
|
class InnerModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.relu = torch.nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
return self.relu(torch.sin(x))
|
|
|
|
inner_mod = InnerModule()
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_inner_mod = torch._dynamo.optimize(cnt)(inner_mod)
|
|
|
|
class OuterModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.mod = opt_inner_mod
|
|
|
|
def forward(self, x):
|
|
return self.mod(torch.cos(x))
|
|
|
|
outer_mod = OuterModule()
|
|
opt_outer_mod = torch._dynamo.optimize(cnt)(outer_mod)
|
|
|
|
x = torch.randn(4)
|
|
self.assertIsInstance(opt_outer_mod, torch._dynamo.OptimizedModule)
|
|
self.assertTrue(torch._dynamo.testing.same(outer_mod(x), opt_outer_mod(x)))
|
|
# There will be a graph break for the inner mod being OptimizedModule
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
|
|
def test_module_patch(self):
|
|
mod = ModulePatch1()
|
|
mod.forward = types.MethodType(ModulePatch2.forward, mod)
|
|
|
|
def fn(x):
|
|
return mod(x)
|
|
|
|
self.assertTrue(
|
|
torch.allclose(
|
|
torch._dynamo.optimize("eager", nopython=True)(fn)(torch.ones(10)),
|
|
torch.zeros(1),
|
|
)
|
|
)
|
|
|
|
@patch.object(torch._dynamo.config, "skip_nnmodule_hook_guards", False)
|
|
def test_hooks_outer(self):
|
|
class TestModule(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return 2 * x + 1
|
|
|
|
m = TestModule()
|
|
|
|
def forward_hook(
|
|
module: torch.nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor
|
|
) -> torch.Tensor:
|
|
return 2 * output + 1
|
|
|
|
handle = m.register_forward_hook(forward_hook)
|
|
inp = torch.tensor(1.0, requires_grad=True)
|
|
|
|
failure_reason = None
|
|
|
|
def guard_fail_fn(failure):
|
|
nonlocal failure_reason
|
|
failure_reason = failure[0]
|
|
|
|
compiled_m = torch._dynamo.optimize(
|
|
guard_fail_fn=guard_fail_fn, backend="eager"
|
|
)(m)
|
|
|
|
self.assertEqual(compiled_m(inp), m(inp))
|
|
self.assertEqual(compiled_m(inp).item(), 7)
|
|
self.assertTrue(failure_reason is None)
|
|
|
|
# what if we remove our hook? we should recompile?
|
|
handle.remove()
|
|
self.assertEqual(compiled_m(inp), m(inp))
|
|
self.assertEqual(compiled_m(inp).item(), 3)
|
|
# self.assertTrue(failure_reason == "hook")
|
|
|
|
"""
|
|
Summary:
|
|
- removing a hook doesn't fail a guard, becuase we weren't compiling the hook
|
|
(at least into the same graph) as forward in the first place! We do correctly
|
|
omit calling the removed hook, but since this hook is a post forward hook,
|
|
the 'RETURN' from forward is breaking the graph.
|
|
|
|
Why is 'forward' the entrypoint to an InstructionTranslator, after I changed
|
|
the eval_frame entrypoint to Module.__call__?
|
|
"""
|
|
|
|
@patch.object(torch._dynamo.config, "skip_nnmodule_hook_guards", False)
|
|
def test_hooks_inner(self):
|
|
class TestModule(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return 2 * x + 1
|
|
|
|
m = TestModule()
|
|
|
|
def forward_hook(
|
|
module: torch.nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor
|
|
) -> torch.Tensor:
|
|
return 2 * output + 1
|
|
|
|
handle = m.register_forward_hook(forward_hook)
|
|
|
|
def outer_func(tensor):
|
|
x = tensor * 2 + 1
|
|
y = m(x)
|
|
return y
|
|
|
|
inp = torch.tensor(1.0, requires_grad=True)
|
|
|
|
failure_reason = None
|
|
|
|
def guard_fail_fn(failure):
|
|
nonlocal failure_reason
|
|
failure_reason = failure[0]
|
|
|
|
cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
|
|
compiled_func = torch._dynamo.optimize(
|
|
guard_fail_fn=guard_fail_fn,
|
|
backend=cc,
|
|
)(outer_func)
|
|
|
|
self.assertEqual(compiled_func(inp), outer_func(inp))
|
|
self.assertEqual(compiled_func(inp).item(), 15)
|
|
|
|
# We are compiling 1 big graph for all 3 functions including the hook.
|
|
self.assertEqual(cc.frame_count, 1)
|
|
self.assertEqual(cc.op_count, 6)
|
|
|
|
# If we remove the hook, we should recompile
|
|
handle.remove()
|
|
self.assertEqual(compiled_func(inp), outer_func(inp))
|
|
self.assertEqual(compiled_func(inp).item(), 7)
|
|
self.assertTrue("forward_hooks.keys" in failure_reason)
|
|
self.assertEqual(cc.frame_count, 1 + 1)
|
|
self.assertEqual(cc.op_count, 6 + 4)
|
|
|
|
# what if instead of removing, we alter our hook?
|
|
torch._dynamo.reset()
|
|
m = TestModule()
|
|
handle = m.register_forward_hook(forward_hook)
|
|
failure_reason = None
|
|
self.assertEqual(compiled_func(inp), outer_func(inp))
|
|
self.assertEqual(compiled_func(inp).item(), 15)
|
|
|
|
def new_forward_hook(
|
|
module: torch.nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor
|
|
) -> torch.Tensor:
|
|
return 2 * output + 2
|
|
|
|
m._forward_hooks[handle.id] = new_forward_hook
|
|
self.assertEqual(compiled_func(inp), outer_func(inp))
|
|
self.assertEqual(compiled_func(inp).item(), 16)
|
|
self.assertTrue("___check_obj_id(L['m']._forward_hooks" in failure_reason)
|
|
|
|
@patch.object(torch._dynamo.config, "skip_nnmodule_hook_guards", True)
|
|
def test_hooks_skip_guards(self):
|
|
class TestModule(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return 2 * x + 1
|
|
|
|
m = TestModule()
|
|
|
|
def forward_hook(
|
|
module: torch.nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor
|
|
) -> torch.Tensor:
|
|
return 2 * output + 1
|
|
|
|
handle = m.register_forward_hook(forward_hook)
|
|
|
|
def outer_func(tensor):
|
|
x = tensor * 2 + 1
|
|
y = m(x)
|
|
return y
|
|
|
|
inp = torch.tensor(1.0, requires_grad=True)
|
|
|
|
failure_reason = None
|
|
|
|
def guard_fail_fn(failure):
|
|
nonlocal failure_reason
|
|
failure_reason = failure[0]
|
|
|
|
cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
|
|
compiled_func = torch._dynamo.optimize(
|
|
guard_fail_fn=guard_fail_fn,
|
|
backend=cc,
|
|
)(outer_func)
|
|
|
|
m = TestModule()
|
|
handle = m.register_forward_hook(forward_hook)
|
|
failure_reason = None
|
|
self.assertEqual(compiled_func(inp), outer_func(inp))
|
|
self.assertEqual(compiled_func(inp).item(), 15)
|
|
self.assertEqual(cc.frame_count, 1)
|
|
self.assertEqual(cc.op_count, 6)
|
|
|
|
# if we remove the hook, dynamo shouldn't notice
|
|
handle.remove()
|
|
self.assertNotEqual(compiled_func(inp), outer_func(inp))
|
|
self.assertEqual(compiled_func(inp).item(), 15)
|
|
self.assertEqual(cc.frame_count, 1)
|
|
|
|
def _forward_hook_test_helper(self, model):
|
|
forward_handles = {}
|
|
compiled_activations = dict()
|
|
eager_activations = dict()
|
|
activations = None
|
|
|
|
def save_activations(name, mod, inp, out):
|
|
activations[name] = inp
|
|
|
|
for name, module in model.named_modules():
|
|
forward_handles[name] = module.register_forward_hook(
|
|
partial(save_activations, name)
|
|
)
|
|
|
|
compiled_model = torch.compile(model, backend="aot_eager")
|
|
|
|
activations = compiled_activations
|
|
for i in range(2):
|
|
# second iteration is key, hooks would have fired during aot trace
|
|
# on first iter
|
|
compiled_activations.clear()
|
|
x = torch.randn((20, 10))
|
|
pred = compiled_model(x)
|
|
loss = pred.sum()
|
|
loss.backward()
|
|
|
|
activations = eager_activations
|
|
for i in range(2):
|
|
# second iteration is key, hooks would have fired during aot trace
|
|
# on first iter
|
|
eager_activations.clear()
|
|
x = torch.randn((20, 10))
|
|
pred = model(x)
|
|
loss = pred.sum()
|
|
loss.backward()
|
|
|
|
print(f"Recorded Layers: {compiled_activations.keys()}\n\n")
|
|
print(f"Expected Layers: {eager_activations.keys()}")
|
|
|
|
self.assertTrue(compiled_activations.keys() == eager_activations.keys())
|
|
self.assertTrue(activations.keys() == forward_handles.keys())
|
|
|
|
def test_hooks_allowed_modules(self):
|
|
# this test shouldn't care whether hook guards are enabled or not
|
|
class ToyModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.net = torch.nn.Sequential(
|
|
*[torch.nn.Linear(10, 10000), torch.nn.ReLU()]
|
|
+ [torch.nn.Linear(10000, 5), torch.nn.ReLU()]
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.net(x)
|
|
|
|
model = ToyModel()
|
|
self._forward_hook_test_helper(model)
|
|
|
|
def test_hooks_allowed_modules_compiles(self):
|
|
class ToyModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.net = torch.nn.Sequential(
|
|
*[torch.nn.Linear(10, 10000), torch.nn.ReLU()]
|
|
+ [torch.nn.Linear(10000, 5), torch.nn.ReLU()]
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.net(x)
|
|
|
|
model = ToyModel()
|
|
activations = []
|
|
|
|
def save_activations(mod, inp, out):
|
|
activations.append(inp)
|
|
|
|
for name, module in model.named_modules():
|
|
module.register_forward_hook(save_activations)
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
model = torch._dynamo.optimize(cnt, nopython=True)(model)
|
|
for i in range(2):
|
|
# second iteration is key, hooks would have fired during aot trace
|
|
# on first iter
|
|
activations.clear()
|
|
x = torch.randn((20, 10))
|
|
pred = model(x)
|
|
loss = pred.sum()
|
|
loss.backward()
|
|
self.assertEqual(len(activations), 6)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
def test_hooks_allowed_modules_compiles_self_contained(self):
|
|
class ToyModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.net = torch.nn.Sequential(
|
|
*[torch.nn.Linear(10, 10000), torch.nn.ReLU()]
|
|
+ [torch.nn.Linear(10000, 5), torch.nn.ReLU()]
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.net(x) * self.net(x)
|
|
|
|
model = ToyModel()
|
|
forward_handles = {}
|
|
|
|
def output_modifying_hook(mod, inp, out):
|
|
return 2 * out + 1
|
|
|
|
for name, module in model.named_modules():
|
|
forward_handles[name] = module.register_forward_hook(output_modifying_hook)
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
x = torch.randn((20, 10))
|
|
pred_eager = model(x)
|
|
loss_eager = pred_eager.sum()
|
|
eager_loss_bwd = loss_eager.backward()
|
|
|
|
model = torch._dynamo.optimize(cnt, nopython=True)(model)
|
|
pred = model(x)
|
|
|
|
loss = pred.sum()
|
|
loss_bwd = loss.backward()
|
|
|
|
self.assertEqual(eager_loss_bwd, loss_bwd)
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
|
|
# Ndim change, recompile
|
|
pred = model(torch.randn([10, 10, 10]))
|
|
self.assertEqual(cnt.frame_count, 4)
|
|
|
|
# Stable
|
|
pred = model(torch.randn([10, 10, 10]))
|
|
self.assertEqual(cnt.frame_count, 4)
|
|
|
|
def test_dunder_call_explicitly(self):
|
|
# hooks should be triggered if explicit calling `__call__`
|
|
class ToyModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(10, 10000)
|
|
|
|
def forward(self, x):
|
|
return self.linear.__call__(x)
|
|
|
|
model = ToyModel()
|
|
self._forward_hook_test_helper(model)
|
|
|
|
def test_backward_hooks(self):
|
|
# this test shouldn't care whether hook guards are enabled or not
|
|
|
|
class CustomLinear(torch.nn.Module):
|
|
# not an 'allowed module', so should not graph-break
|
|
def __init__(self, a, b):
|
|
super().__init__()
|
|
self.weight = torch.nn.Parameter(torch.randn(a, b))
|
|
|
|
def forward(self, x):
|
|
return torch.mm(x, self.weight)
|
|
|
|
class ToyModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.net = torch.nn.Sequential(
|
|
*[CustomLinear(10, 10)]
|
|
+ [CustomLinear(10, 10000)]
|
|
+ [CustomLinear(10000, 5)]
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.net(x)
|
|
|
|
model = ToyModel()
|
|
backward_hook_handles = {}
|
|
pre_backward_hook_handles = {}
|
|
|
|
grad_sizes = {}
|
|
|
|
def backward_hook(name, mod, grad_inp, grad_out):
|
|
grad_sizes[name] = (
|
|
(gi.shape for gi in grad_inp),
|
|
(go.shape for go in grad_out),
|
|
)
|
|
return None
|
|
|
|
pre_grad_sizes = {}
|
|
|
|
def backward_pre_hook(name, mod, grad_out):
|
|
pre_grad_sizes[name] = (go.shape for go in grad_out)
|
|
return None
|
|
|
|
for name, module in model.named_modules():
|
|
backward_hook_handles[name] = module.register_full_backward_hook(
|
|
partial(backward_hook, name)
|
|
)
|
|
|
|
pre_backward_hook_handles[name] = module.register_full_backward_pre_hook(
|
|
partial(backward_pre_hook, name)
|
|
)
|
|
|
|
model = torch.compile(model, backend="aot_eager")
|
|
|
|
for i in range(2):
|
|
# second iteration is key, hooks would have fired during aot trace
|
|
# on first iter
|
|
x = torch.randn((20, 10))
|
|
pred = model(x)
|
|
loss = pred.sum()
|
|
loss.backward()
|
|
|
|
self.assertTrue(grad_sizes.keys() == backward_hook_handles.keys())
|
|
self.assertTrue(pre_grad_sizes.keys() == pre_backward_hook_handles.keys())
|
|
|
|
def test_module_dict_iter_name(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.activations = torch.nn.ModuleDict(
|
|
[["lrelu", torch.nn.LeakyReLU()], ["prelu", torch.nn.PReLU()]]
|
|
)
|
|
|
|
def forward(self, x):
|
|
for activation_name in self.activations:
|
|
x = self.activations[activation_name](x)
|
|
return x
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
# Eager
|
|
eager_res = MyModule()(torch.ones(10, 10))
|
|
|
|
# Compile
|
|
optim_res = torch._dynamo.optimize(cnt)(MyModule())(torch.ones(10, 10))
|
|
self.assertEqual(eager_res, optim_res)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
def test_module_dict_iter_keys(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.activations = torch.nn.ModuleDict(
|
|
[["lrelu", torch.nn.LeakyReLU()], ["prelu", torch.nn.PReLU()]]
|
|
)
|
|
|
|
def forward(self, x):
|
|
for activation_name in self.activations.keys():
|
|
x = self.activations[activation_name](x)
|
|
return x
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
# Eager
|
|
eager_res = MyModule()(torch.ones(10, 10))
|
|
|
|
# Compile
|
|
optim_res = torch._dynamo.optimize(cnt)(MyModule())(torch.ones(10, 10))
|
|
self.assertEqual(eager_res, optim_res)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
def test_assign_does_not_exist(self):
|
|
class MyModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
self.text_encoding = x + 1
|
|
return self.text_encoding
|
|
|
|
mod = MyModule()
|
|
out = torch.compile(mod, fullgraph=True)(torch.randn(10))
|
|
assert mod.text_encoding is out
|
|
|
|
def test_module_dict_iter_values(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.activations = torch.nn.ModuleDict(
|
|
[["lrelu", torch.nn.LeakyReLU()], ["prelu", torch.nn.PReLU()]]
|
|
)
|
|
|
|
def forward(self, x):
|
|
for activation in self.activations.values():
|
|
x = activation(x)
|
|
return x
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
# Eager
|
|
eager_res = MyModule()(torch.ones(10, 10))
|
|
|
|
# Compile
|
|
optim_res = torch._dynamo.optimize(cnt)(MyModule())(torch.ones(10, 10))
|
|
self.assertEqual(eager_res, optim_res)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
def test_unspecialized_seq(self):
|
|
models = torch.nn.Sequential(torch.nn.Linear(3, 3))
|
|
|
|
def fn(x):
|
|
models[0].training = False
|
|
return models(x)
|
|
|
|
opt_fn = torch._dynamo.optimize("eager")(fn)
|
|
x = torch.randn(1, 3)
|
|
ref = fn(x)
|
|
res = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
|
|
def test_no_op_assignment(self):
|
|
class Mod(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.buffer = torch.rand([4])
|
|
|
|
def forward(self, x):
|
|
# should be a no-op, but causes dynamo to lose the static input
|
|
x = x + 1
|
|
self.buffer = self.buffer.to(x)
|
|
return self.buffer + x
|
|
|
|
compiles_without_buffers = 0
|
|
|
|
def debug_compile(gm, *args, **kwargs):
|
|
nonlocal compiles_without_buffers
|
|
compiles_without_buffers += len(list(gm.buffers())) == 0
|
|
return gm
|
|
|
|
@torch.compile(backend=debug_compile)
|
|
def foo(mod, x):
|
|
return mod(x)
|
|
|
|
mod = Mod()
|
|
foo(mod, torch.rand([4]))
|
|
self.assertEqual(compiles_without_buffers, 0)
|
|
|
|
foo(mod, torch.rand([4], dtype=torch.half))
|
|
self.assertEqual(compiles_without_buffers, 1)
|
|
|
|
class Mod2(Mod):
|
|
def __setattr__(self, name, value):
|
|
return super().__setattr__(name, value)
|
|
|
|
foo(Mod2(), torch.rand([4]))
|
|
# causes two compilations, bc unimplemented custom setattr
|
|
self.assertTrue(compiles_without_buffers >= 2)
|
|
|
|
def test_unspec_non_inlinable_module(self):
|
|
mod = UnspecNonInlinableModule()
|
|
opt_fn = torch._dynamo.optimize("eager")(mod)
|
|
x = torch.randn(100)
|
|
actual = opt_fn(x)
|
|
expected = mod(x)
|
|
self.assertEqual(actual, expected)
|
|
|
|
def test_no_guard_on_torch_nn_modules(self):
|
|
# https://github.com/pytorch/pytorch/issues/110048
|
|
|
|
class MockModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(10, 10)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
mod = MockModule()
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
@torch.compile(backend=cnt)
|
|
def generate(x, c):
|
|
return mod(x) + c
|
|
|
|
for _ in range(0, 10):
|
|
generate(torch.randn(10, 10), 0)
|
|
generate(torch.randn(10, 10), 1)
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
|
|
# Ensure that modification in user module causes recompile
|
|
mod.eval()
|
|
generate(torch.randn(10, 10), 0)
|
|
self.assertEqual(cnt.frame_count, 3)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|