[hop free symbols] lift free symbols in example_value when create_graph_input (#138363)

There are 4 parts (they are hard to further break into smaller ones cause they're highly coupled) in this PR:
1. **Whenever we call create_graph_input, we try to bind the symbols in the graph input.**
We've enforced the invariant that all create_graph_inputs calls must provide an example value, we could intercept at the create_graph_input calls (This PR only handles free symbols in tensors).
2. **We cache the bound_symbols** to avoid lift the same symbol repeated.
3. For lifted symbols, we re-used  **lifted_freevars** i.e. the mapping between symbol proxy in parent graph to the lifted phs in current subgraph, which we handle lifted tensors. In this way, all hops that supports lifted tensors should be able to handle lifted_symints automatically (at least in dynamo part).
4. For **unbacked symbols** created during tracing, we need to also bound these symbols to its proxy. This is to support the tests cases where we want to lift unbacked symbols as input. We need the proxy of the unbacked symbol in parent graph in order to properly create the args to the hop.
5. We change all the tests after free symbols are lifted in subgraphs. And also supports the lifted symbols in existing higher order ops.

**The interaction of nested tracers:**
The previous design for lifting tensor closures is that: suppose we're in nested tracers, whenever we see a new proxy that's not created by create tracer, we recursively look for the proxy in parent tracer until we find the tracer that creates this proxy (either a placeholder or some intermediate results). More detail is in Note [Nested SubgraphTracer and free_variable handling].

Given the above design, the plan for lifting the free symbols is: whenever we lift a free tensor to be the inputs of current subgraph, we'll look at the symbols in it and bind the symbols at the same time.

For example, suppose we have the following function:
```python
def f(x: [s1, s2]):
  def true_f():
    def true_f_inner():
      return x.sin()
```
what will happen in time order:

1. we create a subtracer 1 and start to speculate the outer cond's true_f
2. we create a another subtracer 2 and start to speculate the inner cond's true_f_inner.
3. dynamo realize the tensor input x by calling wrap_tensor in top-level to create graph input x (tracer 0), we bind the symbol s1, s2 after ph for x is created. So the graph now looks like:
```python
def gm(s1, s2, x):
```
4. when seeing TensorVariable.call_method of x,  tracer2 wants to create a call_function(sin, proxy_of_x), but it finds that proxy_of_x is not created by current tracer. So it recursively look up its parent tracer1 and find parent tracer1 also doesn't track this proxy_of_x then it finds the root tracer0, who is the creator of it and tracks it as a ph. Then tracer 1 create_graph_input  to lift the closure to its input ph1 and add (proxy_of_x: ph1) k-v in **lifted_freevars**  of tracer 1.
Now the graph looks like:
```python
def gm(s1, s2, x):
  def true_gm(x):
```
5. Since there are free symbols inside this new tensor input, tracer 1 also binds the symbols (maybe_bind_symbol), which calls create_graph_input for s1 and s2. Now the graph looks like
```python
def gm(s1, s2, x):
  def true_gm(s1, s2, x):
```
6. then it goes back to tracer 2, and call create_graph_input for x and get ph2, tracer 2's **lifted_freevars** records (ph1, ph2). and tracer 2 also binds the symbols in this new tensor input. Now the graph looks like:
```python
def gm(s1, s2, x):
  def true_gm(s1, s2, x):
    def true_gm_inner(s1, s2, x):
```
7. Finally the sin call_function node is created by tracer 2.

**This PR also handles the following cases:**
- What if we lift two tensors share the same symbol? e.g. x1 [s1, s2], x2 [s2, s3]? Each subtracer maintains bound_symbols as a cache that maps a symbol.expr to its proxy in current tracer. So when we see x1, we'll track s1 and s2 as inputs and bound s1 to ph1, s2 to ph2. So when we try to bind symbols of x2, s2 will already be tracked so no graph input is created.
- what if a subgraph close over a symint? e.g.
```python
def f(x):
  def true_f():
    c = x.size(0)
   def true_fn_inner():
     return c
```
When we speculate true_fn_inner, we find proxy_of_c is not tracked by tracer 2, so it recursively looks up its parent. At this point, x and its symbols have been lifted as input of true_f (as a result of lifting x during tracing true_f in tracer 1. Specifically the graph looks like:
```python
def gm(s1, s2, x):
  def true_gm(s1, s2, x):
    def true_gm_inner():
```
So tracer 2 is able to find that s1 have been tracked as ph in tracer 1 so it returns back to gm and call create_graph_input on s1. The graph now looks like:
```python
def gm(s1, s2, x):
  def true_gm(s1, s2, x):
    def true_gm_inner(s1):
     return s1
```

-  What if subgraph close over an unbacked symint? e.g.
```python
def f(x):
  def true_f():
    c =  x.item()
    def true_f_inner():
      return c
```
When x.item() is called, proxy_of_c and its symnode variable is created for tracer 1, and we also call track_unbacked_symbols to record this relationship. So when tracer 2 finds proxy_of_c is not created by current tracer, it recursivelly looks up its parent tracer and finds that that expression u0 has been tracked as a result of track_unbacked_symbol in tracer 1. So it will stop the recursion and create_graph_input u0 in tracer 2. Graph looks like:
```python
def f(x):
  def true_f(s1, s2, x):
    c = x.item()
    def true_gm_inner(u0):
      return u0
    cond(pred, true_gm_inner, false_gm_inner, (c,))
```

- what if subgraph close over a tensor with unbacked symint shape?
```python
def f(x):
  def true_f():
    c = x.item()
    r = torch.randn((c,))
    def true_f_inner():
      return r + 1
```
This is the same as the case of closing over tensors with backed shapes. where we first lift r, then bind u0 in it, which recursively bind_symint of u0 in its parent and found u0 is tracked in parent tracer as a result of .item() call.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138363
Approved by: https://github.com/zou3519
This commit is contained in:
Yidi Wu
2024-11-06 13:33:32 -08:00
committed by PyTorch MergeBot
parent 3368f3ad41
commit ab42967238
14 changed files with 1247 additions and 286 deletions

View File

@ -16,6 +16,7 @@ import torch.utils._pytree as pytree
import torch.utils.checkpoint
from torch._dynamo.backends.common import aot_autograd
from torch._dynamo.testing import (
check_dynamic_shape_capture,
CompileCounter,
CompileCounterWithBackend,
EagerAndRecordGraphs,
@ -37,11 +38,6 @@ from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_
requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
def check_dynamic_shape_capture():
# This also mirrors config from `test/dynamo/test_dynamic_shapes.py:make_dynamic_cls`
return not config.assume_static_by_default
def count_ops(gm, args, freq, op):
actual = [node.target for node in gm.graph.nodes].count(op)
assert actual == freq, f"expected={freq}, actual={actual}"
@ -213,7 +209,8 @@ class HigherOrderOpTests(torch._dynamo.test_case.TestCase):
return wrap(lambda x: torch.sin(x), x)
x = torch.randn(3)
self._test_wrap_simple(f, default_args_generator((x,)), 2)
arg_count = ifdynstaticdefault(2, 3)
self._test_wrap_simple(f, default_args_generator((x,)), arg_count)
def test_enum_arg(self):
class SomeEnum(enum.Enum):
@ -229,7 +226,8 @@ class HigherOrderOpTests(torch._dynamo.test_case.TestCase):
return wrap(g, x, val)
x = torch.randn(3)
self._test_wrap_simple(f, default_args_generator((x, SomeEnum.A)), 2)
arg_count = ifdynstaticdefault(2, 3)
self._test_wrap_simple(f, default_args_generator((x, SomeEnum.A)), arg_count)
def test_return_captured_var(self):
freevar = torch.randn(3)
@ -244,7 +242,10 @@ class HigherOrderOpTests(torch._dynamo.test_case.TestCase):
# Since, `x` is unused, we don't lift it to
# be the input.
self._test_wrap_simple(fn, default_args_generator((x,)), 2)
# when testing with dynamic shape, symbols are lifted as input
arg_count = ifdynstaticdefault(2, 3)
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count)
def test_return_captured_vars(self):
freevar1 = torch.randn(3)
@ -260,7 +261,9 @@ class HigherOrderOpTests(torch._dynamo.test_case.TestCase):
# Since, `x` is unused, we don't lift it to
# be the input.
self._test_wrap_simple(fn, default_args_generator((x,)), 3, 4)
# when testing with dynamic shape, a symbol is lifted as input
arg_count = ifdynstaticdefault(3, 4)
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count, 4)
def test_return_captured_var_used_multiple_times(self):
freevar = torch.randn(3)
@ -273,14 +276,18 @@ class HigherOrderOpTests(torch._dynamo.test_case.TestCase):
return wrap(test, x)
x = torch.randn(3)
self._test_wrap_simple(fn, default_args_generator((x,)), 3, 3)
# when testing with dynamic shape, a symbol is lifted as input
arg_count = ifdynstaticdefault(3, 4)
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count, 3)
def test_capture_untracked_global(self):
def f(x):
return wrap(lambda x: x + global_var, x)
x = torch.randn(3)
self._test_wrap_simple(f, default_args_generator((x,)), 3)
# when testing with dynamic shape, a symbol is lifted as input
arg_count = ifdynstaticdefault(3, 4)
self._test_wrap_simple(f, default_args_generator((x,)), arg_count)
def test_symint_input(self):
def f(x):
@ -386,13 +393,13 @@ class GraphModule(torch.nn.Module):
l_x_ = L_x_
wrap_body_0 = self.wrap_body_0
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_, s0); wrap_body_0 = l_x_ = s0 = None
wrap = torch.ops.higher_order.wrap(wrap_body_0, s0, l_x_); wrap_body_0 = s0 = l_x_ = None
getitem: "f32[s0]" = wrap[0]; wrap = None
return (getitem,)
class wrap_body_0(torch.nn.Module):
def forward(self, l_x_: "f32[s0, 1]", size: "Sym(s0)"):
view: "f32[s0]" = l_x_.view(size); l_x_ = size = None
def forward(self, s0: "Sym(s0)", l_x_: "f32[s0, 1]"):
view: "f32[s0]" = l_x_.view(s0); l_x_ = s0 = None
add: "f32[s0]" = view + 0.5; view = None
return (add,)
""",
@ -418,7 +425,8 @@ class GraphModule(torch.nn.Module):
y2 = t[0] + 0.2
yield (x2, y2, (x2, y2))
self._test_wrap_simple(f, my_args_generator((x, y, (x, y))), 3)
arg_count = ifdynstaticdefault(3, 4)
self._test_wrap_simple(f, my_args_generator((x, y, (x, y))), arg_count)
def test_wrap_pytree_args_not_const_symint_tensor(self):
class MyClass:
@ -488,7 +496,9 @@ class GraphModule(torch.nn.Module):
def g(x):
return wrap(lambda x: x + y, x)
self._test_wrap_simple(g, default_args_generator((x,)), 3)
# when testing with dynamic shape, a symbol is lifted as input
arg_count = ifdynstaticdefault(3, 4)
self._test_wrap_simple(g, default_args_generator((x,)), arg_count)
return g(x)
f(x, y)
@ -500,7 +510,9 @@ class GraphModule(torch.nn.Module):
def f(x, y):
return wrap(lambda x: x + y, x)
self._test_wrap_simple(f, default_args_generator((x, y)), 3)
# when testing with dynamic shape, a symbol is lifted as input
arg_count = ifdynstaticdefault(3, 4)
self._test_wrap_simple(f, default_args_generator((x, y)), arg_count)
def test_capture_tracked_nested(self):
x = torch.randn(3, 3)
@ -509,7 +521,9 @@ class GraphModule(torch.nn.Module):
def f(x, y):
return wrap(lambda x: wrap(lambda x: x + y, x), x)
self._test_wrap_simple(f, default_args_generator((x, y)), 3)
# when testing with dynamic shape, a symbol is lifted as input
arg_count = ifdynstaticdefault(3, 4)
self._test_wrap_simple(f, default_args_generator((x, y)), arg_count)
def test_inlined_functions(self):
def g(x, y):
@ -520,7 +534,9 @@ class GraphModule(torch.nn.Module):
x = torch.randn(3, 3)
y = torch.randn(3, 3)
self._test_wrap_simple(f, default_args_generator((x, y)), 3)
# when testing with dynamic shape, a symbol is lifted as input
arg_count = ifdynstaticdefault(3, 4)
self._test_wrap_simple(f, default_args_generator((x, y)), arg_count)
def test_same_freevar_twice(self):
free = torch.randn(3)
@ -537,7 +553,518 @@ class GraphModule(torch.nn.Module):
# Since, `x` is unused, we don't lift it to
# be the input.
self._test_wrap_simple(f, default_args_generator((x,)), 2, 3)
# when testing with dynamic shape, a symbol is lifted as input
arg_count = ifdynstaticdefault(2, 3)
self._test_wrap_simple(f, default_args_generator((x,)), arg_count, 3)
@torch._dynamo.config.patch(
capture_scalar_outputs=True,
)
def test_unbacked_symbol_closure(self):
def f(x):
c = x.sum().item()
def g(x):
def k(x):
return x + c
return wrap(k, x)
return wrap(g, x)
x = torch.randn(3)
arg_count = ifdynstaticdefault(3, 4)
out_graph = self._test_wrap_simple(
f, default_args_generator((x,)), arg_count, 4, return_graph=True
)
if check_dynamic_shape_capture():
self.assertExpectedInline(
out_graph,
"""\
class GraphModule(torch.nn.Module):
def forward(self, s0: "Sym(s0)", L_x_: "f32[s0]"):
l_x_ = L_x_
sum_1: "f32[]" = l_x_.sum()
item: "Sym(zuf0)" = sum_1.item(); sum_1 = None
wrap_body_1 = self.wrap_body_1
wrap = torch.ops.higher_order.wrap(wrap_body_1, s0, l_x_, item); wrap_body_1 = s0 = l_x_ = item = None
getitem: "f32[s0]" = wrap[0]; wrap = None
return (getitem,)
class wrap_body_1(torch.nn.Module):
def forward(self, s0: "Sym(s0)", l_x_: "f32[s0]", item: "Sym(zuf0)"):
wrap_body_0 = self.wrap_body_0
wrap = torch.ops.higher_order.wrap(wrap_body_0, s0, l_x_, item); wrap_body_0 = s0 = l_x_ = item = None
getitem: "f32[s0]" = wrap[0]; wrap = None
return (getitem,)
class wrap_body_0(torch.nn.Module):
def forward(self, s0: "Sym(s0)", l_x_: "f32[s0]", item: "Sym(zuf0)"):
add: "f32[s0]" = l_x_ + item; l_x_ = item = None
return (add,)
""",
)
else:
self.assertExpectedInline(
out_graph,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[3]"):
l_x_ = L_x_
sum_1: "f32[]" = l_x_.sum()
item: "Sym(zuf0)" = sum_1.item(); sum_1 = None
wrap_body_1 = self.wrap_body_1
wrap = torch.ops.higher_order.wrap(wrap_body_1, l_x_, item); wrap_body_1 = l_x_ = item = None
getitem: "f32[3]" = wrap[0]; wrap = None
return (getitem,)
class wrap_body_1(torch.nn.Module):
def forward(self, l_x_: "f32[3]", item: "Sym(zuf0)"):
wrap_body_0 = self.wrap_body_0
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_, item); wrap_body_0 = l_x_ = item = None
getitem: "f32[3]" = wrap[0]; wrap = None
return (getitem,)
class wrap_body_0(torch.nn.Module):
def forward(self, l_x_: "f32[3]", item: "Sym(zuf0)"):
add: "f32[3]" = l_x_ + item; l_x_ = item = None
return (add,)
""",
)
@torch._dynamo.config.patch(
capture_dynamic_output_shape_ops=True,
)
def test_tensor_with_unbacked_shape_closure(self):
def f(x):
c = x.nonzero()
def g(x):
def k(x):
return x.sin(), c.sin()
return wrap(k, x)
return wrap(g, x)
x = torch.randn(3)
arg_count = ifdynstaticdefault(4, 5)
# when compiled with dynamic, we don't have upper bound runtime assertions for u0
expected_op_count = ifdynstaticdefault(10, 8)
out_graph = self._test_wrap_simple(
f,
default_args_generator((x,)),
arg_count,
expected_op_count,
return_graph=True,
)
if check_dynamic_shape_capture():
self.assertExpectedInline(
out_graph,
"""\
class GraphModule(torch.nn.Module):
def forward(self, s0: "Sym(s0)", L_x_: "f32[s0]"):
l_x_ = L_x_
c: "i64[u0, 1]" = l_x_.nonzero()
sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(c, 0)
_check_is_size = torch._check_is_size(sym_size_int_1); _check_is_size = None
ge: "Sym(u0 >= 0)" = sym_size_int_1 >= 0
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
wrap_body_1 = self.wrap_body_1
wrap = torch.ops.higher_order.wrap(wrap_body_1, s0, l_x_, sym_size_int_1, c); wrap_body_1 = s0 = l_x_ = sym_size_int_1 = c = None
getitem: "f32[s0]" = wrap[0]
getitem_1: "f32[u0, 1]" = wrap[1]; wrap = None
return (getitem, getitem_1)
class wrap_body_1(torch.nn.Module):
def forward(self, s0: "Sym(s0)", l_x_: "f32[s0]", u0: "Sym(u0)", c: "i64[u0, 1]"):
wrap_body_0 = self.wrap_body_0
wrap = torch.ops.higher_order.wrap(wrap_body_0, s0, l_x_, u0, c); wrap_body_0 = s0 = l_x_ = u0 = c = None
child: "f32[s0]" = wrap[0]
child_1: "f32[u0, 1]" = wrap[1]; wrap = None
return (child, child_1)
class wrap_body_0(torch.nn.Module):
def forward(self, s0: "Sym(s0)", l_x_: "f32[s0]", u0: "Sym(u0)", c: "i64[u0, 1]"):
child: "f32[s0]" = l_x_.sin(); l_x_ = None
child_1: "f32[u0, 1]" = c.sin(); c = None
return (child, child_1)
""",
)
else:
self.assertExpectedInline(
out_graph,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[3]"):
l_x_ = L_x_
c: "i64[u0, 1]" = l_x_.nonzero()
sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(c, 0)
_check_is_size = torch._check_is_size(sym_size_int_1); _check_is_size = None
ge: "Sym(u0 >= 0)" = sym_size_int_1 >= 0
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
le: "Sym(u0 <= 3)" = sym_size_int_1 <= 3
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u0 <= 3 on node 'le'"); le = _assert_scalar_default_1 = None
wrap_body_1 = self.wrap_body_1
wrap = torch.ops.higher_order.wrap(wrap_body_1, l_x_, sym_size_int_1, c); wrap_body_1 = l_x_ = sym_size_int_1 = c = None
getitem: "f32[3]" = wrap[0]
getitem_1: "f32[u0, 1]" = wrap[1]; wrap = None
return (getitem, getitem_1)
class wrap_body_1(torch.nn.Module):
def forward(self, l_x_: "f32[3]", u0: "Sym(u0)", c: "i64[u0, 1]"):
wrap_body_0 = self.wrap_body_0
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_, u0, c); wrap_body_0 = l_x_ = u0 = c = None
child: "f32[3]" = wrap[0]
child_1: "f32[u0, 1]" = wrap[1]; wrap = None
return (child, child_1)
class wrap_body_0(torch.nn.Module):
def forward(self, l_x_: "f32[3]", u0: "Sym(u0)", c: "i64[u0, 1]"):
child: "f32[3]" = l_x_.sin(); l_x_ = None
child_1: "f32[u0, 1]" = c.sin(); c = None
return (child, child_1)
""",
)
@torch._dynamo.config.patch(
capture_dynamic_output_shape_ops=True,
)
def test_tensor_to_list_closure(self):
def f(x):
li = x.tolist()
def g(x):
def k(x):
return li[0] + x
return wrap(k, x)
return wrap(g, x)
x = torch.tensor([1, 2, 3], dtype=torch.int16)
arg_count = ifdynstaticdefault(3, 3)
out_graph = self._test_wrap_simple(f, ((x,),), arg_count, 4, return_graph=True)
# tolist will specialize on input shapes, so dynamic and static tests
# have the same graph
self.assertExpectedInline(
out_graph,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_: "i16[3]"):
l_x_ = L_x_
getitem = l_x_[0]
item: "Sym(u0)" = getitem.item(); getitem = None
wrap_body_1 = self.wrap_body_1
wrap = torch.ops.higher_order.wrap(wrap_body_1, item, l_x_); wrap_body_1 = item = l_x_ = None
getitem_3: "i16[3]" = wrap[0]; wrap = None
return (getitem_3,)
class wrap_body_1(torch.nn.Module):
def forward(self, item: "Sym(u0)", l_x_: "i16[3]"):
wrap_body_0 = self.wrap_body_0
wrap = torch.ops.higher_order.wrap(wrap_body_0, item, l_x_); wrap_body_0 = item = l_x_ = None
getitem: "i16[3]" = wrap[0]; wrap = None
return (getitem,)
class wrap_body_0(torch.nn.Module):
def forward(self, item: "Sym(u0)", l_x_: "i16[3]"):
add: "i16[3]" = item + l_x_; item = l_x_ = None
return (add,)
""",
)
@torch._dynamo.config.patch(
capture_dynamic_output_shape_ops=True,
)
def test_tensor_and_unbacked_symbol_closure(self):
def f(x):
c = x.nonzero()
sz = c.size(0)
def g(x):
def k(x):
return x.sin() + sz, c.sin()
return wrap(k, x)
return wrap(g, x)
x = torch.randn(3)
arg_count = ifdynstaticdefault(4, 5)
# when compiled with dynamic, we don't have upper bound runtime assertions for u0
expected_op_count = ifdynstaticdefault(10, 8)
out_graph = self._test_wrap_simple(
f,
default_args_generator((x,)),
arg_count,
expected_op_count,
return_graph=True,
)
# Note that u0 is accessed from sz and the shape of c
# We cached via the symbol u0 and de-duplicate them.
if not check_dynamic_shape_capture():
self.assertExpectedInline(
out_graph,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[3]"):
l_x_ = L_x_
c: "i64[u0, 1]" = l_x_.nonzero()
sym_size_int: "Sym(u0)" = torch.ops.aten.sym_size.int(c, 0)
_check_is_size = torch._check_is_size(sym_size_int); _check_is_size = None
ge: "Sym(u0 >= 0)" = sym_size_int >= 0
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
le: "Sym(u0 <= 3)" = sym_size_int <= 3
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u0 <= 3 on node 'le'"); le = _assert_scalar_default_1 = None
wrap_body_1 = self.wrap_body_1
wrap = torch.ops.higher_order.wrap(wrap_body_1, l_x_, sym_size_int, c); wrap_body_1 = l_x_ = sym_size_int = c = None
getitem: "f32[3]" = wrap[0]
getitem_1: "f32[u0, 1]" = wrap[1]; wrap = None
return (getitem, getitem_1)
class wrap_body_1(torch.nn.Module):
def forward(self, l_x_: "f32[3]", size: "Sym(u0)", c: "i64[u0, 1]"):
wrap_body_0 = self.wrap_body_0
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_, size, c); wrap_body_0 = l_x_ = size = c = None
child: "f32[3]" = wrap[0]
child_1: "f32[u0, 1]" = wrap[1]; wrap = None
return (child, child_1)
class wrap_body_0(torch.nn.Module):
def forward(self, l_x_: "f32[3]", size: "Sym(u0)", c: "i64[u0, 1]"):
sin: "f32[3]" = l_x_.sin(); l_x_ = None
child: "f32[3]" = sin + size; sin = size = None
child_1: "f32[u0, 1]" = c.sin(); c = None
return (child, child_1)
""",
)
@torch._dynamo.config.patch(
capture_dynamic_output_shape_ops=True,
)
def test_concat_unbacked_shape_tensor(self):
def f(x, y):
c = x.nonzero()
d = y.nonzero()
cat = torch.cat((c, d))
def g(x):
def k(x):
return cat.sum() + x
return wrap(k, x)
return wrap(g, x)
x = torch.randn(3)
y = torch.randn(3)
arg_count = ifdynstaticdefault(5, 6)
# when compiled with dynamic, we don't have upper bound runtime assertions for u0 and u1
expected_op_count = ifdynstaticdefault(17, 13)
out_graph = self._test_wrap_simple(
f,
default_args_generator((x, y)),
arg_count,
expected_op_count,
return_graph=True,
)
if not check_dynamic_shape_capture():
self.assertExpectedInline(
out_graph,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[3]", L_y_: "f32[3]"):
l_x_ = L_x_
l_y_ = L_y_
c: "i64[u0, 1]" = l_x_.nonzero()
sym_size_int_2: "Sym(u0)" = torch.ops.aten.sym_size.int(c, 0)
_check_is_size = torch._check_is_size(sym_size_int_2); _check_is_size = None
ge: "Sym(u0 >= 0)" = sym_size_int_2 >= 0
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
le: "Sym(u0 <= 3)" = sym_size_int_2 <= 3
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u0 <= 3 on node 'le'"); le = _assert_scalar_default_1 = None
d: "i64[u1, 1]" = l_y_.nonzero(); l_y_ = None
sym_size_int_3: "Sym(u1)" = torch.ops.aten.sym_size.int(d, 0)
_check_is_size_1 = torch._check_is_size(sym_size_int_3); _check_is_size_1 = None
ge_1: "Sym(u1 >= 0)" = sym_size_int_3 >= 0
_assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_default_2 = None
le_1: "Sym(u1 <= 3)" = sym_size_int_3 <= 3
_assert_scalar_default_3 = torch.ops.aten._assert_scalar.default(le_1, "Runtime assertion failed for expression u1 <= 3 on node 'le_1'"); le_1 = _assert_scalar_default_3 = None
cat: "i64[u0 + u1, 1]" = torch.cat((c, d)); c = d = None
wrap_body_1 = self.wrap_body_1
wrap = torch.ops.higher_order.wrap(wrap_body_1, sym_size_int_2, sym_size_int_3, cat, l_x_); wrap_body_1 = sym_size_int_2 = sym_size_int_3 = cat = l_x_ = None
getitem: "f32[3]" = wrap[0]; wrap = None
return (getitem,)
class wrap_body_1(torch.nn.Module):
def forward(self, u0: "Sym(u0)", u1: "Sym(u1)", cat: "i64[u0 + u1, 1]", l_x_: "f32[3]"):
wrap_body_0 = self.wrap_body_0
wrap = torch.ops.higher_order.wrap(wrap_body_0, u0, u1, cat, l_x_); wrap_body_0 = u0 = u1 = cat = l_x_ = None
getitem: "f32[3]" = wrap[0]; wrap = None
return (getitem,)
class wrap_body_0(torch.nn.Module):
def forward(self, u0: "Sym(u0)", u1: "Sym(u1)", cat: "i64[u0 + u1, 1]", l_x_: "f32[3]"):
sum_1: "i64[]" = cat.sum(); cat = None
add: "f32[3]" = sum_1 + l_x_; sum_1 = l_x_ = None
return (add,)
""",
)
@torch._dynamo.config.patch(
assume_static_by_default=False,
dynamic_shapes=True,
)
def test_lift_tensors_with_shared_symbols(self):
def f(x, y):
def g(x):
def k(x):
return x @ y
return wrap(k, x)
return wrap(g, x)
x = torch.randn(2, 3)
y = torch.randn(3, 4)
out_graph = self._test_wrap_simple(
f,
default_args_generator((x, y)),
6,
2,
return_graph=True,
)
self.assertExpectedInline(
out_graph,
"""\
class GraphModule(torch.nn.Module):
def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", L_x_: "f32[s0, s1]", s2: "Sym(s2)", L_y_: "f32[s1, s2]"):
l_x_ = L_x_
l_y_ = L_y_
wrap_body_1 = self.wrap_body_1
wrap = torch.ops.higher_order.wrap(wrap_body_1, s0, s1, l_x_, s2, l_y_); wrap_body_1 = s0 = s1 = l_x_ = s2 = l_y_ = None
getitem: "f32[s0, s2]" = wrap[0]; wrap = None
return (getitem,)
class wrap_body_1(torch.nn.Module):
def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", l_x_: "f32[s0, s1]", s2: "Sym(s2)", l_y_: "f32[s1, s2]"):
wrap_body_0 = self.wrap_body_0
wrap = torch.ops.higher_order.wrap(wrap_body_0, s0, s1, l_x_, s2, l_y_); wrap_body_0 = s0 = s1 = l_x_ = s2 = l_y_ = None
getitem: "f32[s0, s2]" = wrap[0]; wrap = None
return (getitem,)
class wrap_body_0(torch.nn.Module):
def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", l_x_: "f32[s0, s1]", s2: "Sym(s2)", l_y_: "f32[s1, s2]"):
matmul: "f32[s0, s2]" = l_x_ @ l_y_; l_x_ = l_y_ = None
return (matmul,)
""",
)
@torch._dynamo.config.patch(
assume_static_by_default=False,
dynamic_shapes=True,
capture_dynamic_output_shape_ops=True,
)
def test_lift_tensors_with_compound_expressions(self):
def f(x, y):
x = x.view(-1, 2)
c = y.nonzero()
d = torch.concat((x, c))
def g(x):
def k(x):
return d.sum() + x
return wrap(k, x)
return wrap(g, x)
x = torch.randn(2, 3)
y = torch.randn(3, 4)
f(x, y)
if not check_dynamic_shape_capture():
out_graph = self._test_wrap_simple(
f,
default_args_generator((x, y)),
6,
9,
return_graph=True,
)
self.assertExpectedInline(
out_graph,
"""\
class GraphModule(torch.nn.Module):
def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", L_x_: "f32[s0, s1]", s2: "Sym(s2)", L_y_: "f32[s1, s2]"):
l_x_ = L_x_
l_y_ = L_y_
x: "f32[((s0*s1)//2), ((s0*s1)//(((s0*s1)//2)))]" = l_x_.view(-1, 2); l_x_ = None
c: "i64[u0, 2]" = l_y_.nonzero(); l_y_ = None
sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(c, 0)
_check_is_size = torch._check_is_size(sym_size_int_1); _check_is_size = None
ge: "Sym(u0 >= 0)" = sym_size_int_1 >= 0
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
d: "f32[u0 + ((s0*s1)//2), ((s0*s1)//(((s0*s1)//2)))]" = torch.concat((x, c)); c = None
wrap_body_1 = self.wrap_body_1
wrap = torch.ops.higher_order.wrap(wrap_body_1, sym_size_int_1, s1, s0, d, x); wrap_body_1 = sym_size_int_1 = s1 = s0 = d = x = None
getitem: "f32[((s0*s1)//2), ((s0*s1)//(((s0*s1)//2)))]" = wrap[0]; wrap = None
return (getitem,)
class wrap_body_1(torch.nn.Module):
def forward(self, u0: "Sym(u0)", s1: "Sym(s1)", s0: "Sym(s0)", d: "f32[u0 + ((s0*s1)//2), ((s0*s1)//(((s0*s1)//2)))]", x: "f32[((s0*s1)//2), ((s0*s1)//(((s0*s1)//2)))]"):
wrap_body_0 = self.wrap_body_0
wrap = torch.ops.higher_order.wrap(wrap_body_0, u0, s1, s0, d, x); wrap_body_0 = u0 = s1 = s0 = d = x = None
getitem: "f32[((s0*s1)//2), ((s0*s1)//(((s0*s1)//2)))]" = wrap[0]; wrap = None
return (getitem,)
class wrap_body_0(torch.nn.Module):
def forward(self, u0: "Sym(u0)", s1: "Sym(s1)", s0: "Sym(s0)", d: "f32[u0 + ((s0*s1)//2), ((s0*s1)//(((s0*s1)//2)))]", x: "f32[((s0*s1)//2), ((s0*s1)//(((s0*s1)//2)))]"):
sum_1: "f32[]" = d.sum(); d = None
add: "f32[((s0*s1)//2), ((s0*s1)//(((s0*s1)//2)))]" = sum_1 + x; sum_1 = x = None
return (add,)
""",
)
def test_register_subclass(self):
from torch._higher_order_ops.cond import cond_op
@ -1054,7 +1581,8 @@ class GraphModule(torch.nn.Module):
return wrap(f, x)
x = torch.randn(3, 3)
self._test_wrap_simple(g, default_args_generator((x,)), 2)
arg_count = ifdynstaticdefault(2, 3)
self._test_wrap_simple(g, default_args_generator((x,)), arg_count)
def test_wrap_kwarg(self):
def f(x, y):
@ -1062,7 +1590,8 @@ class GraphModule(torch.nn.Module):
x = torch.randn(3)
y = torch.randn(3, 3)
self._test_wrap_simple(f, default_args_generator((x, y)), 3)
arg_count = ifdynstaticdefault(3, 4)
self._test_wrap_simple(f, default_args_generator((x, y)), arg_count)
def test_wrap_kwarg_int(self):
def f(x, y):
@ -1071,9 +1600,12 @@ class GraphModule(torch.nn.Module):
x = torch.randn(3)
y = 8
self._test_wrap_simple(
f, default_args_generator((x, y)), ifdynstaticdefault(2, 3)
arg_count = (
ifdynstaticdefault(2, 3) + 1
if check_dynamic_shape_capture()
else ifdynstaticdefault(2, 3)
)
self._test_wrap_simple(f, default_args_generator((x, y)), arg_count)
def test_wrap_all_kwarg(self):
def f(y, x):
@ -1082,7 +1614,8 @@ class GraphModule(torch.nn.Module):
x = torch.randn(3)
y = torch.randn(3, 3)
self._test_wrap_simple(f, default_args_generator((x, y)), 3)
arg_count = ifdynstaticdefault(3, 4)
self._test_wrap_simple(f, default_args_generator((x, y)), arg_count)
def test_wrap_kwarg_only(self):
def f(x, y):
@ -1094,7 +1627,8 @@ class GraphModule(torch.nn.Module):
x = torch.randn(3)
y = torch.randn(3, 3)
self._test_wrap_simple(f, default_args_generator((x, y)), 3)
arg_count = ifdynstaticdefault(3, 4)
self._test_wrap_simple(f, default_args_generator((x, y)), arg_count)
def test_wrap_kwarg_default(self):
def f(x, y):
@ -1106,7 +1640,8 @@ class GraphModule(torch.nn.Module):
x = torch.randn(3)
y = torch.randn(3, 3)
self._test_wrap_simple(f, default_args_generator((x, y)), 3)
arg_count = ifdynstaticdefault(3, 4)
self._test_wrap_simple(f, default_args_generator((x, y)), arg_count)
def test_wrap_kwarg_default_if_branch(self):
def f(x, y):
@ -1121,7 +1656,8 @@ class GraphModule(torch.nn.Module):
x = torch.randn(3)
y = torch.randn(3, 3)
self._test_wrap_simple(f, default_args_generator((x, y)), 3)
arg_count = ifdynstaticdefault(3, 4)
self._test_wrap_simple(f, default_args_generator((x, y)), arg_count)
def test_wrap_kwarg_recompile(self):
def f(x, y, z=None):
@ -1162,7 +1698,8 @@ class GraphModule(torch.nn.Module):
x = torch.randn(3)
y = torch.randn(3, 3)
self._test_wrap_simple(f, default_args_generator((x, y, 8)), 2)
arg_count = ifdynstaticdefault(2, 3)
self._test_wrap_simple(f, default_args_generator((x, y, 8)), arg_count)
def test_map_subgraph_name_is_valid(self):
backend = EagerAndRecordGraphs()
@ -1522,8 +2059,8 @@ def forward(self, child : torch.Tensor, const_unused : int):
and node.target == torch.ops.higher_order.cond
):
_, _, _, operands = node.args
# Each branch takes 3 inputs (buffer, x, z)
self.assertEqual(len(operands), 3)
# Since we compile wit dynamic, each branch takes 4 inputs (buffer, x, z, s1)
self.assertEqual(len(operands), 4)
if node.op == "get_attr":
if str(node.target) in ("cond_true_0, cond_false_0"):
num_placeholders = len(
@ -1535,7 +2072,7 @@ def forward(self, child : torch.Tensor, const_unused : int):
if node.op == "placeholder"
]
)
self.assertEqual(num_placeholders, 3)
self.assertEqual(num_placeholders, 4)
def _check_cond_graph_and_extract(self, fn, args):
backend = EagerAndRecordGraphs()
@ -1826,10 +2363,11 @@ def forward(self):
yield [x], [x.sin()]
yield (x,), (x.sin(),)
arg_count = ifdynstaticdefault(3, 4)
actual_graph = self._test_wrap_simple(
f,
my_args_generator(),
3,
arg_count,
3,
return_graph=True,
)
@ -2000,7 +2538,10 @@ class GraphModule(torch.nn.Module):
return wrap(lambda x: [torch.sin(x), torch.cos(x)], x)
x = torch.randn(3)
self._test_wrap_simple(f, default_args_generator((x,)), 2, expected_opcount=3)
arg_count = ifdynstaticdefault(2, 3)
self._test_wrap_simple(
f, default_args_generator((x,)), arg_count, expected_opcount=3
)
def test_fallback_on_python_primitives_output(self):
counters.clear()
@ -2028,8 +2569,9 @@ class GraphModule(torch.nn.Module):
x = torch.randn(2, 3)
counters.clear()
arg_count = ifdynstaticdefault(2, 4)
graph = self._test_wrap_simple(
f, default_args_generator((x,)), 2, 4, return_graph=True
f, default_args_generator((x,)), arg_count, 4, return_graph=True
)
self.assertEqual(len(counters["graph_break"]), 0)
@ -2066,8 +2608,10 @@ class GraphModule(torch.nn.Module):
x = torch.randn(3)
counters.clear()
arg_count = ifdynstaticdefault(2, 3)
graph = self._test_wrap_simple(
f, default_args_generator((x,)), 2, 2, return_graph=True
f, default_args_generator((x,)), arg_count, 2, return_graph=True
)
self.assertEqual(len(counters["graph_break"]), 0)
@ -2137,7 +2681,8 @@ class GraphModule(torch.nn.Module):
x = torch.randn(3, 3)
y = torch.randn(3, 3)
self._test_wrap_simple(h, default_args_generator((x, y)), 3)
arg_count = ifdynstaticdefault(3, 4)
self._test_wrap_simple(h, default_args_generator((x, y)), arg_count)
def test_internal_nonlocal(self):
def f(x, y):
@ -2162,7 +2707,8 @@ class GraphModule(torch.nn.Module):
x = torch.randn(3, 3)
y = torch.randn(3, 3)
self._test_wrap_simple(h, default_args_generator((x, y)), 3)
arg_count = ifdynstaticdefault(3, 4)
self._test_wrap_simple(h, default_args_generator((x, y)), arg_count)
def test_capture_numpy_number(self):
import numpy as np
@ -2174,7 +2720,8 @@ class GraphModule(torch.nn.Module):
x = torch.randn(3)
# np.number are lifted to graph inputs
self._test_wrap_simple(f, default_args_generator((x,)), 3)
arg_count = ifdynstaticdefault(3, 4)
self._test_wrap_simple(f, default_args_generator((x,)), arg_count)
def test_freevars_as_inputs_to_wrap(self):
y = torch.randn(3)
@ -2183,7 +2730,8 @@ class GraphModule(torch.nn.Module):
return wrap(lambda x, y: x + y, x, y)
x = torch.randn(3)
self._test_wrap_simple(f, default_args_generator((x,)), 3)
arg_count = ifdynstaticdefault(3, 4)
self._test_wrap_simple(f, default_args_generator((x,)), arg_count)
def test_lift_tensor_constant(self):
def f(x):
@ -2191,7 +2739,10 @@ class GraphModule(torch.nn.Module):
return wrap(lambda x: x + y, x)
x = torch.randn(3)
self._test_wrap_simple(f, default_args_generator((x,)), 3, expected_opcount=3)
arg_count = ifdynstaticdefault(3, 4)
self._test_wrap_simple(
f, default_args_generator((x,)), arg_count, expected_opcount=3
)
def test_nested_wrap(self):
class MockModule(torch.nn.Module):
@ -2211,14 +2762,18 @@ class GraphModule(torch.nn.Module):
def fn(x):
return wrap(gn, x)
self._test_wrap_simple(fn, default_args_generator((torch.randn(10, 10),)), 4)
arg_count = ifdynstaticdefault(4, 5)
self._test_wrap_simple(
fn, default_args_generator((torch.randn(10, 10),)), arg_count
)
def test_fn_with_kwargs_in_torch_ops(self):
def fn(x):
return wrap(lambda z: torch.cos(input=z), x)
x = torch.randn(3)
self._test_wrap_simple(fn, default_args_generator((x,)), 2)
arg_count = ifdynstaticdefault(2, 3)
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count)
def test_hooks(self):
class ToyModel(torch.nn.Module):

