Files
pytorch/docs/source/compile/programming_model.nested_graph_breaks.md

5.3 KiB

Nested Graph Breaks

Summary:

  • Graph breaks in nested functions can result in hard-to-understand compiler behavior, which we document below
  • A nested graph break results in {math}\mathcal O(N) duplicate graph break behavior

Recall that when torch.compile is applied to a function, any nested function calls are also traced. A nested graph break refers to any graph break that happens in a nested function call.

def inner(x):
    ...
    torch._dynamo.graph_break()  # nested graph break
    ...

@torch.compile
def outer(x):
    ...
    y = inner(x)
    ...

The resumption semantics around nested graph breaks can be confusing, so we describe the behavior here.

Recall that in fullgraph=False, graph breaks are handled by compiling the FX graph that has been determined so far, running the unsupported code in regular Python, then resuming tracing after the unsupported code with a new FX graph. Resuming a function is actually a fairly complicated technical feat, so resuming tracing is only supported on top-level functions.

We can therefore resume tracing after a nested graph break with this restriction in the following way:

First, consider the below example where torch.compile traces from f and traces all the way until the graph break in inner1 is encountered.

def inner1(x):
    x = x + 1
    torch._dynamo.graph_break()  # stop tracing due to graph break
    return x + 2

def inner2(x):
    x = x + 4
    x = inner1(x)
    x = x + 8

@torch.compile
def f(x):
    # start tracing from here
    x = x + 16
    x = inner2(x)
    x = x + 32

f(torch.randn(3))

Since we can only resume from top-level functions, we graph break on the inner2 call in f.

# The semantics of torch.compile(f)(x) is roughly this:
def compiled_f_semantics(x):
    y = x + 16
    z = inner2(y)
    return torch.compile(resume_f_semantics)(z)

def resume_f_semantics(x):
    return x + 32

compiled_f_semantics(torch.randn(3))

inner2 is then automatically compiled as a top-level function. We trace all the way until the graph break in inner1 is encountered again.

def inner1(x):
    x = x + 1
    torch._dynamo.graph_break()  # stop tracing due to graph break
    return x + 2

# this torch.compile is automatically applied
@torch.compile
def inner2(x):
    # start tracing from here
    x = x + 4
    x = inner1(x)
    x = x + 8

def compiled_f_semantics(x):
    y = x + 16
    z = inner2(y)
    return torch.compile(resume_f_semantics)(z)

def resume_f_semantics(x):
    return x + 32

compiled_f_semantics(torch.randn(3))

Then we graph break on the inner1 call in inner2.

def compiled_inner2_semantics(x):
    y = x + 4
    z = inner1(y)
    return torch.compile(resume_inner2_semantics)(z)

def resume_inner2_semantics(x):
    return x + 8

inner1 is then automatically compiled as a top-level function. The graph break is from inner1, so we handle the graph break normally.

# this torch.compile is automatically applied
@torch.compile
def inner1(x):
    # start tracing from here
    x = x + 1
    torch._dynamo.graph_break()  # stop tracing due to graph break
    return x + 2

def compiled_f_semantics(x):
    y = x + 16
    z = compiled_inner2_semantics(y)
    return torch.compile(resume_f_semantics)(z)

def resume_f_semantics(x):
    return x + 32

def compiled_inner2_semantics(x):
    y = x + 4
    z = inner1(y)
    return torch.compile(resume_inner2_semantics)(z)

def resume_inner2_semantics(x):
    return x + 8

compiled_f_semantics(torch.randn(3))

inner1 is handled normally:

def compiled_inner1_semantics(x):
    y = x + 1
    torch._dynamo.graph_break()
    return torch.compile(resume_inner1_semantics)(y)

def resume_inner1_semantics(x):
    return x + 2

So the initial code is semantically equivalent to

def compiled_f_semantics(x):
    y = x + 16
    z = compiled_inner2_semantics(y)
    return torch.compile(resume_f_semantics)(z)

def resume_f_semantics(x):
    return x + 32

def compiled_inner2_semantics(x):
    y = x + 4
    z = compiled_inner1_semantics(y)
    return torch.compile(resume_inner2_semantics)(z)

def resume_inner2_semantics(x):
    return x + 8

def compiled_inner1_semantics(x):
    y = x + 1
    torch._dynamo.graph_break()
    return torch.compile(resume_inner1_semantics)(y)

def resume_inner1_semantics(x):
    return x + 2

compiled_f_semantics(torch.randn(3))

Note in particular that we traced 3 top-level functions, and that we traced the same graph break 3 times. This explains why you may encounter duplicate graph breaks when using torch.compile.

In summary, nested graph breaks are handled by:

  • Tracing from the top-level function all the way to the nested graph break
  • Graph breaking on the top-level function at the call to the second-level function
  • Compiling the PyTorch ops tracked so far and running the compiled graph
  • Calling the second-level function, which gets automatically compiled as a top-level function
  • Resuming tracing after the second-level function call

Note that the runtime of handling this graph break is {math}\mathcal O(NK), where {math}N is the nesting depth, and {math}K is the number of instructions from the top-level function to the graph break. We end up tracing {math}\mathcal O(N^2) frames, and we trace the same graph break {math}\mathcal O(N) times.