mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[hop free symbols] lift free symbols in example_value when create_graph_input (#138363)
There are 4 parts (they are hard to further break into smaller ones cause they're highly coupled) in this PR: 1. **Whenever we call create_graph_input, we try to bind the symbols in the graph input.** We've enforced the invariant that all create_graph_inputs calls must provide an example value, we could intercept at the create_graph_input calls (This PR only handles free symbols in tensors). 2. **We cache the bound_symbols** to avoid lift the same symbol repeated. 3. For lifted symbols, we re-used **lifted_freevars** i.e. the mapping between symbol proxy in parent graph to the lifted phs in current subgraph, which we handle lifted tensors. In this way, all hops that supports lifted tensors should be able to handle lifted_symints automatically (at least in dynamo part). 4. For **unbacked symbols** created during tracing, we need to also bound these symbols to its proxy. This is to support the tests cases where we want to lift unbacked symbols as input. We need the proxy of the unbacked symbol in parent graph in order to properly create the args to the hop. 5. We change all the tests after free symbols are lifted in subgraphs. And also supports the lifted symbols in existing higher order ops. **The interaction of nested tracers:** The previous design for lifting tensor closures is that: suppose we're in nested tracers, whenever we see a new proxy that's not created by create tracer, we recursively look for the proxy in parent tracer until we find the tracer that creates this proxy (either a placeholder or some intermediate results). More detail is in Note [Nested SubgraphTracer and free_variable handling]. Given the above design, the plan for lifting the free symbols is: whenever we lift a free tensor to be the inputs of current subgraph, we'll look at the symbols in it and bind the symbols at the same time. For example, suppose we have the following function: ```python def f(x: [s1, s2]): def true_f(): def true_f_inner(): return x.sin() ``` what will happen in time order: 1. we create a subtracer 1 and start to speculate the outer cond's true_f 2. we create a another subtracer 2 and start to speculate the inner cond's true_f_inner. 3. dynamo realize the tensor input x by calling wrap_tensor in top-level to create graph input x (tracer 0), we bind the symbol s1, s2 after ph for x is created. So the graph now looks like: ```python def gm(s1, s2, x): ``` 4. when seeing TensorVariable.call_method of x, tracer2 wants to create a call_function(sin, proxy_of_x), but it finds that proxy_of_x is not created by current tracer. So it recursively look up its parent tracer1 and find parent tracer1 also doesn't track this proxy_of_x then it finds the root tracer0, who is the creator of it and tracks it as a ph. Then tracer 1 create_graph_input to lift the closure to its input ph1 and add (proxy_of_x: ph1) k-v in **lifted_freevars** of tracer 1. Now the graph looks like: ```python def gm(s1, s2, x): def true_gm(x): ``` 5. Since there are free symbols inside this new tensor input, tracer 1 also binds the symbols (maybe_bind_symbol), which calls create_graph_input for s1 and s2. Now the graph looks like ```python def gm(s1, s2, x): def true_gm(s1, s2, x): ``` 6. then it goes back to tracer 2, and call create_graph_input for x and get ph2, tracer 2's **lifted_freevars** records (ph1, ph2). and tracer 2 also binds the symbols in this new tensor input. Now the graph looks like: ```python def gm(s1, s2, x): def true_gm(s1, s2, x): def true_gm_inner(s1, s2, x): ``` 7. Finally the sin call_function node is created by tracer 2. **This PR also handles the following cases:** - What if we lift two tensors share the same symbol? e.g. x1 [s1, s2], x2 [s2, s3]? Each subtracer maintains bound_symbols as a cache that maps a symbol.expr to its proxy in current tracer. So when we see x1, we'll track s1 and s2 as inputs and bound s1 to ph1, s2 to ph2. So when we try to bind symbols of x2, s2 will already be tracked so no graph input is created. - what if a subgraph close over a symint? e.g. ```python def f(x): def true_f(): c = x.size(0) def true_fn_inner(): return c ``` When we speculate true_fn_inner, we find proxy_of_c is not tracked by tracer 2, so it recursively looks up its parent. At this point, x and its symbols have been lifted as input of true_f (as a result of lifting x during tracing true_f in tracer 1. Specifically the graph looks like: ```python def gm(s1, s2, x): def true_gm(s1, s2, x): def true_gm_inner(): ``` So tracer 2 is able to find that s1 have been tracked as ph in tracer 1 so it returns back to gm and call create_graph_input on s1. The graph now looks like: ```python def gm(s1, s2, x): def true_gm(s1, s2, x): def true_gm_inner(s1): return s1 ``` - What if subgraph close over an unbacked symint? e.g. ```python def f(x): def true_f(): c = x.item() def true_f_inner(): return c ``` When x.item() is called, proxy_of_c and its symnode variable is created for tracer 1, and we also call track_unbacked_symbols to record this relationship. So when tracer 2 finds proxy_of_c is not created by current tracer, it recursivelly looks up its parent tracer and finds that that expression u0 has been tracked as a result of track_unbacked_symbol in tracer 1. So it will stop the recursion and create_graph_input u0 in tracer 2. Graph looks like: ```python def f(x): def true_f(s1, s2, x): c = x.item() def true_gm_inner(u0): return u0 cond(pred, true_gm_inner, false_gm_inner, (c,)) ``` - what if subgraph close over a tensor with unbacked symint shape? ```python def f(x): def true_f(): c = x.item() r = torch.randn((c,)) def true_f_inner(): return r + 1 ``` This is the same as the case of closing over tensors with backed shapes. where we first lift r, then bind u0 in it, which recursively bind_symint of u0 in its parent and found u0 is tracked in parent tracer as a result of .item() call. Pull Request resolved: https://github.com/pytorch/pytorch/pull/138363 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
3368f3ad41
commit
ab42967238
@ -16,6 +16,7 @@ import torch.utils._pytree as pytree
|
||||
import torch.utils.checkpoint
|
||||
from torch._dynamo.backends.common import aot_autograd
|
||||
from torch._dynamo.testing import (
|
||||
check_dynamic_shape_capture,
|
||||
CompileCounter,
|
||||
CompileCounterWithBackend,
|
||||
EagerAndRecordGraphs,
|
||||
@ -37,11 +38,6 @@ from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_
|
||||
requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
|
||||
|
||||
|
||||
def check_dynamic_shape_capture():
|
||||
# This also mirrors config from `test/dynamo/test_dynamic_shapes.py:make_dynamic_cls`
|
||||
return not config.assume_static_by_default
|
||||
|
||||
|
||||
def count_ops(gm, args, freq, op):
|
||||
actual = [node.target for node in gm.graph.nodes].count(op)
|
||||
assert actual == freq, f"expected={freq}, actual={actual}"
|
||||
@ -213,7 +209,8 @@ class HigherOrderOpTests(torch._dynamo.test_case.TestCase):
|
||||
return wrap(lambda x: torch.sin(x), x)
|
||||
|
||||
x = torch.randn(3)
|
||||
self._test_wrap_simple(f, default_args_generator((x,)), 2)
|
||||
arg_count = ifdynstaticdefault(2, 3)
|
||||
self._test_wrap_simple(f, default_args_generator((x,)), arg_count)
|
||||
|
||||
def test_enum_arg(self):
|
||||
class SomeEnum(enum.Enum):
|
||||
@ -229,7 +226,8 @@ class HigherOrderOpTests(torch._dynamo.test_case.TestCase):
|
||||
return wrap(g, x, val)
|
||||
|
||||
x = torch.randn(3)
|
||||
self._test_wrap_simple(f, default_args_generator((x, SomeEnum.A)), 2)
|
||||
arg_count = ifdynstaticdefault(2, 3)
|
||||
self._test_wrap_simple(f, default_args_generator((x, SomeEnum.A)), arg_count)
|
||||
|
||||
def test_return_captured_var(self):
|
||||
freevar = torch.randn(3)
|
||||
@ -244,7 +242,10 @@ class HigherOrderOpTests(torch._dynamo.test_case.TestCase):
|
||||
|
||||
# Since, `x` is unused, we don't lift it to
|
||||
# be the input.
|
||||
self._test_wrap_simple(fn, default_args_generator((x,)), 2)
|
||||
|
||||
# when testing with dynamic shape, symbols are lifted as input
|
||||
arg_count = ifdynstaticdefault(2, 3)
|
||||
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count)
|
||||
|
||||
def test_return_captured_vars(self):
|
||||
freevar1 = torch.randn(3)
|
||||
@ -260,7 +261,9 @@ class HigherOrderOpTests(torch._dynamo.test_case.TestCase):
|
||||
|
||||
# Since, `x` is unused, we don't lift it to
|
||||
# be the input.
|
||||
self._test_wrap_simple(fn, default_args_generator((x,)), 3, 4)
|
||||
# when testing with dynamic shape, a symbol is lifted as input
|
||||
arg_count = ifdynstaticdefault(3, 4)
|
||||
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count, 4)
|
||||
|
||||
def test_return_captured_var_used_multiple_times(self):
|
||||
freevar = torch.randn(3)
|
||||
@ -273,14 +276,18 @@ class HigherOrderOpTests(torch._dynamo.test_case.TestCase):
|
||||
return wrap(test, x)
|
||||
|
||||
x = torch.randn(3)
|
||||
self._test_wrap_simple(fn, default_args_generator((x,)), 3, 3)
|
||||
# when testing with dynamic shape, a symbol is lifted as input
|
||||
arg_count = ifdynstaticdefault(3, 4)
|
||||
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count, 3)
|
||||
|
||||
def test_capture_untracked_global(self):
|
||||
def f(x):
|
||||
return wrap(lambda x: x + global_var, x)
|
||||
|
||||
x = torch.randn(3)
|
||||
self._test_wrap_simple(f, default_args_generator((x,)), 3)
|
||||
# when testing with dynamic shape, a symbol is lifted as input
|
||||
arg_count = ifdynstaticdefault(3, 4)
|
||||
self._test_wrap_simple(f, default_args_generator((x,)), arg_count)
|
||||
|
||||
def test_symint_input(self):
|
||||
def f(x):
|
||||
@ -386,13 +393,13 @@ class GraphModule(torch.nn.Module):
|
||||
l_x_ = L_x_
|
||||
|
||||
wrap_body_0 = self.wrap_body_0
|
||||
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_, s0); wrap_body_0 = l_x_ = s0 = None
|
||||
wrap = torch.ops.higher_order.wrap(wrap_body_0, s0, l_x_); wrap_body_0 = s0 = l_x_ = None
|
||||
getitem: "f32[s0]" = wrap[0]; wrap = None
|
||||
return (getitem,)
|
||||
|
||||
class wrap_body_0(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[s0, 1]", size: "Sym(s0)"):
|
||||
view: "f32[s0]" = l_x_.view(size); l_x_ = size = None
|
||||
def forward(self, s0: "Sym(s0)", l_x_: "f32[s0, 1]"):
|
||||
view: "f32[s0]" = l_x_.view(s0); l_x_ = s0 = None
|
||||
add: "f32[s0]" = view + 0.5; view = None
|
||||
return (add,)
|
||||
""",
|
||||
@ -418,7 +425,8 @@ class GraphModule(torch.nn.Module):
|
||||
y2 = t[0] + 0.2
|
||||
yield (x2, y2, (x2, y2))
|
||||
|
||||
self._test_wrap_simple(f, my_args_generator((x, y, (x, y))), 3)
|
||||
arg_count = ifdynstaticdefault(3, 4)
|
||||
self._test_wrap_simple(f, my_args_generator((x, y, (x, y))), arg_count)
|
||||
|
||||
def test_wrap_pytree_args_not_const_symint_tensor(self):
|
||||
class MyClass:
|
||||
@ -488,7 +496,9 @@ class GraphModule(torch.nn.Module):
|
||||
def g(x):
|
||||
return wrap(lambda x: x + y, x)
|
||||
|
||||
self._test_wrap_simple(g, default_args_generator((x,)), 3)
|
||||
# when testing with dynamic shape, a symbol is lifted as input
|
||||
arg_count = ifdynstaticdefault(3, 4)
|
||||
self._test_wrap_simple(g, default_args_generator((x,)), arg_count)
|
||||
return g(x)
|
||||
|
||||
f(x, y)
|
||||
@ -500,7 +510,9 @@ class GraphModule(torch.nn.Module):
|
||||
def f(x, y):
|
||||
return wrap(lambda x: x + y, x)
|
||||
|
||||
self._test_wrap_simple(f, default_args_generator((x, y)), 3)
|
||||
# when testing with dynamic shape, a symbol is lifted as input
|
||||
arg_count = ifdynstaticdefault(3, 4)
|
||||
self._test_wrap_simple(f, default_args_generator((x, y)), arg_count)
|
||||
|
||||
def test_capture_tracked_nested(self):
|
||||
x = torch.randn(3, 3)
|
||||
@ -509,7 +521,9 @@ class GraphModule(torch.nn.Module):
|
||||
def f(x, y):
|
||||
return wrap(lambda x: wrap(lambda x: x + y, x), x)
|
||||
|
||||
self._test_wrap_simple(f, default_args_generator((x, y)), 3)
|
||||
# when testing with dynamic shape, a symbol is lifted as input
|
||||
arg_count = ifdynstaticdefault(3, 4)
|
||||
self._test_wrap_simple(f, default_args_generator((x, y)), arg_count)
|
||||
|
||||
def test_inlined_functions(self):
|
||||
def g(x, y):
|
||||
@ -520,7 +534,9 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
x = torch.randn(3, 3)
|
||||
y = torch.randn(3, 3)
|
||||
self._test_wrap_simple(f, default_args_generator((x, y)), 3)
|
||||
# when testing with dynamic shape, a symbol is lifted as input
|
||||
arg_count = ifdynstaticdefault(3, 4)
|
||||
self._test_wrap_simple(f, default_args_generator((x, y)), arg_count)
|
||||
|
||||
def test_same_freevar_twice(self):
|
||||
free = torch.randn(3)
|
||||
@ -537,7 +553,518 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
# Since, `x` is unused, we don't lift it to
|
||||
# be the input.
|
||||
self._test_wrap_simple(f, default_args_generator((x,)), 2, 3)
|
||||
# when testing with dynamic shape, a symbol is lifted as input
|
||||
arg_count = ifdynstaticdefault(2, 3)
|
||||
self._test_wrap_simple(f, default_args_generator((x,)), arg_count, 3)
|
||||
|
||||
@torch._dynamo.config.patch(
|
||||
capture_scalar_outputs=True,
|
||||
)
|
||||
def test_unbacked_symbol_closure(self):
|
||||
def f(x):
|
||||
c = x.sum().item()
|
||||
|
||||
def g(x):
|
||||
def k(x):
|
||||
return x + c
|
||||
|
||||
return wrap(k, x)
|
||||
|
||||
return wrap(g, x)
|
||||
|
||||
x = torch.randn(3)
|
||||
arg_count = ifdynstaticdefault(3, 4)
|
||||
out_graph = self._test_wrap_simple(
|
||||
f, default_args_generator((x,)), arg_count, 4, return_graph=True
|
||||
)
|
||||
|
||||
if check_dynamic_shape_capture():
|
||||
self.assertExpectedInline(
|
||||
out_graph,
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, s0: "Sym(s0)", L_x_: "f32[s0]"):
|
||||
l_x_ = L_x_
|
||||
|
||||
sum_1: "f32[]" = l_x_.sum()
|
||||
item: "Sym(zuf0)" = sum_1.item(); sum_1 = None
|
||||
|
||||
wrap_body_1 = self.wrap_body_1
|
||||
wrap = torch.ops.higher_order.wrap(wrap_body_1, s0, l_x_, item); wrap_body_1 = s0 = l_x_ = item = None
|
||||
getitem: "f32[s0]" = wrap[0]; wrap = None
|
||||
return (getitem,)
|
||||
|
||||
class wrap_body_1(torch.nn.Module):
|
||||
def forward(self, s0: "Sym(s0)", l_x_: "f32[s0]", item: "Sym(zuf0)"):
|
||||
wrap_body_0 = self.wrap_body_0
|
||||
wrap = torch.ops.higher_order.wrap(wrap_body_0, s0, l_x_, item); wrap_body_0 = s0 = l_x_ = item = None
|
||||
getitem: "f32[s0]" = wrap[0]; wrap = None
|
||||
return (getitem,)
|
||||
|
||||
class wrap_body_0(torch.nn.Module):
|
||||
def forward(self, s0: "Sym(s0)", l_x_: "f32[s0]", item: "Sym(zuf0)"):
|
||||
add: "f32[s0]" = l_x_ + item; l_x_ = item = None
|
||||
return (add,)
|
||||
""",
|
||||
)
|
||||
else:
|
||||
self.assertExpectedInline(
|
||||
out_graph,
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_x_: "f32[3]"):
|
||||
l_x_ = L_x_
|
||||
|
||||
sum_1: "f32[]" = l_x_.sum()
|
||||
item: "Sym(zuf0)" = sum_1.item(); sum_1 = None
|
||||
|
||||
wrap_body_1 = self.wrap_body_1
|
||||
wrap = torch.ops.higher_order.wrap(wrap_body_1, l_x_, item); wrap_body_1 = l_x_ = item = None
|
||||
getitem: "f32[3]" = wrap[0]; wrap = None
|
||||
return (getitem,)
|
||||
|
||||
class wrap_body_1(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[3]", item: "Sym(zuf0)"):
|
||||
wrap_body_0 = self.wrap_body_0
|
||||
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_, item); wrap_body_0 = l_x_ = item = None
|
||||
getitem: "f32[3]" = wrap[0]; wrap = None
|
||||
return (getitem,)
|
||||
|
||||
class wrap_body_0(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[3]", item: "Sym(zuf0)"):
|
||||
add: "f32[3]" = l_x_ + item; l_x_ = item = None
|
||||
return (add,)
|
||||
""",
|
||||
)
|
||||
|
||||
@torch._dynamo.config.patch(
|
||||
capture_dynamic_output_shape_ops=True,
|
||||
)
|
||||
def test_tensor_with_unbacked_shape_closure(self):
|
||||
def f(x):
|
||||
c = x.nonzero()
|
||||
|
||||
def g(x):
|
||||
def k(x):
|
||||
return x.sin(), c.sin()
|
||||
|
||||
return wrap(k, x)
|
||||
|
||||
return wrap(g, x)
|
||||
|
||||
x = torch.randn(3)
|
||||
arg_count = ifdynstaticdefault(4, 5)
|
||||
# when compiled with dynamic, we don't have upper bound runtime assertions for u0
|
||||
expected_op_count = ifdynstaticdefault(10, 8)
|
||||
out_graph = self._test_wrap_simple(
|
||||
f,
|
||||
default_args_generator((x,)),
|
||||
arg_count,
|
||||
expected_op_count,
|
||||
return_graph=True,
|
||||
)
|
||||
|
||||
if check_dynamic_shape_capture():
|
||||
self.assertExpectedInline(
|
||||
out_graph,
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, s0: "Sym(s0)", L_x_: "f32[s0]"):
|
||||
l_x_ = L_x_
|
||||
|
||||
c: "i64[u0, 1]" = l_x_.nonzero()
|
||||
|
||||
sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(c, 0)
|
||||
_check_is_size = torch._check_is_size(sym_size_int_1); _check_is_size = None
|
||||
|
||||
ge: "Sym(u0 >= 0)" = sym_size_int_1 >= 0
|
||||
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
|
||||
|
||||
wrap_body_1 = self.wrap_body_1
|
||||
wrap = torch.ops.higher_order.wrap(wrap_body_1, s0, l_x_, sym_size_int_1, c); wrap_body_1 = s0 = l_x_ = sym_size_int_1 = c = None
|
||||
getitem: "f32[s0]" = wrap[0]
|
||||
getitem_1: "f32[u0, 1]" = wrap[1]; wrap = None
|
||||
return (getitem, getitem_1)
|
||||
|
||||
class wrap_body_1(torch.nn.Module):
|
||||
def forward(self, s0: "Sym(s0)", l_x_: "f32[s0]", u0: "Sym(u0)", c: "i64[u0, 1]"):
|
||||
wrap_body_0 = self.wrap_body_0
|
||||
wrap = torch.ops.higher_order.wrap(wrap_body_0, s0, l_x_, u0, c); wrap_body_0 = s0 = l_x_ = u0 = c = None
|
||||
child: "f32[s0]" = wrap[0]
|
||||
child_1: "f32[u0, 1]" = wrap[1]; wrap = None
|
||||
return (child, child_1)
|
||||
|
||||
class wrap_body_0(torch.nn.Module):
|
||||
def forward(self, s0: "Sym(s0)", l_x_: "f32[s0]", u0: "Sym(u0)", c: "i64[u0, 1]"):
|
||||
child: "f32[s0]" = l_x_.sin(); l_x_ = None
|
||||
child_1: "f32[u0, 1]" = c.sin(); c = None
|
||||
return (child, child_1)
|
||||
""",
|
||||
)
|
||||
else:
|
||||
self.assertExpectedInline(
|
||||
out_graph,
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_x_: "f32[3]"):
|
||||
l_x_ = L_x_
|
||||
|
||||
c: "i64[u0, 1]" = l_x_.nonzero()
|
||||
|
||||
sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(c, 0)
|
||||
_check_is_size = torch._check_is_size(sym_size_int_1); _check_is_size = None
|
||||
|
||||
ge: "Sym(u0 >= 0)" = sym_size_int_1 >= 0
|
||||
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
|
||||
le: "Sym(u0 <= 3)" = sym_size_int_1 <= 3
|
||||
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u0 <= 3 on node 'le'"); le = _assert_scalar_default_1 = None
|
||||
|
||||
wrap_body_1 = self.wrap_body_1
|
||||
wrap = torch.ops.higher_order.wrap(wrap_body_1, l_x_, sym_size_int_1, c); wrap_body_1 = l_x_ = sym_size_int_1 = c = None
|
||||
getitem: "f32[3]" = wrap[0]
|
||||
getitem_1: "f32[u0, 1]" = wrap[1]; wrap = None
|
||||
return (getitem, getitem_1)
|
||||
|
||||
class wrap_body_1(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[3]", u0: "Sym(u0)", c: "i64[u0, 1]"):
|
||||
wrap_body_0 = self.wrap_body_0
|
||||
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_, u0, c); wrap_body_0 = l_x_ = u0 = c = None
|
||||
child: "f32[3]" = wrap[0]
|
||||
child_1: "f32[u0, 1]" = wrap[1]; wrap = None
|
||||
return (child, child_1)
|
||||
|
||||
class wrap_body_0(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[3]", u0: "Sym(u0)", c: "i64[u0, 1]"):
|
||||
child: "f32[3]" = l_x_.sin(); l_x_ = None
|
||||
child_1: "f32[u0, 1]" = c.sin(); c = None
|
||||
return (child, child_1)
|
||||
""",
|
||||
)
|
||||
|
||||
@torch._dynamo.config.patch(
|
||||
capture_dynamic_output_shape_ops=True,
|
||||
)
|
||||
def test_tensor_to_list_closure(self):
|
||||
def f(x):
|
||||
li = x.tolist()
|
||||
|
||||
def g(x):
|
||||
def k(x):
|
||||
return li[0] + x
|
||||
|
||||
return wrap(k, x)
|
||||
|
||||
return wrap(g, x)
|
||||
|
||||
x = torch.tensor([1, 2, 3], dtype=torch.int16)
|
||||
arg_count = ifdynstaticdefault(3, 3)
|
||||
out_graph = self._test_wrap_simple(f, ((x,),), arg_count, 4, return_graph=True)
|
||||
|
||||
# tolist will specialize on input shapes, so dynamic and static tests
|
||||
# have the same graph
|
||||
self.assertExpectedInline(
|
||||
out_graph,
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_x_: "i16[3]"):
|
||||
l_x_ = L_x_
|
||||
|
||||
getitem = l_x_[0]
|
||||
item: "Sym(u0)" = getitem.item(); getitem = None
|
||||
|
||||
wrap_body_1 = self.wrap_body_1
|
||||
wrap = torch.ops.higher_order.wrap(wrap_body_1, item, l_x_); wrap_body_1 = item = l_x_ = None
|
||||
getitem_3: "i16[3]" = wrap[0]; wrap = None
|
||||
return (getitem_3,)
|
||||
|
||||
class wrap_body_1(torch.nn.Module):
|
||||
def forward(self, item: "Sym(u0)", l_x_: "i16[3]"):
|
||||
wrap_body_0 = self.wrap_body_0
|
||||
wrap = torch.ops.higher_order.wrap(wrap_body_0, item, l_x_); wrap_body_0 = item = l_x_ = None
|
||||
getitem: "i16[3]" = wrap[0]; wrap = None
|
||||
return (getitem,)
|
||||
|
||||
class wrap_body_0(torch.nn.Module):
|
||||
def forward(self, item: "Sym(u0)", l_x_: "i16[3]"):
|
||||
add: "i16[3]" = item + l_x_; item = l_x_ = None
|
||||
return (add,)
|
||||
""",
|
||||
)
|
||||
|
||||
@torch._dynamo.config.patch(
|
||||
capture_dynamic_output_shape_ops=True,
|
||||
)
|
||||
def test_tensor_and_unbacked_symbol_closure(self):
|
||||
def f(x):
|
||||
c = x.nonzero()
|
||||
sz = c.size(0)
|
||||
|
||||
def g(x):
|
||||
def k(x):
|
||||
return x.sin() + sz, c.sin()
|
||||
|
||||
return wrap(k, x)
|
||||
|
||||
return wrap(g, x)
|
||||
|
||||
x = torch.randn(3)
|
||||
arg_count = ifdynstaticdefault(4, 5)
|
||||
# when compiled with dynamic, we don't have upper bound runtime assertions for u0
|
||||
expected_op_count = ifdynstaticdefault(10, 8)
|
||||
out_graph = self._test_wrap_simple(
|
||||
f,
|
||||
default_args_generator((x,)),
|
||||
arg_count,
|
||||
expected_op_count,
|
||||
return_graph=True,
|
||||
)
|
||||
|
||||
# Note that u0 is accessed from sz and the shape of c
|
||||
# We cached via the symbol u0 and de-duplicate them.
|
||||
if not check_dynamic_shape_capture():
|
||||
self.assertExpectedInline(
|
||||
out_graph,
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_x_: "f32[3]"):
|
||||
l_x_ = L_x_
|
||||
|
||||
c: "i64[u0, 1]" = l_x_.nonzero()
|
||||
|
||||
sym_size_int: "Sym(u0)" = torch.ops.aten.sym_size.int(c, 0)
|
||||
_check_is_size = torch._check_is_size(sym_size_int); _check_is_size = None
|
||||
|
||||
ge: "Sym(u0 >= 0)" = sym_size_int >= 0
|
||||
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
|
||||
le: "Sym(u0 <= 3)" = sym_size_int <= 3
|
||||
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u0 <= 3 on node 'le'"); le = _assert_scalar_default_1 = None
|
||||
|
||||
wrap_body_1 = self.wrap_body_1
|
||||
wrap = torch.ops.higher_order.wrap(wrap_body_1, l_x_, sym_size_int, c); wrap_body_1 = l_x_ = sym_size_int = c = None
|
||||
getitem: "f32[3]" = wrap[0]
|
||||
getitem_1: "f32[u0, 1]" = wrap[1]; wrap = None
|
||||
return (getitem, getitem_1)
|
||||
|
||||
class wrap_body_1(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[3]", size: "Sym(u0)", c: "i64[u0, 1]"):
|
||||
wrap_body_0 = self.wrap_body_0
|
||||
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_, size, c); wrap_body_0 = l_x_ = size = c = None
|
||||
child: "f32[3]" = wrap[0]
|
||||
child_1: "f32[u0, 1]" = wrap[1]; wrap = None
|
||||
return (child, child_1)
|
||||
|
||||
class wrap_body_0(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[3]", size: "Sym(u0)", c: "i64[u0, 1]"):
|
||||
sin: "f32[3]" = l_x_.sin(); l_x_ = None
|
||||
child: "f32[3]" = sin + size; sin = size = None
|
||||
child_1: "f32[u0, 1]" = c.sin(); c = None
|
||||
return (child, child_1)
|
||||
""",
|
||||
)
|
||||
|
||||
@torch._dynamo.config.patch(
|
||||
capture_dynamic_output_shape_ops=True,
|
||||
)
|
||||
def test_concat_unbacked_shape_tensor(self):
|
||||
def f(x, y):
|
||||
c = x.nonzero()
|
||||
d = y.nonzero()
|
||||
cat = torch.cat((c, d))
|
||||
|
||||
def g(x):
|
||||
def k(x):
|
||||
return cat.sum() + x
|
||||
|
||||
return wrap(k, x)
|
||||
|
||||
return wrap(g, x)
|
||||
|
||||
x = torch.randn(3)
|
||||
y = torch.randn(3)
|
||||
arg_count = ifdynstaticdefault(5, 6)
|
||||
# when compiled with dynamic, we don't have upper bound runtime assertions for u0 and u1
|
||||
expected_op_count = ifdynstaticdefault(17, 13)
|
||||
out_graph = self._test_wrap_simple(
|
||||
f,
|
||||
default_args_generator((x, y)),
|
||||
arg_count,
|
||||
expected_op_count,
|
||||
return_graph=True,
|
||||
)
|
||||
|
||||
if not check_dynamic_shape_capture():
|
||||
self.assertExpectedInline(
|
||||
out_graph,
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_x_: "f32[3]", L_y_: "f32[3]"):
|
||||
l_x_ = L_x_
|
||||
l_y_ = L_y_
|
||||
|
||||
c: "i64[u0, 1]" = l_x_.nonzero()
|
||||
|
||||
sym_size_int_2: "Sym(u0)" = torch.ops.aten.sym_size.int(c, 0)
|
||||
_check_is_size = torch._check_is_size(sym_size_int_2); _check_is_size = None
|
||||
|
||||
ge: "Sym(u0 >= 0)" = sym_size_int_2 >= 0
|
||||
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
|
||||
le: "Sym(u0 <= 3)" = sym_size_int_2 <= 3
|
||||
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u0 <= 3 on node 'le'"); le = _assert_scalar_default_1 = None
|
||||
|
||||
d: "i64[u1, 1]" = l_y_.nonzero(); l_y_ = None
|
||||
|
||||
sym_size_int_3: "Sym(u1)" = torch.ops.aten.sym_size.int(d, 0)
|
||||
_check_is_size_1 = torch._check_is_size(sym_size_int_3); _check_is_size_1 = None
|
||||
|
||||
ge_1: "Sym(u1 >= 0)" = sym_size_int_3 >= 0
|
||||
_assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_default_2 = None
|
||||
le_1: "Sym(u1 <= 3)" = sym_size_int_3 <= 3
|
||||
_assert_scalar_default_3 = torch.ops.aten._assert_scalar.default(le_1, "Runtime assertion failed for expression u1 <= 3 on node 'le_1'"); le_1 = _assert_scalar_default_3 = None
|
||||
|
||||
cat: "i64[u0 + u1, 1]" = torch.cat((c, d)); c = d = None
|
||||
|
||||
wrap_body_1 = self.wrap_body_1
|
||||
wrap = torch.ops.higher_order.wrap(wrap_body_1, sym_size_int_2, sym_size_int_3, cat, l_x_); wrap_body_1 = sym_size_int_2 = sym_size_int_3 = cat = l_x_ = None
|
||||
getitem: "f32[3]" = wrap[0]; wrap = None
|
||||
return (getitem,)
|
||||
|
||||
class wrap_body_1(torch.nn.Module):
|
||||
def forward(self, u0: "Sym(u0)", u1: "Sym(u1)", cat: "i64[u0 + u1, 1]", l_x_: "f32[3]"):
|
||||
wrap_body_0 = self.wrap_body_0
|
||||
wrap = torch.ops.higher_order.wrap(wrap_body_0, u0, u1, cat, l_x_); wrap_body_0 = u0 = u1 = cat = l_x_ = None
|
||||
getitem: "f32[3]" = wrap[0]; wrap = None
|
||||
return (getitem,)
|
||||
|
||||
class wrap_body_0(torch.nn.Module):
|
||||
def forward(self, u0: "Sym(u0)", u1: "Sym(u1)", cat: "i64[u0 + u1, 1]", l_x_: "f32[3]"):
|
||||
sum_1: "i64[]" = cat.sum(); cat = None
|
||||
add: "f32[3]" = sum_1 + l_x_; sum_1 = l_x_ = None
|
||||
return (add,)
|
||||
""",
|
||||
)
|
||||
|
||||
@torch._dynamo.config.patch(
|
||||
assume_static_by_default=False,
|
||||
dynamic_shapes=True,
|
||||
)
|
||||
def test_lift_tensors_with_shared_symbols(self):
|
||||
def f(x, y):
|
||||
def g(x):
|
||||
def k(x):
|
||||
return x @ y
|
||||
|
||||
return wrap(k, x)
|
||||
|
||||
return wrap(g, x)
|
||||
|
||||
x = torch.randn(2, 3)
|
||||
y = torch.randn(3, 4)
|
||||
|
||||
out_graph = self._test_wrap_simple(
|
||||
f,
|
||||
default_args_generator((x, y)),
|
||||
6,
|
||||
2,
|
||||
return_graph=True,
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
out_graph,
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", L_x_: "f32[s0, s1]", s2: "Sym(s2)", L_y_: "f32[s1, s2]"):
|
||||
l_x_ = L_x_
|
||||
l_y_ = L_y_
|
||||
|
||||
wrap_body_1 = self.wrap_body_1
|
||||
wrap = torch.ops.higher_order.wrap(wrap_body_1, s0, s1, l_x_, s2, l_y_); wrap_body_1 = s0 = s1 = l_x_ = s2 = l_y_ = None
|
||||
getitem: "f32[s0, s2]" = wrap[0]; wrap = None
|
||||
return (getitem,)
|
||||
|
||||
class wrap_body_1(torch.nn.Module):
|
||||
def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", l_x_: "f32[s0, s1]", s2: "Sym(s2)", l_y_: "f32[s1, s2]"):
|
||||
wrap_body_0 = self.wrap_body_0
|
||||
wrap = torch.ops.higher_order.wrap(wrap_body_0, s0, s1, l_x_, s2, l_y_); wrap_body_0 = s0 = s1 = l_x_ = s2 = l_y_ = None
|
||||
getitem: "f32[s0, s2]" = wrap[0]; wrap = None
|
||||
return (getitem,)
|
||||
|
||||
class wrap_body_0(torch.nn.Module):
|
||||
def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", l_x_: "f32[s0, s1]", s2: "Sym(s2)", l_y_: "f32[s1, s2]"):
|
||||
matmul: "f32[s0, s2]" = l_x_ @ l_y_; l_x_ = l_y_ = None
|
||||
return (matmul,)
|
||||
""",
|
||||
)
|
||||
|
||||
@torch._dynamo.config.patch(
|
||||
assume_static_by_default=False,
|
||||
dynamic_shapes=True,
|
||||
capture_dynamic_output_shape_ops=True,
|
||||
)
|
||||
def test_lift_tensors_with_compound_expressions(self):
|
||||
def f(x, y):
|
||||
x = x.view(-1, 2)
|
||||
c = y.nonzero()
|
||||
d = torch.concat((x, c))
|
||||
|
||||
def g(x):
|
||||
def k(x):
|
||||
return d.sum() + x
|
||||
|
||||
return wrap(k, x)
|
||||
|
||||
return wrap(g, x)
|
||||
|
||||
x = torch.randn(2, 3)
|
||||
y = torch.randn(3, 4)
|
||||
|
||||
f(x, y)
|
||||
|
||||
if not check_dynamic_shape_capture():
|
||||
out_graph = self._test_wrap_simple(
|
||||
f,
|
||||
default_args_generator((x, y)),
|
||||
6,
|
||||
9,
|
||||
return_graph=True,
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
out_graph,
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", L_x_: "f32[s0, s1]", s2: "Sym(s2)", L_y_: "f32[s1, s2]"):
|
||||
l_x_ = L_x_
|
||||
l_y_ = L_y_
|
||||
|
||||
x: "f32[((s0*s1)//2), ((s0*s1)//(((s0*s1)//2)))]" = l_x_.view(-1, 2); l_x_ = None
|
||||
|
||||
c: "i64[u0, 2]" = l_y_.nonzero(); l_y_ = None
|
||||
|
||||
sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(c, 0)
|
||||
_check_is_size = torch._check_is_size(sym_size_int_1); _check_is_size = None
|
||||
|
||||
ge: "Sym(u0 >= 0)" = sym_size_int_1 >= 0
|
||||
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
|
||||
|
||||
d: "f32[u0 + ((s0*s1)//2), ((s0*s1)//(((s0*s1)//2)))]" = torch.concat((x, c)); c = None
|
||||
|
||||
wrap_body_1 = self.wrap_body_1
|
||||
wrap = torch.ops.higher_order.wrap(wrap_body_1, sym_size_int_1, s1, s0, d, x); wrap_body_1 = sym_size_int_1 = s1 = s0 = d = x = None
|
||||
getitem: "f32[((s0*s1)//2), ((s0*s1)//(((s0*s1)//2)))]" = wrap[0]; wrap = None
|
||||
return (getitem,)
|
||||
|
||||
class wrap_body_1(torch.nn.Module):
|
||||
def forward(self, u0: "Sym(u0)", s1: "Sym(s1)", s0: "Sym(s0)", d: "f32[u0 + ((s0*s1)//2), ((s0*s1)//(((s0*s1)//2)))]", x: "f32[((s0*s1)//2), ((s0*s1)//(((s0*s1)//2)))]"):
|
||||
wrap_body_0 = self.wrap_body_0
|
||||
wrap = torch.ops.higher_order.wrap(wrap_body_0, u0, s1, s0, d, x); wrap_body_0 = u0 = s1 = s0 = d = x = None
|
||||
getitem: "f32[((s0*s1)//2), ((s0*s1)//(((s0*s1)//2)))]" = wrap[0]; wrap = None
|
||||
return (getitem,)
|
||||
|
||||
class wrap_body_0(torch.nn.Module):
|
||||
def forward(self, u0: "Sym(u0)", s1: "Sym(s1)", s0: "Sym(s0)", d: "f32[u0 + ((s0*s1)//2), ((s0*s1)//(((s0*s1)//2)))]", x: "f32[((s0*s1)//2), ((s0*s1)//(((s0*s1)//2)))]"):
|
||||
sum_1: "f32[]" = d.sum(); d = None
|
||||
add: "f32[((s0*s1)//2), ((s0*s1)//(((s0*s1)//2)))]" = sum_1 + x; sum_1 = x = None
|
||||
return (add,)
|
||||
""",
|
||||
)
|
||||
|
||||
def test_register_subclass(self):
|
||||
from torch._higher_order_ops.cond import cond_op
|
||||
@ -1054,7 +1581,8 @@ class GraphModule(torch.nn.Module):
|
||||
return wrap(f, x)
|
||||
|
||||
x = torch.randn(3, 3)
|
||||
self._test_wrap_simple(g, default_args_generator((x,)), 2)
|
||||
arg_count = ifdynstaticdefault(2, 3)
|
||||
self._test_wrap_simple(g, default_args_generator((x,)), arg_count)
|
||||
|
||||
def test_wrap_kwarg(self):
|
||||
def f(x, y):
|
||||
@ -1062,7 +1590,8 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
x = torch.randn(3)
|
||||
y = torch.randn(3, 3)
|
||||
self._test_wrap_simple(f, default_args_generator((x, y)), 3)
|
||||
arg_count = ifdynstaticdefault(3, 4)
|
||||
self._test_wrap_simple(f, default_args_generator((x, y)), arg_count)
|
||||
|
||||
def test_wrap_kwarg_int(self):
|
||||
def f(x, y):
|
||||
@ -1071,9 +1600,12 @@ class GraphModule(torch.nn.Module):
|
||||
x = torch.randn(3)
|
||||
y = 8
|
||||
|
||||
self._test_wrap_simple(
|
||||
f, default_args_generator((x, y)), ifdynstaticdefault(2, 3)
|
||||
arg_count = (
|
||||
ifdynstaticdefault(2, 3) + 1
|
||||
if check_dynamic_shape_capture()
|
||||
else ifdynstaticdefault(2, 3)
|
||||
)
|
||||
self._test_wrap_simple(f, default_args_generator((x, y)), arg_count)
|
||||
|
||||
def test_wrap_all_kwarg(self):
|
||||
def f(y, x):
|
||||
@ -1082,7 +1614,8 @@ class GraphModule(torch.nn.Module):
|
||||
x = torch.randn(3)
|
||||
y = torch.randn(3, 3)
|
||||
|
||||
self._test_wrap_simple(f, default_args_generator((x, y)), 3)
|
||||
arg_count = ifdynstaticdefault(3, 4)
|
||||
self._test_wrap_simple(f, default_args_generator((x, y)), arg_count)
|
||||
|
||||
def test_wrap_kwarg_only(self):
|
||||
def f(x, y):
|
||||
@ -1094,7 +1627,8 @@ class GraphModule(torch.nn.Module):
|
||||
x = torch.randn(3)
|
||||
y = torch.randn(3, 3)
|
||||
|
||||
self._test_wrap_simple(f, default_args_generator((x, y)), 3)
|
||||
arg_count = ifdynstaticdefault(3, 4)
|
||||
self._test_wrap_simple(f, default_args_generator((x, y)), arg_count)
|
||||
|
||||
def test_wrap_kwarg_default(self):
|
||||
def f(x, y):
|
||||
@ -1106,7 +1640,8 @@ class GraphModule(torch.nn.Module):
|
||||
x = torch.randn(3)
|
||||
y = torch.randn(3, 3)
|
||||
|
||||
self._test_wrap_simple(f, default_args_generator((x, y)), 3)
|
||||
arg_count = ifdynstaticdefault(3, 4)
|
||||
self._test_wrap_simple(f, default_args_generator((x, y)), arg_count)
|
||||
|
||||
def test_wrap_kwarg_default_if_branch(self):
|
||||
def f(x, y):
|
||||
@ -1121,7 +1656,8 @@ class GraphModule(torch.nn.Module):
|
||||
x = torch.randn(3)
|
||||
y = torch.randn(3, 3)
|
||||
|
||||
self._test_wrap_simple(f, default_args_generator((x, y)), 3)
|
||||
arg_count = ifdynstaticdefault(3, 4)
|
||||
self._test_wrap_simple(f, default_args_generator((x, y)), arg_count)
|
||||
|
||||
def test_wrap_kwarg_recompile(self):
|
||||
def f(x, y, z=None):
|
||||
@ -1162,7 +1698,8 @@ class GraphModule(torch.nn.Module):
|
||||
x = torch.randn(3)
|
||||
y = torch.randn(3, 3)
|
||||
|
||||
self._test_wrap_simple(f, default_args_generator((x, y, 8)), 2)
|
||||
arg_count = ifdynstaticdefault(2, 3)
|
||||
self._test_wrap_simple(f, default_args_generator((x, y, 8)), arg_count)
|
||||
|
||||
def test_map_subgraph_name_is_valid(self):
|
||||
backend = EagerAndRecordGraphs()
|
||||
@ -1522,8 +2059,8 @@ def forward(self, child : torch.Tensor, const_unused : int):
|
||||
and node.target == torch.ops.higher_order.cond
|
||||
):
|
||||
_, _, _, operands = node.args
|
||||
# Each branch takes 3 inputs (buffer, x, z)
|
||||
self.assertEqual(len(operands), 3)
|
||||
# Since we compile wit dynamic, each branch takes 4 inputs (buffer, x, z, s1)
|
||||
self.assertEqual(len(operands), 4)
|
||||
if node.op == "get_attr":
|
||||
if str(node.target) in ("cond_true_0, cond_false_0"):
|
||||
num_placeholders = len(
|
||||
@ -1535,7 +2072,7 @@ def forward(self, child : torch.Tensor, const_unused : int):
|
||||
if node.op == "placeholder"
|
||||
]
|
||||
)
|
||||
self.assertEqual(num_placeholders, 3)
|
||||
self.assertEqual(num_placeholders, 4)
|
||||
|
||||
def _check_cond_graph_and_extract(self, fn, args):
|
||||
backend = EagerAndRecordGraphs()
|
||||
@ -1826,10 +2363,11 @@ def forward(self):
|
||||
yield [x], [x.sin()]
|
||||
yield (x,), (x.sin(),)
|
||||
|
||||
arg_count = ifdynstaticdefault(3, 4)
|
||||
actual_graph = self._test_wrap_simple(
|
||||
f,
|
||||
my_args_generator(),
|
||||
3,
|
||||
arg_count,
|
||||
3,
|
||||
return_graph=True,
|
||||
)
|
||||
@ -2000,7 +2538,10 @@ class GraphModule(torch.nn.Module):
|
||||
return wrap(lambda x: [torch.sin(x), torch.cos(x)], x)
|
||||
|
||||
x = torch.randn(3)
|
||||
self._test_wrap_simple(f, default_args_generator((x,)), 2, expected_opcount=3)
|
||||
arg_count = ifdynstaticdefault(2, 3)
|
||||
self._test_wrap_simple(
|
||||
f, default_args_generator((x,)), arg_count, expected_opcount=3
|
||||
)
|
||||
|
||||
def test_fallback_on_python_primitives_output(self):
|
||||
counters.clear()
|
||||
@ -2028,8 +2569,9 @@ class GraphModule(torch.nn.Module):
|
||||
x = torch.randn(2, 3)
|
||||
|
||||
counters.clear()
|
||||
arg_count = ifdynstaticdefault(2, 4)
|
||||
graph = self._test_wrap_simple(
|
||||
f, default_args_generator((x,)), 2, 4, return_graph=True
|
||||
f, default_args_generator((x,)), arg_count, 4, return_graph=True
|
||||
)
|
||||
self.assertEqual(len(counters["graph_break"]), 0)
|
||||
|
||||
@ -2066,8 +2608,10 @@ class GraphModule(torch.nn.Module):
|
||||
x = torch.randn(3)
|
||||
|
||||
counters.clear()
|
||||
|
||||
arg_count = ifdynstaticdefault(2, 3)
|
||||
graph = self._test_wrap_simple(
|
||||
f, default_args_generator((x,)), 2, 2, return_graph=True
|
||||
f, default_args_generator((x,)), arg_count, 2, return_graph=True
|
||||
)
|
||||
self.assertEqual(len(counters["graph_break"]), 0)
|
||||
|
||||
@ -2137,7 +2681,8 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
x = torch.randn(3, 3)
|
||||
y = torch.randn(3, 3)
|
||||
self._test_wrap_simple(h, default_args_generator((x, y)), 3)
|
||||
arg_count = ifdynstaticdefault(3, 4)
|
||||
self._test_wrap_simple(h, default_args_generator((x, y)), arg_count)
|
||||
|
||||
def test_internal_nonlocal(self):
|
||||
def f(x, y):
|
||||
@ -2162,7 +2707,8 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
x = torch.randn(3, 3)
|
||||
y = torch.randn(3, 3)
|
||||
self._test_wrap_simple(h, default_args_generator((x, y)), 3)
|
||||
arg_count = ifdynstaticdefault(3, 4)
|
||||
self._test_wrap_simple(h, default_args_generator((x, y)), arg_count)
|
||||
|
||||
def test_capture_numpy_number(self):
|
||||
import numpy as np
|
||||
@ -2174,7 +2720,8 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
x = torch.randn(3)
|
||||
# np.number are lifted to graph inputs
|
||||
self._test_wrap_simple(f, default_args_generator((x,)), 3)
|
||||
arg_count = ifdynstaticdefault(3, 4)
|
||||
self._test_wrap_simple(f, default_args_generator((x,)), arg_count)
|
||||
|
||||
def test_freevars_as_inputs_to_wrap(self):
|
||||
y = torch.randn(3)
|
||||
@ -2183,7 +2730,8 @@ class GraphModule(torch.nn.Module):
|
||||
return wrap(lambda x, y: x + y, x, y)
|
||||
|
||||
x = torch.randn(3)
|
||||
self._test_wrap_simple(f, default_args_generator((x,)), 3)
|
||||
arg_count = ifdynstaticdefault(3, 4)
|
||||
self._test_wrap_simple(f, default_args_generator((x,)), arg_count)
|
||||
|
||||
def test_lift_tensor_constant(self):
|
||||
def f(x):
|
||||
@ -2191,7 +2739,10 @@ class GraphModule(torch.nn.Module):
|
||||
return wrap(lambda x: x + y, x)
|
||||
|
||||
x = torch.randn(3)
|
||||
self._test_wrap_simple(f, default_args_generator((x,)), 3, expected_opcount=3)
|
||||
arg_count = ifdynstaticdefault(3, 4)
|
||||
self._test_wrap_simple(
|
||||
f, default_args_generator((x,)), arg_count, expected_opcount=3
|
||||
)
|
||||
|
||||
def test_nested_wrap(self):
|
||||
class MockModule(torch.nn.Module):
|
||||
@ -2211,14 +2762,18 @@ class GraphModule(torch.nn.Module):
|
||||
def fn(x):
|
||||
return wrap(gn, x)
|
||||
|
||||
self._test_wrap_simple(fn, default_args_generator((torch.randn(10, 10),)), 4)
|
||||
arg_count = ifdynstaticdefault(4, 5)
|
||||
self._test_wrap_simple(
|
||||
fn, default_args_generator((torch.randn(10, 10),)), arg_count
|
||||
)
|
||||
|
||||
def test_fn_with_kwargs_in_torch_ops(self):
|
||||
def fn(x):
|
||||
return wrap(lambda z: torch.cos(input=z), x)
|
||||
|
||||
x = torch.randn(3)
|
||||
self._test_wrap_simple(fn, default_args_generator((x,)), 2)
|
||||
arg_count = ifdynstaticdefault(2, 3)
|
||||
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count)
|
||||
|
||||
def test_hooks(self):
|
||||
class ToyModel(torch.nn.Module):
|
||||
|
@ -1941,10 +1941,10 @@ class GraphModule(torch.nn.Module):
|
||||
normalize_gm(fw[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "f32[s0, s1]", primals_2: "f32[s0, s1]", primals_3: "Sym(s0)", primals_4: "Sym(s1)", primals_5: "Sym(s1)", primals_6: "Sym(s0)", primals_7: "Sym(s1)"):
|
||||
mul: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(primals_1, primals_3); primals_1 = None
|
||||
mul_3: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(primals_2, primals_3); primals_2 = None
|
||||
return (mul, mul_3, primals_6, primals_7, primals_7, primals_3, primals_6, primals_7)
|
||||
def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1]", primals_4: "f32[s0, s1]", primals_5: "Sym(s0)", primals_6: "Sym(s1)", primals_7: "Sym(s1)"):
|
||||
mul: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(primals_3, primals_1); primals_3 = None
|
||||
mul_3: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(primals_4, primals_1); primals_4 = None
|
||||
return (mul, mul_3, primals_5, primals_7, primals_7, primals_1, primals_5, primals_7)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
@ -1952,10 +1952,10 @@ class GraphModule(torch.nn.Module):
|
||||
normalize_gm(bw[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_3: "Sym(s0)", primals_6: "Sym(s0)", primals_7: "Sym(s1)", tangents_1: "f32[s0, s1]", tangents_2: "f32[s0, s1]"):
|
||||
mul_8: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(tangents_1, primals_3); tangents_1 = None
|
||||
mul_9: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(tangents_2, primals_3); tangents_2 = primals_3 = None
|
||||
return (mul_8, mul_9, primals_6, primals_7, primals_7, None, None)
|
||||
def forward(self, primals_1: "Sym(s0)", primals_5: "Sym(s0)", primals_7: "Sym(s1)", tangents_1: "f32[s0, s1]", tangents_2: "f32[s0, s1]"):
|
||||
mul_8: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(tangents_1, primals_1); tangents_1 = None
|
||||
mul_9: "f32[s0, s1]" = torch.ops.aten.mul.Tensor(tangents_2, primals_1); tangents_2 = primals_1 = None
|
||||
return (None, None, mul_8, mul_9, primals_5, primals_7, primals_7)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
@ -1974,13 +1974,13 @@ class GraphModule(torch.nn.Module):
|
||||
normalize_gm(fw[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "f32[s0, s1]", primals_2: "f32[s0, s1]", primals_3: "Sym(s0)", primals_4: "Sym(s1)", primals_5: "Sym(s1)", primals_6: "Sym(s0)", primals_7: "Sym(s1)"):
|
||||
clone: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_1); primals_1 = None
|
||||
clone_1: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_2); primals_2 = None
|
||||
def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1]", primals_4: "f32[s0, s1]", primals_5: "Sym(s0)", primals_6: "Sym(s1)", primals_7: "Sym(s1)"):
|
||||
clone: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
|
||||
clone_1: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
|
||||
|
||||
view: "f32[s1, s0]" = torch.ops.aten.view.default(clone, [primals_4, primals_3]); clone = None
|
||||
view_1: "f32[s1, s0]" = torch.ops.aten.view.default(clone_1, [primals_4, primals_3]); clone_1 = primals_3 = None
|
||||
return (view, view_1, primals_4, primals_6, primals_6, primals_6, primals_7)
|
||||
view: "f32[s1, s0]" = torch.ops.aten.view.default(clone, [primals_2, primals_1]); clone = None
|
||||
view_1: "f32[s1, s0]" = torch.ops.aten.view.default(clone_1, [primals_2, primals_1]); clone_1 = primals_1 = None
|
||||
return (view, view_1, primals_2, primals_5, primals_5, primals_5, primals_7)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
@ -1988,10 +1988,10 @@ class GraphModule(torch.nn.Module):
|
||||
normalize_gm(bw[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_6: "Sym(s0)", primals_7: "Sym(s1)", tangents_1: "f32[s1, s0]", tangents_2: "f32[s1, s0]"):
|
||||
view_2: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_1, [primals_6, primals_7]); tangents_1 = None
|
||||
view_3: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_2, [primals_6, primals_7]); tangents_2 = None
|
||||
return (view_2, view_3, primals_6, primals_7, primals_7, None, None)
|
||||
def forward(self, primals_5: "Sym(s0)", primals_7: "Sym(s1)", tangents_1: "f32[s1, s0]", tangents_2: "f32[s1, s0]"):
|
||||
view_2: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None
|
||||
view_3: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None
|
||||
return (None, None, view_2, view_3, primals_5, primals_7, primals_7)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
@ -2057,13 +2057,13 @@ class GraphModule(torch.nn.Module):
|
||||
normalize_gm(fw[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "f32[s0, s1]", primals_2: "f32[s0, s1]", primals_3: "Sym(s0)", primals_4: "Sym(s1)", primals_5: "Sym(s1)", primals_6: "Sym(s0)", primals_7: "Sym(s1)"):
|
||||
clone: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_1); primals_1 = None
|
||||
clone_1: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_2); primals_2 = None
|
||||
def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1]", primals_4: "f32[s0, s1]", primals_5: "Sym(s0)", primals_6: "Sym(s1)", primals_7: "Sym(s1)"):
|
||||
clone: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
|
||||
clone_1: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
|
||||
|
||||
view: "f32[s0, s1]" = torch.ops.aten.view.default(clone, [primals_3, primals_4]); clone = None
|
||||
view_1: "f32[s0, s1]" = torch.ops.aten.view.default(clone_1, [primals_3, primals_4]); clone_1 = primals_3 = primals_4 = None
|
||||
return (view, view_1, primals_6, primals_7, primals_7, primals_6, primals_7)
|
||||
view: "f32[s0, s1]" = torch.ops.aten.view.default(clone, [primals_1, primals_2]); clone = None
|
||||
view_1: "f32[s0, s1]" = torch.ops.aten.view.default(clone_1, [primals_1, primals_2]); clone_1 = primals_1 = primals_2 = None
|
||||
return (view, view_1, primals_5, primals_7, primals_7, primals_5, primals_7)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
@ -2071,10 +2071,10 @@ class GraphModule(torch.nn.Module):
|
||||
normalize_gm(bw[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_6: "Sym(s0)", primals_7: "Sym(s1)", tangents_1: "f32[s0, s1]", tangents_2: "f32[s0, s1]"):
|
||||
view_2: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_1, [primals_6, primals_7]); tangents_1 = None
|
||||
view_3: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_2, [primals_6, primals_7]); tangents_2 = None
|
||||
return (view_2, view_3, primals_6, primals_7, primals_7, None, None)
|
||||
def forward(self, primals_5: "Sym(s0)", primals_7: "Sym(s1)", tangents_1: "f32[s0, s1]", tangents_2: "f32[s0, s1]"):
|
||||
view_2: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None
|
||||
view_3: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None
|
||||
return (None, None, view_2, view_3, primals_5, primals_7, primals_7)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
@ -2093,14 +2093,14 @@ class GraphModule(torch.nn.Module):
|
||||
normalize_gm(fw[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "f32[s0, s1]", primals_2: "f32[s0, s1]", primals_3: "Sym(s0)", primals_4: "Sym(s1)", primals_5: "Sym(s1)", primals_6: "Sym(s0)", primals_7: "Sym(s1)"):
|
||||
clone: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_1); primals_1 = None
|
||||
clone_1: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_2); primals_2 = None
|
||||
def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1]", primals_4: "f32[s0, s1]", primals_5: "Sym(s0)", primals_6: "Sym(s1)", primals_7: "Sym(s1)"):
|
||||
clone: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
|
||||
clone_1: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
|
||||
|
||||
mul_6: "Sym(s0*s1)" = primals_3 * primals_4; primals_3 = primals_4 = None
|
||||
mul_6: "Sym(s0*s1)" = primals_1 * primals_2; primals_1 = primals_2 = None
|
||||
view: "f32[s0*s1]" = torch.ops.aten.view.default(clone, [mul_6]); clone = None
|
||||
view_1: "f32[s0*s1]" = torch.ops.aten.view.default(clone_1, [mul_6]); clone_1 = None
|
||||
return (view, view_1, mul_6, primals_6, primals_7)
|
||||
return (view, view_1, mul_6, primals_5, primals_7)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
@ -2108,10 +2108,10 @@ class GraphModule(torch.nn.Module):
|
||||
normalize_gm(bw[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_6: "Sym(s0)", primals_7: "Sym(s1)", tangents_1: "f32[s0*s1]", tangents_2: "f32[s0*s1]"):
|
||||
view_2: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_1, [primals_6, primals_7]); tangents_1 = None
|
||||
view_3: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_2, [primals_6, primals_7]); tangents_2 = None
|
||||
return (view_2, view_3, primals_6, primals_7, primals_7, None, None)
|
||||
def forward(self, primals_5: "Sym(s0)", primals_7: "Sym(s1)", tangents_1: "f32[s0*s1]", tangents_2: "f32[s0*s1]"):
|
||||
view_2: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None
|
||||
view_3: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None
|
||||
return (None, None, view_2, view_3, primals_5, primals_7, primals_7)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
@ -2130,14 +2130,14 @@ class GraphModule(torch.nn.Module):
|
||||
normalize_gm(fw[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "f32[s0, s1]", primals_2: "f32[s0, s1]", primals_3: "Sym(s0)", primals_4: "Sym(s1)", primals_5: "Sym(s1)", primals_6: "Sym(s0)", primals_7: "Sym(s1)"):
|
||||
clone: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_1); primals_1 = None
|
||||
clone_1: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_2); primals_2 = None
|
||||
def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1]", primals_4: "f32[s0, s1]", primals_5: "Sym(s0)", primals_6: "Sym(s1)", primals_7: "Sym(s1)"):
|
||||
clone: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
|
||||
clone_1: "f32[s0, s1]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
|
||||
|
||||
mul_6: "Sym(s0*s1)" = primals_3 * primals_4; primals_3 = primals_4 = None
|
||||
mul_6: "Sym(s0*s1)" = primals_1 * primals_2; primals_1 = primals_2 = None
|
||||
view: "f32[s0*s1]" = torch.ops.aten.view.default(clone, [mul_6])
|
||||
view_1: "f32[s0*s1]" = torch.ops.aten.view.default(clone_1, [mul_6]); clone_1 = None
|
||||
return (clone, view, view_1, mul_6, primals_6, primals_7)
|
||||
return (clone, view, view_1, mul_6, primals_5, primals_7)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
@ -2145,10 +2145,10 @@ class GraphModule(torch.nn.Module):
|
||||
normalize_gm(bw[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_6: "Sym(s0)", primals_7: "Sym(s1)", tangents_1: "f32[s0*s1]", tangents_2: "f32[s0*s1]"):
|
||||
view_2: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_1, [primals_6, primals_7]); tangents_1 = None
|
||||
view_3: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_2, [primals_6, primals_7]); tangents_2 = None
|
||||
return (view_2, view_3, primals_6, primals_7, primals_7, None, None)
|
||||
def forward(self, primals_5: "Sym(s0)", primals_7: "Sym(s1)", tangents_1: "f32[s0*s1]", tangents_2: "f32[s0*s1]"):
|
||||
view_2: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None
|
||||
view_3: "f32[s0, s1]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None
|
||||
return (None, None, view_2, view_3, primals_5, primals_7, primals_7)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
@ -2226,9 +2226,9 @@ class GraphModule(torch.nn.Module):
|
||||
normalize_gm(fw[1].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "f32[3, s0]", primals_2: "f32[3, s0]", primals_3: "Sym(s0)", primals_4: "Sym(s0)", primals_5: "Sym(s0)"):
|
||||
clone: "f32[3, s0]" = torch.ops.aten.clone.default(primals_1); primals_1 = None
|
||||
clone_1: "f32[3, s0]" = torch.ops.aten.clone.default(primals_2); primals_2 = None
|
||||
def forward(self, primals_1: "Sym(s0)", primals_2: "f32[3, s0]", primals_3: "f32[3, s0]", primals_4: "Sym(s0)", primals_5: "Sym(s0)"):
|
||||
clone: "f32[3, s0]" = torch.ops.aten.clone.default(primals_2); primals_2 = None
|
||||
clone_1: "f32[3, s0]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
|
||||
|
||||
view: "f32[3*s0]" = torch.ops.aten.view.default(clone, [-1])
|
||||
sym_numel_default: "Sym(3*s0)" = torch.ops.aten.sym_numel.default(clone)
|
||||
@ -2255,7 +2255,7 @@ class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_5: "Sym(s0)", tangents_1: "f32[3*s0]", tangents_2: "f32[3*s0]"):
|
||||
view_2: "f32[3, s0]" = torch.ops.aten.view.default(tangents_1, [3, primals_5]); tangents_1 = None
|
||||
view_3: "f32[3, s0]" = torch.ops.aten.view.default(tangents_2, [3, primals_5]); tangents_2 = None
|
||||
return (view_2, view_3, primals_5, primals_5, None)
|
||||
return (None, view_2, view_3, primals_5, primals_5)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
@ -2282,9 +2282,9 @@ class GraphModule(torch.nn.Module):
|
||||
normalize_gm(fw[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "f32[3, s0]", primals_2: "f32[3, s0]", primals_3: "Sym(s0)", primals_4: "Sym(s0)", primals_5: "Sym(s0)"):
|
||||
clone: "f32[3, s0]" = torch.ops.aten.clone.default(primals_1); primals_1 = None
|
||||
clone_1: "f32[3, s0]" = torch.ops.aten.clone.default(primals_2); primals_2 = None
|
||||
def forward(self, primals_1: "Sym(s0)", primals_2: "f32[3, s0]", primals_3: "f32[3, s0]", primals_4: "Sym(s0)", primals_5: "Sym(s0)"):
|
||||
clone: "f32[3, s0]" = torch.ops.aten.clone.default(primals_2); primals_2 = None
|
||||
clone_1: "f32[3, s0]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
|
||||
|
||||
view: "f32[3*s0]" = torch.ops.aten.view.default(clone, [-1])
|
||||
sym_numel_default: "Sym(3*s0)" = torch.ops.aten.sym_numel.default(clone)
|
||||
@ -2300,7 +2300,7 @@ class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_5: "Sym(s0)", tangents_1: "f32[3*s0]", tangents_2: "f32[3*s0]"):
|
||||
view_2: "f32[3, s0]" = torch.ops.aten.view.default(tangents_1, [3, primals_5]); tangents_1 = None
|
||||
view_3: "f32[3, s0]" = torch.ops.aten.view.default(tangents_2, [3, primals_5]); tangents_2 = None
|
||||
return (view_2, view_3, primals_5, primals_5, None)
|
||||
return (None, view_2, view_3, primals_5, primals_5)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
@ -2466,11 +2466,11 @@ class GraphModule(torch.nn.Module):
|
||||
normalize_gm(fw[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "f64[s0, s1]", primals_2: "i64[s2 + 1]", primals_3: "f32[s6, 0]", primals_4: "f32[s7, 0]", primals_5: "Sym(s2)", primals_6: "Sym(s1)", primals_7: "Sym(s1)", primals_8: "Sym(s1)", primals_9: "Sym(s2)", primals_10: "Sym(s3)"):
|
||||
clone: "f64[s0, s1]" = torch.ops.aten.clone.default(primals_1); primals_1 = None
|
||||
def forward(self, primals_1: "Sym(s2)", primals_2: "Sym(s3)", primals_3: "Sym(s1)", primals_4: "f64[s0, s1]", primals_5: "i64[s2 + 1]", primals_6: "f32[s6, 0]", primals_7: "f32[s7, 0]", primals_8: "Sym(s2)", primals_9: "Sym(s1)", primals_10: "Sym(s1)"):
|
||||
clone: "f64[s0, s1]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
|
||||
|
||||
mul: "f64[s0, s1]" = torch.ops.aten.mul.Tensor(clone, primals_9); clone = None
|
||||
return (mul, primals_2, primals_3, primals_4, primals_9, primals_8, primals_8, primals_8, primals_9)
|
||||
mul: "f64[s0, s1]" = torch.ops.aten.mul.Tensor(clone, primals_1); clone = None
|
||||
return (mul, primals_5, primals_6, primals_7, primals_8, primals_10, primals_10, primals_1, primals_8, primals_10)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
@ -2478,9 +2478,9 @@ class GraphModule(torch.nn.Module):
|
||||
normalize_gm(bw[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_8: "Sym(s1)", primals_9: "Sym(s2)", tangents_1: "f64[s0, s1]", tangents_2: "i64[s2 + 1]", tangents_3: "f32[s6, 0]", tangents_4: "f32[s7, 0]"):
|
||||
mul_1: "f64[s0, s1]" = torch.ops.aten.mul.Tensor(tangents_1, primals_9); tangents_1 = None
|
||||
return (mul_1, tangents_2, tangents_3, tangents_4, primals_9, primals_8, primals_8, None, None, None)
|
||||
def forward(self, primals_1: "Sym(s2)", primals_8: "Sym(s2)", primals_10: "Sym(s1)", tangents_1: "f64[s0, s1]", tangents_2: "i64[s2 + 1]", tangents_3: "f32[s6, 0]", tangents_4: "f32[s7, 0]"):
|
||||
mul_1: "f64[s0, s1]" = torch.ops.aten.mul.Tensor(tangents_1, primals_1); tangents_1 = primals_1 = None
|
||||
return (None, None, None, mul_1, tangents_2, tangents_3, tangents_4, primals_8, primals_10, primals_10)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
@ -2499,12 +2499,12 @@ class GraphModule(torch.nn.Module):
|
||||
normalize_gm(fw[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "f64[s0, s1]", primals_2: "i64[s2 + 1]", primals_3: "f32[s6, 0]", primals_4: "f32[s7, 0]", primals_5: "Sym(s2)", primals_6: "Sym(s1)", primals_7: "Sym(s1)", primals_8: "Sym(s1)", primals_9: "Sym(s2)", primals_10: "Sym(s3)"):
|
||||
clone: "f64[s0, s1]" = torch.ops.aten.clone.default(primals_1); primals_1 = None
|
||||
def forward(self, primals_1: "Sym(s2)", primals_2: "Sym(s3)", primals_3: "Sym(s1)", primals_4: "f64[s0, s1]", primals_5: "i64[s2 + 1]", primals_6: "f32[s6, 0]", primals_7: "f32[s7, 0]", primals_8: "Sym(s2)", primals_9: "Sym(s1)", primals_10: "Sym(s1)"):
|
||||
clone: "f64[s0, s1]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
|
||||
|
||||
cat: "f64[s0, 2*s1]" = torch.ops.aten.cat.default([clone, clone], 1); clone = None
|
||||
add_2: "Sym(2*s1)" = primals_8 + primals_8
|
||||
return (cat, primals_2, primals_3, primals_4, primals_9, add_2, add_2, primals_8, primals_9, add_2)
|
||||
add_2: "Sym(2*s1)" = primals_10 + primals_10
|
||||
return (cat, primals_5, primals_6, primals_7, primals_8, add_2, add_2, primals_8, primals_10, add_2)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
@ -2512,12 +2512,12 @@ class GraphModule(torch.nn.Module):
|
||||
normalize_gm(bw[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_8: "Sym(s1)", primals_9: "Sym(s2)", add_2: "Sym(2*s1)", tangents_1: "f64[s0, 2*s1]", tangents_2: "i64[s2 + 1]", tangents_3: "f32[s6, 0]", tangents_4: "f32[s7, 0]"):
|
||||
slice_1: "f64[s0, s1]" = torch.ops.aten.slice.Tensor(tangents_1, 1, 0, primals_8)
|
||||
slice_2: "f64[s0, s1]" = torch.ops.aten.slice.Tensor(tangents_1, 1, primals_8, add_2); tangents_1 = add_2 = None
|
||||
def forward(self, primals_8: "Sym(s2)", primals_10: "Sym(s1)", add_2: "Sym(2*s1)", tangents_1: "f64[s0, 2*s1]", tangents_2: "i64[s2 + 1]", tangents_3: "f32[s6, 0]", tangents_4: "f32[s7, 0]"):
|
||||
slice_1: "f64[s0, s1]" = torch.ops.aten.slice.Tensor(tangents_1, 1, 0, primals_10)
|
||||
slice_2: "f64[s0, s1]" = torch.ops.aten.slice.Tensor(tangents_1, 1, primals_10, add_2); tangents_1 = add_2 = None
|
||||
|
||||
add_4: "f64[s0, s1]" = torch.ops.aten.add.Tensor(slice_1, slice_2); slice_1 = slice_2 = None
|
||||
return (add_4, tangents_2, tangents_3, tangents_4, primals_9, primals_8, primals_8, None, None, None)
|
||||
return (None, None, None, add_4, tangents_2, tangents_3, tangents_4, primals_8, primals_10, primals_10)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
@ -2545,7 +2545,7 @@ class GraphModule(torch.nn.Module):
|
||||
normalize_gm(fw[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class <lambda>(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f64[9, s2]", arg1_1: "i64[s3 + 1]", arg2_1: "f32[s7, 0]", arg3_1: "f32[s8, 0]", arg4_1: "Sym(s3)", arg5_1: "Sym(s2)", arg6_1: "Sym(s2)", arg7_1: "Sym(s2)", arg8_1: "Sym(s3)", arg9_1: "Sym(s4)"):
|
||||
def forward(self, arg0_1: "Sym(s3)", arg1_1: "Sym(s4)", arg2_1: "Sym(s2)", arg3_1: "f64[9, s2]", arg4_1: "i64[s3 + 1]", arg5_1: "f32[s7, 0]", arg6_1: "f32[s8, 0]", arg7_1: "Sym(s3)", arg8_1: "Sym(s2)", arg9_1: "Sym(s2)"):
|
||||
randn: "f64[2, 5]" = torch.ops.aten.randn.default([2, 5], dtype = torch.float64, device = device(type='cpu'), pin_memory = False)
|
||||
randn_1: "f64[3, 5]" = torch.ops.aten.randn.default([3, 5], dtype = torch.float64, device = device(type='cpu'), pin_memory = False)
|
||||
randn_2: "f64[4, 5]" = torch.ops.aten.randn.default([4, 5], dtype = torch.float64, device = device(type='cpu'), pin_memory = False)
|
||||
@ -2559,7 +2559,7 @@ class <lambda>(torch.nn.Module):
|
||||
zeros_1: "f32[2, 0]" = torch.ops.aten.zeros.default([2, 0], device = device(type='cpu'), pin_memory = False)
|
||||
zeros_2: "f32[4, 0]" = torch.ops.aten.zeros.default([4, 0], device = device(type='cpu'), pin_memory = False)
|
||||
|
||||
cat_2: "f64[9, s2 + 5]" = torch.ops.aten.cat.default([cat, arg0_1], 1); cat = arg0_1 = None
|
||||
cat_2: "f64[9, s2 + 5]" = torch.ops.aten.cat.default([cat, arg3_1], 1); cat = arg3_1 = None
|
||||
|
||||
sin: "f64[9, s2 + 5]" = torch.ops.aten.sin.default(cat_2)
|
||||
mul: "f64[9, s2 + 5]" = torch.ops.aten.mul.Tensor(sin, 3); sin = None
|
||||
@ -2722,7 +2722,7 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase, NestedTensorTestCase):
|
||||
norm_graph,
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_nt_: "f64[3, s1, 5]", s1: "Sym(s1)"):
|
||||
def forward(self, s1: "Sym(s1)", L_nt_: "f64[3, s1, 5]"):
|
||||
l_nt_ = L_nt_
|
||||
|
||||
add: "f64[3, s1, 5]" = l_nt_ + 2; l_nt_ = None
|
||||
|
@ -368,15 +368,17 @@ class TestControlFlow(TestCase):
|
||||
gm.code.strip(),
|
||||
"""\
|
||||
def forward(self, pred_1, x_1):
|
||||
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1,)); true_graph_0 = false_graph_0 = None
|
||||
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1, sym_size_int)); true_graph_0 = false_graph_0 = None
|
||||
getitem = cond[0]; cond = None
|
||||
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
|
||||
true_graph_1 = self.true_graph_1
|
||||
false_graph_1 = self.false_graph_1
|
||||
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = None
|
||||
getitem_1 = cond_1[0]; cond_1 = None
|
||||
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1, sym_size_int)); pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = sym_size_int = None
|
||||
getitem_1 = cond_1[0]
|
||||
getitem_2 = cond_1[1]; cond_1 = getitem_2 = None
|
||||
return (getitem_1,)""", # noqa: B950
|
||||
)
|
||||
|
||||
@ -409,15 +411,17 @@ def forward(self, pred_1, x_1):
|
||||
gm.code.strip(),
|
||||
"""\
|
||||
def forward(self, pred_1, x_1):
|
||||
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1,)); true_graph_0 = false_graph_0 = None
|
||||
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1, sym_size_int)); true_graph_0 = false_graph_0 = None
|
||||
getitem = cond[0]; cond = None
|
||||
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
|
||||
true_graph_1 = self.true_graph_1
|
||||
false_graph_1 = self.false_graph_1
|
||||
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = None
|
||||
getitem_1 = cond_1[0]; cond_1 = None
|
||||
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1, sym_size_int)); pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = sym_size_int = None
|
||||
getitem_1 = cond_1[0]
|
||||
getitem_2 = cond_1[1]; cond_1 = getitem_2 = None
|
||||
return (getitem_1,)""", # noqa: B950
|
||||
)
|
||||
|
||||
@ -518,16 +522,20 @@ def forward(self, pred_1, x_1):
|
||||
gm.code.strip(),
|
||||
"""\
|
||||
def forward(self, pred_1, x_1, y_1, z_1):
|
||||
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0); x_1 = None
|
||||
sym_size_int_1 = torch.ops.aten.sym_size.int(y_1, 0)
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (z_1, y_1)); true_graph_0 = false_graph_0 = None
|
||||
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (z_1, y_1, sym_size_int, sym_size_int_1)); true_graph_0 = false_graph_0 = None
|
||||
getitem = cond[0]; cond = None
|
||||
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
|
||||
true_graph_1 = self.true_graph_1
|
||||
false_graph_1 = self.false_graph_1
|
||||
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, z_1, y_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = z_1 = y_1 = None
|
||||
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, z_1, y_1, sym_size_int, sym_size_int_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = z_1 = y_1 = sym_size_int = sym_size_int_1 = None
|
||||
getitem_1 = cond_1[0]
|
||||
getitem_2 = cond_1[1]; cond_1 = getitem_2 = None
|
||||
getitem_2 = cond_1[1]; getitem_2 = None
|
||||
getitem_3 = cond_1[2]; getitem_3 = None
|
||||
getitem_4 = cond_1[3]; cond_1 = getitem_4 = None
|
||||
return (getitem_1,)""", # noqa: B950
|
||||
)
|
||||
|
||||
@ -572,12 +580,13 @@ def forward(self, pred_1, x_1, y_1, z_1):
|
||||
gm.code.strip(),
|
||||
"""\
|
||||
def forward(self, pred_1, x_1):
|
||||
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
_param_constant0 = self._param_constant0
|
||||
_param_constant1 = self._param_constant1
|
||||
_tensor_constant0 = self._tensor_constant0
|
||||
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (_param_constant0, _param_constant1, x_1, _tensor_constant0)); true_graph_0 = false_graph_0 = _param_constant0 = _param_constant1 = _tensor_constant0 = None
|
||||
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (_param_constant0, _param_constant1, x_1, sym_size_int, _tensor_constant0)); true_graph_0 = false_graph_0 = _param_constant0 = _param_constant1 = _tensor_constant0 = None
|
||||
getitem = cond[0]; cond = None
|
||||
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
|
||||
true_graph_1 = self.true_graph_1
|
||||
@ -585,11 +594,12 @@ def forward(self, pred_1, x_1):
|
||||
_param_constant0_1 = self._param_constant0
|
||||
_param_constant1_1 = self._param_constant1
|
||||
_tensor_constant0_1 = self._tensor_constant0
|
||||
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, _param_constant0_1, _param_constant1_1, x_1, _tensor_constant0_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = _param_constant0_1 = _param_constant1_1 = x_1 = _tensor_constant0_1 = None
|
||||
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, _param_constant0_1, _param_constant1_1, x_1, sym_size_int, _tensor_constant0_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = _param_constant0_1 = _param_constant1_1 = x_1 = sym_size_int = _tensor_constant0_1 = None
|
||||
getitem_1 = cond_1[0]; getitem_1 = None
|
||||
getitem_2 = cond_1[1]
|
||||
getitem_3 = cond_1[2]; getitem_3 = None
|
||||
getitem_4 = cond_1[3]; cond_1 = getitem_4 = None
|
||||
getitem_4 = cond_1[3]; getitem_4 = None
|
||||
getitem_5 = cond_1[4]; cond_1 = getitem_5 = None
|
||||
return (getitem_2,)""", # noqa: B950
|
||||
)
|
||||
|
||||
@ -692,24 +702,30 @@ def forward(self, x_1):
|
||||
gm.code.strip(),
|
||||
"""\
|
||||
def forward(self, pred_1, a_1, b_1, c_1):
|
||||
sym_size_int = torch.ops.aten.sym_size.int(a_1, 0)
|
||||
sym_size_int_1 = torch.ops.aten.sym_size.int(b_1, 0)
|
||||
sym_size_int_2 = torch.ops.aten.sym_size.int(c_1, 0)
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (a_1, b_1, c_1)); true_graph_0 = false_graph_0 = None
|
||||
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (a_1, b_1, sym_size_int, sym_size_int_1, c_1, sym_size_int_2)); true_graph_0 = false_graph_0 = None
|
||||
getitem = cond[0]; cond = None
|
||||
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
|
||||
true_graph_1 = self.true_graph_1
|
||||
false_graph_1 = self.false_graph_1
|
||||
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, a_1, b_1, c_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = a_1 = b_1 = c_1 = None
|
||||
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, a_1, b_1, sym_size_int, sym_size_int_1, c_1, sym_size_int_2)); pred_1 = true_graph_1 = false_graph_1 = ones_like = a_1 = b_1 = sym_size_int = sym_size_int_1 = c_1 = sym_size_int_2 = None
|
||||
getitem_1 = cond_1[0]
|
||||
getitem_2 = cond_1[1]
|
||||
getitem_3 = cond_1[2]; cond_1 = getitem_3 = None
|
||||
getitem_3 = cond_1[2]; getitem_3 = None
|
||||
getitem_4 = cond_1[3]; getitem_4 = None
|
||||
getitem_5 = cond_1[4]; getitem_5 = None
|
||||
getitem_6 = cond_1[5]; cond_1 = getitem_6 = None
|
||||
return (getitem_1, getitem_2)""", # noqa: B950
|
||||
)
|
||||
# Forward
|
||||
self.assertExpectedInline(
|
||||
gm.true_graph_0.code.strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1, arg1_1, arg2_1):
|
||||
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1):
|
||||
add = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
|
||||
return (add,)""",
|
||||
)
|
||||
@ -717,11 +733,11 @@ def forward(self, arg0_1, arg1_1, arg2_1):
|
||||
self.assertExpectedInline(
|
||||
gm.true_graph_1.code.strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
|
||||
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1):
|
||||
add = torch.ops.aten.add.Tensor(arg1_1, arg2_1); arg1_1 = arg2_1 = add = None
|
||||
clone = torch.ops.aten.clone.default(arg0_1)
|
||||
clone_1 = torch.ops.aten.clone.default(arg0_1); arg0_1 = None
|
||||
return [clone, clone_1, None]""",
|
||||
return [clone, clone_1, None, None, None, None]""",
|
||||
)
|
||||
|
||||
def test_cond_autograd_pytree_input(self):
|
||||
@ -1050,15 +1066,17 @@ def forward(self, pred_1, x_1):
|
||||
gm.code.strip(),
|
||||
"""\
|
||||
def forward(self, pred_1, x_1):
|
||||
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1,)); true_graph_0 = false_graph_0 = None
|
||||
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1, sym_size_int)); true_graph_0 = false_graph_0 = None
|
||||
getitem = cond[0]; cond = None
|
||||
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
|
||||
true_graph_1 = self.true_graph_1
|
||||
false_graph_1 = self.false_graph_1
|
||||
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = None
|
||||
getitem_1 = cond_1[0]; cond_1 = None
|
||||
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1, sym_size_int)); pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = sym_size_int = None
|
||||
getitem_1 = cond_1[0]
|
||||
getitem_2 = cond_1[1]; cond_1 = getitem_2 = None
|
||||
return (getitem_1,)""", # noqa: B950
|
||||
)
|
||||
|
||||
@ -1623,7 +1641,7 @@ def forward(self, pred_1, x_1):
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
# TODO: Does not work because of the usage of vmap witin associative_scan
|
||||
# The parameterization is commented out for the moment and the test is marked with expected fail
|
||||
# The paT206899919 rameterization is commented out for the moment and the test is marked with expected fail
|
||||
# Fails with: AssertionError: scan is not an OpOverload
|
||||
@skipIfRocm(msg="Unsupported on ROCM yet")
|
||||
@unittest.skipIf(not SM70OrLater, "triton")
|
||||
@ -2370,10 +2388,14 @@ def forward(self, fct_1, init_1, xs_1):
|
||||
select = torch.ops.aten.select.int(xs_1, 0, 0)
|
||||
add = torch.ops.aten.add.Tensor(init_1, select); add = None
|
||||
add_1 = torch.ops.aten.add.Tensor(init_1, select); select = add_1 = None
|
||||
sym_size_int_1 = torch.ops.aten.sym_size.int(init_1, 1)
|
||||
sym_size_int_2 = torch.ops.aten.sym_size.int(init_1, 2)
|
||||
clone = torch.ops.aten.clone.default(init_1); clone = None
|
||||
select_copy = torch.ops.aten.select_copy.int(xs_1, 0, 0); select_copy = None
|
||||
sym_size_int_3 = torch.ops.aten.sym_size.int(xs_1, 1)
|
||||
sym_size_int_4 = torch.ops.aten.sym_size.int(xs_1, 2)
|
||||
scan_combine_graph_0 = self.scan_combine_graph_0
|
||||
scan = torch.ops.higher_order.scan(scan_combine_graph_0, [init_1], [xs_1], 0, True, []); scan_combine_graph_0 = init_1 = xs_1 = None
|
||||
scan = torch.ops.higher_order.scan(scan_combine_graph_0, [init_1], [xs_1], 0, True, [sym_size_int_1, sym_size_int_2, sym_size_int_3, sym_size_int_4]); scan_combine_graph_0 = init_1 = xs_1 = sym_size_int_1 = sym_size_int_2 = sym_size_int_3 = sym_size_int_4 = None
|
||||
getitem = scan[0]
|
||||
getitem_1 = scan[1]; scan = None
|
||||
return (getitem, getitem_1)""", # noqa: B950
|
||||
@ -2612,6 +2634,8 @@ class AssociativeScanTests(TestCase):
|
||||
|
||||
num_dims = [random.randint(2, 5) for _ in range(4)]
|
||||
for num_dim in num_dims:
|
||||
# To avoid triggering automatic dynamic shape
|
||||
torch._dynamo.reset()
|
||||
shapes = [random.randint(1, 9) for _ in range(num_dim)]
|
||||
rnd_scan_dim = random.randint(0, num_dim - 1)
|
||||
x = torch.randn(*shapes, device=device)
|
||||
@ -3692,9 +3716,13 @@ def forward(self, l_iter_, l_x_, l__self___dec_cond_fn, l__self___linear_bias_bo
|
||||
gm.code.strip("\n"),
|
||||
"""\
|
||||
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
|
||||
sym_size_int = torch.ops.aten.sym_size.int(arg2_1, 0)
|
||||
sym_size_int_1 = torch.ops.aten.sym_size.int(arg2_1, 1)
|
||||
sym_size_int_2 = torch.ops.aten.sym_size.int(arg3_1, 0)
|
||||
sym_size_int_3 = torch.ops.aten.sym_size.int(arg3_1, 1)
|
||||
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
|
||||
while_loop_body_graph_0 = self.while_loop_body_graph_0
|
||||
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (arg0_1, arg1_1, arg2_1, arg3_1), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = arg0_1 = arg1_1 = arg2_1 = arg3_1 = None
|
||||
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (arg0_1, arg1_1, arg2_1, arg3_1), (sym_size_int, sym_size_int_1, sym_size_int_2, sym_size_int_3)); while_loop_cond_graph_0 = while_loop_body_graph_0 = arg0_1 = arg1_1 = arg2_1 = arg3_1 = sym_size_int = sym_size_int_1 = sym_size_int_2 = sym_size_int_3 = None
|
||||
getitem = while_loop[0]
|
||||
getitem_1 = while_loop[1]
|
||||
getitem_2 = while_loop[2]
|
||||
@ -3705,10 +3733,10 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
|
||||
self.assertExpectedInline(
|
||||
outer_body.code.strip("\n"),
|
||||
"""\
|
||||
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
|
||||
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1):
|
||||
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
|
||||
while_loop_body_graph_0 = self.while_loop_body_graph_0
|
||||
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (arg0_1, arg1_1, arg2_1, arg3_1), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = arg0_1 = arg1_1 = arg2_1 = arg3_1 = None
|
||||
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (arg0_1, arg1_1, arg2_1, arg3_1), (arg7_1, arg7_1, arg7_1, arg7_1)); while_loop_cond_graph_0 = while_loop_body_graph_0 = arg0_1 = arg1_1 = arg2_1 = arg3_1 = arg7_1 = None
|
||||
getitem = while_loop[0]
|
||||
getitem_1 = while_loop[1]
|
||||
getitem_2 = while_loop[2]
|
||||
@ -3723,10 +3751,10 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
|
||||
self.assertExpectedInline(
|
||||
outer_body.code.strip("\n"),
|
||||
"""\
|
||||
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
|
||||
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1):
|
||||
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
|
||||
while_loop_body_graph_0 = self.while_loop_body_graph_0
|
||||
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (arg0_1, arg1_1, arg2_1, arg3_1), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = arg0_1 = arg1_1 = arg2_1 = arg3_1 = None
|
||||
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (arg0_1, arg1_1, arg2_1, arg3_1), (arg7_1, arg7_1, arg7_1, arg7_1)); while_loop_cond_graph_0 = while_loop_body_graph_0 = arg0_1 = arg1_1 = arg2_1 = arg3_1 = arg7_1 = None
|
||||
getitem = while_loop[0]
|
||||
getitem_1 = while_loop[1]
|
||||
getitem_2 = while_loop[2]
|
||||
@ -3741,7 +3769,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
|
||||
self.assertExpectedInline(
|
||||
inner_body.code.strip("\n"),
|
||||
"""\
|
||||
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
|
||||
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1):
|
||||
clone = torch.ops.aten.clone.default(arg0_1); arg0_1 = None
|
||||
sub = torch.ops.aten.sub.Tensor(arg1_1, 1); arg1_1 = None
|
||||
add = torch.ops.aten.add.Tensor(arg2_1, 3.14); arg2_1 = None
|
||||
@ -3752,7 +3780,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
|
||||
self.assertExpectedInline(
|
||||
inner_cond.code.strip("\n"),
|
||||
"""\
|
||||
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
|
||||
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1):
|
||||
gt = torch.ops.aten.gt.Scalar(arg1_1, 0); arg1_1 = None
|
||||
return gt
|
||||
""",
|
||||
@ -3854,23 +3882,27 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
|
||||
def forward(self, a_1, b_1):
|
||||
sum_1 = torch.ops.aten.sum.default(a_1)
|
||||
gt = torch.ops.aten.gt.Scalar(sum_1, 0); sum_1 = None
|
||||
sym_size_int = torch.ops.aten.sym_size.int(a_1, 0)
|
||||
sym_size_int_1 = torch.ops.aten.sym_size.int(a_1, 1)
|
||||
sym_size_int_2 = torch.ops.aten.sym_size.int(b_1, 0)
|
||||
sym_size_int_3 = torch.ops.aten.sym_size.int(b_1, 1)
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [a_1, b_1]); gt = true_graph_0 = false_graph_0 = a_1 = b_1 = None
|
||||
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [a_1, b_1, sym_size_int, sym_size_int_1, sym_size_int_2, sym_size_int_3]); gt = true_graph_0 = false_graph_0 = a_1 = b_1 = sym_size_int = sym_size_int_1 = sym_size_int_2 = sym_size_int_3 = None
|
||||
getitem = cond[0]; cond = None
|
||||
return getitem""", # noqa: B950
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
gm.true_graph_0.code.strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1, arg1_1):
|
||||
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1):
|
||||
add = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
|
||||
return (add,)""",
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
gm.false_graph_0.code.strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1, arg1_1):
|
||||
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1):
|
||||
mul = torch.ops.aten.mul.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
|
||||
return (mul,)""",
|
||||
)
|
||||
@ -3916,6 +3948,8 @@ def forward(self, arg0_1, arg1_1):
|
||||
if isinstance(val, tuple):
|
||||
for v in val:
|
||||
yield v.fake_mode.shape_env
|
||||
elif isinstance(val, torch.SymInt):
|
||||
yield val.node.shape_env
|
||||
else:
|
||||
yield val.fake_mode.shape_env
|
||||
|
||||
@ -4983,10 +5017,11 @@ def forward(self, arg0_1):
|
||||
"""\
|
||||
def forward(self, x_1):
|
||||
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
|
||||
eq = sym_size_int == 4; sym_size_int = None
|
||||
eq = sym_size_int == 4
|
||||
sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1)
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1]); eq = true_graph_0 = false_graph_0 = x_1 = None
|
||||
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1, sym_size_int, sym_size_int_1]); eq = true_graph_0 = false_graph_0 = x_1 = sym_size_int = sym_size_int_1 = None
|
||||
getitem = cond[0]; cond = None
|
||||
return getitem""", # noqa: B950
|
||||
)
|
||||
@ -5015,11 +5050,12 @@ def forward(self, x_1):
|
||||
nonzero = torch.ops.aten.nonzero.default(x_1)
|
||||
sym_size_int = torch.ops.aten.sym_size.int(nonzero, 0); nonzero = None
|
||||
gt = sym_size_int > 3; sym_size_int = None
|
||||
sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 0)
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [x_1]); gt = true_graph_0 = false_graph_0 = x_1 = None
|
||||
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [x_1, sym_size_int_1]); gt = true_graph_0 = false_graph_0 = x_1 = sym_size_int_1 = None
|
||||
getitem = cond[0]; cond = None
|
||||
return getitem""",
|
||||
return getitem""", # noqa: B950
|
||||
)
|
||||
|
||||
def _check_closure_correctly_lifted(self, f, *, args, exp_res, exp_arg_num):
|
||||
@ -5092,19 +5128,20 @@ def forward(self, x_1):
|
||||
"""\
|
||||
def forward(self, x_1):
|
||||
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
|
||||
eq = sym_size_int == 4; sym_size_int = None
|
||||
eq = sym_size_int == 4
|
||||
sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1)
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
_tensor_constant0 = self._tensor_constant0
|
||||
_tensor_constant1 = self._tensor_constant1
|
||||
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1, _tensor_constant0, _tensor_constant1]); eq = true_graph_0 = false_graph_0 = x_1 = _tensor_constant0 = _tensor_constant1 = None
|
||||
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1, _tensor_constant0, sym_size_int, sym_size_int_1, _tensor_constant1]); eq = true_graph_0 = false_graph_0 = x_1 = _tensor_constant0 = sym_size_int = sym_size_int_1 = _tensor_constant1 = None
|
||||
getitem = cond[0]; cond = None
|
||||
return getitem""", # noqa: B950
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
gm.true_graph_0.code.strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1, arg1_1, arg2_1):
|
||||
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
||||
add = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
|
||||
return (add,)""",
|
||||
)
|
||||
@ -5330,10 +5367,11 @@ def forward(self, arg0_1, arg1_1):
|
||||
"""\
|
||||
def forward(self, x_1):
|
||||
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
|
||||
eq = sym_size_int == 4; sym_size_int = None
|
||||
eq = sym_size_int == 4
|
||||
sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1)
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1]); eq = true_graph_0 = false_graph_0 = x_1 = None
|
||||
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1, sym_size_int, sym_size_int_1]); eq = true_graph_0 = false_graph_0 = x_1 = sym_size_int = sym_size_int_1 = None
|
||||
getitem = cond[0]; cond = None
|
||||
return getitem""", # noqa: B950
|
||||
)
|
||||
@ -5341,7 +5379,7 @@ def forward(self, x_1):
|
||||
self.assertExpectedInline(
|
||||
gm.true_graph_0.code.strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1):
|
||||
def forward(self, arg0_1, arg1_1, arg2_1):
|
||||
cos = torch.ops.aten.cos.default(arg0_1)
|
||||
sub = torch.ops.aten.sub.Tensor(arg0_1, cos); arg0_1 = cos = None
|
||||
return (sub,)""",
|
||||
@ -5350,7 +5388,7 @@ def forward(self, arg0_1):
|
||||
self.assertExpectedInline(
|
||||
gm.false_graph_0.code.strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1):
|
||||
def forward(self, arg0_1, arg1_1, arg2_1):
|
||||
sin = torch.ops.aten.sin.default(arg0_1)
|
||||
add = torch.ops.aten.add.Tensor(arg0_1, sin); arg0_1 = sin = None
|
||||
return (add,)""",
|
||||
@ -5689,7 +5727,7 @@ def forward(self, s0 : torch.SymInt, L_a_ : torch.Tensor, L_b_ : torch.Tensor, L
|
||||
tensor = torch.tensor([True])
|
||||
cond_true_0 = self.cond_true_0
|
||||
cond_false_0 = self.cond_false_0
|
||||
cond = torch.ops.higher_order.cond(tensor, cond_true_0, cond_false_0, [l_a_, l_b_, l_self_num]); tensor = cond_true_0 = cond_false_0 = l_a_ = l_b_ = l_self_num = None
|
||||
cond = torch.ops.higher_order.cond(tensor, cond_true_0, cond_false_0, [l_a_, l_b_, l_self_num, s0]); tensor = cond_true_0 = cond_false_0 = l_a_ = l_b_ = l_self_num = s0 = None
|
||||
getitem = cond[0]; cond = None
|
||||
return (getitem,)""", # noqa: B950
|
||||
)
|
||||
|
@ -774,6 +774,16 @@ def skip_if_halide(fn):
|
||||
return wrapper
|
||||
|
||||
|
||||
def skip_if_dynamic(fn):
|
||||
@functools.wraps(fn)
|
||||
def wrapper(self):
|
||||
if ifdynstaticdefault(True, False) or torch._dynamo.config.dynamic_shapes:
|
||||
raise unittest.SkipTest("associtaive_scan doesn's support lifted SymInts.")
|
||||
return fn(self)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def is_halide_backend(device):
|
||||
if getattr(device, "type", device) == "cpu":
|
||||
return config.cpu_backend == "halide"
|
||||
@ -2038,6 +2048,7 @@ class CommonTemplate:
|
||||
|
||||
@skipCUDAIf(TEST_WITH_ROCM, "associative_scan is not supported on ROCm")
|
||||
@skip_if_halide # scan ops
|
||||
@skip_if_dynamic # TODO: support lifted symints when dynamic
|
||||
def test_custom_scan_op(self):
|
||||
if self.device != "cuda":
|
||||
raise unittest.SkipTest("associative_scan only supported on GPU")
|
||||
@ -2063,6 +2074,7 @@ class CommonTemplate:
|
||||
self.assertEqual(expect, actual)
|
||||
|
||||
@skip_if_halide # scan ops
|
||||
@skip_if_dynamic # TODO: support lifted symints when dynamic
|
||||
def test_custom_scan_op_compiled(self):
|
||||
if self.device != "cuda":
|
||||
raise unittest.SkipTest("associative_scan only supported on GPU")
|
||||
@ -2090,6 +2102,7 @@ class CommonTemplate:
|
||||
|
||||
@skipCUDAIf(TEST_WITH_ROCM, "associative_scan is not supported on ROCm")
|
||||
@skip_if_halide # scan ops
|
||||
@skip_if_dynamic # TODO: support lifted symints when dynamic
|
||||
def test_custom_scan_op_multi_input(self):
|
||||
if self.device != "cuda":
|
||||
raise unittest.SkipTest("associative_scan only supported on GPU")
|
||||
@ -2114,6 +2127,7 @@ class CommonTemplate:
|
||||
|
||||
@skipCUDAIf(TEST_WITH_ROCM, "associative_scan is not supported on ROCm")
|
||||
@skip_if_halide # scan ops
|
||||
@skip_if_dynamic # TODO: support lifted symints when dynamic
|
||||
def test_custom_scan_would_split(self):
|
||||
if self.device != "cuda":
|
||||
raise unittest.SkipTest("associative_scan only supported on GPU")
|
||||
|
@ -255,7 +255,7 @@ class OutputGraph:
|
||||
torch_function_mode_stack,
|
||||
):
|
||||
super().__init__()
|
||||
self.tracers = [SubgraphTracer(self, export_root=export)]
|
||||
self.tracers = [SubgraphTracer(self, is_export=export)]
|
||||
# Map from graph input's `Source` to its `VariableTracker` to
|
||||
# de-duplicate graph inputs by source and reuse the tracker
|
||||
self.input_source_to_var: Dict[Source, VariableTracker] = {}
|
||||
@ -577,7 +577,10 @@ class OutputGraph:
|
||||
prior_tracer
|
||||
if prior_tracer
|
||||
else SubgraphTracer(
|
||||
self, parent=self.current_tracer, source_target=source_target
|
||||
self,
|
||||
parent=self.current_tracer,
|
||||
source_target=source_target,
|
||||
is_export=self.current_tracer.is_export,
|
||||
)
|
||||
)
|
||||
self.tracers.append(tracer)
|
||||
@ -657,71 +660,6 @@ class OutputGraph:
|
||||
def current_tx(self):
|
||||
return self.root_tx if not self._current_tx else self._current_tx[-1]
|
||||
|
||||
def add_symbol_bindings(self, arg: GraphArg):
|
||||
# Insert implicit size vars as necessary. With dynamic shapes, we
|
||||
# maintain the invariant that every sizevar gets a direct SymInt input
|
||||
# into the graph. This means downstream graph transforms can assume
|
||||
# every size variable is explicitly bound and accessible, instead of
|
||||
# having to pull it out implicitly from tensors.
|
||||
|
||||
if self.export:
|
||||
return
|
||||
|
||||
assert arg.fake_tensor is not None
|
||||
|
||||
def bind_symint(s: torch.SymInt, prop):
|
||||
if not (is_symbolic(s) and isinstance(s.node.expr, sympy.Symbol)):
|
||||
return
|
||||
s0 = s.node.expr
|
||||
if s0 in self.bound_symbols:
|
||||
return
|
||||
log.debug("bind_symint %s %s", s, prop.name())
|
||||
# TODO: don't readd symint if we already have it in graph
|
||||
# (this is harmless because we do remove the unused ones later)
|
||||
proxy = self.root_tracer.create_graph_input(
|
||||
str(s0),
|
||||
type(s),
|
||||
s,
|
||||
before=True,
|
||||
source=prop,
|
||||
)
|
||||
self.root_tracer.bound_symbols[s0] = proxy
|
||||
assert isinstance(s, torch.SymInt)
|
||||
proxy.node.meta["grapharg"] = GraphArg(
|
||||
prop,
|
||||
s,
|
||||
pass_arg_as_tensor=False,
|
||||
fake_tensor=None,
|
||||
is_tensor=False,
|
||||
)
|
||||
|
||||
def handle_tensor(t, src):
|
||||
for i, s in enumerate(t.size()):
|
||||
bind_symint(s, TensorPropertySource(src, TensorProperty.SIZE, i))
|
||||
if t.layout is torch.strided:
|
||||
for i, s in enumerate(t.stride()):
|
||||
bind_symint(s, TensorPropertySource(src, TensorProperty.STRIDE, i))
|
||||
bind_symint(
|
||||
t.storage_offset(),
|
||||
TensorPropertySource(src, TensorProperty.STORAGE_OFFSET),
|
||||
)
|
||||
elif t.layout is torch.sparse_coo:
|
||||
handle_tensor(t._indices(), src)
|
||||
handle_tensor(t._values(), src)
|
||||
elif t.layout in {torch.sparse_csr, torch.sparse_bsr}:
|
||||
handle_tensor(t.crow_indices(), src)
|
||||
handle_tensor(t.col_indices(), src)
|
||||
elif t.layout in {torch.sparse_csc, torch.sparse_bsc}:
|
||||
handle_tensor(t.ccol_indices(), src)
|
||||
handle_tensor(t.row_indices(), src)
|
||||
if is_traceable_wrapper_subclass(t):
|
||||
attrs, ctx = t.__tensor_flatten__()
|
||||
for attr in attrs:
|
||||
inner_t = getattr(t, attr)
|
||||
handle_tensor(inner_t, AttrSource(src, attr))
|
||||
|
||||
handle_tensor(arg.fake_tensor, arg.source)
|
||||
|
||||
def count_calls(self):
|
||||
return count_calls(self.graph)
|
||||
|
||||
@ -1834,6 +1772,17 @@ def check_pt2_compliant_op(output_graph, kind, target, args, kwargs):
|
||||
_compile_id_counter = itertools.count()
|
||||
|
||||
|
||||
class LazyProxy:
|
||||
def __init__(self, tracer, fn, *args, **kwargs):
|
||||
self.tracer = tracer
|
||||
self.fn = fn
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
def __call__(self):
|
||||
return self.fn(*self.args, **self.kwargs)
|
||||
|
||||
|
||||
class SubgraphTracer(fx.Tracer):
|
||||
"""
|
||||
Holds an FX graph that is being traced. OutputGraph owns a SubgraphTracer
|
||||
@ -1842,19 +1791,13 @@ class SubgraphTracer(fx.Tracer):
|
||||
compiling and executing the graph.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, output_graph, parent=None, export_root=False, source_target=None
|
||||
):
|
||||
def __init__(self, output_graph, parent=None, is_export=False, source_target=None):
|
||||
super().__init__()
|
||||
self.output_graph = weakref.proxy(output_graph)
|
||||
self.graph = torch.fx.Graph()
|
||||
|
||||
# The export is only ever set for the ROOT tracer. It controls
|
||||
# whether or not certain inputs are allowed to be added or not.
|
||||
# Look at call sites of create_graph_input to see how it is used.
|
||||
if export_root:
|
||||
assert parent is None
|
||||
self.export_root = export_root
|
||||
# See note [Export inputs must be explicitly passed in]
|
||||
self.is_export = is_export
|
||||
# Map from graph input name to its placeholder proxy object, where the
|
||||
# map's keys give all current placeholder node names and can be used to
|
||||
# create unique node names
|
||||
@ -1879,8 +1822,11 @@ class SubgraphTracer(fx.Tracer):
|
||||
# Dicts maintain the order of args for the HigherOrderOperator call.
|
||||
self.lifted_freevars = {}
|
||||
|
||||
# map symbols to their bound proxy placeholders.
|
||||
self.bound_symbols: Dict[sympy.Symbol, torch.fx.Proxy] = {}
|
||||
# map basic symbols (unbacked and unbacked) to their bound proxies.
|
||||
# There are only two cases where bound_symbols will be recorded:
|
||||
# 1. when we create_graph_input for a backed SymInt that's basic symbol
|
||||
# 2. when we track_unbacked_symbols for intermediate results that contain unbacked symints.
|
||||
self.bound_symbols: Dict[sympy.Symbol, Union[torch.fx.Proxy, LazyProxy]] = {}
|
||||
|
||||
self.prev_inst = None
|
||||
# True if this tracer is currently tracing into torch.utils.checkpoint
|
||||
@ -1892,6 +1838,8 @@ class SubgraphTracer(fx.Tracer):
|
||||
# backward recomputation of the checkpoint region doesn't affect its correctness.
|
||||
self.allow_side_effects_under_checkpoint = False
|
||||
|
||||
self.debug_level: int = parent.debug_level + 1 if parent is not None else 0
|
||||
|
||||
self._cur_code = None
|
||||
self._orig_gm_meta = None
|
||||
self._orig_gm_lineno_map = None
|
||||
@ -2139,15 +2087,19 @@ class SubgraphTracer(fx.Tracer):
|
||||
self, name, type_expr, example_value, before=False, source=None
|
||||
):
|
||||
log.debug(
|
||||
"create_graph_input %s %s",
|
||||
"create_graph_input %s %s %s at debug_level %s before=%s",
|
||||
name,
|
||||
source.name() if source is not None else "(none)",
|
||||
example_value,
|
||||
self.debug_level,
|
||||
before,
|
||||
)
|
||||
if source is None:
|
||||
assert (
|
||||
self.parent is not None
|
||||
), "you are required to provide a source for inputs on the root tracer"
|
||||
), f"you are required to provide a source for inputs {name} example_val {example_value} on the root tracer"
|
||||
|
||||
# Note [Export inputs must be explicitly passed in]
|
||||
# In eager, we are generally OK with adding graph inputs whenever we
|
||||
# want, because we take care of writing the bytecode that knows how
|
||||
# to source all the inputs.
|
||||
@ -2156,7 +2108,7 @@ class SubgraphTracer(fx.Tracer):
|
||||
# object which only depends on the inputs you explicitly passed to it.
|
||||
# So we are a bit more strict about what sources can become inputs
|
||||
# in export
|
||||
if self.export_root:
|
||||
if self.is_export and self.parent is None:
|
||||
if not is_from_local_source(source, allow_cell_or_freevar=False):
|
||||
self.output_graph.source_to_user_stacks.setdefault(source, []).append(
|
||||
TracingContext.extract_stack()
|
||||
@ -2188,6 +2140,44 @@ class SubgraphTracer(fx.Tracer):
|
||||
self.input_name_to_proxy[k] = v
|
||||
else:
|
||||
self.input_name_to_proxy[name] = proxy
|
||||
|
||||
# NOTE: [Auto lift basic free symbols when create_graph_input]
|
||||
# Whenever we call create_graph_input, we try to also lift the basic symbols in example values
|
||||
# as graph input.
|
||||
# This applies to both top-level graph and subgraphs in higher order ops.
|
||||
# It has several cases:
|
||||
# 1. When create_graph_input for a tensor that has symbolic shapes,
|
||||
# we look for basic symbols in its size and stride, we check if the symbol is bound
|
||||
# in current graph (i.e. bound_symbols), it it's not bound, we'll create a placeholder
|
||||
# for it then recursively check its parent, creates ph if not bound.
|
||||
# Every tracer maintains a mapping (i.e. lifted_freevars)
|
||||
# that maps from parent proxy to proxy in current tracer for the symbol.
|
||||
# 2. When create_graph_input for a tensor with unbacked symbolic shapes,
|
||||
# Backed symbols all come from inputs's symbolic shape. But unbacked symbols
|
||||
# can be created while tracing. So we use track_unbacked_symbols will intercept
|
||||
# at wrap_fx_proxy, and try to bind the unbacked symbols immediately after they're
|
||||
# created.
|
||||
# 3. subgraph will also lifted basic symbols in compound exprs of tensor shape.
|
||||
# For example, if an input to subgraph takes size [s1+s2//8], we'll look for the
|
||||
# the free symbols in the sizes and lift as inputs similar to 1 in _lift_symbols_in_symint)
|
||||
# 4. When create_graph_input for a SymInt, if the symint is a basic symbol, we'll track it
|
||||
# in bound_symbols so that we don't lift the same basic symbol twice. When the symint is a
|
||||
# compound expr, we'll just create the proxy for the compouned expr but not lift its basic symbols.
|
||||
# Also see NOTE: [Export inputs must be explicitly passed in]
|
||||
is_strict_export = self.is_export
|
||||
is_non_strict_export = torch.compiler.is_compiling()
|
||||
if (
|
||||
not is_strict_export
|
||||
and not is_non_strict_export
|
||||
and isinstance(example_value, torch.Tensor)
|
||||
):
|
||||
self._lift_basic_symbols(example_value, source)
|
||||
|
||||
# Bound the symbol to ph if example_value is a SymInt with basic symbol.
|
||||
if isinstance(example_value, torch.SymInt) and isinstance(
|
||||
example_value.node.expr, sympy.Symbol
|
||||
):
|
||||
self.bound_symbols[example_value.node.expr] = proxy
|
||||
return proxy
|
||||
|
||||
# See NOTE: [Nested SubgraphTracer and free_variable handling] for more details
|
||||
@ -2197,6 +2187,21 @@ class SubgraphTracer(fx.Tracer):
|
||||
assert (
|
||||
self.parent is not None
|
||||
), "lift_tracked_freevar_to_input should not be called on root SubgraphTracer"
|
||||
|
||||
example_value = proxy.node.meta["example_value"]
|
||||
|
||||
# To avoid lifting the same symbol twice, we check whether basic symbols has been tracked.
|
||||
# For example, the basic symbols may have already been lifted for current subgraph when
|
||||
# we automatically lift basic symbols in the sizes/strides of a tensor t.
|
||||
# Suppose parent graph calls sz = t.size()[0], it creates
|
||||
# a proxy in parent and the subgraph accesses sz via closure. sz's proxy is not tracked
|
||||
# in current sub-tracer so we may lift the same symbol twice.
|
||||
if (
|
||||
isinstance(example_value, torch.SymInt)
|
||||
and example_value.node.expr in self.bound_symbols
|
||||
):
|
||||
return self.bound_symbols[example_value.node.expr]
|
||||
|
||||
# Proxys are associated with VariableTracker.
|
||||
# It is possible that we've already lifted the Proxy to be an input.
|
||||
# If that is the case, just return the already lifted Proxy.
|
||||
@ -2228,6 +2233,263 @@ class SubgraphTracer(fx.Tracer):
|
||||
return arg
|
||||
return self.lift_tracked_freevar_to_input(arg)
|
||||
|
||||
# See NOTE: [Auto lift basic free symbols when create_graph_input] for overall design
|
||||
# You MUST call this API every time when creating a proxy in wrap_fx_proxy for a call
|
||||
# that produced unbacked symints or tensors with unbacked symint shapes.
|
||||
# This function is used to track the unbacked symints with its proxies created during
|
||||
# dynamo tracing so that subgraph knows how to bind a symbol input with parent's proxy.
|
||||
# LazyProxy are created for tensor shapes that're unbacked so that we don't create proxies
|
||||
# for symbols that're not going to be used.
|
||||
def track_unbacked_symbols(
|
||||
self, example_value, e_proxy: Union[LazyProxy, torch.fx.Proxy]
|
||||
):
|
||||
# When binding the symbols in an exmaple_value, we bind the symbols
|
||||
# to the proxy's associatied Tracer instead of current tracer.
|
||||
# This is because:
|
||||
# 1. We may be calling wrap_tensors during speculate_subgraph because
|
||||
# the variables are lazily realized. The proxy are top-level phs but
|
||||
# current tracer is a subtracer.
|
||||
# 2. For autograd.Function, we trace the backward graph with a new tracer
|
||||
# whose parent is the forward tracer, but we're using all the proxies created
|
||||
# in forward tracer to trace the backward.
|
||||
# For example, forward calls save_for_backward for a input tensor t.
|
||||
# Backward calls t.tolist(). In this case, all the proxies that backward tracer
|
||||
# sees are from parent tracer (i.e. the forward tracer). (e.g. t[0].item())
|
||||
# See test_validate_outputs_unbacked for repro on 2.
|
||||
tracer = e_proxy.tracer
|
||||
assert isinstance(tracer, SubgraphTracer)
|
||||
|
||||
def need_bind(s) -> bool:
|
||||
from torch.fx.experimental.symbolic_shapes import is_symbolic
|
||||
|
||||
return (
|
||||
is_symbolic(s)
|
||||
and isinstance(s.node.expr, sympy.Symbol)
|
||||
and s.node.shape_env.is_unbacked_symint(s.node.expr)
|
||||
and s.node.expr not in self.bound_symbols
|
||||
)
|
||||
|
||||
def _proxy_with_example_value(example_value, *args, **kwargs):
|
||||
proxy = tracer.create_proxy(*args, **kwargs)
|
||||
set_example_value(proxy.node, example_value)
|
||||
return proxy
|
||||
|
||||
if isinstance(example_value, torch.Tensor):
|
||||
for i, s in enumerate(example_value.size()):
|
||||
if need_bind(s):
|
||||
log.debug(
|
||||
"_track_unbacked_symbols %s for %s.size()[%s] at debug_level %s",
|
||||
s,
|
||||
e_proxy,
|
||||
i,
|
||||
tracer.debug_level,
|
||||
)
|
||||
lazy_proxy = LazyProxy(
|
||||
tracer,
|
||||
_proxy_with_example_value,
|
||||
s,
|
||||
"call_function",
|
||||
torch.ops.aten.sym_size.int,
|
||||
(e_proxy, i),
|
||||
{},
|
||||
type_expr=type(s),
|
||||
)
|
||||
self.track_unbacked_symbols(s, lazy_proxy)
|
||||
|
||||
if example_value.layout is torch.strided:
|
||||
for i, s in enumerate(example_value.stride()):
|
||||
if need_bind(s):
|
||||
log.debug(
|
||||
"_track_unbacked_symbols %s for %s.stride()[%s] at debug_level %s",
|
||||
s,
|
||||
e_proxy,
|
||||
i,
|
||||
tracer.debug_level,
|
||||
)
|
||||
lazy_proxy = LazyProxy(
|
||||
tracer,
|
||||
_proxy_with_example_value,
|
||||
s,
|
||||
"call_function",
|
||||
torch.ops.aten.sym_stride.int,
|
||||
(e_proxy, i),
|
||||
{},
|
||||
type_expr=type(s),
|
||||
)
|
||||
self.track_unbacked_symbols(s, lazy_proxy)
|
||||
|
||||
elif example_value.layout is torch.sparse_coo:
|
||||
self.track_unbacked_symbols(example_value._indices(), e_proxy)
|
||||
self.track_unbacked_symbols(example_value._values(), e_proxy)
|
||||
elif example_value.layout in {torch.sparse_csr, torch.sparse_bsr}:
|
||||
self.track_unbacked_symbols(example_value.crow_indices(), e_proxy)
|
||||
self.track_unbacked_symbols(example_value.col_indices(), e_proxy)
|
||||
elif example_value.layout in {torch.sparse_csc, torch.sparse_bsc}:
|
||||
self.track_unbacked_symbols(example_value.ccol_indices(), e_proxy)
|
||||
self.track_unbacked_symbols(example_value.row_indices(), e_proxy)
|
||||
if is_traceable_wrapper_subclass(example_value):
|
||||
attrs, ctx = example_value.__tensor_flatten__()
|
||||
for attr in attrs:
|
||||
inner_t = getattr(example_value, attr)
|
||||
self.track_unbacked_symbols(inner_t, getattr(e_proxy, attr))
|
||||
elif isinstance(example_value, torch.SymInt):
|
||||
# Only bind unbacked symbols. backed symbols are lifted as inputs.
|
||||
if need_bind(example_value):
|
||||
expr = example_value.node.expr
|
||||
tracer.bound_symbols[expr] = e_proxy
|
||||
|
||||
# See Note [Auto lift basic free symbols when create_graph_input]
|
||||
def _lift_basic_symbols(
|
||||
self, example_value: Union[torch.SymInt, torch.Tensor], src: Optional[Source]
|
||||
):
|
||||
# The before arg is for inserting symints in the sizes/strides of a tensor
|
||||
# before the tensor. This odering ensures that when we look at the tensor's
|
||||
# symbols, they're already lifted/tracked. E.g. this assumption is used
|
||||
# in insert_deferred_runtime_asserts.
|
||||
def _lift_symbols_in_symint(
|
||||
s: Union[int, torch.SymInt],
|
||||
source: Optional[Source],
|
||||
before: bool = False,
|
||||
) -> None:
|
||||
if not is_symbolic(s):
|
||||
return
|
||||
|
||||
assert isinstance(s, torch.SymInt)
|
||||
self_to_be_bound = self.lookup_unbound_symbols(s)
|
||||
if len(self_to_be_bound) == 0:
|
||||
return
|
||||
|
||||
# For subgraph
|
||||
if self.parent is not None:
|
||||
# Recursively lift symbols in symint until top-level.
|
||||
self.parent._lift_basic_symbols(s, source)
|
||||
for s0 in self_to_be_bound:
|
||||
parent_proxy = self.parent.bound_symbols[s0]
|
||||
example_val = parent_proxy.node.meta["example_value"]
|
||||
assert isinstance(example_val, torch.SymInt)
|
||||
ph = self.create_graph_input(
|
||||
str(s0),
|
||||
type(example_val),
|
||||
example_val,
|
||||
before=before,
|
||||
source=source,
|
||||
)
|
||||
log.debug(
|
||||
"_lift_symbols_in_symint %s from %s at debug_level %s",
|
||||
s0,
|
||||
source.name() if source is not None else "subgraph inputs",
|
||||
self.debug_level,
|
||||
)
|
||||
self.lifted_freevars[parent_proxy] = ph
|
||||
# For root_tracer:
|
||||
else:
|
||||
assert len(self_to_be_bound) == 1, (
|
||||
f"For root tracer, we only expect to bind basic symbols (compound symbols "
|
||||
f"should be cached before) but got unbound symbols {self_to_be_bound} in {s}"
|
||||
)
|
||||
assert source is not None, (
|
||||
f"Source of '{s}' is None when lifting it to input of top-level. If it's an unbacked symbol, "
|
||||
"this could be because it's not tracked with lazy_bind_unbacked_symbols. "
|
||||
f"Otherwise, should provide a source when create_graph_input for `{s}` at root tracer."
|
||||
)
|
||||
s0 = next(iter(self_to_be_bound))
|
||||
ph = self.create_graph_input(
|
||||
str(s0),
|
||||
type(s),
|
||||
s,
|
||||
before=before,
|
||||
source=source,
|
||||
)
|
||||
log.debug(
|
||||
"_lift_symbols_in_symint %s from %s at debug_level %s",
|
||||
s,
|
||||
source.name() if source is not None else "subgraph inputs",
|
||||
self.debug_level,
|
||||
)
|
||||
ph.node.meta["grapharg"] = GraphArg(
|
||||
source,
|
||||
s,
|
||||
pass_arg_as_tensor=False,
|
||||
fake_tensor=None,
|
||||
is_tensor=False,
|
||||
)
|
||||
|
||||
if isinstance(example_value, torch.Tensor):
|
||||
for i, s in enumerate(example_value.size()):
|
||||
_lift_symbols_in_symint(
|
||||
s,
|
||||
(
|
||||
TensorPropertySource(src, TensorProperty.SIZE, i)
|
||||
if src is not None
|
||||
else None
|
||||
),
|
||||
before=True,
|
||||
)
|
||||
if example_value.layout is torch.strided:
|
||||
for i, s in enumerate(example_value.stride()):
|
||||
_lift_symbols_in_symint(
|
||||
s,
|
||||
(
|
||||
TensorPropertySource(src, TensorProperty.STRIDE, i)
|
||||
if src is not None
|
||||
else None
|
||||
),
|
||||
before=True,
|
||||
)
|
||||
_lift_symbols_in_symint(
|
||||
example_value.storage_offset(),
|
||||
(
|
||||
TensorPropertySource(src, TensorProperty.STORAGE_OFFSET)
|
||||
if src is not None
|
||||
else None
|
||||
),
|
||||
before=True,
|
||||
)
|
||||
elif example_value.layout is torch.sparse_coo:
|
||||
self._lift_basic_symbols(example_value._indices(), src)
|
||||
self._lift_basic_symbols(example_value._values(), src)
|
||||
elif example_value.layout in {torch.sparse_csr, torch.sparse_bsr}:
|
||||
self._lift_basic_symbols(example_value.crow_indices(), src)
|
||||
self._lift_basic_symbols(example_value.col_indices(), src)
|
||||
elif example_value.layout in {torch.sparse_csc, torch.sparse_bsc}:
|
||||
self._lift_basic_symbols(example_value.ccol_indices(), src)
|
||||
self._lift_basic_symbols(example_value.row_indices(), src)
|
||||
if is_traceable_wrapper_subclass(example_value):
|
||||
attrs, ctx = example_value.__tensor_flatten__()
|
||||
for attr in attrs:
|
||||
inner_t = getattr(example_value, attr)
|
||||
self._lift_basic_symbols(
|
||||
inner_t, AttrSource(src, attr) if src is not None else None
|
||||
)
|
||||
elif isinstance(example_value, torch.SymInt):
|
||||
_lift_symbols_in_symint(
|
||||
example_value,
|
||||
src,
|
||||
)
|
||||
|
||||
# Lookup the proxy in current tracer for each symbol in expressions of s,
|
||||
# See Note [Auto lift basic free symbols when create_graph_input]
|
||||
def lookup_unbound_symbols(self, s: torch.SymInt) -> List[sympy.Symbol]:
|
||||
free_symbols = s.node.expr.free_symbols
|
||||
if len(free_symbols) == 0:
|
||||
return []
|
||||
|
||||
to_be_bound = []
|
||||
for s0 in free_symbols:
|
||||
if s0 not in self.bound_symbols:
|
||||
to_be_bound.append(s0)
|
||||
continue
|
||||
|
||||
proxy = self.bound_symbols[s0]
|
||||
if isinstance(proxy, LazyProxy):
|
||||
proxy = proxy()
|
||||
self.bound_symbols[s0] = proxy
|
||||
assert (
|
||||
isinstance(proxy, torch.fx.Proxy) and proxy.tracer is self
|
||||
), f"The proxy of symbol {s0} doesn't belong to current tracer."
|
||||
# Sort the symbols so that we can have a deterministic lifting order
|
||||
return sorted(to_be_bound, key=lambda s: s.name)
|
||||
|
||||
|
||||
# NOTE: [HigherOrderOperator tracing design]
|
||||
# Ignoring HigherOrderOperators for a moment,
|
||||
|
@ -382,6 +382,11 @@ def rand_strided(
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
def check_dynamic_shape_capture() -> bool:
|
||||
# This also mirrors config from `test/dynamo/test_dynamic_shapes.py:make_dynamic_cls`
|
||||
return not config.assume_static_by_default
|
||||
|
||||
|
||||
def _make_fn_with_patches(fn: Callable[..., _T], *patches: Any) -> Callable[..., _T]:
|
||||
@functools.wraps(fn)
|
||||
def _fn(*args: Any, **kwargs: Any) -> _T:
|
||||
|
@ -1702,7 +1702,6 @@ class VariableBuilder:
|
||||
|
||||
grapharg = GraphArg(source, value, False, fake_tensor_value)
|
||||
tensor_proxy.node.meta["grapharg"] = grapharg
|
||||
self.tx.output.add_symbol_bindings(grapharg)
|
||||
return tensor_variable
|
||||
|
||||
def wrap_numpy_ndarray(self, value):
|
||||
@ -2281,6 +2280,11 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
|
||||
# tensor, the stored example value will update too!)
|
||||
example_value = _clone_input(example_value, tx.fake_mode)
|
||||
set_example_value(proxy.node, example_value)
|
||||
# We bind the unbacked symints in sizes/trdies of tensor lazily.
|
||||
# So that subgraphs can access the unbacked symbol's proxy in parent graph
|
||||
# when lifting unbacked symbols of input tensors to subgraph inputs.
|
||||
# We do it lazily because the tensor may not be used in subgraphs.
|
||||
tx.output.current_tracer.track_unbacked_symbols(example_value, proxy)
|
||||
specialized_props = target_cls.specialize(example_value)
|
||||
# TODO: not sure about this fake mode test
|
||||
if (
|
||||
@ -2368,6 +2372,7 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
|
||||
elif example_value is None or proxy.node.target is torch.manual_seed:
|
||||
return ConstantVariable.create(None, **options)
|
||||
elif isinstance(example_value, (torch.SymInt, torch.SymFloat, torch.SymBool)):
|
||||
tx.output.current_tracer.track_unbacked_symbols(example_value, proxy)
|
||||
set_example_value(proxy.node, example_value)
|
||||
return SymNodeVariable(proxy, example_value, **options)
|
||||
elif (
|
||||
|
@ -8,7 +8,7 @@ import itertools
|
||||
import logging
|
||||
import types
|
||||
import warnings
|
||||
from typing import Dict, List, Optional, TYPE_CHECKING
|
||||
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
import torch._C
|
||||
import torch.fx
|
||||
@ -539,6 +539,70 @@ def speculate_subgraph(
|
||||
graph.lint()
|
||||
lifted_freevars = subtracer.lifted_freevars
|
||||
|
||||
# NOTE: [HigherOrderOperator subgraph input ordering]
|
||||
# The input ordering of the higher order ops is determined by the order of
|
||||
# the creatation of the placehoder.
|
||||
# Mannually created inputs are created in validate_args_and_maybe_create_graph_inputs before
|
||||
# speculating subgraph.
|
||||
# During subgraph speculation, we may lift closured tensors and free symbols as inputs,
|
||||
# their ordering is determined by the time they are lifted: earlier lifted ones precede later
|
||||
# lifted ones.
|
||||
#
|
||||
# Suppose the placeholders are
|
||||
# O1, O2, X1, O3, O4, X2, X3, O5 where Xs are lifted phs
|
||||
# The following code re-order the placeholders to
|
||||
# O1, O2, O3, O4, O5, X1, X2, X3
|
||||
def move_lifted_freevars_phs_to_end(
|
||||
graph: torch.fx.Graph, lifted_freevars: Tuple[torch.fx.Node]
|
||||
):
|
||||
lifted_ph_set = {
|
||||
child_p.node for child_p in lifted_freevars.values()
|
||||
}
|
||||
|
||||
prev_phs = [n for n in graph.nodes if n.op == "placeholder"]
|
||||
|
||||
# No need to reorder when graph doesn't have args or doesn't
|
||||
# have lifted freevars or all inputs are lifted freevars.
|
||||
if (
|
||||
len(prev_phs) == 0
|
||||
or len(lifted_ph_set) == 0
|
||||
or len(prev_phs) == len(lifted_ph_set)
|
||||
):
|
||||
return
|
||||
|
||||
# Step 1: find first X1
|
||||
for x1 in prev_phs:
|
||||
if x1 in lifted_ph_set:
|
||||
break
|
||||
|
||||
assert x1 is not None and x1.op == "placeholder"
|
||||
# Step 2: starting from the X1, skip Xs and prepend Os before X1.
|
||||
cand_x = x1.next
|
||||
while cand_x is not None and cand_x.op == "placeholder":
|
||||
if cand_x in lifted_ph_set:
|
||||
cand_x = cand_x.next
|
||||
else:
|
||||
nxt = cand_x.next
|
||||
cand_x._remove_from_list()
|
||||
x1.prepend(cand_x)
|
||||
cand_x = nxt
|
||||
|
||||
# Step 3: assert that all placeholders are in the correct order as .
|
||||
# in lifted_freevars
|
||||
after_phs = [
|
||||
node for node in graph.nodes if node.op == "placeholder"
|
||||
][-len(lifted_freevars) :]
|
||||
assert len(after_phs) == len(lifted_freevars)
|
||||
for child_proxy, ph in zip(lifted_freevars.values(), after_phs):
|
||||
assert (
|
||||
child_proxy.node is ph
|
||||
), "The order of placeholders is different from the order of lifted_freevars"
|
||||
|
||||
graph.lint()
|
||||
|
||||
if len(lifted_freevars) > 0:
|
||||
move_lifted_freevars_phs_to_end(graph, lifted_freevars)
|
||||
|
||||
return (
|
||||
(output, treespec),
|
||||
graph,
|
||||
@ -716,7 +780,7 @@ class CondHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
f"{operands.python_type()}",
|
||||
)
|
||||
operands_seq = operands.unpack_var_sequence(tx)
|
||||
if not only_consist_of(operands, (TensorVariable,)):
|
||||
if not only_consist_of(operands, (TensorVariable, ConstantVariable)):
|
||||
unimplemented(
|
||||
"Expect operands to be a tuple of pytrees that only consists of tensor leaves."
|
||||
)
|
||||
|
@ -28,6 +28,7 @@ from torch._higher_order_ops.utils import (
|
||||
saved_tensors_and_symints,
|
||||
unique_graph_id,
|
||||
UnsupportedAliasMutationException,
|
||||
validate_subgraph_args_types,
|
||||
)
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
@ -58,6 +59,7 @@ class CondOp(HigherOrderOperator):
|
||||
super().__init__("cond")
|
||||
|
||||
def __call__(self, pred, true_fn, false_fn, operands):
|
||||
validate_subgraph_args_types(operands)
|
||||
return super().__call__(pred, true_fn, false_fn, operands)
|
||||
|
||||
|
||||
@ -246,9 +248,6 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
|
||||
assert isinstance(
|
||||
operands, (list, tuple)
|
||||
), f"Cond operands must be a list or tuple of tensors and SymInts {operands}"
|
||||
assert all(
|
||||
isinstance(o, (torch.Tensor, torch.SymInt)) for o in operands
|
||||
), f"Cond operands must be a list of tensors and SymInts {operands}"
|
||||
|
||||
true_graph = reenter_make_fx(true_fn)(*operands)
|
||||
false_graph = reenter_make_fx(false_fn)(*operands)
|
||||
@ -366,6 +365,9 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
|
||||
|
||||
@cond_op.py_impl(DispatchKey.CompositeExplicitAutograd)
|
||||
def cond_op_dense(pred, true_fn, false_fn, operands):
|
||||
assert all(
|
||||
isinstance(o, (torch.Tensor, int)) for o in operands
|
||||
), f"Dense implementation operands must be a list of tensors and ints {operands}"
|
||||
mode = _get_current_dispatch_mode()
|
||||
assert mode is None, "Mode should never be enabled for CPU/CUDA key"
|
||||
if pred:
|
||||
|
@ -12,6 +12,7 @@ from torch._higher_order_ops.utils import (
|
||||
save_tensors_and_symints_for_backward,
|
||||
saved_tensors_and_symints,
|
||||
UnsupportedAliasMutationException,
|
||||
validate_subgraph_args_types,
|
||||
)
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._subclasses import FakeTensorMode
|
||||
@ -85,11 +86,7 @@ class FlexAttentionHOP(HigherOrderOperator):
|
||||
score_mod_other_buffers: Tuple = (),
|
||||
mask_mod_other_buffers: Tuple = (),
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if not all(
|
||||
isinstance(buf, (torch.Tensor, torch.SymInt, int))
|
||||
for buf in score_mod_other_buffers + mask_mod_other_buffers
|
||||
):
|
||||
raise RuntimeError("Other buffers must be tensors.")
|
||||
validate_subgraph_args_types(score_mod_other_buffers + mask_mod_other_buffers)
|
||||
return super().__call__(
|
||||
query,
|
||||
key,
|
||||
@ -129,11 +126,7 @@ class FlexAttentionBackwardHOP(HigherOrderOperator):
|
||||
) -> Tuple[
|
||||
torch.Tensor, torch.Tensor, torch.Tensor, Tuple[Optional[torch.Tensor], ...]
|
||||
]:
|
||||
if not all(
|
||||
isinstance(buf, torch.Tensor)
|
||||
for buf in score_mod_other_buffers + mask_mod_other_buffers
|
||||
):
|
||||
raise RuntimeError("Other buffers must be tensors.")
|
||||
validate_subgraph_args_types(score_mod_other_buffers + mask_mod_other_buffers)
|
||||
return super().__call__(
|
||||
query,
|
||||
key,
|
||||
@ -415,10 +408,6 @@ def flex_attention_functionalize(
|
||||
assert isinstance(block_mask_unwrapped, tuple)
|
||||
assert isinstance(score_mod_other_buffers_unwrapped, tuple)
|
||||
assert isinstance(mask_mod_other_buffers_unwrapped, tuple)
|
||||
assert all(
|
||||
isinstance(item, (torch.Tensor, torch.SymInt, int))
|
||||
for item in score_mod_other_buffers_unwrapped + mask_mod_other_buffers_unwrapped
|
||||
)
|
||||
|
||||
example_vals = (
|
||||
[torch.zeros((), dtype=query.dtype)]
|
||||
@ -531,12 +520,8 @@ def create_fw_bw_graph(
|
||||
unwrapped_other_buffers = pytree.tree_map(_from_fun, other_buffers)
|
||||
|
||||
assert all(
|
||||
isinstance(t, (FakeTensor, torch.SymInt, int))
|
||||
for t in unwrapped_score_mod_indexes
|
||||
)
|
||||
assert all(
|
||||
isinstance(t, (FakeTensor, torch.SymInt, int))
|
||||
for t in unwrapped_other_buffers
|
||||
isinstance(t, (FakeTensor, int, torch.SymInt))
|
||||
for t in unwrapped_score_mod_indexes + unwrapped_other_buffers
|
||||
)
|
||||
|
||||
example_flat_out = pytree.tree_map(
|
||||
@ -595,7 +580,9 @@ class FlexAttentionAutogradOp(torch.autograd.Function):
|
||||
*score_mod_other_buffers: Tuple[Any, ...],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
any_buffer_requires_grad = any(
|
||||
buffer.requires_grad for buffer in mask_mod_other_buffers
|
||||
buffer.requires_grad
|
||||
for buffer in mask_mod_other_buffers
|
||||
if isinstance(buffer, torch.Tensor)
|
||||
)
|
||||
assert (
|
||||
not any_buffer_requires_grad
|
||||
@ -777,9 +764,16 @@ def sdpa_dense_backward(
|
||||
actual_grad_query = torch.empty_like(query)
|
||||
actual_grad_key = torch.empty_like(key)
|
||||
actual_grad_value = torch.empty_like(value)
|
||||
|
||||
def _maybe_new_buffer(
|
||||
buffer: Union[torch.Tensor, torch.SymInt, int]
|
||||
) -> Optional[Union[torch.Tensor, torch.SymInt, int]]:
|
||||
if isinstance(buffer, torch.Tensor):
|
||||
return torch.empty_like(buffer) if buffer.requires_grad else None
|
||||
return buffer
|
||||
|
||||
actual_grad_score_mod_captured = [
|
||||
torch.empty_like(buffer) if buffer.requires_grad else None
|
||||
for buffer in score_mod_other_buffers
|
||||
_maybe_new_buffer(buffer) for buffer in score_mod_other_buffers
|
||||
]
|
||||
|
||||
Bq, Bkv = query.size(0), key.size(0)
|
||||
@ -883,7 +877,7 @@ def sdpa_dense_backward(
|
||||
actual_grad_key.copy_(grad_key)
|
||||
actual_grad_value.copy_(grad_value)
|
||||
score_mod_other_buffer_grads = [
|
||||
actual_grad.copy_(grad) if actual_grad is not None else actual_grad
|
||||
actual_grad.copy_(grad) if isinstance(actual_grad, torch.Tensor) else None
|
||||
for actual_grad, grad in zip(
|
||||
actual_grad_score_mod_captured, grad_score_mod_captured
|
||||
)
|
||||
@ -1072,10 +1066,6 @@ def flex_attention_backward_functionalize(
|
||||
assert isinstance(block_mask_unwrapped, tuple)
|
||||
assert isinstance(score_mod_other_buffers_unwrapped, tuple)
|
||||
assert isinstance(mask_mod_other_buffers_unwrapped, tuple)
|
||||
assert all(
|
||||
isinstance(item, torch.Tensor)
|
||||
for item in score_mod_other_buffers_unwrapped + mask_mod_other_buffers_unwrapped
|
||||
)
|
||||
|
||||
with ctx.redispatch_to_next() as m:
|
||||
functional_fw_graph = ctx.functionalize(fw_graph)
|
||||
|
@ -16,6 +16,7 @@ from torch._higher_order_ops.utils import (
|
||||
reenter_make_fx,
|
||||
unique_graph_id,
|
||||
UnsupportedAliasMutationException,
|
||||
validate_subgraph_args_types,
|
||||
)
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
@ -230,6 +231,7 @@ class ScanOp(HigherOrderOperator):
|
||||
|
||||
def __call__(self, combine_fn, init, xs, dim, reverse, additional_inputs):
|
||||
assert isinstance(additional_inputs, list), "additional_inputs must be a list."
|
||||
validate_subgraph_args_types(additional_inputs)
|
||||
return super().__call__(combine_fn, init, xs, dim, reverse, additional_inputs)
|
||||
|
||||
|
||||
@ -335,10 +337,15 @@ def trace_scan(
|
||||
reverse: bool,
|
||||
additional_inputs: List[torch.Tensor],
|
||||
):
|
||||
from torch._dynamo.utils import clone_input
|
||||
|
||||
with disable_proxy_modes_tracing():
|
||||
sample_inits = [x_init.clone() for x_init in init]
|
||||
sample_inits = [clone_input(x_init) for x_init in init]
|
||||
sample_inputs = [first_slice_copy(x, dim) for x in xs]
|
||||
sample_additional_inputs = [x.clone() for x in additional_inputs]
|
||||
sample_additional_inputs = [
|
||||
clone_input(x) if isinstance(x, torch.Tensor) else x
|
||||
for x in additional_inputs
|
||||
]
|
||||
combine_graph = reenter_make_fx(combine_fn)(
|
||||
*sample_inits, *sample_inputs, *sample_additional_inputs
|
||||
)
|
||||
|
@ -2,7 +2,7 @@
|
||||
import functools
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, List
|
||||
from typing import Any, Callable, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.fx.traceback as fx_traceback
|
||||
@ -481,3 +481,19 @@ def get_dummy_aot_autograd_config():
|
||||
# Slices off the first element of a given dimension
|
||||
def first_slice_copy(t: torch.Tensor, dim: int = 0) -> torch.Tensor:
|
||||
return torch.select_copy(t, dim, 0)
|
||||
|
||||
|
||||
# Note [lifted arg types in hop]
|
||||
# For dynamoed hops, we automatically lift the free symbols in tensors as arguments.
|
||||
# This has implications for the types of lifted args for different dispatch keys:
|
||||
# 1. functionalization, FakeTensorMode, ProxyTorchDispatchMode, Autograd need to support torch.Symint
|
||||
# lifted args because it's on the path of torch.compile(dynamic=True).
|
||||
# 2. functionalization, FakeTensorMode, ProxyTorchDispatchMode, Autograd, CompositeExplicitAutograd need
|
||||
# to support int arguments. In the eager run case, we re-trace the subgraph in AutogradKey, so inner
|
||||
# hops may receive int inputs from the shape of outer tensor inputs.
|
||||
# However, CompositeExplicitAutograd won't receive SymInt inputs because it only accepts real tensor inputs.
|
||||
def validate_subgraph_args_types(lifted_args: Union[Tuple[Any], List[Any]]):
|
||||
allowed_types = (torch.Tensor, int, torch.SymInt)
|
||||
assert all(
|
||||
isinstance(arg, (torch.Tensor, int, torch.SymInt)) for arg in lifted_args
|
||||
), f"{lifted_args} can only be of {allowed_types} but got {tuple(type(arg) for arg in lifted_args)}"
|
||||
|
@ -12,6 +12,7 @@ from torch._higher_order_ops.utils import (
|
||||
autograd_not_implemented,
|
||||
reenter_make_fx,
|
||||
UnsupportedAliasMutationException,
|
||||
validate_subgraph_args_types,
|
||||
)
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
@ -42,6 +43,7 @@ class WhileLoopOp(HigherOrderOperator):
|
||||
raise RuntimeError(
|
||||
f"additional_inputs must be a tuple, got {type(additional_inputs)}"
|
||||
)
|
||||
|
||||
if not all(
|
||||
isinstance(t, (torch.Tensor, int, float, bool)) for t in carried_inputs
|
||||
):
|
||||
@ -50,13 +52,7 @@ class WhileLoopOp(HigherOrderOperator):
|
||||
f"{carried_inputs}"
|
||||
)
|
||||
|
||||
if not all(
|
||||
isinstance(t, (torch.Tensor, int, float, bool)) for t in additional_inputs
|
||||
):
|
||||
raise RuntimeError(
|
||||
"additional_inputs must be a tuple of tensors, ints, floats, or bools, got "
|
||||
f"{additional_inputs}"
|
||||
)
|
||||
validate_subgraph_args_types(additional_inputs)
|
||||
return super().__call__(cond_fn, body_fn, carried_inputs, additional_inputs)
|
||||
|
||||
|
||||
|
@ -76,7 +76,14 @@ def create_placeholder(
|
||||
|
||||
def maybe_realize(args: List[Optional[IRNode]]):
|
||||
"""Accepts a list of optional IRNodes and returns a list of realized IRNodes"""
|
||||
return tree_map(lambda x: realize_inputs(x) if x is not None else None, args)
|
||||
return tree_map(
|
||||
lambda x: (
|
||||
realize_inputs(x)
|
||||
if x is not None and not isinstance(x, sympy.Symbol)
|
||||
else x
|
||||
),
|
||||
args,
|
||||
)
|
||||
|
||||
|
||||
def get_float32_precision():
|
||||
|
Reference in New Issue
Block a user