View File

@ -1941,10 +1941,10 @@ class GraphModule(torch.nn.Module):
normalize_gm(fw[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[s0, s1]", primals_2: "f32[s0, s1]", primals_3: "Sym(s0)", primals_4: "Sym(s1)", primals_5: "Sym(s1)", primals_6: "Sym(s0)", primals_7: "Sym(s1)"):
mul: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(primals_1, primals_3); primals_1 = None
mul_3: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(primals_2, primals_3); primals_2 = None
return (mul, mul_3, primals_6, primals_7, primals_7, primals_3, primals_6, primals_7)
def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1]", primals_4: "f32[s0, s1]", primals_5: "Sym(s0)", primals_6: "Sym(s1)", primals_7: "Sym(s1)"):
mul: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(primals_3, primals_1); primals_3 = None
mul_3: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(primals_4, primals_1); primals_4 = None
return (mul, mul_3, primals_5, primals_7, primals_7, primals_1, primals_5, primals_7)
""", # noqa: B950
)
@ -1952,10 +1952,10 @@ class GraphModule(torch.nn.Module):
normalize_gm(bw[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_3: "Sym(s0)", primals_6: "Sym(s0)", primals_7: "Sym(s1)", tangents_1: "f32[s0, s1]", tangents_2: "f32[s0, s1]"):
mul_8: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(tangents_1, primals_3); tangents_1 = None
mul_9: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(tangents_2, primals_3); tangents_2 = primals_3 = None
return (mul_8, mul_9, primals_6, primals_7, primals_7, None, None)
def forward(self, primals_1: "Sym(s0)", primals_5: "Sym(s0)", primals_7: "Sym(s1)", tangents_1: "f32[s0, s1]", tangents_2: "f32[s0, s1]"):
mul_8: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(tangents_1, primals_1); tangents_1 = None
mul_9: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(tangents_2, primals_1); tangents_2 = primals_1 = None
return (None, None, mul_8, mul_9, primals_5, primals_7, primals_7)
""", # noqa: B950
)
@ -1974,13 +1974,13 @@ class GraphModule(torch.nn.Module):
normalize_gm(fw[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[s0, s1]", primals_2: "f32[s0, s1]", primals_3: "Sym(s0)", primals_4: "Sym(s1)", primals_5: "Sym(s1)", primals_6: "Sym(s0)", primals_7: "Sym(s1)"):
clone: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_1); primals_1 = None
clone_1: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_2); primals_2 = None
def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1]", primals_4: "f32[s0, s1]", primals_5: "Sym(s0)", primals_6: "Sym(s1)", primals_7: "Sym(s1)"):
clone: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
clone_1: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
view: "f32[s1, s0]" = torch.ops.aten.view.default(clone, [primals_4, primals_3]); clone = None
view_1: "f32[s1, s0]" = torch.ops.aten.view.default(clone_1, [primals_4, primals_3]); clone_1 = primals_3 = None
return (view, view_1, primals_4, primals_6, primals_6, primals_6, primals_7)
view: "f32[s1, s0]" = torch.ops.aten.view.default(clone, [primals_2, primals_1]); clone = None
view_1: "f32[s1, s0]" = torch.ops.aten.view.default(clone_1, [primals_2, primals_1]); clone_1 = primals_1 = None
return (view, view_1, primals_2, primals_5, primals_5, primals_5, primals_7)
""", # noqa: B950
)
@ -1988,10 +1988,10 @@ class GraphModule(torch.nn.Module):
normalize_gm(bw[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_6: "Sym(s0)", primals_7: "Sym(s1)", tangents_1: "f32[s1, s0]", tangents_2: "f32[s1, s0]"):
view_2: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_1, [primals_6, primals_7]); tangents_1 = None
view_3: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_2, [primals_6, primals_7]); tangents_2 = None
return (view_2, view_3, primals_6, primals_7, primals_7, None, None)
def forward(self, primals_5: "Sym(s0)", primals_7: "Sym(s1)", tangents_1: "f32[s1, s0]", tangents_2: "f32[s1, s0]"):
view_2: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None
view_3: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None
return (None, None, view_2, view_3, primals_5, primals_7, primals_7)
""", # noqa: B950
)
@ -2057,13 +2057,13 @@ class GraphModule(torch.nn.Module):
normalize_gm(fw[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[s0, s1]", primals_2: "f32[s0, s1]", primals_3: "Sym(s0)", primals_4: "Sym(s1)", primals_5: "Sym(s1)", primals_6: "Sym(s0)", primals_7: "Sym(s1)"):
clone: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_1); primals_1 = None
clone_1: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_2); primals_2 = None
def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1]", primals_4: "f32[s0, s1]", primals_5: "Sym(s0)", primals_6: "Sym(s1)", primals_7: "Sym(s1)"):
clone: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
clone_1: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
view: "f32[s0, s1]" = torch.ops.aten.view.default(clone, [primals_3, primals_4]); clone = None
view_1: "f32[s0, s1]" = torch.ops.aten.view.default(clone_1, [primals_3, primals_4]); clone_1 = primals_3 = primals_4 = None
return (view, view_1, primals_6, primals_7, primals_7, primals_6, primals_7)
view: "f32[s0, s1]" = torch.ops.aten.view.default(clone, [primals_1, primals_2]); clone = None
view_1: "f32[s0, s1]" = torch.ops.aten.view.default(clone_1, [primals_1, primals_2]); clone_1 = primals_1 = primals_2 = None
return (view, view_1, primals_5, primals_7, primals_7, primals_5, primals_7)
""", # noqa: B950
)
@ -2071,10 +2071,10 @@ class GraphModule(torch.nn.Module):
normalize_gm(bw[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_6: "Sym(s0)", primals_7: "Sym(s1)", tangents_1: "f32[s0, s1]", tangents_2: "f32[s0, s1]"):
view_2: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_1, [primals_6, primals_7]); tangents_1 = None
view_3: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_2, [primals_6, primals_7]); tangents_2 = None
return (view_2, view_3, primals_6, primals_7, primals_7, None, None)
def forward(self, primals_5: "Sym(s0)", primals_7: "Sym(s1)", tangents_1: "f32[s0, s1]", tangents_2: "f32[s0, s1]"):
view_2: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None
view_3: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None
return (None, None, view_2, view_3, primals_5, primals_7, primals_7)
""", # noqa: B950
)
@ -2093,14 +2093,14 @@ class GraphModule(torch.nn.Module):
normalize_gm(fw[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[s0, s1]", primals_2: "f32[s0, s1]", primals_3: "Sym(s0)", primals_4: "Sym(s1)", primals_5: "Sym(s1)", primals_6: "Sym(s0)", primals_7: "Sym(s1)"):
clone: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_1); primals_1 = None
clone_1: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_2); primals_2 = None
def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1]", primals_4: "f32[s0, s1]", primals_5: "Sym(s0)", primals_6: "Sym(s1)", primals_7: "Sym(s1)"):
clone: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
clone_1: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
mul_6: "Sym(s0*s1)" = primals_3 * primals_4; primals_3 = primals_4 = None
mul_6: "Sym(s0*s1)" = primals_1 * primals_2; primals_1 = primals_2 = None
view: "f32[s0*s1]" = torch.ops.aten.view.default(clone, [mul_6]); clone = None
view_1: "f32[s0*s1]" = torch.ops.aten.view.default(clone_1, [mul_6]); clone_1 = None
return (view, view_1, mul_6, primals_6, primals_7)
return (view, view_1, mul_6, primals_5, primals_7)
""", # noqa: B950
)
@ -2108,10 +2108,10 @@ class GraphModule(torch.nn.Module):
normalize_gm(bw[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_6: "Sym(s0)", primals_7: "Sym(s1)", tangents_1: "f32[s0*s1]", tangents_2: "f32[s0*s1]"):
view_2: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_1, [primals_6, primals_7]); tangents_1 = None
view_3: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_2, [primals_6, primals_7]); tangents_2 = None
return (view_2, view_3, primals_6, primals_7, primals_7, None, None)
def forward(self, primals_5: "Sym(s0)", primals_7: "Sym(s1)", tangents_1: "f32[s0*s1]", tangents_2: "f32[s0*s1]"):
view_2: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None
view_3: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None
return (None, None, view_2, view_3, primals_5, primals_7, primals_7)
""", # noqa: B950
)
@ -2130,14 +2130,14 @@ class GraphModule(torch.nn.Module):
normalize_gm(fw[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[s0, s1]", primals_2: "f32[s0, s1]", primals_3: "Sym(s0)", primals_4: "Sym(s1)", primals_5: "Sym(s1)", primals_6: "Sym(s0)", primals_7: "Sym(s1)"):
clone: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_1); primals_1 = None
clone_1: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_2); primals_2 = None
def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1]", primals_4: "f32[s0, s1]", primals_5: "Sym(s0)", primals_6: "Sym(s1)", primals_7: "Sym(s1)"):
clone: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
clone_1: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
mul_6: "Sym(s0*s1)" = primals_3 * primals_4; primals_3 = primals_4 = None
mul_6: "Sym(s0*s1)" = primals_1 * primals_2; primals_1 = primals_2 = None
view: "f32[s0*s1]" = torch.ops.aten.view.default(clone, [mul_6])
view_1: "f32[s0*s1]" = torch.ops.aten.view.default(clone_1, [mul_6]); clone_1 = None
return (clone, view, view_1, mul_6, primals_6, primals_7)
return (clone, view, view_1, mul_6, primals_5, primals_7)
""", # noqa: B950
)
@ -2145,10 +2145,10 @@ class GraphModule(torch.nn.Module):
normalize_gm(bw[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_6: "Sym(s0)", primals_7: "Sym(s1)", tangents_1: "f32[s0*s1]", tangents_2: "f32[s0*s1]"):
view_2: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_1, [primals_6, primals_7]); tangents_1 = None
view_3: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_2, [primals_6, primals_7]); tangents_2 = None
return (view_2, view_3, primals_6, primals_7, primals_7, None, None)
def forward(self, primals_5: "Sym(s0)", primals_7: "Sym(s1)", tangents_1: "f32[s0*s1]", tangents_2: "f32[s0*s1]"):
view_2: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None
view_3: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None
return (None, None, view_2, view_3, primals_5, primals_7, primals_7)
""", # noqa: B950
)
@ -2226,9 +2226,9 @@ class GraphModule(torch.nn.Module):
normalize_gm(fw[1].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[3, s0]", primals_2: "f32[3, s0]", primals_3: "Sym(s0)", primals_4: "Sym(s0)", primals_5: "Sym(s0)"):
clone: "f32[3, s0]" = torch.ops.aten.clone.default(primals_1); primals_1 = None
clone_1: "f32[3, s0]" = torch.ops.aten.clone.default(primals_2); primals_2 = None
def forward(self, primals_1: "Sym(s0)", primals_2: "f32[3, s0]", primals_3: "f32[3, s0]", primals_4: "Sym(s0)", primals_5: "Sym(s0)"):
clone: "f32[3, s0]" = torch.ops.aten.clone.default(primals_2); primals_2 = None
clone_1: "f32[3, s0]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
view: "f32[3*s0]" = torch.ops.aten.view.default(clone, [-1])
sym_numel_default: "Sym(3*s0)" = torch.ops.aten.sym_numel.default(clone)
@ -2255,7 +2255,7 @@ class GraphModule(torch.nn.Module):
def forward(self, primals_5: "Sym(s0)", tangents_1: "f32[3*s0]", tangents_2: "f32[3*s0]"):
view_2: "f32[3, s0]" = torch.ops.aten.view.default(tangents_1, [3, primals_5]); tangents_1 = None
view_3: "f32[3, s0]" = torch.ops.aten.view.default(tangents_2, [3, primals_5]); tangents_2 = None
return (view_2, view_3, primals_5, primals_5, None)
return (None, view_2, view_3, primals_5, primals_5)
""", # noqa: B950
)
@ -2282,9 +2282,9 @@ class GraphModule(torch.nn.Module):
normalize_gm(fw[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[3, s0]", primals_2: "f32[3, s0]", primals_3: "Sym(s0)", primals_4: "Sym(s0)", primals_5: "Sym(s0)"):
clone: "f32[3, s0]" = torch.ops.aten.clone.default(primals_1); primals_1 = None
clone_1: "f32[3, s0]" = torch.ops.aten.clone.default(primals_2); primals_2 = None
def forward(self, primals_1: "Sym(s0)", primals_2: "f32[3, s0]", primals_3: "f32[3, s0]", primals_4: "Sym(s0)", primals_5: "Sym(s0)"):
clone: "f32[3, s0]" = torch.ops.aten.clone.default(primals_2); primals_2 = None
clone_1: "f32[3, s0]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
view: "f32[3*s0]" = torch.ops.aten.view.default(clone, [-1])
sym_numel_default: "Sym(3*s0)" = torch.ops.aten.sym_numel.default(clone)
@ -2300,7 +2300,7 @@ class GraphModule(torch.nn.Module):
def forward(self, primals_5: "Sym(s0)", tangents_1: "f32[3*s0]", tangents_2: "f32[3*s0]"):
view_2: "f32[3, s0]" = torch.ops.aten.view.default(tangents_1, [3, primals_5]); tangents_1 = None
view_3: "f32[3, s0]" = torch.ops.aten.view.default(tangents_2, [3, primals_5]); tangents_2 = None
return (view_2, view_3, primals_5, primals_5, None)
return (None, view_2, view_3, primals_5, primals_5)
""", # noqa: B950
)
@ -2466,11 +2466,11 @@ class GraphModule(torch.nn.Module):
normalize_gm(fw[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f64[s0, s1]", primals_2: "i64[s2 + 1]", primals_3: "f32[s6, 0]", primals_4: "f32[s7, 0]", primals_5: "Sym(s2)", primals_6: "Sym(s1)", primals_7: "Sym(s1)", primals_8: "Sym(s1)", primals_9: "Sym(s2)", primals_10: "Sym(s3)"):
clone: "f64[s0, s1]" = torch.ops.aten.clone.default(primals_1); primals_1 = None
def forward(self, primals_1: "Sym(s2)", primals_2: "Sym(s3)", primals_3: "Sym(s1)", primals_4: "f64[s0, s1]", primals_5: "i64[s2 + 1]", primals_6: "f32[s6, 0]", primals_7: "f32[s7, 0]", primals_8: "Sym(s2)", primals_9: "Sym(s1)", primals_10: "Sym(s1)"):
clone: "f64[s0, s1]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
mul: "f64[s0, s1]" = torch.ops.aten.mul.Tensor(clone, primals_9); clone = None
return (mul, primals_2, primals_3, primals_4, primals_9, primals_8, primals_8, primals_8, primals_9)
mul: "f64[s0, s1]" = torch.ops.aten.mul.Tensor(clone, primals_1); clone = None
return (mul, primals_5, primals_6, primals_7, primals_8, primals_10, primals_10, primals_1, primals_8, primals_10)
""", # noqa: B950
)
@ -2478,9 +2478,9 @@ class GraphModule(torch.nn.Module):
normalize_gm(bw[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_8: "Sym(s1)", primals_9: "Sym(s2)", tangents_1: "f64[s0, s1]", tangents_2: "i64[s2 + 1]", tangents_3: "f32[s6, 0]", tangents_4: "f32[s7, 0]"):
mul_1: "f64[s0, s1]" = torch.ops.aten.mul.Tensor(tangents_1, primals_9); tangents_1 = None
return (mul_1, tangents_2, tangents_3, tangents_4, primals_9, primals_8, primals_8, None, None, None)
def forward(self, primals_1: "Sym(s2)", primals_8: "Sym(s2)", primals_10: "Sym(s1)", tangents_1: "f64[s0, s1]", tangents_2: "i64[s2 + 1]", tangents_3: "f32[s6, 0]", tangents_4: "f32[s7, 0]"):
mul_1: "f64[s0, s1]" = torch.ops.aten.mul.Tensor(tangents_1, primals_1); tangents_1 = primals_1 = None
return (None, None, None, mul_1, tangents_2, tangents_3, tangents_4, primals_8, primals_10, primals_10)
""", # noqa: B950
)
@ -2499,12 +2499,12 @@ class GraphModule(torch.nn.Module):
normalize_gm(fw[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f64[s0, s1]", primals_2: "i64[s2 + 1]", primals_3: "f32[s6, 0]", primals_4: "f32[s7, 0]", primals_5: "Sym(s2)", primals_6: "Sym(s1)", primals_7: "Sym(s1)", primals_8: "Sym(s1)", primals_9: "Sym(s2)", primals_10: "Sym(s3)"):
clone: "f64[s0, s1]" = torch.ops.aten.clone.default(primals_1); primals_1 = None
def forward(self, primals_1: "Sym(s2)", primals_2: "Sym(s3)", primals_3: "Sym(s1)", primals_4: "f64[s0, s1]", primals_5: "i64[s2 + 1]", primals_6: "f32[s6, 0]", primals_7: "f32[s7, 0]", primals_8: "Sym(s2)", primals_9: "Sym(s1)", primals_10: "Sym(s1)"):
clone: "f64[s0, s1]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
cat: "f64[s0, 2*s1]" = torch.ops.aten.cat.default([clone, clone], 1); clone = None
add_2: "Sym(2*s1)" = primals_8 + primals_8
return (cat, primals_2, primals_3, primals_4, primals_9, add_2, add_2, primals_8, primals_9, add_2)
add_2: "Sym(2*s1)" = primals_10 + primals_10
return (cat, primals_5, primals_6, primals_7, primals_8, add_2, add_2, primals_8, primals_10, add_2)
""", # noqa: B950
)
@ -2512,12 +2512,12 @@ class GraphModule(torch.nn.Module):
normalize_gm(bw[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_8: "Sym(s1)", primals_9: "Sym(s2)", add_2: "Sym(2*s1)", tangents_1: "f64[s0, 2*s1]", tangents_2: "i64[s2 + 1]", tangents_3: "f32[s6, 0]", tangents_4: "f32[s7, 0]"):
slice_1: "f64[s0, s1]" = torch.ops.aten.slice.Tensor(tangents_1, 1, 0, primals_8)
slice_2: "f64[s0, s1]" = torch.ops.aten.slice.Tensor(tangents_1, 1, primals_8, add_2); tangents_1 = add_2 = None
def forward(self, primals_8: "Sym(s2)", primals_10: "Sym(s1)", add_2: "Sym(2*s1)", tangents_1: "f64[s0, 2*s1]", tangents_2: "i64[s2 + 1]", tangents_3: "f32[s6, 0]", tangents_4: "f32[s7, 0]"):
slice_1: "f64[s0, s1]" = torch.ops.aten.slice.Tensor(tangents_1, 1, 0, primals_10)
slice_2: "f64[s0, s1]" = torch.ops.aten.slice.Tensor(tangents_1, 1, primals_10, add_2); tangents_1 = add_2 = None
add_4: "f64[s0, s1]" = torch.ops.aten.add.Tensor(slice_1, slice_2); slice_1 = slice_2 = None
return (add_4, tangents_2, tangents_3, tangents_4, primals_9, primals_8, primals_8, None, None, None)
return (None, None, None, add_4, tangents_2, tangents_3, tangents_4, primals_8, primals_10, primals_10)
""", # noqa: B950
)
@ -2545,7 +2545,7 @@ class GraphModule(torch.nn.Module):
normalize_gm(fw[0].print_readable(print_output=False)),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f64[9, s2]", arg1_1: "i64[s3 + 1]", arg2_1: "f32[s7, 0]", arg3_1: "f32[s8, 0]", arg4_1: "Sym(s3)", arg5_1: "Sym(s2)", arg6_1: "Sym(s2)", arg7_1: "Sym(s2)", arg8_1: "Sym(s3)", arg9_1: "Sym(s4)"):
def forward(self, arg0_1: "Sym(s3)", arg1_1: "Sym(s4)", arg2_1: "Sym(s2)", arg3_1: "f64[9, s2]", arg4_1: "i64[s3 + 1]", arg5_1: "f32[s7, 0]", arg6_1: "f32[s8, 0]", arg7_1: "Sym(s3)", arg8_1: "Sym(s2)", arg9_1: "Sym(s2)"):
randn: "f64[2, 5]" = torch.ops.aten.randn.default([2, 5], dtype = torch.float64, device = device(type='cpu'), pin_memory = False)
randn_1: "f64[3, 5]" = torch.ops.aten.randn.default([3, 5], dtype = torch.float64, device = device(type='cpu'), pin_memory = False)
randn_2: "f64[4, 5]" = torch.ops.aten.randn.default([4, 5], dtype = torch.float64, device = device(type='cpu'), pin_memory = False)
@ -2559,7 +2559,7 @@ class <lambda>(torch.nn.Module):
zeros_1: "f32[2, 0]" = torch.ops.aten.zeros.default([2, 0], device = device(type='cpu'), pin_memory = False)
zeros_2: "f32[4, 0]" = torch.ops.aten.zeros.default([4, 0], device = device(type='cpu'), pin_memory = False)
cat_2: "f64[9, s2 + 5]" = torch.ops.aten.cat.default([cat, arg0_1], 1); cat = arg0_1 = None
cat_2: "f64[9, s2 + 5]" = torch.ops.aten.cat.default([cat, arg3_1], 1); cat = arg3_1 = None
sin: "f64[9, s2 + 5]" = torch.ops.aten.sin.default(cat_2)
mul: "f64[9, s2 + 5]" = torch.ops.aten.mul.Tensor(sin, 3); sin = None
@ -2722,7 +2722,7 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase, NestedTensorTestCase):
norm_graph,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_nt_: "f64[3, s1, 5]", s1: "Sym(s1)"):
def forward(self, s1: "Sym(s1)", L_nt_: "f64[3, s1, 5]"):
l_nt_ = L_nt_
add: "f64[3, s1, 5]" = l_nt_ + 2; l_nt_ = None

