mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
The fixes are generated by: ```bash ruff check --fix --preview --unsafe-fixes --select=E226 . lintrunner -a --take "RUFF,PYFMT" --all-files ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/144415 Approved by: https://github.com/huydhn, https://github.com/Skylion007
273 lines
8.4 KiB
Python
273 lines
8.4 KiB
Python
# Owner(s): ["oncall: export"]
|
|
|
|
|
|
def random_dag(n: int):
|
|
"""
|
|
Util to generate a random DAG with n nodes.
|
|
|
|
The nodes are numbered 0, 1, ..., n-1. The DAG is generated by randomly
|
|
choosing a subset of edges from the complete graph on n nodes, such that
|
|
for each (i, j) we have i < j.
|
|
"""
|
|
import random
|
|
|
|
edges = {}
|
|
for i in range(n):
|
|
edges[i] = []
|
|
for j in range(i + 1, n):
|
|
if random.choice([True, False]):
|
|
edges[i].append(j)
|
|
|
|
return edges
|
|
|
|
|
|
class Block:
|
|
"""
|
|
Util to generate a block of Python-formatted code.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._code = []
|
|
|
|
def __repr__(self):
|
|
return "".join(self._code)
|
|
|
|
def new_line(self, line: str):
|
|
"""
|
|
Add a new line of code. The line is automatically suffixed
|
|
with a newline character.
|
|
"""
|
|
self._code.append(line + "\n")
|
|
|
|
def new_block(self, block: "Block"):
|
|
"""
|
|
Add a new block of code. All lines in the new block are
|
|
automatically prefixed by a tab character.
|
|
"""
|
|
self._code.extend(" " + line for line in block._code)
|
|
|
|
|
|
class TestGenerator:
|
|
"""
|
|
Abstract base class for generating test code.
|
|
|
|
Users should subclass this class and implement the test_name() and
|
|
test_body() methods. The test_name() method should return a string
|
|
that uniquely identifies the test. The test_body() method should
|
|
yield blocks of code.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._count = 0
|
|
|
|
def _generate_test_name(self):
|
|
self._count += 1
|
|
return f"{self.test_name()}_{self._count}"
|
|
|
|
def generate_test(self):
|
|
test_name = self._generate_test_name()
|
|
|
|
code = Block()
|
|
code.new_line(f"def {test_name}():")
|
|
for block in self.test_body():
|
|
code.new_block(block)
|
|
code.new_line(f"{test_name}()")
|
|
return str(code)
|
|
|
|
def test_name(self):
|
|
raise NotImplementedError
|
|
|
|
def test_body(self):
|
|
raise NotImplementedError
|
|
|
|
|
|
class NNModuleGenerator:
|
|
"""
|
|
Abstract base class for generating a nn.Module.
|
|
|
|
Users should subclass this class and implement the gen_init_body() and
|
|
gen_forward_body() methods. The gen_init_body() method should return a
|
|
block of code that initializes the nn.Module. The gen_forward_body() method
|
|
should return a block of code that defines the forward() of the nn.Module.
|
|
"""
|
|
|
|
def gen_init_body(self, i: int):
|
|
raise NotImplementedError
|
|
|
|
def gen_forward_body(self, i: int):
|
|
raise NotImplementedError
|
|
|
|
def gen_nn_module(self, i: int):
|
|
def gen_nn_module_body():
|
|
code = Block()
|
|
code.new_line("def __init__(self):")
|
|
code.new_block(self.gen_init_body(i))
|
|
code.new_line("def forward(self, x):")
|
|
code.new_block(self.gen_forward_body(i))
|
|
return code
|
|
|
|
code = Block()
|
|
code.new_line(f"class N{i}(torch.nn.Module):")
|
|
code.new_block(gen_nn_module_body())
|
|
return code
|
|
|
|
|
|
class Unflatten(TestGenerator):
|
|
"""
|
|
Generates test that unflattens a model with several nn.Modules that call
|
|
each other. The modules are generated by calling the nn_module_generator()
|
|
method.
|
|
|
|
The model is exported and then unflattened. The unflattened model is then
|
|
compared against the eager model.
|
|
"""
|
|
|
|
def __init__(self, n: int):
|
|
super().__init__()
|
|
self.n = n
|
|
|
|
def nn_module_generator(self):
|
|
class GenNNModule(NNModuleGenerator):
|
|
def __init__(self, n: int):
|
|
super().__init__()
|
|
self.n = n
|
|
self.calls = random_dag(self.n)
|
|
|
|
def gen_init_body(self, i: int):
|
|
code = Block()
|
|
code.new_line("super().__init__()")
|
|
if i < self.n - 1:
|
|
code.new_line(f"self.n{i + 1} = N{i + 1}()")
|
|
return code
|
|
|
|
def gen_forward_body(self, i: int):
|
|
def path(i, j):
|
|
if i + 1 == j:
|
|
return f"n{j}"
|
|
else:
|
|
return f"n{i + 1}.{path(i + 1, j)}"
|
|
|
|
code = Block()
|
|
for j in self.calls[i]:
|
|
code.new_line(f"x = self.{path(i, j)}(x + 1)")
|
|
code.new_line("return x + 1")
|
|
return code
|
|
|
|
return GenNNModule(self.n)
|
|
|
|
def test_name(self):
|
|
return f"{self.__class__.__name__}_{self.n}"
|
|
|
|
def test_body(self):
|
|
def path(i, j):
|
|
if i + 1 == j:
|
|
return f"n{j}"
|
|
else:
|
|
return f"n{i + 1}.{path(i + 1, j)}"
|
|
|
|
nn_module_generator = self.nn_module_generator()
|
|
for i in range(self.n):
|
|
yield nn_module_generator.gen_nn_module(self.n - 1 - i)
|
|
|
|
fqns = "".join(f"'{path(0, j)},'" for j in range(1, self.n))
|
|
|
|
def gen_main():
|
|
code = Block()
|
|
code.new_line("inp = (torch.ones(1),)")
|
|
code.new_line("eager = N0()(*inp)")
|
|
code.new_line(
|
|
f"ep = torch.export.export(N0(), inp, strict=False, preserve_module_call_signature=({fqns}))"
|
|
)
|
|
code.new_line("epm = ep.module()")
|
|
code.new_line("ufm = torch.export.unflatten(ep)")
|
|
code.new_line("assert torch.allclose(epm(*inp), eager)")
|
|
code.new_line("assert torch.allclose(ufm(*inp), eager)")
|
|
return code
|
|
|
|
yield gen_main()
|
|
|
|
|
|
class ConstantUnflatten(Unflatten):
|
|
"""
|
|
Generates test that unflattens a model with several nn.Modules that call
|
|
each other and access constants. The modules are generated by calling the
|
|
nn_module_generator() method.
|
|
"""
|
|
|
|
def nn_module_generator(self):
|
|
class GenNNModule(NNModuleGenerator):
|
|
def __init__(self, n):
|
|
super().__init__()
|
|
self.n = n
|
|
self.accesses = random_dag(self.n)
|
|
self.calls = random_dag(self.n)
|
|
|
|
def gen_init_body(self, i: int):
|
|
code = Block()
|
|
code.new_line("super().__init__()")
|
|
code.new_line("self.const = torch.ones(1)")
|
|
if i < self.n - 1:
|
|
code.new_line(f"self.n{i + 1} = N{i + 1}()")
|
|
return code
|
|
|
|
def gen_forward_body(self, i: int):
|
|
def path(i, j):
|
|
if i + 1 == j:
|
|
return f"n{j}"
|
|
else:
|
|
return f"n{i + 1}.{path(i + 1, j)}"
|
|
|
|
code = Block()
|
|
for j in self.accesses[i]:
|
|
code.new_line(f"x = x + self.{path(i, j)}.const")
|
|
for j in self.calls[i]:
|
|
code.new_line(f"x = self.{path(i, j)}(x + 1)")
|
|
code.new_line("return x + 1")
|
|
return code
|
|
|
|
return GenNNModule(self.n)
|
|
|
|
|
|
class BufferUnflatten(Unflatten):
|
|
"""
|
|
Generates test that unflattens a model with several nn.Modules that call
|
|
each other and access and mutate buffers. The modules are generated by
|
|
calling the nn_module_generator() method.
|
|
"""
|
|
|
|
def nn_module_generator(self):
|
|
class GenNNModule(NNModuleGenerator):
|
|
def __init__(self, n):
|
|
super().__init__()
|
|
self.n = n
|
|
self.accesses = random_dag(self.n)
|
|
self.mutations = random_dag(self.n)
|
|
self.calls = random_dag(self.n)
|
|
|
|
def gen_init_body(self, i: int):
|
|
code = Block()
|
|
code.new_line("super().__init__()")
|
|
code.new_line("self.buf = torch.nn.Buffer(torch.ones(1))")
|
|
if i < self.n - 1:
|
|
code.new_line(f"self.n{i + 1} = N{i + 1}()")
|
|
return code
|
|
|
|
def gen_forward_body(self, i: int):
|
|
def path(i, j):
|
|
if i + 1 == j:
|
|
return f"n{j}"
|
|
else:
|
|
return f"n{i + 1}.{path(i + 1, j)}"
|
|
|
|
code = Block()
|
|
for j in self.accesses[i]:
|
|
code.new_line(f"x = x + self.{path(i, j)}.buf")
|
|
for j in self.calls[i]:
|
|
code.new_line(f"x = self.{path(i, j)}(x + 1)")
|
|
for j in self.mutations[i]:
|
|
code.new_line(f"self.{path(i, j)}.buf.add_(1)")
|
|
code.new_line("return x + 1")
|
|
return code
|
|
|
|
return GenNNModule(self.n)
|