mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151134 Approved by: https://github.com/anijain2305 ghstack dependencies: #152036
562 lines
20 KiB
Python
562 lines
20 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
import unittest
|
|
from collections.abc import Sequence
|
|
from typing import Any, Callable, Union
|
|
|
|
import torch
|
|
import torch._dynamo
|
|
import torch._dynamo.test_case
|
|
import torch._dynamo.testing
|
|
from torch._dynamo.testing import EagerAndRecordGraphs
|
|
from torch.fx.graph_module import GraphModule
|
|
|
|
|
|
def compile_and_extract_graph(
|
|
fn, *args, **kwargs
|
|
) -> tuple[Callable, list[torch.fx.GraphModule]]:
|
|
backend = EagerAndRecordGraphs()
|
|
result_fn = torch.compile(backend=backend, fullgraph=True)(fn)
|
|
# Run fn to capture graph
|
|
_ = result_fn(*args, **kwargs)
|
|
return result_fn, backend.graphs
|
|
|
|
|
|
def get_num_input_nodes(graph: GraphModule) -> int:
|
|
"""Returns the number of input nodes in the input GraphModule
|
|
by counting the number of placeholder tensors
|
|
"""
|
|
placeholder_cnt = 0
|
|
for node in graph.graph.nodes:
|
|
# Missing in some export tests so check manually
|
|
placeholder_is_tensor = "example_value" in node.meta and isinstance(
|
|
node.meta["example_value"], torch.Tensor
|
|
)
|
|
if node.op == "placeholder" and placeholder_is_tensor:
|
|
placeholder_cnt += 1
|
|
return placeholder_cnt
|
|
|
|
|
|
class SimpleLinearModule(torch.nn.Module):
|
|
"""
|
|
Simple linear model with 1 parameter and 1 buffer
|
|
for basic testing purposes
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.fwd = torch.nn.Linear(5, 1)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return self.fwd(x)
|
|
|
|
|
|
class ResBlock(torch.nn.Module):
|
|
"""
|
|
Basic resnet building block - used for testing structure
|
|
more typical of real models (i.e sequential, activations,
|
|
and batchnorm)
|
|
"""
|
|
|
|
def __init__(self, in_: int, out_: int):
|
|
super().__init__()
|
|
self.conv1 = torch.nn.Sequential(
|
|
torch.nn.Conv2d(in_, out_, kernel_size=3, padding=1),
|
|
torch.nn.BatchNorm2d(out_),
|
|
torch.nn.ReLU(),
|
|
)
|
|
self.conv2 = torch.nn.Sequential(
|
|
torch.nn.Conv2d(out_, out_, kernel_size=3, padding=1),
|
|
torch.nn.BatchNorm2d(out_),
|
|
)
|
|
self.activation = torch.nn.ReLU()
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
skip = x
|
|
out = self.conv1(x)
|
|
out = self.conv2(out)
|
|
out += skip
|
|
out = self.activation(out)
|
|
return out
|
|
|
|
|
|
class InstallParamsAsGraphAttrTests(torch._dynamo.test_case.TestCase):
|
|
@torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
|
|
@torch._dynamo.config.patch(install_free_tensors=False)
|
|
def check_num_inputs_and_equality_no_install(
|
|
self,
|
|
fn_to_compile: Union[torch.nn.Module, Callable],
|
|
expected_num_inline_inputs: int,
|
|
example_inputs: Sequence[Any],
|
|
) -> None:
|
|
"""Compiles the original fn, then:
|
|
* Checks that the number of inputs in the graph is expected_num_inputs
|
|
* Checks that the compiled fn and original fn are equal
|
|
"""
|
|
# inlined ex
|
|
opt_fn, graphs = compile_and_extract_graph(fn_to_compile, *example_inputs)
|
|
self.assertEqual(len(graphs), 1, msg="Expected 1 graph (no breaks)")
|
|
actual_num_inputs = get_num_input_nodes(graphs[0])
|
|
self.assertEqual(actual_num_inputs, expected_num_inline_inputs)
|
|
self.assertEqual(opt_fn(*example_inputs), fn_to_compile(*example_inputs))
|
|
|
|
@torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
|
|
@torch._dynamo.config.patch(install_free_tensors=True)
|
|
def check_num_inputs_and_equality_install(
|
|
self,
|
|
fn_to_compile: Union[torch.nn.Module, Callable],
|
|
expected_num_installed_inputs: int,
|
|
example_inputs: Sequence[Any],
|
|
) -> None:
|
|
"""Compiles the original fn, then:
|
|
* Checks the number of inputs when installed is consistent with original_fn
|
|
# Checks that the compiled fn when installed and original fn are equal
|
|
"""
|
|
opt_installed_fn, graphs = compile_and_extract_graph(
|
|
fn_to_compile, *example_inputs
|
|
)
|
|
self.assertEqual(len(graphs), 1, msg="Expected 1 graph (no breaks)")
|
|
actual_num_inputs = get_num_input_nodes(graphs[0])
|
|
self.assertEqual(actual_num_inputs, expected_num_installed_inputs)
|
|
self.assertEqual(
|
|
opt_installed_fn(*example_inputs), fn_to_compile(*example_inputs)
|
|
)
|
|
|
|
# ==================== Test Params and Buffer from NN Module ====================
|
|
def test_optimizing_linear(self) -> None:
|
|
net = SimpleLinearModule()
|
|
input1 = torch.randn((1, 5))
|
|
# Expected: 1 + 1 * 2 = 3
|
|
self.check_num_inputs_and_equality_no_install(net, 3, (input1,))
|
|
self.check_num_inputs_and_equality_install(net, 1, (input1,))
|
|
|
|
def test_breadth_linear(self) -> None:
|
|
class BreadthModel(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.fwd = torch.nn.Linear(1, 1)
|
|
self.fwd2 = torch.nn.Linear(1, 1)
|
|
self.fwd3 = torch.nn.Linear(1, 1)
|
|
self.fwd4 = torch.nn.Linear(1, 1)
|
|
self.fwd5 = torch.nn.Linear(1, 1)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return (
|
|
self.fwd(x)
|
|
+ self.fwd2(x)
|
|
+ self.fwd3(x)
|
|
+ self.fwd4(x)
|
|
+ self.fwd5(x)
|
|
)
|
|
|
|
net = BreadthModel()
|
|
input1 = torch.randn((1, 1))
|
|
# Expected: 1 + 5 * 2 = 11
|
|
self.check_num_inputs_and_equality_no_install(net, 11, (input1,))
|
|
self.check_num_inputs_and_equality_install(net, 1, (input1,))
|
|
|
|
def test_nested_linear(self) -> None:
|
|
class NestedModel(torch.nn.Module):
|
|
def __init__(self, inner_module: torch.nn.Module) -> None:
|
|
super().__init__()
|
|
self.fwd = torch.nn.Linear(1, 1)
|
|
self.inner_module = inner_module
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return self.fwd(self.inner_module(x))
|
|
|
|
# Nest 5x
|
|
kDepth = 4
|
|
net = SimpleLinearModule()
|
|
for _ in range(kDepth):
|
|
net = NestedModel(net)
|
|
input1 = torch.randn((1, 5))
|
|
self.check_num_inputs_and_equality_no_install(
|
|
net, 1 + 2 * (kDepth + 1), (input1,)
|
|
)
|
|
self.check_num_inputs_and_equality_install(net, 1, (input1,))
|
|
|
|
def test_simple_batchnorm(self) -> None:
|
|
net = torch.nn.BatchNorm2d(3)
|
|
tensor = torch.randn((1, 3, 3, 3))
|
|
# BatchNorm2d has 2 params, and 3 buffers
|
|
self.check_num_inputs_and_equality_no_install(net, 6, (tensor,))
|
|
self.check_num_inputs_and_equality_install(net, 1, (tensor,))
|
|
|
|
def test_nets_as_input(self) -> None:
|
|
"""
|
|
Tests when the nn.Module is an input to the fn we are optimizing
|
|
|
|
In this case, we should treat it as regular input, which means we
|
|
can lift parameters/buffers, but should not install them
|
|
"""
|
|
# Test nn model as input
|
|
net = SimpleLinearModule()
|
|
net2 = SimpleLinearModule()
|
|
x = torch.randn(1, 5)
|
|
|
|
def test_fn(x: torch.Tensor, net: torch.nn.Module) -> torch.Tensor:
|
|
return net(x)
|
|
|
|
# When nn is in input, we don't install the params
|
|
self.check_num_inputs_and_equality_no_install(test_fn, 3, (x, net))
|
|
self.check_num_inputs_and_equality_install(test_fn, 1, (x, net))
|
|
|
|
def test_fn2(
|
|
x: torch.Tensor, net: torch.nn.Module, net2: torch.nn.Module
|
|
) -> torch.Tensor:
|
|
return net(x) + net2(x)
|
|
|
|
self.check_num_inputs_and_equality_no_install(test_fn2, 5, (x, net, net2))
|
|
self.check_num_inputs_and_equality_install(test_fn2, 1, (x, net, net2))
|
|
|
|
def test_fn3(x: torch.Tensor, net: torch.nn.Module) -> torch.Tensor:
|
|
return net(x) + net2(x)
|
|
|
|
# In case of local scope (net2 here), we can install
|
|
self.check_num_inputs_and_equality_no_install(test_fn3, 5, (x, net))
|
|
self.check_num_inputs_and_equality_install(test_fn3, 1, (x, net))
|
|
|
|
def test_fn_list(x: torch.Tensor, nets: list[torch.nn.Module]):
|
|
return sum([net(x) for net in nets])
|
|
|
|
self.check_num_inputs_and_equality_no_install(test_fn_list, 5, (x, [net, net2]))
|
|
self.check_num_inputs_and_equality_install(test_fn_list, 1, (x, [net, net2]))
|
|
|
|
def test_resnet_structure(self) -> None:
|
|
net = ResBlock(3, 3)
|
|
tensor = torch.randn(1, 3, 3, 3)
|
|
# Conv2d has 2 params, BatchNorm2d has 3 buffers + 2 params, and Relu has 0 params
|
|
# So expected = 2 + 5 + 5 + 2 = 14 + 1 for input
|
|
self.check_num_inputs_and_equality_no_install(net, 15, (tensor,))
|
|
self.check_num_inputs_and_equality_install(net, 1, (tensor,))
|
|
|
|
def test_transformer(self) -> None:
|
|
# needs eval mode - must disable dropout
|
|
transformer = torch.nn.Transformer(d_model=32).eval()
|
|
src = torch.rand(10, 32, 32)
|
|
tgt = torch.rand(20, 32, 32)
|
|
|
|
self.check_num_inputs_and_equality_no_install(transformer, 186, (src, tgt))
|
|
self.check_num_inputs_and_equality_install(transformer, 2, (src, tgt))
|
|
|
|
# ==================== Test Parameters and Buffers as input ====================
|
|
def test_optimizing_params_in_input(self) -> None:
|
|
param = torch.nn.Parameter(torch.randn(1, 5))
|
|
net = SimpleLinearModule()
|
|
|
|
def test_fn(x: torch.Tensor) -> torch.Tensor:
|
|
return net(x)
|
|
|
|
self.check_num_inputs_and_equality_no_install(test_fn, 3, (param,))
|
|
self.check_num_inputs_and_equality_install(test_fn, 1, (param,))
|
|
|
|
x = torch.randn(1, 5)
|
|
|
|
def test_fn2(x: torch.Tensor, param: torch.nn.Parameter) -> torch.Tensor:
|
|
return net(x) + param
|
|
|
|
# net gets installed, param does not here
|
|
self.check_num_inputs_and_equality_no_install(test_fn2, 4, (x, param))
|
|
self.check_num_inputs_and_equality_install(test_fn2, 2, (x, param))
|
|
|
|
global global_param
|
|
global_param = torch.nn.Parameter(torch.randn(1, 5))
|
|
|
|
def test_fn3(x: torch.Tensor) -> torch.Tensor:
|
|
return net(x) + global_param
|
|
|
|
# net and global does too
|
|
self.check_num_inputs_and_equality_no_install(test_fn3, 4, (x,))
|
|
self.check_num_inputs_and_equality_install(test_fn3, 1, (x,))
|
|
|
|
def test_fn4(
|
|
x: torch.Tensor, list_params: list[torch.nn.Parameter]
|
|
) -> torch.Tensor:
|
|
return net(x) + sum(list_params)
|
|
|
|
# list_params should not be installed
|
|
self.check_num_inputs_and_equality_no_install(test_fn4, 4, (x, [param, param]))
|
|
self.check_num_inputs_and_equality_install(test_fn4, 2, (x, [param, param]))
|
|
|
|
def test_optimizing_buffer_in_input(self) -> None:
|
|
buf = torch.nn.Buffer(data=torch.ones((1, 5)))
|
|
net = SimpleLinearModule()
|
|
|
|
def test_fn(x: torch.Tensor) -> torch.Tensor:
|
|
return net(x)
|
|
|
|
self.check_num_inputs_and_equality_no_install(test_fn, 3, (buf,))
|
|
self.check_num_inputs_and_equality_install(test_fn, 1, (buf,))
|
|
|
|
x = torch.randn(1, 5)
|
|
|
|
def test_fn2(x: torch.Tensor, buf: torch.nn.Buffer):
|
|
return net(x) + buf
|
|
|
|
# net gets installed, buf does not here
|
|
self.check_num_inputs_and_equality_no_install(test_fn2, 4, (x, buf))
|
|
self.check_num_inputs_and_equality_install(test_fn2, 2, (x, buf))
|
|
|
|
global global_buf
|
|
global_buf = torch.nn.Buffer(torch.randn(1, 5))
|
|
|
|
def test_fn3(x: torch.Tensor) -> torch.Tensor:
|
|
return net(x) + global_buf
|
|
|
|
# net and global does too
|
|
self.check_num_inputs_and_equality_no_install(test_fn3, 4, (x,))
|
|
self.check_num_inputs_and_equality_install(test_fn3, 1, (x,))
|
|
|
|
def test_optimizing_buffer_and_param_in_input(self) -> None:
|
|
param = torch.nn.Parameter(torch.randn(5, 1))
|
|
buf = torch.nn.Buffer(data=torch.ones((1, 1)))
|
|
x = torch.randn(1, 5)
|
|
|
|
def test_linear(x: torch.Tensor) -> torch.Tensor:
|
|
return param * x + buf
|
|
|
|
self.check_num_inputs_and_equality_no_install(test_linear, 3, (x,))
|
|
self.check_num_inputs_and_equality_install(test_linear, 1, (x,))
|
|
|
|
def test_linear_explicit(
|
|
x: torch.Tensor, a: torch.Tensor, b: torch.Tensor
|
|
) -> torch.Tensor:
|
|
return a * x + b
|
|
|
|
# Now, param and buf are input so should not be inlined
|
|
self.check_num_inputs_and_equality_no_install(
|
|
test_linear_explicit, 3, (x, param, buf)
|
|
)
|
|
self.check_num_inputs_and_equality_install(
|
|
test_linear_explicit, 3, (x, param, buf)
|
|
)
|
|
|
|
|
|
class InstallParamsWhenExport(torch._dynamo.test_case.TestCase):
|
|
@torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
|
|
@torch._dynamo.config.patch(install_free_tensors=True)
|
|
def check_export_matches_expectation(
|
|
self,
|
|
fn_to_export: Callable,
|
|
expected_num_exported_inputs: int,
|
|
example_inputs: Sequence[Any],
|
|
) -> None:
|
|
"""Exports the original fn, then:
|
|
* Checks that the number of inputs in the exported is expected_num_exported_inputs
|
|
* Checks that the exported fn and original fn are equal
|
|
"""
|
|
exported_fn = torch._dynamo.export(fn_to_export)
|
|
out_graph = exported_fn(*example_inputs)[0]
|
|
actual_num_inputs = get_num_input_nodes(out_graph)
|
|
self.assertEqual(actual_num_inputs, expected_num_exported_inputs)
|
|
self.assertEqual(out_graph(*example_inputs), fn_to_export(*example_inputs))
|
|
|
|
def test_simple_linear(self) -> None:
|
|
net = SimpleLinearModule()
|
|
input1 = torch.randn((1, 5))
|
|
self.check_export_matches_expectation(net, 1, (input1,))
|
|
|
|
def test_fn(x: torch.Tensor) -> torch.Tensor:
|
|
return net(x)
|
|
|
|
self.check_export_matches_expectation(test_fn, 1, (input1,))
|
|
|
|
# Check multiple inputs
|
|
def test_fn_2(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
return net(x) + net(y)
|
|
|
|
input2 = torch.randn((1, 5))
|
|
self.check_export_matches_expectation(test_fn_2, 2, (input1, input2))
|
|
|
|
def test_simple_batchnorm(self) -> None:
|
|
net = torch.nn.BatchNorm2d(3)
|
|
tensor = torch.randn((1, 3, 3, 3))
|
|
self.check_export_matches_expectation(net, 1, (tensor,))
|
|
|
|
def test_fn(x: torch.Tensor) -> torch.Tensor:
|
|
return net(x)
|
|
|
|
self.check_export_matches_expectation(test_fn, 1, (tensor,))
|
|
|
|
def test_resnet_structure(self) -> None:
|
|
net = ResBlock(3, 3)
|
|
tensor = torch.randn(1, 3, 3, 3)
|
|
self.check_export_matches_expectation(net, 1, (tensor,))
|
|
|
|
def test_fn(x: torch.Tensor) -> torch.Tensor:
|
|
return net(x)
|
|
|
|
self.check_export_matches_expectation(test_fn, 1, (tensor,))
|
|
|
|
def test_transformer(self) -> None:
|
|
transformer = torch.nn.Transformer(d_model=32).eval()
|
|
src = torch.rand(10, 32, 32)
|
|
tgt = torch.rand(20, 32, 32)
|
|
|
|
self.check_export_matches_expectation(transformer, 2, (src, tgt))
|
|
|
|
def test_fn(src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor:
|
|
return transformer(src, tgt)
|
|
|
|
self.check_export_matches_expectation(test_fn, 2, (src, tgt))
|
|
|
|
def test_optimizing_params_in_input(self) -> None:
|
|
param = torch.nn.Parameter(torch.randn(1, 5))
|
|
net = SimpleLinearModule()
|
|
|
|
def test_fn(x: torch.Tensor) -> torch.Tensor:
|
|
return net(x)
|
|
|
|
self.check_export_matches_expectation(net, 1, (param,))
|
|
self.check_export_matches_expectation(test_fn, 1, (param,))
|
|
|
|
x = torch.randn(1, 5)
|
|
|
|
def test_fn2(x: torch.Tensor, param: torch.nn.Parameter) -> torch.Tensor:
|
|
return net(x) + param
|
|
|
|
# net gets installed, param does not here
|
|
self.check_export_matches_expectation(test_fn2, 2, (x, param))
|
|
|
|
def test_fn3(
|
|
x: torch.Tensor, list_params: list[torch.nn.Parameter]
|
|
) -> torch.Tensor:
|
|
return net(x) + sum(list_params)
|
|
|
|
# list_params should not be installed or inlined here
|
|
self.check_export_matches_expectation(test_fn3, 2, (x, [param, param]))
|
|
|
|
def test_optimizing_buffer_in_input(self) -> None:
|
|
buf = torch.nn.Buffer(data=torch.ones((1, 5)))
|
|
net = SimpleLinearModule()
|
|
|
|
def test_fn(x: torch.Tensor) -> torch.Tensor:
|
|
return net(x)
|
|
|
|
self.check_export_matches_expectation(net, 1, (buf,))
|
|
self.check_export_matches_expectation(test_fn, 1, (buf,))
|
|
|
|
x = torch.randn(1, 5)
|
|
|
|
def test_fn2(x: torch.Tensor, buf: torch.nn.Buffer) -> torch.Tensor:
|
|
return net(x) + buf
|
|
|
|
# net gets installed, buf does not here
|
|
self.check_export_matches_expectation(test_fn2, 2, (x, buf))
|
|
|
|
def test_optimizing_buffer_and_param_in_input(self) -> None:
|
|
param = torch.nn.Parameter(torch.randn(5, 1))
|
|
buf = torch.nn.Buffer(data=torch.ones((1, 1)))
|
|
x = torch.randn(1, 5)
|
|
|
|
def test_linear_explicit(
|
|
x: torch.Tensor, a: torch.Tensor, b: torch.Tensor
|
|
) -> torch.Tensor:
|
|
return a * x + b
|
|
|
|
# Now, param and buf are input so should not be inlined
|
|
self.check_export_matches_expectation(test_linear_explicit, 3, (x, param, buf))
|
|
|
|
def test_global_tensor_export(self) -> None:
|
|
global x
|
|
x = torch.randn((5, 5))
|
|
|
|
def fn(a: torch.Tensor) -> torch.Tensor:
|
|
return a + x
|
|
|
|
inp = torch.randn(5, 5)
|
|
self.check_export_matches_expectation(fn, 1, (inp,))
|
|
|
|
def test_nonlocal_closure(self) -> None:
|
|
x = torch.randn((5, 5))
|
|
|
|
def fn(a: torch.Tensor) -> torch.Tensor:
|
|
return a + x
|
|
|
|
inp = torch.randn((5, 5))
|
|
self.check_export_matches_expectation(fn, 1, (inp,))
|
|
|
|
@torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
|
|
@torch._dynamo.config.patch(install_free_tensors=True)
|
|
def test_modify_net_state(self) -> None:
|
|
class Mod(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(5, 5)
|
|
self.a = None
|
|
|
|
def forward(self, x):
|
|
if self.a is None:
|
|
self.a = torch.ones_like(x)
|
|
return self.linear(x) + self.a
|
|
|
|
mod = Mod()
|
|
inp = torch.randn(5, 5)
|
|
# NOTE: since this fn modifies original class,
|
|
# need to get reference value before tracing
|
|
res = mod(inp)
|
|
mod.a = None
|
|
ep = torch._dynamo.export(mod)
|
|
graph, _ = ep(inp)
|
|
self.assertEqual(graph(inp), res)
|
|
|
|
def test_list_of_tensor(self) -> None:
|
|
def fn(x: list[torch.Tensor]):
|
|
return x[0] + x[1]
|
|
|
|
inp = [torch.tensor([1.3, 3.77, 0.1]), torch.tensor([8.7, 6.23, 9.9])]
|
|
self.check_export_matches_expectation(fn, 2, (inp,))
|
|
|
|
def test_nested_list_of_tensor(self) -> None:
|
|
def fn(x: list[Union[list[torch.Tensor], torch.Tensor]]):
|
|
return x[0][0] + x[1] # type: ignore[index]
|
|
|
|
inp = [[torch.tensor([1.3, 3.77, 0.1])], torch.tensor([8.7, 6.23, 9.9])]
|
|
self.check_export_matches_expectation(fn, 2, (inp,))
|
|
|
|
def test_dict_of_tensor(self) -> None:
|
|
inp_dict = {"temp": torch.tensor(12)}
|
|
|
|
def fn(inp: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
return inp_dict["temp"] + 5
|
|
|
|
self.check_export_matches_expectation(fn, 1, (inp_dict,))
|
|
|
|
# TODO[lucaskabela]: register the flatten/unflatten function so we can evaluate this test
|
|
@unittest.expectedFailure
|
|
def test_user_defined_object(self) -> None:
|
|
class UserDefinedTestClass:
|
|
def __init__(self, x, y) -> None:
|
|
self.x = x
|
|
self.y = y
|
|
|
|
x = torch.randn((3, 3))
|
|
y = torch.randn((3, 3))
|
|
|
|
def fn(obj: UserDefinedTestClass, inp: torch.Tensor) -> torch.Tensor:
|
|
return obj.x + obj.y + inp
|
|
|
|
z = torch.randn((3, 1))
|
|
|
|
self.check_export_matches_expectation(fn, 2, (UserDefinedTestClass(x, y), z))
|
|
|
|
def test_tensors_as_nn_attr(self) -> None:
|
|
class Mod(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.a = torch.ones((5, 5))
|
|
self.b = torch.ones((5, 5))
|
|
|
|
def forward(self, x):
|
|
return self.a + self.b + x
|
|
|
|
mod = Mod()
|
|
inp = torch.randn(5, 5)
|
|
self.check_export_matches_expectation(mod, 1, (inp,))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|