View File

@ -368,15 +368,17 @@ class TestControlFlow(TestCase):
gm.code.strip(),
"""\
def forward(self, pred_1, x_1):
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1,)); true_graph_0 = false_graph_0 = None
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1, sym_size_int)); true_graph_0 = false_graph_0 = None
getitem = cond[0]; cond = None
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
true_graph_1 = self.true_graph_1
false_graph_1 = self.false_graph_1
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = None
getitem_1 = cond_1[0]; cond_1 = None
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1, sym_size_int)); pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = sym_size_int = None
getitem_1 = cond_1[0]
getitem_2 = cond_1[1]; cond_1 = getitem_2 = None
return (getitem_1,)""", # noqa: B950
)
@ -409,15 +411,17 @@ def forward(self, pred_1, x_1):
gm.code.strip(),
"""\
def forward(self, pred_1, x_1):
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1,)); true_graph_0 = false_graph_0 = None
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1, sym_size_int)); true_graph_0 = false_graph_0 = None
getitem = cond[0]; cond = None
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
true_graph_1 = self.true_graph_1
false_graph_1 = self.false_graph_1
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = None
getitem_1 = cond_1[0]; cond_1 = None
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1, sym_size_int)); pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = sym_size_int = None
getitem_1 = cond_1[0]
getitem_2 = cond_1[1]; cond_1 = getitem_2 = None
return (getitem_1,)""", # noqa: B950
)
@ -518,16 +522,20 @@ def forward(self, pred_1, x_1):
gm.code.strip(),
"""\
def forward(self, pred_1, x_1, y_1, z_1):
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0); x_1 = None
sym_size_int_1 = torch.ops.aten.sym_size.int(y_1, 0)
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (z_1, y_1)); true_graph_0 = false_graph_0 = None
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (z_1, y_1, sym_size_int, sym_size_int_1)); true_graph_0 = false_graph_0 = None
getitem = cond[0]; cond = None
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
true_graph_1 = self.true_graph_1
false_graph_1 = self.false_graph_1
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, z_1, y_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = z_1 = y_1 = None
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, z_1, y_1, sym_size_int, sym_size_int_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = z_1 = y_1 = sym_size_int = sym_size_int_1 = None
getitem_1 = cond_1[0]
getitem_2 = cond_1[1]; cond_1 = getitem_2 = None
getitem_2 = cond_1[1]; getitem_2 = None
getitem_3 = cond_1[2]; getitem_3 = None
getitem_4 = cond_1[3]; cond_1 = getitem_4 = None
return (getitem_1,)""", # noqa: B950
)
@ -572,12 +580,13 @@ def forward(self, pred_1, x_1, y_1, z_1):
gm.code.strip(),
"""\
def forward(self, pred_1, x_1):
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
_param_constant0 = self._param_constant0
_param_constant1 = self._param_constant1
_tensor_constant0 = self._tensor_constant0
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (_param_constant0, _param_constant1, x_1, _tensor_constant0)); true_graph_0 = false_graph_0 = _param_constant0 = _param_constant1 = _tensor_constant0 = None
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (_param_constant0, _param_constant1, x_1, sym_size_int, _tensor_constant0)); true_graph_0 = false_graph_0 = _param_constant0 = _param_constant1 = _tensor_constant0 = None
getitem = cond[0]; cond = None
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
true_graph_1 = self.true_graph_1
@ -585,11 +594,12 @@ def forward(self, pred_1, x_1):
_param_constant0_1 = self._param_constant0
_param_constant1_1 = self._param_constant1
_tensor_constant0_1 = self._tensor_constant0
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, _param_constant0_1, _param_constant1_1, x_1, _tensor_constant0_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = _param_constant0_1 = _param_constant1_1 = x_1 = _tensor_constant0_1 = None
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, _param_constant0_1, _param_constant1_1, x_1, sym_size_int, _tensor_constant0_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = _param_constant0_1 = _param_constant1_1 = x_1 = sym_size_int = _tensor_constant0_1 = None
getitem_1 = cond_1[0]; getitem_1 = None
getitem_2 = cond_1[1]
getitem_3 = cond_1[2]; getitem_3 = None
getitem_4 = cond_1[3]; cond_1 = getitem_4 = None
getitem_4 = cond_1[3]; getitem_4 = None
getitem_5 = cond_1[4]; cond_1 = getitem_5 = None
return (getitem_2,)""", # noqa: B950
)
@ -692,24 +702,30 @@ def forward(self, x_1):
gm.code.strip(),
"""\
def forward(self, pred_1, a_1, b_1, c_1):
sym_size_int = torch.ops.aten.sym_size.int(a_1, 0)
sym_size_int_1 = torch.ops.aten.sym_size.int(b_1, 0)
sym_size_int_2 = torch.ops.aten.sym_size.int(c_1, 0)
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (a_1, b_1, c_1)); true_graph_0 = false_graph_0 = None
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (a_1, b_1, sym_size_int, sym_size_int_1, c_1, sym_size_int_2)); true_graph_0 = false_graph_0 = None
getitem = cond[0]; cond = None
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
true_graph_1 = self.true_graph_1
false_graph_1 = self.false_graph_1
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, a_1, b_1, c_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = a_1 = b_1 = c_1 = None
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, a_1, b_1, sym_size_int, sym_size_int_1, c_1, sym_size_int_2)); pred_1 = true_graph_1 = false_graph_1 = ones_like = a_1 = b_1 = sym_size_int = sym_size_int_1 = c_1 = sym_size_int_2 = None
getitem_1 = cond_1[0]
getitem_2 = cond_1[1]
getitem_3 = cond_1[2]; cond_1 = getitem_3 = None
getitem_3 = cond_1[2]; getitem_3 = None
getitem_4 = cond_1[3]; getitem_4 = None
getitem_5 = cond_1[4]; getitem_5 = None
getitem_6 = cond_1[5]; cond_1 = getitem_6 = None
return (getitem_1, getitem_2)""", # noqa: B950
)
# Forward
self.assertExpectedInline(
gm.true_graph_0.code.strip(),
"""\
def forward(self, arg0_1, arg1_1, arg2_1):
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1):
add = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
return (add,)""",
)
@ -717,11 +733,11 @@ def forward(self, arg0_1, arg1_1, arg2_1):
self.assertExpectedInline(
gm.true_graph_1.code.strip(),
"""\
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1):
add = torch.ops.aten.add.Tensor(arg1_1, arg2_1); arg1_1 = arg2_1 = add = None
clone = torch.ops.aten.clone.default(arg0_1)
clone_1 = torch.ops.aten.clone.default(arg0_1); arg0_1 = None
return [clone, clone_1, None]""",
return [clone, clone_1, None, None, None, None]""",
)
def test_cond_autograd_pytree_input(self):
@ -1050,15 +1066,17 @@ def forward(self, pred_1, x_1):
gm.code.strip(),
"""\
def forward(self, pred_1, x_1):
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1,)); true_graph_0 = false_graph_0 = None
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1, sym_size_int)); true_graph_0 = false_graph_0 = None
getitem = cond[0]; cond = None
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
true_graph_1 = self.true_graph_1
false_graph_1 = self.false_graph_1
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = None
getitem_1 = cond_1[0]; cond_1 = None
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1, sym_size_int)); pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = sym_size_int = None
getitem_1 = cond_1[0]
getitem_2 = cond_1[1]; cond_1 = getitem_2 = None
return (getitem_1,)""", # noqa: B950
)
@ -1623,7 +1641,7 @@ def forward(self, pred_1, x_1):
self.assertEqual(result, expected_result)
# TODO: Does not work because of the usage of vmap witin associative_scan
# The parameterization is commented out for the moment and the test is marked with expected fail
# The paT206899919 rameterization is commented out for the moment and the test is marked with expected fail
# Fails with: AssertionError: scan is not an OpOverload
@skipIfRocm(msg="Unsupported on ROCM yet")
@unittest.skipIf(not SM70OrLater, "triton")
@ -2370,10 +2388,14 @@ def forward(self, fct_1, init_1, xs_1):
select = torch.ops.aten.select.int(xs_1, 0, 0)
add = torch.ops.aten.add.Tensor(init_1, select); add = None
add_1 = torch.ops.aten.add.Tensor(init_1, select); select = add_1 = None
sym_size_int_1 = torch.ops.aten.sym_size.int(init_1, 1)
sym_size_int_2 = torch.ops.aten.sym_size.int(init_1, 2)
clone = torch.ops.aten.clone.default(init_1); clone = None
select_copy = torch.ops.aten.select_copy.int(xs_1, 0, 0); select_copy = None
sym_size_int_3 = torch.ops.aten.sym_size.int(xs_1, 1)
sym_size_int_4 = torch.ops.aten.sym_size.int(xs_1, 2)
scan_combine_graph_0 = self.scan_combine_graph_0
scan = torch.ops.higher_order.scan(scan_combine_graph_0, [init_1], [xs_1], 0, True, []); scan_combine_graph_0 = init_1 = xs_1 = None
scan = torch.ops.higher_order.scan(scan_combine_graph_0, [init_1], [xs_1], 0, True, [sym_size_int_1, sym_size_int_2, sym_size_int_3, sym_size_int_4]); scan_combine_graph_0 = init_1 = xs_1 = sym_size_int_1 = sym_size_int_2 = sym_size_int_3 = sym_size_int_4 = None
getitem = scan[0]
getitem_1 = scan[1]; scan = None
return (getitem, getitem_1)""", # noqa: B950
@ -2612,6 +2634,8 @@ class AssociativeScanTests(TestCase):
num_dims = [random.randint(2, 5) for _ in range(4)]
for num_dim in num_dims:
# To avoid triggering automatic dynamic shape
torch._dynamo.reset()
shapes = [random.randint(1, 9) for _ in range(num_dim)]
rnd_scan_dim = random.randint(0, num_dim - 1)
x = torch.randn(*shapes, device=device)
@ -3692,9 +3716,13 @@ def forward(self, l_iter_, l_x_, l__self___dec_cond_fn, l__self___linear_bias_bo
gm.code.strip("\n"),
"""\
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
sym_size_int = torch.ops.aten.sym_size.int(arg2_1, 0)
sym_size_int_1 = torch.ops.aten.sym_size.int(arg2_1, 1)
sym_size_int_2 = torch.ops.aten.sym_size.int(arg3_1, 0)
sym_size_int_3 = torch.ops.aten.sym_size.int(arg3_1, 1)
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
while_loop_body_graph_0 = self.while_loop_body_graph_0
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (arg0_1, arg1_1, arg2_1, arg3_1), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = arg0_1 = arg1_1 = arg2_1 = arg3_1 = None
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (arg0_1, arg1_1, arg2_1, arg3_1), (sym_size_int, sym_size_int_1, sym_size_int_2, sym_size_int_3)); while_loop_cond_graph_0 = while_loop_body_graph_0 = arg0_1 = arg1_1 = arg2_1 = arg3_1 = sym_size_int = sym_size_int_1 = sym_size_int_2 = sym_size_int_3 = None
getitem = while_loop[0]
getitem_1 = while_loop[1]
getitem_2 = while_loop[2]
@ -3705,10 +3733,10 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
self.assertExpectedInline(
outer_body.code.strip("\n"),
"""\
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1):
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
while_loop_body_graph_0 = self.while_loop_body_graph_0
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (arg0_1, arg1_1, arg2_1, arg3_1), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = arg0_1 = arg1_1 = arg2_1 = arg3_1 = None
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (arg0_1, arg1_1, arg2_1, arg3_1), (arg7_1, arg7_1, arg7_1, arg7_1)); while_loop_cond_graph_0 = while_loop_body_graph_0 = arg0_1 = arg1_1 = arg2_1 = arg3_1 = arg7_1 = None
getitem = while_loop[0]
getitem_1 = while_loop[1]
getitem_2 = while_loop[2]
@ -3723,10 +3751,10 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
self.assertExpectedInline(
outer_body.code.strip("\n"),
"""\
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1):
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
while_loop_body_graph_0 = self.while_loop_body_graph_0
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (arg0_1, arg1_1, arg2_1, arg3_1), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = arg0_1 = arg1_1 = arg2_1 = arg3_1 = None
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (arg0_1, arg1_1, arg2_1, arg3_1), (arg7_1, arg7_1, arg7_1, arg7_1)); while_loop_cond_graph_0 = while_loop_body_graph_0 = arg0_1 = arg1_1 = arg2_1 = arg3_1 = arg7_1 = None
getitem = while_loop[0]
getitem_1 = while_loop[1]
getitem_2 = while_loop[2]
@ -3741,7 +3769,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
self.assertExpectedInline(
inner_body.code.strip("\n"),
"""\
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1):
clone = torch.ops.aten.clone.default(arg0_1); arg0_1 = None
sub = torch.ops.aten.sub.Tensor(arg1_1, 1); arg1_1 = None
add = torch.ops.aten.add.Tensor(arg2_1, 3.14); arg2_1 = None
@ -3752,7 +3780,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
self.assertExpectedInline(
inner_cond.code.strip("\n"),
"""\
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1):
gt = torch.ops.aten.gt.Scalar(arg1_1, 0); arg1_1 = None
return gt
""",
@ -3854,23 +3882,27 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
def forward(self, a_1, b_1):
sum_1 = torch.ops.aten.sum.default(a_1)
gt = torch.ops.aten.gt.Scalar(sum_1, 0); sum_1 = None
sym_size_int = torch.ops.aten.sym_size.int(a_1, 0)
sym_size_int_1 = torch.ops.aten.sym_size.int(a_1, 1)
sym_size_int_2 = torch.ops.aten.sym_size.int(b_1, 0)
sym_size_int_3 = torch.ops.aten.sym_size.int(b_1, 1)
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [a_1, b_1]); gt = true_graph_0 = false_graph_0 = a_1 = b_1 = None
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [a_1, b_1, sym_size_int, sym_size_int_1, sym_size_int_2, sym_size_int_3]); gt = true_graph_0 = false_graph_0 = a_1 = b_1 = sym_size_int = sym_size_int_1 = sym_size_int_2 = sym_size_int_3 = None
getitem = cond[0]; cond = None
return getitem""", # noqa: B950
)
self.assertExpectedInline(
gm.true_graph_0.code.strip(),
"""\
def forward(self, arg0_1, arg1_1):
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1):
add = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
return (add,)""",
)
self.assertExpectedInline(
gm.false_graph_0.code.strip(),
"""\
def forward(self, arg0_1, arg1_1):
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1):
mul = torch.ops.aten.mul.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
return (mul,)""",
)
@ -3916,6 +3948,8 @@ def forward(self, arg0_1, arg1_1):
if isinstance(val, tuple):
for v in val:
yield v.fake_mode.shape_env
elif isinstance(val, torch.SymInt):
yield val.node.shape_env
else:
yield val.fake_mode.shape_env
@ -4983,10 +5017,11 @@ def forward(self, arg0_1):
"""\
def forward(self, x_1):
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
eq = sym_size_int == 4; sym_size_int = None
eq = sym_size_int == 4
sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1)
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1]); eq = true_graph_0 = false_graph_0 = x_1 = None
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1, sym_size_int, sym_size_int_1]); eq = true_graph_0 = false_graph_0 = x_1 = sym_size_int = sym_size_int_1 = None
getitem = cond[0]; cond = None
return getitem""", # noqa: B950
)
@ -5015,11 +5050,12 @@ def forward(self, x_1):
nonzero = torch.ops.aten.nonzero.default(x_1)
sym_size_int = torch.ops.aten.sym_size.int(nonzero, 0); nonzero = None
gt = sym_size_int > 3; sym_size_int = None
sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 0)
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [x_1]); gt = true_graph_0 = false_graph_0 = x_1 = None
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [x_1, sym_size_int_1]); gt = true_graph_0 = false_graph_0 = x_1 = sym_size_int_1 = None
getitem = cond[0]; cond = None
return getitem""",
return getitem""", # noqa: B950
)
def _check_closure_correctly_lifted(self, f, *, args, exp_res, exp_arg_num):
@ -5092,19 +5128,20 @@ def forward(self, x_1):
"""\
def forward(self, x_1):
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
eq = sym_size_int == 4; sym_size_int = None
eq = sym_size_int == 4
sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1)
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
_tensor_constant0 = self._tensor_constant0
_tensor_constant1 = self._tensor_constant1
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1, _tensor_constant0, _tensor_constant1]); eq = true_graph_0 = false_graph_0 = x_1 = _tensor_constant0 = _tensor_constant1 = None
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1, _tensor_constant0, sym_size_int, sym_size_int_1, _tensor_constant1]); eq = true_graph_0 = false_graph_0 = x_1 = _tensor_constant0 = sym_size_int = sym_size_int_1 = _tensor_constant1 = None
getitem = cond[0]; cond = None
return getitem""", # noqa: B950
)
self.assertExpectedInline(
gm.true_graph_0.code.strip(),
"""\
def forward(self, arg0_1, arg1_1, arg2_1):
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
add = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
return (add,)""",
)
@ -5330,10 +5367,11 @@ def forward(self, arg0_1, arg1_1):
"""\
def forward(self, x_1):
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
eq = sym_size_int == 4; sym_size_int = None
eq = sym_size_int == 4
sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1)
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1]); eq = true_graph_0 = false_graph_0 = x_1 = None
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1, sym_size_int, sym_size_int_1]); eq = true_graph_0 = false_graph_0 = x_1 = sym_size_int = sym_size_int_1 = None
getitem = cond[0]; cond = None
return getitem""", # noqa: B950
)
@ -5341,7 +5379,7 @@ def forward(self, x_1):
self.assertExpectedInline(
gm.true_graph_0.code.strip(),
"""\
def forward(self, arg0_1):
def forward(self, arg0_1, arg1_1, arg2_1):
cos = torch.ops.aten.cos.default(arg0_1)
sub = torch.ops.aten.sub.Tensor(arg0_1, cos); arg0_1 = cos = None
return (sub,)""",
@ -5350,7 +5388,7 @@ def forward(self, arg0_1):
self.assertExpectedInline(
gm.false_graph_0.code.strip(),
"""\
def forward(self, arg0_1):
def forward(self, arg0_1, arg1_1, arg2_1):
sin = torch.ops.aten.sin.default(arg0_1)
add = torch.ops.aten.add.Tensor(arg0_1, sin); arg0_1 = sin = None
return (add,)""",
@ -5689,7 +5727,7 @@ def forward(self, s0 : torch.SymInt, L_a_ : torch.Tensor, L_b_ : torch.Tensor, L
tensor = torch.tensor([True])
cond_true_0 = self.cond_true_0
cond_false_0 = self.cond_false_0
cond = torch.ops.higher_order.cond(tensor, cond_true_0, cond_false_0, [l_a_, l_b_, l_self_num]); tensor = cond_true_0 = cond_false_0 = l_a_ = l_b_ = l_self_num = None
cond = torch.ops.higher_order.cond(tensor, cond_true_0, cond_false_0, [l_a_, l_b_, l_self_num, s0]); tensor = cond_true_0 = cond_false_0 = l_a_ = l_b_ = l_self_num = s0 = None
getitem = cond[0]; cond = None
return (getitem,)""", # noqa: B950
)

