# Owner(s): ["oncall: export"] def random_dag(n: int): """ Util to generate a random DAG with n nodes. The nodes are numbered 0, 1, ..., n-1. The DAG is generated by randomly choosing a subset of edges from the complete graph on n nodes, such that for each (i, j) we have i < j. """ import random edges = {} for i in range(n): edges[i] = [] for j in range(i + 1, n): if random.choice([True, False]): edges[i].append(j) return edges class Block: """ Util to generate a block of Python-formatted code. """ def __init__(self): self._code = [] def __repr__(self): return "".join(self._code) def new_line(self, line: str): """ Add a new line of code. The line is automatically suffixed with a newline character. """ self._code.append(line + "\n") def new_block(self, block: "Block"): """ Add a new block of code. All lines in the new block are automatically prefixed by a tab character. """ self._code.extend(" " + line for line in block._code) class TestGenerator: """ Abstract base class for generating test code. Users should subclass this class and implement the test_name() and test_body() methods. The test_name() method should return a string that uniquely identifies the test. The test_body() method should yield blocks of code. """ def __init__(self): self._count = 0 def _generate_test_name(self): self._count += 1 return f"{self.test_name()}_{self._count}" def generate_test(self): test_name = self._generate_test_name() code = Block() code.new_line(f"def {test_name}():") for block in self.test_body(): code.new_block(block) code.new_line(f"{test_name}()") return str(code) def test_name(self): raise NotImplementedError def test_body(self): raise NotImplementedError class NNModuleGenerator: """ Abstract base class for generating a nn.Module. Users should subclass this class and implement the gen_init_body() and gen_forward_body() methods. The gen_init_body() method should return a block of code that initializes the nn.Module. The gen_forward_body() method should return a block of code that defines the forward() of the nn.Module. """ def gen_init_body(self, i: int): raise NotImplementedError def gen_forward_body(self, i: int): raise NotImplementedError def gen_nn_module(self, i: int): def gen_nn_module_body(): code = Block() code.new_line("def __init__(self):") code.new_block(self.gen_init_body(i)) code.new_line("def forward(self, x):") code.new_block(self.gen_forward_body(i)) return code code = Block() code.new_line(f"class N{i}(torch.nn.Module):") code.new_block(gen_nn_module_body()) return code class Unflatten(TestGenerator): """ Generates test that unflattens a model with several nn.Modules that call each other. The modules are generated by calling the nn_module_generator() method. The model is exported and then unflattened. The unflattened model is then compared against the eager model. """ def __init__(self, n: int): super().__init__() self.n = n def nn_module_generator(self): class GenNNModule(NNModuleGenerator): def __init__(self, n: int): super().__init__() self.n = n self.calls = random_dag(self.n) def gen_init_body(self, i: int): code = Block() code.new_line("super().__init__()") if i < self.n - 1: code.new_line(f"self.n{i + 1} = N{i + 1}()") return code def gen_forward_body(self, i: int): def path(i, j): if i + 1 == j: return f"n{j}" else: return f"n{i + 1}.{path(i + 1, j)}" code = Block() for j in self.calls[i]: code.new_line(f"x = self.{path(i, j)}(x + 1)") code.new_line("return x + 1") return code return GenNNModule(self.n) def test_name(self): return f"{self.__class__.__name__}_{self.n}" def test_body(self): def path(i, j): if i + 1 == j: return f"n{j}" else: return f"n{i + 1}.{path(i + 1, j)}" nn_module_generator = self.nn_module_generator() for i in range(self.n): yield nn_module_generator.gen_nn_module(self.n - 1 - i) fqns = "".join(f"'{path(0, j)},'" for j in range(1, self.n)) def gen_main(): code = Block() code.new_line("inp = (torch.ones(1),)") code.new_line("eager = N0()(*inp)") code.new_line( f"ep = torch.export.export(N0(), inp, strict=False, preserve_module_call_signature=({fqns}))" ) code.new_line("epm = ep.module()") code.new_line("ufm = torch.export.unflatten(ep)") code.new_line("assert torch.allclose(epm(*inp), eager)") code.new_line("assert torch.allclose(ufm(*inp), eager)") return code yield gen_main() class ConstantUnflatten(Unflatten): """ Generates test that unflattens a model with several nn.Modules that call each other and access constants. The modules are generated by calling the nn_module_generator() method. """ def nn_module_generator(self): class GenNNModule(NNModuleGenerator): def __init__(self, n): super().__init__() self.n = n self.accesses = random_dag(self.n) self.calls = random_dag(self.n) def gen_init_body(self, i: int): code = Block() code.new_line("super().__init__()") code.new_line("self.const = torch.ones(1)") if i < self.n - 1: code.new_line(f"self.n{i + 1} = N{i + 1}()") return code def gen_forward_body(self, i: int): def path(i, j): if i + 1 == j: return f"n{j}" else: return f"n{i + 1}.{path(i + 1, j)}" code = Block() for j in self.accesses[i]: code.new_line(f"x = x + self.{path(i, j)}.const") for j in self.calls[i]: code.new_line(f"x = self.{path(i, j)}(x + 1)") code.new_line("return x + 1") return code return GenNNModule(self.n) class BufferUnflatten(Unflatten): """ Generates test that unflattens a model with several nn.Modules that call each other and access and mutate buffers. The modules are generated by calling the nn_module_generator() method. """ def nn_module_generator(self): class GenNNModule(NNModuleGenerator): def __init__(self, n): super().__init__() self.n = n self.accesses = random_dag(self.n) self.mutations = random_dag(self.n) self.calls = random_dag(self.n) def gen_init_body(self, i: int): code = Block() code.new_line("super().__init__()") code.new_line("self.buf = torch.nn.Buffer(torch.ones(1))") if i < self.n - 1: code.new_line(f"self.n{i + 1} = N{i + 1}()") return code def gen_forward_body(self, i: int): def path(i, j): if i + 1 == j: return f"n{j}" else: return f"n{i + 1}.{path(i + 1, j)}" code = Block() for j in self.accesses[i]: code.new_line(f"x = x + self.{path(i, j)}.buf") for j in self.calls[i]: code.new_line(f"x = self.{path(i, j)}(x + 1)") for j in self.mutations[i]: code.new_line(f"self.{path(i, j)}.buf.add_(1)") code.new_line("return x + 1") return code return GenNNModule(self.n)