Files
pytorch/test/export/random_dag.py
Xuehai Pan dcc3cf7066 [BE] fix ruff rule E226: add missing whitespace around operator in f-strings (#144415)
The fixes are generated by:

```bash
ruff check --fix --preview --unsafe-fixes --select=E226 .
lintrunner -a --take "RUFF,PYFMT" --all-files
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144415
Approved by: https://github.com/huydhn, https://github.com/Skylion007
2025-01-08 21:55:00 +00:00

273 lines
8.4 KiB
Python

# Owner(s): ["oncall: export"]
def random_dag(n: int):
"""
Util to generate a random DAG with n nodes.
The nodes are numbered 0, 1, ..., n-1. The DAG is generated by randomly
choosing a subset of edges from the complete graph on n nodes, such that
for each (i, j) we have i < j.
"""
import random
edges = {}
for i in range(n):
edges[i] = []
for j in range(i + 1, n):
if random.choice([True, False]):
edges[i].append(j)
return edges
class Block:
"""
Util to generate a block of Python-formatted code.
"""
def __init__(self):
self._code = []
def __repr__(self):
return "".join(self._code)
def new_line(self, line: str):
"""
Add a new line of code. The line is automatically suffixed
with a newline character.
"""
self._code.append(line + "\n")
def new_block(self, block: "Block"):
"""
Add a new block of code. All lines in the new block are
automatically prefixed by a tab character.
"""
self._code.extend(" " + line for line in block._code)
class TestGenerator:
"""
Abstract base class for generating test code.
Users should subclass this class and implement the test_name() and
test_body() methods. The test_name() method should return a string
that uniquely identifies the test. The test_body() method should
yield blocks of code.
"""
def __init__(self):
self._count = 0
def _generate_test_name(self):
self._count += 1
return f"{self.test_name()}_{self._count}"
def generate_test(self):
test_name = self._generate_test_name()
code = Block()
code.new_line(f"def {test_name}():")
for block in self.test_body():
code.new_block(block)
code.new_line(f"{test_name}()")
return str(code)
def test_name(self):
raise NotImplementedError
def test_body(self):
raise NotImplementedError
class NNModuleGenerator:
"""
Abstract base class for generating a nn.Module.
Users should subclass this class and implement the gen_init_body() and
gen_forward_body() methods. The gen_init_body() method should return a
block of code that initializes the nn.Module. The gen_forward_body() method
should return a block of code that defines the forward() of the nn.Module.
"""
def gen_init_body(self, i: int):
raise NotImplementedError
def gen_forward_body(self, i: int):
raise NotImplementedError
def gen_nn_module(self, i: int):
def gen_nn_module_body():
code = Block()
code.new_line("def __init__(self):")
code.new_block(self.gen_init_body(i))
code.new_line("def forward(self, x):")
code.new_block(self.gen_forward_body(i))
return code
code = Block()
code.new_line(f"class N{i}(torch.nn.Module):")
code.new_block(gen_nn_module_body())
return code
class Unflatten(TestGenerator):
"""
Generates test that unflattens a model with several nn.Modules that call
each other. The modules are generated by calling the nn_module_generator()
method.
The model is exported and then unflattened. The unflattened model is then
compared against the eager model.
"""
def __init__(self, n: int):
super().__init__()
self.n = n
def nn_module_generator(self):
class GenNNModule(NNModuleGenerator):
def __init__(self, n: int):
super().__init__()
self.n = n
self.calls = random_dag(self.n)
def gen_init_body(self, i: int):
code = Block()
code.new_line("super().__init__()")
if i < self.n - 1:
code.new_line(f"self.n{i + 1} = N{i + 1}()")
return code
def gen_forward_body(self, i: int):
def path(i, j):
if i + 1 == j:
return f"n{j}"
else:
return f"n{i + 1}.{path(i + 1, j)}"
code = Block()
for j in self.calls[i]:
code.new_line(f"x = self.{path(i, j)}(x + 1)")
code.new_line("return x + 1")
return code
return GenNNModule(self.n)
def test_name(self):
return f"{self.__class__.__name__}_{self.n}"
def test_body(self):
def path(i, j):
if i + 1 == j:
return f"n{j}"
else:
return f"n{i + 1}.{path(i + 1, j)}"
nn_module_generator = self.nn_module_generator()
for i in range(self.n):
yield nn_module_generator.gen_nn_module(self.n - 1 - i)
fqns = "".join(f"'{path(0, j)},'" for j in range(1, self.n))
def gen_main():
code = Block()
code.new_line("inp = (torch.ones(1),)")
code.new_line("eager = N0()(*inp)")
code.new_line(
f"ep = torch.export.export(N0(), inp, strict=False, preserve_module_call_signature=({fqns}))"
)
code.new_line("epm = ep.module()")
code.new_line("ufm = torch.export.unflatten(ep)")
code.new_line("assert torch.allclose(epm(*inp), eager)")
code.new_line("assert torch.allclose(ufm(*inp), eager)")
return code
yield gen_main()
class ConstantUnflatten(Unflatten):
"""
Generates test that unflattens a model with several nn.Modules that call
each other and access constants. The modules are generated by calling the
nn_module_generator() method.
"""
def nn_module_generator(self):
class GenNNModule(NNModuleGenerator):
def __init__(self, n):
super().__init__()
self.n = n
self.accesses = random_dag(self.n)
self.calls = random_dag(self.n)
def gen_init_body(self, i: int):
code = Block()
code.new_line("super().__init__()")
code.new_line("self.const = torch.ones(1)")
if i < self.n - 1:
code.new_line(f"self.n{i + 1} = N{i + 1}()")
return code
def gen_forward_body(self, i: int):
def path(i, j):
if i + 1 == j:
return f"n{j}"
else:
return f"n{i + 1}.{path(i + 1, j)}"
code = Block()
for j in self.accesses[i]:
code.new_line(f"x = x + self.{path(i, j)}.const")
for j in self.calls[i]:
code.new_line(f"x = self.{path(i, j)}(x + 1)")
code.new_line("return x + 1")
return code
return GenNNModule(self.n)
class BufferUnflatten(Unflatten):
"""
Generates test that unflattens a model with several nn.Modules that call
each other and access and mutate buffers. The modules are generated by
calling the nn_module_generator() method.
"""
def nn_module_generator(self):
class GenNNModule(NNModuleGenerator):
def __init__(self, n):
super().__init__()
self.n = n
self.accesses = random_dag(self.n)
self.mutations = random_dag(self.n)
self.calls = random_dag(self.n)
def gen_init_body(self, i: int):
code = Block()
code.new_line("super().__init__()")
code.new_line("self.buf = torch.nn.Buffer(torch.ones(1))")
if i < self.n - 1:
code.new_line(f"self.n{i + 1} = N{i + 1}()")
return code
def gen_forward_body(self, i: int):
def path(i, j):
if i + 1 == j:
return f"n{j}"
else:
return f"n{i + 1}.{path(i + 1, j)}"
code = Block()
for j in self.accesses[i]:
code.new_line(f"x = x + self.{path(i, j)}.buf")
for j in self.calls[i]:
code.new_line(f"x = self.{path(i, j)}(x + 1)")
for j in self.mutations[i]:
code.new_line(f"self.{path(i, j)}.buf.add_(1)")
code.new_line("return x + 1")
return code
return GenNNModule(self.n)