View File

@ -774,6 +774,16 @@ def skip_if_halide(fn):
return wrapper
def skip_if_dynamic(fn):
@functools.wraps(fn)
def wrapper(self):
if ifdynstaticdefault(True, False) or torch._dynamo.config.dynamic_shapes:
raise unittest.SkipTest("associtaive_scan doesn's support lifted SymInts.")
return fn(self)
return wrapper
def is_halide_backend(device):
if getattr(device, "type", device) == "cpu":
return config.cpu_backend == "halide"
@ -2038,6 +2048,7 @@ class CommonTemplate:
@skipCUDAIf(TEST_WITH_ROCM, "associative_scan is not supported on ROCm")
@skip_if_halide # scan ops
@skip_if_dynamic # TODO: support lifted symints when dynamic
def test_custom_scan_op(self):
if self.device != "cuda":
raise unittest.SkipTest("associative_scan only supported on GPU")
@ -2063,6 +2074,7 @@ class CommonTemplate:
self.assertEqual(expect, actual)
@skip_if_halide # scan ops
@skip_if_dynamic # TODO: support lifted symints when dynamic
def test_custom_scan_op_compiled(self):
if self.device != "cuda":
raise unittest.SkipTest("associative_scan only supported on GPU")
@ -2090,6 +2102,7 @@ class CommonTemplate:
@skipCUDAIf(TEST_WITH_ROCM, "associative_scan is not supported on ROCm")
@skip_if_halide # scan ops
@skip_if_dynamic # TODO: support lifted symints when dynamic
def test_custom_scan_op_multi_input(self):
if self.device != "cuda":
raise unittest.SkipTest("associative_scan only supported on GPU")
@ -2114,6 +2127,7 @@ class CommonTemplate:
@skipCUDAIf(TEST_WITH_ROCM, "associative_scan is not supported on ROCm")
@skip_if_halide # scan ops
@skip_if_dynamic # TODO: support lifted symints when dynamic
def test_custom_scan_would_split(self):
if self.device != "cuda":
raise unittest.SkipTest("associative_scan only supported on GPU")

