mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add functionality for installing free variables (#151134)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151134 Approved by: https://github.com/anijain2305 ghstack dependencies: #152036
This commit is contained in:
committed by
PyTorch MergeBot
parent
402d19c0bd
commit
03970dfd4c
82
test/dynamo/test_inline_and_install.py
Normal file
82
test/dynamo/test_inline_and_install.py
Normal file
@ -0,0 +1,82 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
import unittest
|
||||
|
||||
from torch._dynamo import config
|
||||
from torch._dynamo.testing import make_test_cls_with_patches
|
||||
|
||||
|
||||
try:
|
||||
from . import test_export
|
||||
except ImportError:
|
||||
import test_export
|
||||
|
||||
|
||||
test_classes = {}
|
||||
|
||||
|
||||
def make_dynamic_cls(cls):
|
||||
suffix = "_inline_and_install"
|
||||
|
||||
cls_prefix = "InlineAndInstall"
|
||||
|
||||
test_class = make_test_cls_with_patches(
|
||||
cls,
|
||||
cls_prefix,
|
||||
suffix,
|
||||
(config, "install_free_tensors", True),
|
||||
(config, "inline_inbuilt_nn_modules", True),
|
||||
xfail_prop="_expected_failure_inline_and_install",
|
||||
)
|
||||
|
||||
test_classes[test_class.__name__] = test_class
|
||||
# REMOVING THIS LINE WILL STOP TESTS FROM RUNNING
|
||||
globals()[test_class.__name__] = test_class
|
||||
test_class.__module__ = __name__
|
||||
return test_class
|
||||
|
||||
|
||||
tests = [
|
||||
test_export.ExportTests,
|
||||
]
|
||||
for test in tests:
|
||||
make_dynamic_cls(test)
|
||||
del test
|
||||
|
||||
# After installing and inlining is turned on, these tests won't throw
|
||||
# errors in export (which is expected for the test to pass)
|
||||
# Therefore, these unittest are expected to fail, and we need to update the
|
||||
# semantics
|
||||
unittest.expectedFailure(
|
||||
InlineAndInstallExportTests.test_invalid_input_global_inline_and_install # noqa: F821
|
||||
)
|
||||
unittest.expectedFailure(
|
||||
InlineAndInstallExportTests.test_invalid_input_global_multiple_access_inline_and_install # noqa: F821
|
||||
)
|
||||
unittest.expectedFailure(
|
||||
InlineAndInstallExportTests.test_invalid_input_nonlocal_inline_and_install # noqa: F821
|
||||
)
|
||||
|
||||
|
||||
# These tests do string comparisson on the graphs, and since buffers are now inlined, they
|
||||
# are named different, resulting in failure
|
||||
unittest.expectedFailure(
|
||||
InlineAndInstallExportTests.test_param_buffer_safe_from_mutation_simple_inline_and_install # noqa: F821
|
||||
)
|
||||
|
||||
|
||||
# This particular test is marked expecting failure, since dynamo was creating second param for a
|
||||
# and this was causing a failure in the sum; however with these changes, that test is fixed
|
||||
# so will now pass, so we need to mark that it is no longer expected to fail
|
||||
def expectedSuccess(test_item):
|
||||
test_item.__unittest_expecting_failure__ = False
|
||||
return test_item
|
||||
|
||||
|
||||
expectedSuccess(
|
||||
InlineAndInstallExportTests.test_sum_param_inline_and_install # noqa: F821
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
@ -1,5 +1,7 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
from typing import Callable
|
||||
import unittest
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
@ -25,9 +27,11 @@ def get_num_input_nodes(graph: GraphModule) -> int:
|
||||
"""
|
||||
placeholder_cnt = 0
|
||||
for node in graph.graph.nodes:
|
||||
if node.op == "placeholder" and isinstance(
|
||||
# Missing in some export tests so check manually
|
||||
placeholder_is_tensor = "example_value" in node.meta and isinstance(
|
||||
node.meta["example_value"], torch.Tensor
|
||||
):
|
||||
)
|
||||
if node.op == "placeholder" and placeholder_is_tensor:
|
||||
placeholder_cnt += 1
|
||||
return placeholder_cnt
|
||||
|
||||
@ -46,28 +50,84 @@ class SimpleLinearModule(torch.nn.Module):
|
||||
return self.fwd(x)
|
||||
|
||||
|
||||
class ResBlock(torch.nn.Module):
|
||||
"""
|
||||
Basic resnet building block - used for testing structure
|
||||
more typical of real models (i.e sequential, activations,
|
||||
and batchnorm)
|
||||
"""
|
||||
|
||||
def __init__(self, in_: int, out_: int):
|
||||
super().__init__()
|
||||
self.conv1 = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(in_, out_, kernel_size=3, padding=1),
|
||||
torch.nn.BatchNorm2d(out_),
|
||||
torch.nn.ReLU(),
|
||||
)
|
||||
self.conv2 = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(out_, out_, kernel_size=3, padding=1),
|
||||
torch.nn.BatchNorm2d(out_),
|
||||
)
|
||||
self.activation = torch.nn.ReLU()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
skip = x
|
||||
out = self.conv1(x)
|
||||
out = self.conv2(out)
|
||||
out += skip
|
||||
out = self.activation(out)
|
||||
return out
|
||||
|
||||
|
||||
class InstallParamsAsGraphAttrTests(torch._dynamo.test_case.TestCase):
|
||||
@torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
|
||||
@torch._dynamo.config.patch(install_free_tensors=False)
|
||||
def check_num_inputs_and_equality(
|
||||
self, original_fn, example_input, expected_num_inputs: int
|
||||
def check_num_inputs_and_equality_no_install(
|
||||
self,
|
||||
fn_to_compile: Union[torch.nn.Module, Callable],
|
||||
expected_num_inline_inputs: int,
|
||||
example_inputs: Sequence[Any],
|
||||
) -> None:
|
||||
"""Compiles the original fn, then:
|
||||
* Checks that the number of inputs in the graph is expected_num_inputs
|
||||
* Checks that the compiled fn and original fn are equal
|
||||
"""
|
||||
opt_fn, graphs = compile_and_extract_graph(original_fn, example_input)
|
||||
# inlined ex
|
||||
opt_fn, graphs = compile_and_extract_graph(fn_to_compile, *example_inputs)
|
||||
self.assertEqual(len(graphs), 1, msg="Expected 1 graph (no breaks)")
|
||||
actual_num_inputs = get_num_input_nodes(graphs[0])
|
||||
self.assertEqual(actual_num_inputs, expected_num_inputs)
|
||||
self.assertEqual(opt_fn(example_input), original_fn(example_input))
|
||||
self.assertEqual(actual_num_inputs, expected_num_inline_inputs)
|
||||
self.assertEqual(opt_fn(*example_inputs), fn_to_compile(*example_inputs))
|
||||
|
||||
@torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
|
||||
@torch._dynamo.config.patch(install_free_tensors=True)
|
||||
def check_num_inputs_and_equality_install(
|
||||
self,
|
||||
fn_to_compile: Union[torch.nn.Module, Callable],
|
||||
expected_num_installed_inputs: int,
|
||||
example_inputs: Sequence[Any],
|
||||
) -> None:
|
||||
"""Compiles the original fn, then:
|
||||
* Checks the number of inputs when installed is consistent with original_fn
|
||||
# Checks that the compiled fn when installed and original fn are equal
|
||||
"""
|
||||
opt_installed_fn, graphs = compile_and_extract_graph(
|
||||
fn_to_compile, *example_inputs
|
||||
)
|
||||
self.assertEqual(len(graphs), 1, msg="Expected 1 graph (no breaks)")
|
||||
actual_num_inputs = get_num_input_nodes(graphs[0])
|
||||
self.assertEqual(actual_num_inputs, expected_num_installed_inputs)
|
||||
self.assertEqual(
|
||||
opt_installed_fn(*example_inputs), fn_to_compile(*example_inputs)
|
||||
)
|
||||
|
||||
# ==================== Test Params and Buffer from NN Module ====================
|
||||
def test_optimizing_linear(self) -> None:
|
||||
net = SimpleLinearModule()
|
||||
input1 = torch.randn((1, 5))
|
||||
# Expected: 1 + 1 * 2 = 3
|
||||
self.check_num_inputs_and_equality(net, input1, 3)
|
||||
self.check_num_inputs_and_equality_no_install(net, 3, (input1,))
|
||||
self.check_num_inputs_and_equality_install(net, 1, (input1,))
|
||||
|
||||
def test_breadth_linear(self) -> None:
|
||||
class BreadthModel(torch.nn.Module):
|
||||
@ -79,7 +139,7 @@ class InstallParamsAsGraphAttrTests(torch._dynamo.test_case.TestCase):
|
||||
self.fwd4 = torch.nn.Linear(1, 1)
|
||||
self.fwd5 = torch.nn.Linear(1, 1)
|
||||
|
||||
def forward(self, x) -> torch.Tensor:
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return (
|
||||
self.fwd(x)
|
||||
+ self.fwd2(x)
|
||||
@ -91,16 +151,17 @@ class InstallParamsAsGraphAttrTests(torch._dynamo.test_case.TestCase):
|
||||
net = BreadthModel()
|
||||
input1 = torch.randn((1, 1))
|
||||
# Expected: 1 + 5 * 2 = 11
|
||||
self.check_num_inputs_and_equality(net, input1, 11)
|
||||
self.check_num_inputs_and_equality_no_install(net, 11, (input1,))
|
||||
self.check_num_inputs_and_equality_install(net, 1, (input1,))
|
||||
|
||||
def test_nested_linear(self) -> None:
|
||||
class NestedModel(torch.nn.Module):
|
||||
def __init__(self, inner_module) -> None:
|
||||
def __init__(self, inner_module: torch.nn.Module) -> None:
|
||||
super().__init__()
|
||||
self.fwd = torch.nn.Linear(1, 1)
|
||||
self.inner_module = inner_module
|
||||
|
||||
def forward(self, x) -> torch.Tensor:
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.fwd(self.inner_module(x))
|
||||
|
||||
# Nest 5x
|
||||
@ -109,28 +170,142 @@ class InstallParamsAsGraphAttrTests(torch._dynamo.test_case.TestCase):
|
||||
for _ in range(kDepth):
|
||||
net = NestedModel(net)
|
||||
input1 = torch.randn((1, 5))
|
||||
self.check_num_inputs_and_equality(net, input1, 1 + 2 * (kDepth + 1))
|
||||
self.check_num_inputs_and_equality_no_install(
|
||||
net, 1 + 2 * (kDepth + 1), (input1,)
|
||||
)
|
||||
self.check_num_inputs_and_equality_install(net, 1, (input1,))
|
||||
|
||||
# TODO[@lucaskabela]: Test nontrivial such as resnet, ffn, or transformer
|
||||
def test_simple_batchnorm(self) -> None:
|
||||
net = torch.nn.BatchNorm2d(3)
|
||||
tensor = torch.randn((1, 3, 3, 3))
|
||||
# BatchNorm2d has 2 params, and 3 buffers
|
||||
self.check_num_inputs_and_equality_no_install(net, 6, (tensor,))
|
||||
self.check_num_inputs_and_equality_install(net, 1, (tensor,))
|
||||
|
||||
def test_nets_as_input(self) -> None:
|
||||
"""
|
||||
Tests when the nn.Module is an input to the fn we are optimizing
|
||||
|
||||
In this case, we should treat it as regular input, which means we
|
||||
can lift parameters/buffers, but should not install them
|
||||
"""
|
||||
# Test nn model as input
|
||||
net = SimpleLinearModule()
|
||||
net2 = SimpleLinearModule()
|
||||
x = torch.randn(1, 5)
|
||||
|
||||
def test_fn(x: torch.Tensor, net: torch.nn.Module) -> torch.Tensor:
|
||||
return net(x)
|
||||
|
||||
# When nn is in input, we don't install the params
|
||||
self.check_num_inputs_and_equality_no_install(test_fn, 3, (x, net))
|
||||
self.check_num_inputs_and_equality_install(test_fn, 1, (x, net))
|
||||
|
||||
def test_fn2(
|
||||
x: torch.Tensor, net: torch.nn.Module, net2: torch.nn.Module
|
||||
) -> torch.Tensor:
|
||||
return net(x) + net2(x)
|
||||
|
||||
self.check_num_inputs_and_equality_no_install(test_fn2, 5, (x, net, net2))
|
||||
self.check_num_inputs_and_equality_install(test_fn2, 1, (x, net, net2))
|
||||
|
||||
def test_fn3(x: torch.Tensor, net: torch.nn.Module) -> torch.Tensor:
|
||||
return net(x) + net2(x)
|
||||
|
||||
# In case of local scope (net2 here), we can install
|
||||
self.check_num_inputs_and_equality_no_install(test_fn3, 5, (x, net))
|
||||
self.check_num_inputs_and_equality_install(test_fn3, 1, (x, net))
|
||||
|
||||
def test_fn_list(x: torch.Tensor, nets: list[torch.nn.Module]):
|
||||
return sum([net(x) for net in nets])
|
||||
|
||||
self.check_num_inputs_and_equality_no_install(test_fn_list, 5, (x, [net, net2]))
|
||||
self.check_num_inputs_and_equality_install(test_fn_list, 1, (x, [net, net2]))
|
||||
|
||||
def test_resnet_structure(self) -> None:
|
||||
net = ResBlock(3, 3)
|
||||
tensor = torch.randn(1, 3, 3, 3)
|
||||
# Conv2d has 2 params, BatchNorm2d has 3 buffers + 2 params, and Relu has 0 params
|
||||
# So expected = 2 + 5 + 5 + 2 = 14 + 1 for input
|
||||
self.check_num_inputs_and_equality_no_install(net, 15, (tensor,))
|
||||
self.check_num_inputs_and_equality_install(net, 1, (tensor,))
|
||||
|
||||
def test_transformer(self) -> None:
|
||||
# needs eval mode - must disable dropout
|
||||
transformer = torch.nn.Transformer(d_model=32).eval()
|
||||
src = torch.rand(10, 32, 32)
|
||||
tgt = torch.rand(20, 32, 32)
|
||||
|
||||
self.check_num_inputs_and_equality_no_install(transformer, 186, (src, tgt))
|
||||
self.check_num_inputs_and_equality_install(transformer, 2, (src, tgt))
|
||||
|
||||
# ==================== Test Parameters and Buffers as input ====================
|
||||
def test_optimizing_params_in_input(self) -> None:
|
||||
param = torch.nn.Parameter(torch.randn(1, 5))
|
||||
net = SimpleLinearModule()
|
||||
|
||||
def test_fn(x):
|
||||
def test_fn(x: torch.Tensor) -> torch.Tensor:
|
||||
return net(x)
|
||||
|
||||
self.check_num_inputs_and_equality(test_fn, param, 3)
|
||||
self.check_num_inputs_and_equality_no_install(test_fn, 3, (param,))
|
||||
self.check_num_inputs_and_equality_install(test_fn, 1, (param,))
|
||||
|
||||
x = torch.randn(1, 5)
|
||||
|
||||
def test_fn2(x: torch.Tensor, param: torch.nn.Parameter) -> torch.Tensor:
|
||||
return net(x) + param
|
||||
|
||||
# net gets installed, param does not here
|
||||
self.check_num_inputs_and_equality_no_install(test_fn2, 4, (x, param))
|
||||
self.check_num_inputs_and_equality_install(test_fn2, 2, (x, param))
|
||||
|
||||
global global_param
|
||||
global_param = torch.nn.Parameter(torch.randn(1, 5))
|
||||
|
||||
def test_fn3(x: torch.Tensor) -> torch.Tensor:
|
||||
return net(x) + global_param
|
||||
|
||||
# net and global does too
|
||||
self.check_num_inputs_and_equality_no_install(test_fn3, 4, (x,))
|
||||
self.check_num_inputs_and_equality_install(test_fn3, 1, (x,))
|
||||
|
||||
def test_fn4(
|
||||
x: torch.Tensor, list_params: list[torch.nn.Parameter]
|
||||
) -> torch.Tensor:
|
||||
return net(x) + sum(list_params)
|
||||
|
||||
# list_params should not be installed
|
||||
self.check_num_inputs_and_equality_no_install(test_fn4, 4, (x, [param, param]))
|
||||
self.check_num_inputs_and_equality_install(test_fn4, 2, (x, [param, param]))
|
||||
|
||||
def test_optimizing_buffer_in_input(self) -> None:
|
||||
buf = torch.nn.Buffer(data=torch.ones((1, 5)))
|
||||
net = SimpleLinearModule()
|
||||
|
||||
def test_fn(x) -> torch.Tensor:
|
||||
def test_fn(x: torch.Tensor) -> torch.Tensor:
|
||||
return net(x)
|
||||
|
||||
self.check_num_inputs_and_equality(test_fn, buf, 3)
|
||||
self.check_num_inputs_and_equality_no_install(test_fn, 3, (buf,))
|
||||
self.check_num_inputs_and_equality_install(test_fn, 1, (buf,))
|
||||
|
||||
x = torch.randn(1, 5)
|
||||
|
||||
def test_fn2(x: torch.Tensor, buf: torch.nn.Buffer):
|
||||
return net(x) + buf
|
||||
|
||||
# net gets installed, buf does not here
|
||||
self.check_num_inputs_and_equality_no_install(test_fn2, 4, (x, buf))
|
||||
self.check_num_inputs_and_equality_install(test_fn2, 2, (x, buf))
|
||||
|
||||
global global_buf
|
||||
global_buf = torch.nn.Buffer(torch.randn(1, 5))
|
||||
|
||||
def test_fn3(x: torch.Tensor) -> torch.Tensor:
|
||||
return net(x) + global_buf
|
||||
|
||||
# net and global does too
|
||||
self.check_num_inputs_and_equality_no_install(test_fn3, 4, (x,))
|
||||
self.check_num_inputs_and_equality_install(test_fn3, 1, (x,))
|
||||
|
||||
def test_optimizing_buffer_and_param_in_input(self) -> None:
|
||||
param = torch.nn.Parameter(torch.randn(5, 1))
|
||||
@ -140,7 +315,244 @@ class InstallParamsAsGraphAttrTests(torch._dynamo.test_case.TestCase):
|
||||
def test_linear(x: torch.Tensor) -> torch.Tensor:
|
||||
return param * x + buf
|
||||
|
||||
self.check_num_inputs_and_equality(test_linear, x, 3)
|
||||
self.check_num_inputs_and_equality_no_install(test_linear, 3, (x,))
|
||||
self.check_num_inputs_and_equality_install(test_linear, 1, (x,))
|
||||
|
||||
def test_linear_explicit(
|
||||
x: torch.Tensor, a: torch.Tensor, b: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
return a * x + b
|
||||
|
||||
# Now, param and buf are input so should not be inlined
|
||||
self.check_num_inputs_and_equality_no_install(
|
||||
test_linear_explicit, 3, (x, param, buf)
|
||||
)
|
||||
self.check_num_inputs_and_equality_install(
|
||||
test_linear_explicit, 3, (x, param, buf)
|
||||
)
|
||||
|
||||
|
||||
class InstallParamsWhenExport(torch._dynamo.test_case.TestCase):
|
||||
@torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
|
||||
@torch._dynamo.config.patch(install_free_tensors=True)
|
||||
def check_export_matches_expectation(
|
||||
self,
|
||||
fn_to_export: Callable,
|
||||
expected_num_exported_inputs: int,
|
||||
example_inputs: Sequence[Any],
|
||||
) -> None:
|
||||
"""Exports the original fn, then:
|
||||
* Checks that the number of inputs in the exported is expected_num_exported_inputs
|
||||
* Checks that the exported fn and original fn are equal
|
||||
"""
|
||||
exported_fn = torch._dynamo.export(fn_to_export)
|
||||
out_graph = exported_fn(*example_inputs)[0]
|
||||
actual_num_inputs = get_num_input_nodes(out_graph)
|
||||
self.assertEqual(actual_num_inputs, expected_num_exported_inputs)
|
||||
self.assertEqual(out_graph(*example_inputs), fn_to_export(*example_inputs))
|
||||
|
||||
def test_simple_linear(self) -> None:
|
||||
net = SimpleLinearModule()
|
||||
input1 = torch.randn((1, 5))
|
||||
self.check_export_matches_expectation(net, 1, (input1,))
|
||||
|
||||
def test_fn(x: torch.Tensor) -> torch.Tensor:
|
||||
return net(x)
|
||||
|
||||
self.check_export_matches_expectation(test_fn, 1, (input1,))
|
||||
|
||||
# Check multiple inputs
|
||||
def test_fn_2(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return net(x) + net(y)
|
||||
|
||||
input2 = torch.randn((1, 5))
|
||||
self.check_export_matches_expectation(test_fn_2, 2, (input1, input2))
|
||||
|
||||
def test_simple_batchnorm(self) -> None:
|
||||
net = torch.nn.BatchNorm2d(3)
|
||||
tensor = torch.randn((1, 3, 3, 3))
|
||||
self.check_export_matches_expectation(net, 1, (tensor,))
|
||||
|
||||
def test_fn(x: torch.Tensor) -> torch.Tensor:
|
||||
return net(x)
|
||||
|
||||
self.check_export_matches_expectation(test_fn, 1, (tensor,))
|
||||
|
||||
def test_resnet_structure(self) -> None:
|
||||
net = ResBlock(3, 3)
|
||||
tensor = torch.randn(1, 3, 3, 3)
|
||||
self.check_export_matches_expectation(net, 1, (tensor,))
|
||||
|
||||
def test_fn(x: torch.Tensor) -> torch.Tensor:
|
||||
return net(x)
|
||||
|
||||
self.check_export_matches_expectation(test_fn, 1, (tensor,))
|
||||
|
||||
def test_transformer(self) -> None:
|
||||
transformer = torch.nn.Transformer(d_model=32).eval()
|
||||
src = torch.rand(10, 32, 32)
|
||||
tgt = torch.rand(20, 32, 32)
|
||||
|
||||
self.check_export_matches_expectation(transformer, 2, (src, tgt))
|
||||
|
||||
def test_fn(src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor:
|
||||
return transformer(src, tgt)
|
||||
|
||||
self.check_export_matches_expectation(test_fn, 2, (src, tgt))
|
||||
|
||||
def test_optimizing_params_in_input(self) -> None:
|
||||
param = torch.nn.Parameter(torch.randn(1, 5))
|
||||
net = SimpleLinearModule()
|
||||
|
||||
def test_fn(x: torch.Tensor) -> torch.Tensor:
|
||||
return net(x)
|
||||
|
||||
self.check_export_matches_expectation(net, 1, (param,))
|
||||
self.check_export_matches_expectation(test_fn, 1, (param,))
|
||||
|
||||
x = torch.randn(1, 5)
|
||||
|
||||
def test_fn2(x: torch.Tensor, param: torch.nn.Parameter) -> torch.Tensor:
|
||||
return net(x) + param
|
||||
|
||||
# net gets installed, param does not here
|
||||
self.check_export_matches_expectation(test_fn2, 2, (x, param))
|
||||
|
||||
def test_fn3(
|
||||
x: torch.Tensor, list_params: list[torch.nn.Parameter]
|
||||
) -> torch.Tensor:
|
||||
return net(x) + sum(list_params)
|
||||
|
||||
# list_params should not be installed or inlined here
|
||||
self.check_export_matches_expectation(test_fn3, 2, (x, [param, param]))
|
||||
|
||||
def test_optimizing_buffer_in_input(self) -> None:
|
||||
buf = torch.nn.Buffer(data=torch.ones((1, 5)))
|
||||
net = SimpleLinearModule()
|
||||
|
||||
def test_fn(x: torch.Tensor) -> torch.Tensor:
|
||||
return net(x)
|
||||
|
||||
self.check_export_matches_expectation(net, 1, (buf,))
|
||||
self.check_export_matches_expectation(test_fn, 1, (buf,))
|
||||
|
||||
x = torch.randn(1, 5)
|
||||
|
||||
def test_fn2(x: torch.Tensor, buf: torch.nn.Buffer) -> torch.Tensor:
|
||||
return net(x) + buf
|
||||
|
||||
# net gets installed, buf does not here
|
||||
self.check_export_matches_expectation(test_fn2, 2, (x, buf))
|
||||
|
||||
def test_optimizing_buffer_and_param_in_input(self) -> None:
|
||||
param = torch.nn.Parameter(torch.randn(5, 1))
|
||||
buf = torch.nn.Buffer(data=torch.ones((1, 1)))
|
||||
x = torch.randn(1, 5)
|
||||
|
||||
def test_linear_explicit(
|
||||
x: torch.Tensor, a: torch.Tensor, b: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
return a * x + b
|
||||
|
||||
# Now, param and buf are input so should not be inlined
|
||||
self.check_export_matches_expectation(test_linear_explicit, 3, (x, param, buf))
|
||||
|
||||
def test_global_tensor_export(self) -> None:
|
||||
global x
|
||||
x = torch.randn((5, 5))
|
||||
|
||||
def fn(a: torch.Tensor) -> torch.Tensor:
|
||||
return a + x
|
||||
|
||||
inp = torch.randn(5, 5)
|
||||
self.check_export_matches_expectation(fn, 1, (inp,))
|
||||
|
||||
def test_nonlocal_closure(self) -> None:
|
||||
x = torch.randn((5, 5))
|
||||
|
||||
def fn(a: torch.Tensor) -> torch.Tensor:
|
||||
return a + x
|
||||
|
||||
inp = torch.randn((5, 5))
|
||||
self.check_export_matches_expectation(fn, 1, (inp,))
|
||||
|
||||
@torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
|
||||
@torch._dynamo.config.patch(install_free_tensors=True)
|
||||
def test_modify_net_state(self) -> None:
|
||||
class Mod(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(5, 5)
|
||||
self.a = None
|
||||
|
||||
def forward(self, x):
|
||||
if self.a is None:
|
||||
self.a = torch.ones_like(x)
|
||||
return self.linear(x) + self.a
|
||||
|
||||
mod = Mod()
|
||||
inp = torch.randn(5, 5)
|
||||
# NOTE: since this fn modifies original class,
|
||||
# need to get reference value before tracing
|
||||
res = mod(inp)
|
||||
mod.a = None
|
||||
ep = torch._dynamo.export(mod)
|
||||
graph, _ = ep(inp)
|
||||
self.assertEqual(graph(inp), res)
|
||||
|
||||
def test_list_of_tensor(self) -> None:
|
||||
def fn(x: list[torch.Tensor]):
|
||||
return x[0] + x[1]
|
||||
|
||||
inp = [torch.tensor([1.3, 3.77, 0.1]), torch.tensor([8.7, 6.23, 9.9])]
|
||||
self.check_export_matches_expectation(fn, 2, (inp,))
|
||||
|
||||
def test_nested_list_of_tensor(self) -> None:
|
||||
def fn(x: list[Union[list[torch.Tensor], torch.Tensor]]):
|
||||
return x[0][0] + x[1] # type: ignore[index]
|
||||
|
||||
inp = [[torch.tensor([1.3, 3.77, 0.1])], torch.tensor([8.7, 6.23, 9.9])]
|
||||
self.check_export_matches_expectation(fn, 2, (inp,))
|
||||
|
||||
def test_dict_of_tensor(self) -> None:
|
||||
inp_dict = {"temp": torch.tensor(12)}
|
||||
|
||||
def fn(inp: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
return inp_dict["temp"] + 5
|
||||
|
||||
self.check_export_matches_expectation(fn, 1, (inp_dict,))
|
||||
|
||||
# TODO[lucaskabela]: register the flatten/unflatten function so we can evaluate this test
|
||||
@unittest.expectedFailure
|
||||
def test_user_defined_object(self) -> None:
|
||||
class UserDefinedTestClass:
|
||||
def __init__(self, x, y) -> None:
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
x = torch.randn((3, 3))
|
||||
y = torch.randn((3, 3))
|
||||
|
||||
def fn(obj: UserDefinedTestClass, inp: torch.Tensor) -> torch.Tensor:
|
||||
return obj.x + obj.y + inp
|
||||
|
||||
z = torch.randn((3, 1))
|
||||
|
||||
self.check_export_matches_expectation(fn, 2, (UserDefinedTestClass(x, y), z))
|
||||
|
||||
def test_tensors_as_nn_attr(self) -> None:
|
||||
class Mod(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.a = torch.ones((5, 5))
|
||||
self.b = torch.ones((5, 5))
|
||||
|
||||
def forward(self, x):
|
||||
return self.a + self.b + x
|
||||
|
||||
mod = Mod()
|
||||
inp = torch.randn(5, 5)
|
||||
self.check_export_matches_expectation(mod, 1, (inp,))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -205,6 +205,7 @@ class Inp3:
|
||||
|
||||
NON_STRICT_SUFFIX = "_nonstrict"
|
||||
STRICT_SUFFIX = "_strict"
|
||||
INLINE_AND_INSTALL_STRICT_SUFFIX = "_inline_and_install_strict"
|
||||
RETRACEABILITY_STRICT_SUFFIX = "_retraceability_strict"
|
||||
RETRACEABILITY_NON_STRICT_SUFFIX = "_retraceability_nonstrict"
|
||||
SERDES_SUFFIX = "serdes"
|
||||
@ -235,6 +236,10 @@ def is_legacy_test(test_name):
|
||||
)
|
||||
|
||||
|
||||
def is_inline_and_install_strict_test(test_name: str) -> bool:
|
||||
return test_name.endswith(INLINE_AND_INSTALL_STRICT_SUFFIX)
|
||||
|
||||
|
||||
def is_retracebility_test(test_name):
|
||||
return test_name.endswith(RETRACEABILITY_STRICT_SUFFIX) or test_name.endswith(
|
||||
RETRACEABILITY_NON_STRICT_SUFFIX
|
||||
@ -7334,9 +7339,23 @@ def forward(self, b_a_buffer, x):
|
||||
)
|
||||
|
||||
else:
|
||||
self.assertExpectedInline(
|
||||
ep.graph_module.code.strip(),
|
||||
"""\
|
||||
if is_inline_and_install_strict_test(self._testMethodName):
|
||||
self.assertExpectedInline(
|
||||
ep.graph_module.code.strip(),
|
||||
"""\
|
||||
def forward(self, b____modules__a____buffers__buffer, x):
|
||||
sym_size_int_1 = torch.ops.aten.sym_size.int(x, 0)
|
||||
gt = sym_size_int_1 > 4; sym_size_int_1 = None
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, (x, b____modules__a____buffers__buffer)); gt = true_graph_0 = false_graph_0 = x = b____modules__a____buffers__buffer = None
|
||||
getitem = cond[0]; cond = None
|
||||
return (getitem,)""",
|
||||
)
|
||||
else:
|
||||
self.assertExpectedInline(
|
||||
ep.graph_module.code.strip(),
|
||||
"""\
|
||||
def forward(self, b_a_buffer, x):
|
||||
sym_size_int_1 = torch.ops.aten.sym_size.int(x, 0)
|
||||
gt = sym_size_int_1 > 4; sym_size_int_1 = None
|
||||
@ -7345,7 +7364,7 @@ def forward(self, b_a_buffer, x):
|
||||
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, (x, b_a_buffer)); gt = true_graph_0 = false_graph_0 = x = b_a_buffer = None
|
||||
getitem = cond[0]; cond = None
|
||||
return (getitem,)""",
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(ep.module()(torch.ones(6, 4)), Foo()(torch.ones(6, 4)))
|
||||
)
|
||||
@ -7897,9 +7916,27 @@ def forward(self, p_lin_weight, p_lin_bias, x):
|
||||
decomp_table={torch.ops.aten.linear.default: _decompose_linear_custom}
|
||||
)
|
||||
|
||||
self.assertExpectedInline(
|
||||
str(ep_decompose_linear.graph_module.code).strip(),
|
||||
"""\
|
||||
if is_inline_and_install_strict_test(self._testMethodName):
|
||||
self.assertExpectedInline(
|
||||
str(ep_decompose_linear.graph_module.code).strip(),
|
||||
"""\
|
||||
def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_linear_bias, c_linear_weight, x, y):
|
||||
conv2d = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias); x = p_conv_weight = p_conv_bias = None
|
||||
conv1d = torch.ops.aten.conv1d.default(y, p_conv1d_weight, p_conv1d_bias); y = p_conv1d_weight = p_conv1d_bias = None
|
||||
permute = torch.ops.aten.permute.default(c_linear_weight, [1, 0]); c_linear_weight = None
|
||||
matmul = torch.ops.aten.matmul.default(conv2d, permute); conv2d = permute = None
|
||||
mul = torch.ops.aten.mul.Tensor(c_linear_bias, 2); c_linear_bias = None
|
||||
add = torch.ops.aten.add.Tensor(matmul, mul); matmul = mul = None
|
||||
cos = torch.ops.aten.cos.default(add); add = None
|
||||
sum_1 = torch.ops.aten.sum.default(conv1d); conv1d = None
|
||||
add_1 = torch.ops.aten.add.Tensor(cos, sum_1); cos = sum_1 = None
|
||||
return (add_1,)""",
|
||||
)
|
||||
|
||||
else:
|
||||
self.assertExpectedInline(
|
||||
str(ep_decompose_linear.graph_module.code).strip(),
|
||||
"""\
|
||||
def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_linear_weight, c_linear_bias, x, y):
|
||||
conv2d = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias); x = p_conv_weight = p_conv_bias = None
|
||||
conv1d = torch.ops.aten.conv1d.default(y, p_conv1d_weight, p_conv1d_bias); y = p_conv1d_weight = p_conv1d_bias = None
|
||||
@ -7911,7 +7948,7 @@ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_
|
||||
sum_1 = torch.ops.aten.sum.default(conv1d); conv1d = None
|
||||
add_1 = torch.ops.aten.add.Tensor(cos, sum_1); cos = sum_1 = None
|
||||
return (add_1,)""",
|
||||
)
|
||||
)
|
||||
|
||||
def test_export_decomps_dynamic(self):
|
||||
class M(torch.nn.Module):
|
||||
@ -11819,16 +11856,29 @@ def forward(self, x):
|
||||
torch.randn(4),
|
||||
)
|
||||
ep = export(Foo(), inputs)
|
||||
expected_names = [ # user inputs should be prioritized, unprefixed
|
||||
("p_param_1", InputKind.PARAMETER),
|
||||
("b_alpha_1", InputKind.BUFFER),
|
||||
("b_beta_1", InputKind.BUFFER),
|
||||
("c_gamma_1", InputKind.CONSTANT_TENSOR),
|
||||
("p_param", InputKind.USER_INPUT),
|
||||
("b_alpha", InputKind.USER_INPUT),
|
||||
("b_beta", InputKind.USER_INPUT),
|
||||
("c_gamma", InputKind.USER_INPUT),
|
||||
]
|
||||
if is_inline_and_install_strict_test(self._testMethodName):
|
||||
# when installed, prefix name
|
||||
expected_names = [ # user inputs should be prioritized, unprefixed
|
||||
("p____parameters__param", InputKind.PARAMETER),
|
||||
("b____buffers__alpha", InputKind.BUFFER),
|
||||
("b____buffers__beta", InputKind.BUFFER),
|
||||
("c_gamma_1", InputKind.CONSTANT_TENSOR),
|
||||
("p_param", InputKind.USER_INPUT),
|
||||
("b_alpha", InputKind.USER_INPUT),
|
||||
("b_beta", InputKind.USER_INPUT),
|
||||
("c_gamma", InputKind.USER_INPUT),
|
||||
]
|
||||
else:
|
||||
expected_names = [ # user inputs should be prioritized, unprefixed
|
||||
("p_param_1", InputKind.PARAMETER),
|
||||
("b_alpha_1", InputKind.BUFFER),
|
||||
("b_beta_1", InputKind.BUFFER),
|
||||
("c_gamma_1", InputKind.CONSTANT_TENSOR),
|
||||
("p_param", InputKind.USER_INPUT),
|
||||
("b_alpha", InputKind.USER_INPUT),
|
||||
("b_beta", InputKind.USER_INPUT),
|
||||
("c_gamma", InputKind.USER_INPUT),
|
||||
]
|
||||
real_names = [
|
||||
(spec.arg.name, spec.kind) for spec in ep.graph_signature.input_specs
|
||||
]
|
||||
@ -12789,8 +12839,14 @@ graph():
|
||||
list(nn_module_stack.values())[-1][0]
|
||||
for nn_module_stack in nn_module_stacks
|
||||
]
|
||||
self.assertEqual(filtered_nn_module_stack[0], "sub_net.0")
|
||||
self.assertEqual(filtered_nn_module_stack[1], "sub_net.2")
|
||||
|
||||
if is_inline_and_install_strict_test(self._testMethodName):
|
||||
# when inlined and install have same ID so reference same layer
|
||||
self.assertEqual(filtered_nn_module_stack[0], "sub_net.0")
|
||||
self.assertEqual(filtered_nn_module_stack[1], "sub_net.0")
|
||||
else:
|
||||
self.assertEqual(filtered_nn_module_stack[0], "sub_net.0")
|
||||
self.assertEqual(filtered_nn_module_stack[1], "sub_net.2")
|
||||
|
||||
def test_slice_nn_module_stack(self):
|
||||
class N(torch.nn.Module):
|
||||
@ -12823,8 +12879,16 @@ graph():
|
||||
list(nn_module_stack.values())[-1][0]
|
||||
for nn_module_stack in nn_module_stacks
|
||||
]
|
||||
self.assertEqual(filtered_nn_module_stack[0], "mod_list_1.slice(2, 3, None).2")
|
||||
self.assertEqual(filtered_nn_module_stack[1], "mod_list_2.slice(4, 5, None).0")
|
||||
if is_inline_and_install_strict_test(self._testMethodName):
|
||||
self.assertEqual(filtered_nn_module_stack[0], "mod_list_1.2")
|
||||
self.assertEqual(filtered_nn_module_stack[1], "mod_list_1.2")
|
||||
else:
|
||||
self.assertEqual(
|
||||
filtered_nn_module_stack[0], "mod_list_1.slice(2, 3, None).2"
|
||||
)
|
||||
self.assertEqual(
|
||||
filtered_nn_module_stack[1], "mod_list_2.slice(4, 5, None).0"
|
||||
)
|
||||
|
||||
def test_split_const_gm_with_lifted_constants(self):
|
||||
class Model(torch.nn.Module):
|
||||
|
90
test/export/test_export_with_inline_and_install.py
Normal file
90
test/export/test_export_with_inline_and_install.py
Normal file
@ -0,0 +1,90 @@
|
||||
# Owner(s): ["oncall: export"]
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
from torch._dynamo import config
|
||||
from torch._dynamo.testing import make_test_cls_with_patches
|
||||
|
||||
|
||||
try:
|
||||
from . import test_export, testing
|
||||
except ImportError:
|
||||
import test_export # @manual=fbcode//caffe2/test:test_export-library
|
||||
import testing # @manual=fbcode//caffe2/test:test_export-library
|
||||
|
||||
from torch.export import export
|
||||
|
||||
|
||||
test_classes = {}
|
||||
|
||||
|
||||
def mocked_strict_export(*args, **kwargs):
|
||||
# If user already specified strict, don't make it strict
|
||||
if "strict" in kwargs:
|
||||
return export(*args, **kwargs)
|
||||
return export(*args, **kwargs, strict=True)
|
||||
|
||||
|
||||
def make_dynamic_cls(cls):
|
||||
# Some test check for ending in suffix; need to make
|
||||
# the `_strict` for end of string as a result
|
||||
suffix = test_export.INLINE_AND_INSTALL_STRICT_SUFFIX
|
||||
|
||||
cls_prefix = "InlineAndInstall"
|
||||
|
||||
cls_a = testing.make_test_cls_with_mocked_export(
|
||||
cls,
|
||||
"StrictExport",
|
||||
suffix,
|
||||
mocked_strict_export,
|
||||
xfail_prop="_expected_failure_strict",
|
||||
)
|
||||
test_class = make_test_cls_with_patches(
|
||||
cls_a,
|
||||
cls_prefix,
|
||||
"",
|
||||
(config, "install_free_tensors", True),
|
||||
(config, "inline_inbuilt_nn_modules", True),
|
||||
xfail_prop="_expected_failure_inline_and_install",
|
||||
)
|
||||
|
||||
test_classes[test_class.__name__] = test_class
|
||||
# REMOVING THIS LINE WILL STOP TESTS FROM RUNNING
|
||||
globals()[test_class.__name__] = test_class
|
||||
test_class.__module__ = __name__
|
||||
return test_class
|
||||
|
||||
|
||||
tests = [
|
||||
test_export.TestDynamismExpression,
|
||||
test_export.TestExport,
|
||||
]
|
||||
for test in tests:
|
||||
make_dynamic_cls(test)
|
||||
del test
|
||||
|
||||
|
||||
# NOTE: For this test, we have a failure that occurs because the buffers (for BatchNorm2D) are installed, and not
|
||||
# graph input. Therefore, they are not in the `program.graph_signature.inputs_to_buffers`
|
||||
# and so not found by the unit test when counting the buffers
|
||||
unittest.expectedFailure(
|
||||
InlineAndInstallStrictExportTestExport.test_buffer_util_inline_and_install_strict # noqa: F821
|
||||
)
|
||||
|
||||
# NOTE: For this test, when we call `LOAD_ATTR`, we fail to realizing the LazyVariableTracker
|
||||
# This is because the variable is popped off stack, pushed into TupleVariable (then ConstDictVariable)
|
||||
# So, in the first case (not nested return), the LazyVariable is realized at the RETURN_VALUE call;
|
||||
# for the second case (nested return), the LazyVariable is not realized until we begin COMPILING_GRAPH
|
||||
# As a result, we don't install the variable, so crash when we expect the variable to be installed later
|
||||
# Potential fix: We can force the lazy variable tracker to realize; just need to see how this is done for the non
|
||||
# nested case
|
||||
unittest.expectedFailure(
|
||||
InlineAndInstallStrictExportTestExport.test_constant_output_inline_and_install_strict # noqa: F821
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
@ -117,14 +117,10 @@ def is_dynamic_nn_module(obj: Any, is_export: bool) -> bool:
|
||||
return True
|
||||
if hasattr(obj, "torchdynamo_force_dynamic"):
|
||||
return obj.torchdynamo_force_dynamic
|
||||
# For export, we will have to fix
|
||||
# 1) Input signature problem because params are lifted as inputs
|
||||
# 2) nn module stack info changes
|
||||
# 3) adjust failing tests
|
||||
if (
|
||||
isinstance(obj, torch.nn.Module)
|
||||
and config.inline_inbuilt_nn_modules
|
||||
and not is_export
|
||||
and (not is_export or config.install_free_tensors)
|
||||
):
|
||||
return True
|
||||
|
||||
|
@ -875,12 +875,30 @@ def is_from_global_source(source: Source):
|
||||
return isinstance(source, GlobalSource)
|
||||
|
||||
|
||||
def is_from_nonlocal_source(source: Source):
|
||||
if isinstance(source, ChainedSource):
|
||||
return is_from_nonlocal_source(source.base)
|
||||
return (
|
||||
isinstance(source, LocalSource)
|
||||
and source.is_derefed_cell_contents
|
||||
and not source.is_input
|
||||
)
|
||||
|
||||
|
||||
def is_from_source(source: Source, target: Source):
|
||||
if isinstance(source, ChainedSource):
|
||||
return is_from_source(source.base, target)
|
||||
return source == target
|
||||
|
||||
|
||||
def is_from_unspecialized_nn_module_source(source: Source):
|
||||
if isinstance(source, UnspecializedNNModuleSource):
|
||||
return True
|
||||
if isinstance(source, ChainedSource):
|
||||
return is_from_unspecialized_nn_module_source(source.base)
|
||||
return False
|
||||
|
||||
|
||||
def is_from_unspecialized_param_buffer_source(source: Source):
|
||||
if isinstance(source, UnspecializedParamBufferSource):
|
||||
return True
|
||||
|
@ -100,7 +100,10 @@ from ..source import (
|
||||
GetItemSource,
|
||||
GradSource,
|
||||
is_constant_source,
|
||||
is_from_global_source,
|
||||
is_from_nonlocal_source,
|
||||
is_from_optimizer_source,
|
||||
is_from_unspecialized_nn_module_source,
|
||||
ListGetItemSource,
|
||||
LocalSource,
|
||||
NumpyTensorSource,
|
||||
@ -552,9 +555,9 @@ class VariableBuilder:
|
||||
source_key = k
|
||||
|
||||
source_value = GetItemSource(self.get_source(), source_key)
|
||||
value = LazyVariableTracker.create(v, source_value)
|
||||
res_value = LazyVariableTracker.create(v, source_value)
|
||||
|
||||
return key, value
|
||||
return key, res_value
|
||||
|
||||
items = dict(build_key_value(k, v) for k, v in value.items())
|
||||
|
||||
@ -698,17 +701,17 @@ class VariableBuilder:
|
||||
# We need all the keys to be hashable. We do this within the
|
||||
# _HashableTracker class in dicts.py
|
||||
def build_key_value(i, k, v):
|
||||
base = self.get_source()
|
||||
if all_const:
|
||||
key = ConstantVariable.create(k)
|
||||
source_key = k
|
||||
else:
|
||||
source_key = ConstDictKeySource(self.get_source(), i)
|
||||
source_key = ConstDictKeySource(base, i)
|
||||
key = LazyVariableTracker.create(k, source_key)
|
||||
source_value = DictGetItemSource(base, source_key)
|
||||
res_value = LazyVariableTracker.create(v, source_value)
|
||||
|
||||
source_value = DictGetItemSource(self.get_source(), source_key)
|
||||
value = LazyVariableTracker.create(v, source_value)
|
||||
|
||||
return key, value
|
||||
return key, res_value
|
||||
|
||||
# Ensure that we call dict.keys and not value.keys (which can call
|
||||
# overridden keys method). In the C++ guards, we relied on
|
||||
@ -1339,13 +1342,14 @@ class VariableBuilder:
|
||||
# We need all the keys to be hashable. We do this within the
|
||||
# _HashableTracker class in dicts.py
|
||||
def build_key_value(i, k, v):
|
||||
source_key = ConstDictKeySource(self.get_source(), i)
|
||||
base = self.get_source()
|
||||
source_key = ConstDictKeySource(base, i)
|
||||
key = LazyVariableTracker.create(k, source_key)
|
||||
|
||||
source_value = DictGetItemSource(self.get_source(), source_key)
|
||||
value = LazyVariableTracker.create(v, source_value)
|
||||
source_value = DictGetItemSource(base, source_key)
|
||||
res_value = LazyVariableTracker.create(v, source_value)
|
||||
|
||||
return key, value
|
||||
return key, res_value
|
||||
|
||||
# Ensure that we call dict.keys and not value.keys (which can call
|
||||
# overridden keys method). In the C++ guards, we relied on
|
||||
@ -1769,15 +1773,26 @@ class VariableBuilder:
|
||||
self.mark_static_input(value, guard=is_parameter_freezing())
|
||||
is_static_input = True
|
||||
|
||||
# Install any tensors which are "free" variables; that is:
|
||||
# 1. Globals
|
||||
# 2. NonLocals
|
||||
# 3. tensors that are attributes of nn module
|
||||
should_install_free_tensor = config.install_free_tensors and (
|
||||
is_from_global_source(source)
|
||||
or is_from_nonlocal_source(source)
|
||||
or is_from_unspecialized_nn_module_source(source)
|
||||
)
|
||||
|
||||
make_graph_attribute = is_static_input and (
|
||||
not config.inline_inbuilt_nn_modules
|
||||
or is_parameter_freezing()
|
||||
or torch._dynamo.config.prepare_freezing
|
||||
)
|
||||
|
||||
if (
|
||||
source.guard_source().is_specialized_nn_module() or make_graph_attribute
|
||||
) and not source.guard_source().is_fsdp_module():
|
||||
if should_install_free_tensor or (
|
||||
(source.guard_source().is_specialized_nn_module() or make_graph_attribute)
|
||||
and not source.guard_source().is_fsdp_module()
|
||||
):
|
||||
self.assert_not_wrapped_by_this_graph(value)
|
||||
return self.tx.output.register_attr_or_module(
|
||||
value, self.name, source=source
|
||||
@ -2387,9 +2402,8 @@ def _dataclasses_fields_lambda(obj):
|
||||
for field in dataclasses.fields(value):
|
||||
source = None
|
||||
if obj.source:
|
||||
source = DictGetItemSource(
|
||||
AttrSource(obj.source, "__dataclass_fields__"), field.name
|
||||
)
|
||||
base_src = AttrSource(obj.source, "__dataclass_fields__")
|
||||
source = DictGetItemSource(base_src, field.name)
|
||||
items.append(UserDefinedObjectVariable(field, source=source))
|
||||
return TupleVariable(items)
|
||||
|
||||
|
@ -1261,7 +1261,7 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
||||
and isinstance(self, variables.UnspecializedNNModuleVariable)
|
||||
# export has some awkwardness around specialized and unspecialized modules. Skip wrapping source for export
|
||||
# usecase for now.
|
||||
and not tx.output.export
|
||||
and (not tx.output.export or torch._dynamo.config.install_free_tensors)
|
||||
):
|
||||
# Recalculate source for params/buffers
|
||||
if name in ("_buffers", "_parameters"):
|
||||
|
Reference in New Issue
Block a user