View File

@ -255,7 +255,7 @@ class OutputGraph:
torch_function_mode_stack,
):
super().__init__()
self.tracers = [SubgraphTracer(self, export_root=export)]
self.tracers = [SubgraphTracer(self, is_export=export)]
# Map from graph input's `Source` to its `VariableTracker` to
# de-duplicate graph inputs by source and reuse the tracker
self.input_source_to_var: Dict[Source, VariableTracker] = {}
@ -577,7 +577,10 @@ class OutputGraph:
prior_tracer
if prior_tracer
else SubgraphTracer(
self, parent=self.current_tracer, source_target=source_target
self,
parent=self.current_tracer,
source_target=source_target,
is_export=self.current_tracer.is_export,
)
)
self.tracers.append(tracer)
@ -657,71 +660,6 @@ class OutputGraph:
def current_tx(self):
return self.root_tx if not self._current_tx else self._current_tx[-1]
def add_symbol_bindings(self, arg: GraphArg):
# Insert implicit size vars as necessary. With dynamic shapes, we
# maintain the invariant that every sizevar gets a direct SymInt input
# into the graph. This means downstream graph transforms can assume
# every size variable is explicitly bound and accessible, instead of
# having to pull it out implicitly from tensors.
if self.export:
return
assert arg.fake_tensor is not None
def bind_symint(s: torch.SymInt, prop):
if not (is_symbolic(s) and isinstance(s.node.expr, sympy.Symbol)):
return
s0 = s.node.expr
if s0 in self.bound_symbols:
return
log.debug("bind_symint %s %s", s, prop.name())
# TODO: don't readd symint if we already have it in graph
# (this is harmless because we do remove the unused ones later)
proxy = self.root_tracer.create_graph_input(
str(s0),
type(s),
s,
before=True,
source=prop,
)
self.root_tracer.bound_symbols[s0] = proxy
assert isinstance(s, torch.SymInt)
proxy.node.meta["grapharg"] = GraphArg(
prop,
s,
pass_arg_as_tensor=False,
fake_tensor=None,
is_tensor=False,
)
def handle_tensor(t, src):
for i, s in enumerate(t.size()):
bind_symint(s, TensorPropertySource(src, TensorProperty.SIZE, i))
if t.layout is torch.strided:
for i, s in enumerate(t.stride()):
bind_symint(s, TensorPropertySource(src, TensorProperty.STRIDE, i))
bind_symint(
t.storage_offset(),
TensorPropertySource(src, TensorProperty.STORAGE_OFFSET),
)
elif t.layout is torch.sparse_coo:
handle_tensor(t._indices(), src)
handle_tensor(t._values(), src)
elif t.layout in {torch.sparse_csr, torch.sparse_bsr}:
handle_tensor(t.crow_indices(), src)
handle_tensor(t.col_indices(), src)
elif t.layout in {torch.sparse_csc, torch.sparse_bsc}:
handle_tensor(t.ccol_indices(), src)
handle_tensor(t.row_indices(), src)
if is_traceable_wrapper_subclass(t):
attrs, ctx = t.__tensor_flatten__()
for attr in attrs:
inner_t = getattr(t, attr)
handle_tensor(inner_t, AttrSource(src, attr))
handle_tensor(arg.fake_tensor, arg.source)
def count_calls(self):
return count_calls(self.graph)
@ -1834,6 +1772,17 @@ def check_pt2_compliant_op(output_graph, kind, target, args, kwargs):
_compile_id_counter = itertools.count()
class LazyProxy:
def __init__(self, tracer, fn, *args, **kwargs):
self.tracer = tracer
self.fn = fn
self.args = args
self.kwargs = kwargs
def __call__(self):
return self.fn(*self.args, **self.kwargs)
class SubgraphTracer(fx.Tracer):
"""
Holds an FX graph that is being traced. OutputGraph owns a SubgraphTracer
@ -1842,19 +1791,13 @@ class SubgraphTracer(fx.Tracer):
compiling and executing the graph.
"""
def __init__(
self, output_graph, parent=None, export_root=False, source_target=None
):
def __init__(self, output_graph, parent=None, is_export=False, source_target=None):
super().__init__()
self.output_graph = weakref.proxy(output_graph)
self.graph = torch.fx.Graph()
# The export is only ever set for the ROOT tracer. It controls
# whether or not certain inputs are allowed to be added or not.
# Look at call sites of create_graph_input to see how it is used.
if export_root:
assert parent is None
self.export_root = export_root
# See note [Export inputs must be explicitly passed in]
self.is_export = is_export
# Map from graph input name to its placeholder proxy object, where the
# map's keys give all current placeholder node names and can be used to
# create unique node names
@ -1879,8 +1822,11 @@ class SubgraphTracer(fx.Tracer):
# Dicts maintain the order of args for the HigherOrderOperator call.
self.lifted_freevars = {}
# map symbols to their bound proxy placeholders.
self.bound_symbols: Dict[sympy.Symbol, torch.fx.Proxy] = {}
# map basic symbols (unbacked and unbacked) to their bound proxies.
# There are only two cases where bound_symbols will be recorded:
# 1. when we create_graph_input for a backed SymInt that's basic symbol
# 2. when we track_unbacked_symbols for intermediate results that contain unbacked symints.
self.bound_symbols: Dict[sympy.Symbol, Union[torch.fx.Proxy, LazyProxy]] = {}
self.prev_inst = None
# True if this tracer is currently tracing into torch.utils.checkpoint
@ -1892,6 +1838,8 @@ class SubgraphTracer(fx.Tracer):
# backward recomputation of the checkpoint region doesn't affect its correctness.
self.allow_side_effects_under_checkpoint = False
self.debug_level: int = parent.debug_level + 1 if parent is not None else 0
self._cur_code = None
self._orig_gm_meta = None
self._orig_gm_lineno_map = None
@ -2139,15 +2087,19 @@ class SubgraphTracer(fx.Tracer):
self, name, type_expr, example_value, before=False, source=None
):
log.debug(
"create_graph_input %s %s",
"create_graph_input %s %s %s at debug_level %s before=%s",
name,
source.name() if source is not None else "(none)",
example_value,
self.debug_level,
before,
)
if source is None:
assert (
self.parent is not None
), "you are required to provide a source for inputs on the root tracer"
), f"you are required to provide a source for inputs {name} example_val {example_value} on the root tracer"
# Note [Export inputs must be explicitly passed in]
# In eager, we are generally OK with adding graph inputs whenever we
# want, because we take care of writing the bytecode that knows how
# to source all the inputs.
@ -2156,7 +2108,7 @@ class SubgraphTracer(fx.Tracer):
# object which only depends on the inputs you explicitly passed to it.
# So we are a bit more strict about what sources can become inputs
# in export
if self.export_root:
if self.is_export and self.parent is None:
if not is_from_local_source(source, allow_cell_or_freevar=False):
self.output_graph.source_to_user_stacks.setdefault(source, []).append(
TracingContext.extract_stack()
@ -2188,6 +2140,44 @@ class SubgraphTracer(fx.Tracer):
self.input_name_to_proxy[k] = v
else:
self.input_name_to_proxy[name] = proxy
# NOTE: [Auto lift basic free symbols when create_graph_input]
# Whenever we call create_graph_input, we try to also lift the basic symbols in example values
# as graph input.
# This applies to both top-level graph and subgraphs in higher order ops.
# It has several cases:
# 1. When create_graph_input for a tensor that has symbolic shapes,
# we look for basic symbols in its size and stride, we check if the symbol is bound
# in current graph (i.e. bound_symbols), it it's not bound, we'll create a placeholder
# for it then recursively check its parent, creates ph if not bound.
# Every tracer maintains a mapping (i.e. lifted_freevars)
# that maps from parent proxy to proxy in current tracer for the symbol.
# 2. When create_graph_input for a tensor with unbacked symbolic shapes,
# Backed symbols all come from inputs's symbolic shape. But unbacked symbols
# can be created while tracing. So we use track_unbacked_symbols will intercept
# at wrap_fx_proxy, and try to bind the unbacked symbols immediately after they're
# created.
# 3. subgraph will also lifted basic symbols in compound exprs of tensor shape.
# For example, if an input to subgraph takes size [s1+s2//8], we'll look for the
# the free symbols in the sizes and lift as inputs similar to 1 in _lift_symbols_in_symint)
# 4. When create_graph_input for a SymInt, if the symint is a basic symbol, we'll track it
# in bound_symbols so that we don't lift the same basic symbol twice. When the symint is a
# compound expr, we'll just create the proxy for the compouned expr but not lift its basic symbols.
# Also see NOTE: [Export inputs must be explicitly passed in]
is_strict_export = self.is_export
is_non_strict_export = torch.compiler.is_compiling()
if (
not is_strict_export
and not is_non_strict_export
and isinstance(example_value, torch.Tensor)
):
self._lift_basic_symbols(example_value, source)
# Bound the symbol to ph if example_value is a SymInt with basic symbol.
if isinstance(example_value, torch.SymInt) and isinstance(
example_value.node.expr, sympy.Symbol
):
self.bound_symbols[example_value.node.expr] = proxy
return proxy
# See NOTE: [Nested SubgraphTracer and free_variable handling] for more details
@ -2197,6 +2187,21 @@ class SubgraphTracer(fx.Tracer):
assert (
self.parent is not None
), "lift_tracked_freevar_to_input should not be called on root SubgraphTracer"
example_value = proxy.node.meta["example_value"]
# To avoid lifting the same symbol twice, we check whether basic symbols has been tracked.
# For example, the basic symbols may have already been lifted for current subgraph when
# we automatically lift basic symbols in the sizes/strides of a tensor t.
# Suppose parent graph calls sz = t.size()[0], it creates
# a proxy in parent and the subgraph accesses sz via closure. sz's proxy is not tracked
# in current sub-tracer so we may lift the same symbol twice.
if (
isinstance(example_value, torch.SymInt)
and example_value.node.expr in self.bound_symbols
):
return self.bound_symbols[example_value.node.expr]
# Proxys are associated with VariableTracker.
# It is possible that we've already lifted the Proxy to be an input.
# If that is the case, just return the already lifted Proxy.
@ -2228,6 +2233,263 @@ class SubgraphTracer(fx.Tracer):
return arg
return self.lift_tracked_freevar_to_input(arg)
# See NOTE: [Auto lift basic free symbols when create_graph_input] for overall design
# You MUST call this API every time when creating a proxy in wrap_fx_proxy for a call
# that produced unbacked symints or tensors with unbacked symint shapes.
# This function is used to track the unbacked symints with its proxies created during
# dynamo tracing so that subgraph knows how to bind a symbol input with parent's proxy.
# LazyProxy are created for tensor shapes that're unbacked so that we don't create proxies
# for symbols that're not going to be used.
def track_unbacked_symbols(
self, example_value, e_proxy: Union[LazyProxy, torch.fx.Proxy]
):
# When binding the symbols in an exmaple_value, we bind the symbols
# to the proxy's associatied Tracer instead of current tracer.
# This is because:
# 1. We may be calling wrap_tensors during speculate_subgraph because
# the variables are lazily realized. The proxy are top-level phs but
# current tracer is a subtracer.
# 2. For autograd.Function, we trace the backward graph with a new tracer
# whose parent is the forward tracer, but we're using all the proxies created
# in forward tracer to trace the backward.
# For example, forward calls save_for_backward for a input tensor t.
# Backward calls t.tolist(). In this case, all the proxies that backward tracer
# sees are from parent tracer (i.e. the forward tracer). (e.g. t[0].item())
# See test_validate_outputs_unbacked for repro on 2.
tracer = e_proxy.tracer
assert isinstance(tracer, SubgraphTracer)
def need_bind(s) -> bool:
from torch.fx.experimental.symbolic_shapes import is_symbolic
return (
is_symbolic(s)
and isinstance(s.node.expr, sympy.Symbol)
and s.node.shape_env.is_unbacked_symint(s.node.expr)
and s.node.expr not in self.bound_symbols
)
def _proxy_with_example_value(example_value, *args, **kwargs):
proxy = tracer.create_proxy(*args, **kwargs)
set_example_value(proxy.node, example_value)
return proxy
if isinstance(example_value, torch.Tensor):
for i, s in enumerate(example_value.size()):
if need_bind(s):
log.debug(
"_track_unbacked_symbols %s for %s.size()[%s] at debug_level %s",
s,
e_proxy,
i,
tracer.debug_level,
)
lazy_proxy = LazyProxy(
tracer,
_proxy_with_example_value,
s,
"call_function",
torch.ops.aten.sym_size.int,
(e_proxy, i),
{},
type_expr=type(s),
)
self.track_unbacked_symbols(s, lazy_proxy)
if example_value.layout is torch.strided:
for i, s in enumerate(example_value.stride()):
if need_bind(s):
log.debug(
"_track_unbacked_symbols %s for %s.stride()[%s] at debug_level %s",
s,
e_proxy,
i,
tracer.debug_level,
)
lazy_proxy = LazyProxy(
tracer,
_proxy_with_example_value,
s,
"call_function",
torch.ops.aten.sym_stride.int,
(e_proxy, i),
{},
type_expr=type(s),
)
self.track_unbacked_symbols(s, lazy_proxy)
elif example_value.layout is torch.sparse_coo:
self.track_unbacked_symbols(example_value._indices(), e_proxy)
self.track_unbacked_symbols(example_value._values(), e_proxy)
elif example_value.layout in {torch.sparse_csr, torch.sparse_bsr}:
self.track_unbacked_symbols(example_value.crow_indices(), e_proxy)
self.track_unbacked_symbols(example_value.col_indices(), e_proxy)
elif example_value.layout in {torch.sparse_csc, torch.sparse_bsc}:
self.track_unbacked_symbols(example_value.ccol_indices(), e_proxy)
self.track_unbacked_symbols(example_value.row_indices(), e_proxy)
if is_traceable_wrapper_subclass(example_value):
attrs, ctx = example_value.__tensor_flatten__()
for attr in attrs:
inner_t = getattr(example_value, attr)
self.track_unbacked_symbols(inner_t, getattr(e_proxy, attr))
elif isinstance(example_value, torch.SymInt):
# Only bind unbacked symbols. backed symbols are lifted as inputs.
if need_bind(example_value):
expr = example_value.node.expr
tracer.bound_symbols[expr] = e_proxy
# See Note [Auto lift basic free symbols when create_graph_input]
def _lift_basic_symbols(
self, example_value: Union[torch.SymInt, torch.Tensor], src: Optional[Source]
):
# The before arg is for inserting symints in the sizes/strides of a tensor
# before the tensor. This odering ensures that when we look at the tensor's
# symbols, they're already lifted/tracked. E.g. this assumption is used
# in insert_deferred_runtime_asserts.
def _lift_symbols_in_symint(
s: Union[int, torch.SymInt],
source: Optional[Source],
before: bool = False,
) -> None:
if not is_symbolic(s):
return
assert isinstance(s, torch.SymInt)
self_to_be_bound = self.lookup_unbound_symbols(s)
if len(self_to_be_bound) == 0:
return
# For subgraph
if self.parent is not None:
# Recursively lift symbols in symint until top-level.
self.parent._lift_basic_symbols(s, source)
for s0 in self_to_be_bound:
parent_proxy = self.parent.bound_symbols[s0]
example_val = parent_proxy.node.meta["example_value"]
assert isinstance(example_val, torch.SymInt)
ph = self.create_graph_input(
str(s0),
type(example_val),
example_val,
before=before,
source=source,
)
log.debug(
"_lift_symbols_in_symint %s from %s at debug_level %s",
s0,
source.name() if source is not None else "subgraph inputs",
self.debug_level,
)
self.lifted_freevars[parent_proxy] = ph
# For root_tracer:
else:
assert len(self_to_be_bound) == 1, (
f"For root tracer, we only expect to bind basic symbols (compound symbols "
f"should be cached before) but got unbound symbols {self_to_be_bound} in {s}"
)
assert source is not None, (
f"Source of '{s}' is None when lifting it to input of top-level. If it's an unbacked symbol, "
"this could be because it's not tracked with lazy_bind_unbacked_symbols. "
f"Otherwise, should provide a source when create_graph_input for `{s}` at root tracer."
)
s0 = next(iter(self_to_be_bound))
ph = self.create_graph_input(
str(s0),
type(s),
s,
before=before,
source=source,
)
log.debug(
"_lift_symbols_in_symint %s from %s at debug_level %s",
s,
source.name() if source is not None else "subgraph inputs",
self.debug_level,
)
ph.node.meta["grapharg"] = GraphArg(
source,
s,
pass_arg_as_tensor=False,
fake_tensor=None,
is_tensor=False,
)
if isinstance(example_value, torch.Tensor):
for i, s in enumerate(example_value.size()):
_lift_symbols_in_symint(
s,
(
TensorPropertySource(src, TensorProperty.SIZE, i)
if src is not None
else None
),
before=True,
)
if example_value.layout is torch.strided:
for i, s in enumerate(example_value.stride()):
_lift_symbols_in_symint(
s,
(
TensorPropertySource(src, TensorProperty.STRIDE, i)
if src is not None
else None
),
before=True,
)
_lift_symbols_in_symint(
example_value.storage_offset(),
(
TensorPropertySource(src, TensorProperty.STORAGE_OFFSET)
if src is not None
else None
),
before=True,
)
elif example_value.layout is torch.sparse_coo:
self._lift_basic_symbols(example_value._indices(), src)
self._lift_basic_symbols(example_value._values(), src)
elif example_value.layout in {torch.sparse_csr, torch.sparse_bsr}:
self._lift_basic_symbols(example_value.crow_indices(), src)
self._lift_basic_symbols(example_value.col_indices(), src)
elif example_value.layout in {torch.sparse_csc, torch.sparse_bsc}:
self._lift_basic_symbols(example_value.ccol_indices(), src)
self._lift_basic_symbols(example_value.row_indices(), src)
if is_traceable_wrapper_subclass(example_value):
attrs, ctx = example_value.__tensor_flatten__()
for attr in attrs:
inner_t = getattr(example_value, attr)
self._lift_basic_symbols(
inner_t, AttrSource(src, attr) if src is not None else None
)
elif isinstance(example_value, torch.SymInt):
_lift_symbols_in_symint(
example_value,
src,
)
# Lookup the proxy in current tracer for each symbol in expressions of s,
# See Note [Auto lift basic free symbols when create_graph_input]
def lookup_unbound_symbols(self, s: torch.SymInt) -> List[sympy.Symbol]:
free_symbols = s.node.expr.free_symbols
if len(free_symbols) == 0:
return []
to_be_bound = []
for s0 in free_symbols:
if s0 not in self.bound_symbols:
to_be_bound.append(s0)
continue
proxy = self.bound_symbols[s0]
if isinstance(proxy, LazyProxy):
proxy = proxy()
self.bound_symbols[s0] = proxy
assert (
isinstance(proxy, torch.fx.Proxy) and proxy.tracer is self
), f"The proxy of symbol {s0} doesn't belong to current tracer."
# Sort the symbols so that we can have a deterministic lifting order
return sorted(to_be_bound, key=lambda s: s.name)
# NOTE: [HigherOrderOperator tracing design]
# Ignoring HigherOrderOperators for a moment,

View File

@ -382,6 +382,11 @@ def rand_strided(
_T = TypeVar("_T")
def check_dynamic_shape_capture() -> bool:
# This also mirrors config from `test/dynamo/test_dynamic_shapes.py:make_dynamic_cls`
return not config.assume_static_by_default
def _make_fn_with_patches(fn: Callable[..., _T], *patches: Any) -> Callable[..., _T]:
@functools.wraps(fn)
def _fn(*args: Any, **kwargs: Any) -> _T:

View File

@ -1702,7 +1702,6 @@ class VariableBuilder:
grapharg = GraphArg(source, value, False, fake_tensor_value)
tensor_proxy.node.meta["grapharg"] = grapharg
self.tx.output.add_symbol_bindings(grapharg)
return tensor_variable
def wrap_numpy_ndarray(self, value):
@ -2281,6 +2280,11 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
# tensor, the stored example value will update too!)
example_value = _clone_input(example_value, tx.fake_mode)
set_example_value(proxy.node, example_value)
# We bind the unbacked symints in sizes/trdies of tensor lazily.
# So that subgraphs can access the unbacked symbol's proxy in parent graph
# when lifting unbacked symbols of input tensors to subgraph inputs.
# We do it lazily because the tensor may not be used in subgraphs.
tx.output.current_tracer.track_unbacked_symbols(example_value, proxy)
specialized_props = target_cls.specialize(example_value)
# TODO: not sure about this fake mode test
if (
@ -2368,6 +2372,7 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
elif example_value is None or proxy.node.target is torch.manual_seed:
return ConstantVariable.create(None, **options)
elif isinstance(example_value, (torch.SymInt, torch.SymFloat, torch.SymBool)):
tx.output.current_tracer.track_unbacked_symbols(example_value, proxy)
set_example_value(proxy.node, example_value)
return SymNodeVariable(proxy, example_value, **options)
elif (

View File

@ -8,7 +8,7 @@ import itertools
import logging
import types
import warnings
from typing import Dict, List, Optional, TYPE_CHECKING
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
import torch._C
import torch.fx
@ -539,6 +539,70 @@ def speculate_subgraph(
graph.lint()
lifted_freevars = subtracer.lifted_freevars
# NOTE: [HigherOrderOperator subgraph input ordering]
# The input ordering of the higher order ops is determined by the order of
# the creatation of the placehoder.
# Mannually created inputs are created in validate_args_and_maybe_create_graph_inputs before
# speculating subgraph.
# During subgraph speculation, we may lift closured tensors and free symbols as inputs,
# their ordering is determined by the time they are lifted: earlier lifted ones precede later
# lifted ones.
#
# Suppose the placeholders are
# O1, O2, X1, O3, O4, X2, X3, O5 where Xs are lifted phs
# The following code re-order the placeholders to
# O1, O2, O3, O4, O5, X1, X2, X3
def move_lifted_freevars_phs_to_end(
graph: torch.fx.Graph, lifted_freevars: Tuple[torch.fx.Node]
):
lifted_ph_set = {
child_p.node for child_p in lifted_freevars.values()
}
prev_phs = [n for n in graph.nodes if n.op == "placeholder"]
# No need to reorder when graph doesn't have args or doesn't
# have lifted freevars or all inputs are lifted freevars.
if (
len(prev_phs) == 0
or len(lifted_ph_set) == 0
or len(prev_phs) == len(lifted_ph_set)
):
return
# Step 1: find first X1
for x1 in prev_phs:
if x1 in lifted_ph_set:
break
assert x1 is not None and x1.op == "placeholder"
# Step 2: starting from the X1, skip Xs and prepend Os before X1.
cand_x = x1.next
while cand_x is not None and cand_x.op == "placeholder":
if cand_x in lifted_ph_set:
cand_x = cand_x.next
else:
nxt = cand_x.next
cand_x._remove_from_list()
x1.prepend(cand_x)
cand_x = nxt
# Step 3: assert that all placeholders are in the correct order as .
# in lifted_freevars
after_phs = [
node for node in graph.nodes if node.op == "placeholder"
][-len(lifted_freevars) :]
assert len(after_phs) == len(lifted_freevars)
for child_proxy, ph in zip(lifted_freevars.values(), after_phs):
assert (
child_proxy.node is ph
), "The order of placeholders is different from the order of lifted_freevars"
graph.lint()
if len(lifted_freevars) > 0:
move_lifted_freevars_phs_to_end(graph, lifted_freevars)
return (
(output, treespec),
graph,
@ -716,7 +780,7 @@ class CondHigherOrderVariable(TorchHigherOrderOperatorVariable):
f"{operands.python_type()}",
)
operands_seq = operands.unpack_var_sequence(tx)
if not only_consist_of(operands, (TensorVariable,)):
if not only_consist_of(operands, (TensorVariable, ConstantVariable)):
unimplemented(
"Expect operands to be a tuple of pytrees that only consists of tensor leaves."
)

View File

@ -28,6 +28,7 @@ from torch._higher_order_ops.utils import (
saved_tensors_and_symints,
unique_graph_id,
UnsupportedAliasMutationException,
validate_subgraph_args_types,
)
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
@ -58,6 +59,7 @@ class CondOp(HigherOrderOperator):
super().__init__("cond")
def __call__(self, pred, true_fn, false_fn, operands):
validate_subgraph_args_types(operands)
return super().__call__(pred, true_fn, false_fn, operands)
@ -246,9 +248,6 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
assert isinstance(
operands, (list, tuple)
), f"Cond operands must be a list or tuple of tensors and SymInts {operands}"
assert all(
isinstance(o, (torch.Tensor, torch.SymInt)) for o in operands
), f"Cond operands must be a list of tensors and SymInts {operands}"
true_graph = reenter_make_fx(true_fn)(*operands)
false_graph = reenter_make_fx(false_fn)(*operands)
@ -366,6 +365,9 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
@cond_op.py_impl(DispatchKey.CompositeExplicitAutograd)
def cond_op_dense(pred, true_fn, false_fn, operands):
assert all(
isinstance(o, (torch.Tensor, int)) for o in operands
), f"Dense implementation operands must be a list of tensors and ints {operands}"
mode = _get_current_dispatch_mode()
assert mode is None, "Mode should never be enabled for CPU/CUDA key"
if pred:

View File

@ -12,6 +12,7 @@ from torch._higher_order_ops.utils import (
save_tensors_and_symints_for_backward,
saved_tensors_and_symints,
UnsupportedAliasMutationException,
validate_subgraph_args_types,
)
from torch._ops import HigherOrderOperator
from torch._subclasses import FakeTensorMode
@ -85,11 +86,7 @@ class FlexAttentionHOP(HigherOrderOperator):
score_mod_other_buffers: Tuple = (),
mask_mod_other_buffers: Tuple = (),
) -> Tuple[torch.Tensor, torch.Tensor]:
if not all(
isinstance(buf, (torch.Tensor, torch.SymInt, int))
for buf in score_mod_other_buffers + mask_mod_other_buffers
):
raise RuntimeError("Other buffers must be tensors.")
validate_subgraph_args_types(score_mod_other_buffers + mask_mod_other_buffers)
return super().__call__(
query,
key,
@ -129,11 +126,7 @@ class FlexAttentionBackwardHOP(HigherOrderOperator):
) -> Tuple[
torch.Tensor, torch.Tensor, torch.Tensor, Tuple[Optional[torch.Tensor], ...]
]:
if not all(
isinstance(buf, torch.Tensor)
for buf in score_mod_other_buffers + mask_mod_other_buffers
):
raise RuntimeError("Other buffers must be tensors.")
validate_subgraph_args_types(score_mod_other_buffers + mask_mod_other_buffers)
return super().__call__(
query,
key,
@ -415,10 +408,6 @@ def flex_attention_functionalize(
assert isinstance(block_mask_unwrapped, tuple)
assert isinstance(score_mod_other_buffers_unwrapped, tuple)
assert isinstance(mask_mod_other_buffers_unwrapped, tuple)
assert all(
isinstance(item, (torch.Tensor, torch.SymInt, int))
for item in score_mod_other_buffers_unwrapped + mask_mod_other_buffers_unwrapped
)
example_vals = (
[torch.zeros((), dtype=query.dtype)]
@ -531,12 +520,8 @@ def create_fw_bw_graph(
unwrapped_other_buffers = pytree.tree_map(_from_fun, other_buffers)
assert all(
isinstance(t, (FakeTensor, torch.SymInt, int))
for t in unwrapped_score_mod_indexes
)
assert all(
isinstance(t, (FakeTensor, torch.SymInt, int))
for t in unwrapped_other_buffers
isinstance(t, (FakeTensor, int, torch.SymInt))
for t in unwrapped_score_mod_indexes + unwrapped_other_buffers
)
example_flat_out = pytree.tree_map(
@ -595,7 +580,9 @@ class FlexAttentionAutogradOp(torch.autograd.Function):
*score_mod_other_buffers: Tuple[Any, ...],
) -> Tuple[torch.Tensor, torch.Tensor]:
any_buffer_requires_grad = any(
buffer.requires_grad for buffer in mask_mod_other_buffers
buffer.requires_grad
for buffer in mask_mod_other_buffers
if isinstance(buffer, torch.Tensor)
)
assert (
not any_buffer_requires_grad
@ -777,9 +764,16 @@ def sdpa_dense_backward(
actual_grad_query = torch.empty_like(query)
actual_grad_key = torch.empty_like(key)
actual_grad_value = torch.empty_like(value)
def _maybe_new_buffer(
buffer: Union[torch.Tensor, torch.SymInt, int]
) -> Optional[Union[torch.Tensor, torch.SymInt, int]]:
if isinstance(buffer, torch.Tensor):
return torch.empty_like(buffer) if buffer.requires_grad else None
return buffer
actual_grad_score_mod_captured = [
torch.empty_like(buffer) if buffer.requires_grad else None
for buffer in score_mod_other_buffers
_maybe_new_buffer(buffer) for buffer in score_mod_other_buffers
]
Bq, Bkv = query.size(0), key.size(0)
@ -883,7 +877,7 @@ def sdpa_dense_backward(
actual_grad_key.copy_(grad_key)
actual_grad_value.copy_(grad_value)
score_mod_other_buffer_grads = [
actual_grad.copy_(grad) if actual_grad is not None else actual_grad
actual_grad.copy_(grad) if isinstance(actual_grad, torch.Tensor) else None
for actual_grad, grad in zip(
actual_grad_score_mod_captured, grad_score_mod_captured
)
@ -1072,10 +1066,6 @@ def flex_attention_backward_functionalize(
assert isinstance(block_mask_unwrapped, tuple)
assert isinstance(score_mod_other_buffers_unwrapped, tuple)
assert isinstance(mask_mod_other_buffers_unwrapped, tuple)
assert all(
isinstance(item, torch.Tensor)
for item in score_mod_other_buffers_unwrapped + mask_mod_other_buffers_unwrapped
)
with ctx.redispatch_to_next() as m:
functional_fw_graph = ctx.functionalize(fw_graph)

View File

@ -16,6 +16,7 @@ from torch._higher_order_ops.utils import (
reenter_make_fx,
unique_graph_id,
UnsupportedAliasMutationException,
validate_subgraph_args_types,
)
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
@ -230,6 +231,7 @@ class ScanOp(HigherOrderOperator):
def __call__(self, combine_fn, init, xs, dim, reverse, additional_inputs):
assert isinstance(additional_inputs, list), "additional_inputs must be a list."
validate_subgraph_args_types(additional_inputs)
return super().__call__(combine_fn, init, xs, dim, reverse, additional_inputs)
@ -335,10 +337,15 @@ def trace_scan(
reverse: bool,
additional_inputs: List[torch.Tensor],
):
from torch._dynamo.utils import clone_input
with disable_proxy_modes_tracing():
sample_inits = [x_init.clone() for x_init in init]
sample_inits = [clone_input(x_init) for x_init in init]
sample_inputs = [first_slice_copy(x, dim) for x in xs]
sample_additional_inputs = [x.clone() for x in additional_inputs]
sample_additional_inputs = [
clone_input(x) if isinstance(x, torch.Tensor) else x
for x in additional_inputs
]
combine_graph = reenter_make_fx(combine_fn)(
*sample_inits, *sample_inputs, *sample_additional_inputs
)

View File

@ -2,7 +2,7 @@
import functools
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Callable, List
from typing import Any, Callable, List, Tuple, Union
import torch
import torch.fx.traceback as fx_traceback
@ -481,3 +481,19 @@ def get_dummy_aot_autograd_config():
# Slices off the first element of a given dimension
def first_slice_copy(t: torch.Tensor, dim: int = 0) -> torch.Tensor:
return torch.select_copy(t, dim, 0)
# Note [lifted arg types in hop]
# For dynamoed hops, we automatically lift the free symbols in tensors as arguments.
# This has implications for the types of lifted args for different dispatch keys:
# 1. functionalization, FakeTensorMode, ProxyTorchDispatchMode, Autograd need to support torch.Symint
# lifted args because it's on the path of torch.compile(dynamic=True).
# 2. functionalization, FakeTensorMode, ProxyTorchDispatchMode, Autograd, CompositeExplicitAutograd need
# to support int arguments. In the eager run case, we re-trace the subgraph in AutogradKey, so inner
# hops may receive int inputs from the shape of outer tensor inputs.
# However, CompositeExplicitAutograd won't receive SymInt inputs because it only accepts real tensor inputs.
def validate_subgraph_args_types(lifted_args: Union[Tuple[Any], List[Any]]):
allowed_types = (torch.Tensor, int, torch.SymInt)
assert all(
isinstance(arg, (torch.Tensor, int, torch.SymInt)) for arg in lifted_args
), f"{lifted_args} can only be of {allowed_types} but got {tuple(type(arg) for arg in lifted_args)}"

View File

@ -12,6 +12,7 @@ from torch._higher_order_ops.utils import (
autograd_not_implemented,
reenter_make_fx,
UnsupportedAliasMutationException,
validate_subgraph_args_types,
)
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
@ -42,6 +43,7 @@ class WhileLoopOp(HigherOrderOperator):
raise RuntimeError(
f"additional_inputs must be a tuple, got {type(additional_inputs)}"
)
if not all(
isinstance(t, (torch.Tensor, int, float, bool)) for t in carried_inputs
):
@ -50,13 +52,7 @@ class WhileLoopOp(HigherOrderOperator):
f"{carried_inputs}"
)
if not all(
isinstance(t, (torch.Tensor, int, float, bool)) for t in additional_inputs
):
raise RuntimeError(
"additional_inputs must be a tuple of tensors, ints, floats, or bools, got "
f"{additional_inputs}"
)
validate_subgraph_args_types(additional_inputs)
return super().__call__(cond_fn, body_fn, carried_inputs, additional_inputs)

View File

@ -76,7 +76,14 @@ def create_placeholder(
def maybe_realize(args: List[Optional[IRNode]]):
"""Accepts a list of optional IRNodes and returns a list of realized IRNodes"""
return tree_map(lambda x: realize_inputs(x) if x is not None else None, args)
return tree_map(
lambda x: (
realize_inputs(x)
if x is not None and not isinstance(x, sympy.Symbol)
else x
),
args,
)
def get_float32_precision():