Compare commits

...

1 Commits

10 changed files with 379 additions and 85 deletions

View File

@ -442,13 +442,22 @@ class DTensorExportTest(TestCase):
# Run model to verify it works
output = model(*inputs)
with torch._dynamo.config.patch(install_free_tensors=True):
with torch._dynamo.config.patch(
install_free_tensors=(export_fn is _dynamo_graph_capture_for_export)
):
# TODO: switch to use the official graph_capture API once it is ready
gm = export_fn(model)(*inputs)
output_gm = gm(*inputs)
self.assertEqual(output, output_gm)
def test_flex_attention_dtensor_export(self):
@parametrize(
"export_fn",
[
graph_capture_and_aot_export_joint_with_descriptors_v2,
graph_capture_and_aot_export_joint_with_descriptors,
],
)
def test_flex_attention_dtensor_export(self, export_fn):
device_mesh = init_device_mesh(self.device_type, mesh_shape=(self.world_size,))
model = FlexAttentionModel(self.device_type)
@ -485,9 +494,7 @@ class DTensorExportTest(TestCase):
flex_kwargs = {"block_mask": block_mask}
joint_gm = graph_capture_and_aot_export_joint_with_descriptors(
tp_model, inputs, flex_kwargs
)
joint_gm = export_fn(tp_model, inputs, flex_kwargs)
self.assertTrue(
_count_op(joint_gm, torch.ops.higher_order.flex_attention),

View File

@ -3,16 +3,19 @@
import copy
import types
import unittest
from dataclasses import dataclass
from typing import Dict, List, Tuple
import torch
import torch._dynamo
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
from torch._dynamo.test_case import run_tests, TestCase
from torch._functorch.aot_autograd import aot_export_module
from torch.export import export
from torch.export.experimental import _export_forward_backward, _sticky_export
from torch.export.graph_signature import OutputKind
from torch.testing import FileCheck
from torch.testing._internal.common_utils import TEST_CUDA
@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't supported")
@ -403,8 +406,6 @@ def forward(self, x):
self.assertEqual(res_export, res_eager)
def test_dynamo_graph_capture(self):
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
class Foo(torch.nn.Module):
def forward(self, dct, lst, bleh):
x = dct["a"] * lst[1][0]
@ -439,6 +440,151 @@ def forward(self, x):
test_inputs = make_inputs()
self.assertEqual(gm(*test_inputs), foo(*test_inputs))
def test_dynamo_graph_capture_custom_pytree_type(self):
import torch.utils._pytree as pytree
@dataclass
class Bar:
x: torch.Tensor
y: torch.Tensor
class Foo(torch.nn.Module):
def forward(self, bar: Bar):
return bar.x + bar.y
foo = Foo()
def make_inputs():
return (Bar(torch.randn(2, 3), torch.randn(2, 3)),)
pytree.register_dataclass(Bar)
try:
trace_inputs = make_inputs()
gm = dynamo_graph_capture_for_export(foo)(*trace_inputs)
test_inputs = make_inputs()
self.assertExpectedInline(
gm._in_shuffle_graph.code.strip("\r\n "),
"""\
def forward(self, arg0_1, arg1_1, arg2_1):
return (arg1_1, arg2_1)""",
)
self.assertExpectedInline(
gm.code.strip("\r\n "),
"""\
def forward(self, args_0):
_tree_leaf_0, _tree_leaf_1, _tree_leaf_2, = pytree.tree_leaves((self, args_0,))
L_bar_x , L_bar_y , = self._in_shuffle_graph(_tree_leaf_0, _tree_leaf_1, _tree_leaf_2)
l_bar_x = L_bar_x
l_bar_y = L_bar_y
add = l_bar_x + l_bar_y; l_bar_x = l_bar_y = None
return pytree.tree_unflatten(self._out_shuffle_graph(_tree_leaf_0, _tree_leaf_1, _tree_leaf_2, add), self._out_spec)""",
)
self.assertEqual(gm(*test_inputs), foo(*test_inputs))
finally:
pytree._deregister_pytree_node(Bar)
def test_dynamo_graph_capture_closure(self):
from torch.export import Dim
N = 3
outer = torch.randn(10, 32)
class MyModel(torch.nn.Module):
def forward(self, x):
z = x + outer
y = z[:-1, :] # [s0 - 1, 32]
stacked = torch.stack([y] * N, dim=0) # [N * (s0 - 1), 32]
reshaped = stacked.reshape(-1, N, 32) # [(s0 - 1), N, 32]
return reshaped
inps = (torch.randn(10, 32),)
ep = dynamo_graph_capture_for_export(MyModel())(*inps)
self.assertExpectedInline(
ep._in_shuffle_graph.code.strip("\r\n "),
"""\
def forward(self, arg0_1, arg1_1):
_tensor_constant0 = self._tensor_constant0
return (arg1_1, _tensor_constant0)""",
)
self.assertExpectedInline(
ep.code.strip("\r\n "),
"""\
def forward(self, args_0):
_tree_leaf_0, _tree_leaf_1, = pytree.tree_leaves((self, args_0,))
L_x_ , L_outer_ , = self._in_shuffle_graph(_tree_leaf_0, _tree_leaf_1)
l_x_ = L_x_
l_outer_ = L_outer_
z = l_x_ + l_outer_; l_x_ = l_outer_ = None
y = z[(slice(None, -1, None), slice(None, None, None))]; z = None
stacked = torch.stack([y, y, y], dim = 0); y = None
reshaped = stacked.reshape(-1, 3, 32); stacked = None
return pytree.tree_unflatten(self._out_shuffle_graph(_tree_leaf_0, _tree_leaf_1, reshaped), self._out_spec)""",
)
self.assertEqual(ep(*inps), MyModel()(*inps))
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
def test_dynamo_graph_capture_fx_graph_annotate_overlap_pass(self):
class DummyOp(torch.autograd.Function):
@staticmethod
def forward(ctx, x, scalar):
ctx.save_for_backward(x)
return x + scalar
@staticmethod
def backward(ctx, grad_out):
return grad_out, None
def mock_fw_compute(x):
with fx_traceback.annotate({"compute": 0}):
return DummyOp.apply(x, 10)
def mock_bw_comm(x):
with fx_traceback.annotate({"comm": 0}):
return DummyOp.apply(x, 20)
def mock_bw_compute(x):
return DummyOp.apply(x, 30)
class Model(torch.nn.Module):
def forward(self, fw_in, bw_in):
fw_out = mock_fw_compute(fw_in)
# bw_in blocks bw_out
bw_in = mock_bw_comm(bw_in)
bw_out = mock_bw_compute(bw_in)
return fw_out, bw_out
def input_fn():
inputs = (torch.rand(2, 128, device="cuda", requires_grad=True),)
grad_ins = (torch.rand(2, 128, device="cuda"),)
return (
*inputs,
*grad_ins,
)
with torch.device("meta"):
model = Model()
import torch.fx.traceback as fx_traceback
with fx_traceback.preserve_node_meta():
gm = dynamo_graph_capture_for_export(model)(*input_fn())
"""
def forward(self, args_0, args_1):
_tree_leaf_0, _tree_leaf_1, _tree_leaf_2, = pytree.tree_leaves((self, args_0, args_1,))
L_fw_in_ , L_bw_in_ , = self._in_shuffle_graph(_tree_leaf_0, _tree_leaf_1, _tree_leaf_2)
l_fw_in_ = L_fw_in_
l_bw_in_ = L_bw_in_
fwd_body_0 = self.fwd_body_0
bwd_body_0 = self.bwd_body_0
fw_out = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_fw_in_, args_tensor_mask = [True, False], non_differentiable_idx = []); fwd_body_0 = bwd_body_0 = l_fw_in_ = None
bw_in = l_bw_in_ + 20; l_bw_in_ = None
bw_out = bw_in + 30; bw_in = None
return pytree.tree_unflatten(self._out_shuffle_graph(_tree_leaf_0, _tree_leaf_1, _tree_leaf_2, fw_out, bw_out), self._out_spec)
"""
test_inputs = input_fn()
self.assertEqual(gm(*test_inputs), model(*test_inputs))
if __name__ == "__main__":
run_tests()

View File

@ -896,6 +896,7 @@ class DynamoOutput:
output_graph.import_sources,
output_graph.traced_code,
self.bytecode,
self.tracer_output.closure,
)
@ -927,6 +928,7 @@ class GraphCaptureOutput:
import_sources: dict[str, str]
traced_code: list[CodeType]
bytecode: CodeType
closure: Optional[tuple[Any, ...]]
def build_guards(
self,
@ -981,7 +983,7 @@ class CaptureOutput:
return types.FunctionType(
self.graph_capture_output.bytecode,
f_globals,
closure=(),
closure=self.graph_capture_output.closure,
)

View File

@ -1211,7 +1211,9 @@ class _NullDecorator(contextlib.nullcontext): # type: ignore[type-arg]
# Make dynamo graph to have same input/output spec as user code
def argument_names(
f_sig: inspect.Signature, args: list[Any], kwargs: dict[str, Any]
f_sig: inspect.Signature,
args: Union[list[Any], tuple[Any, ...]],
kwargs: dict[str, Any],
) -> list[str]:
def signature_to_fullargspec(sig: inspect.Signature) -> inspect.FullArgSpec:
# Get a list of Parameter objects from the Signature object

View File

@ -1,10 +1,9 @@
import copy
import inspect
import logging
import traceback
import types
from collections import namedtuple
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any, Optional, TYPE_CHECKING, Union
import sympy
@ -19,17 +18,19 @@ from torch._dynamo.utils import dynamo_timed, get_metrics_context
from torch._export.utils import _compiling_state_context
from torch.export.dynamic_shapes import _RelaxedConstraint, Constraint
from torch.fx import Node
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.symbolic_shapes import (
ConstraintViolationError,
DimDynamic,
ShapeEnv,
StatelessSymbolicContext,
)
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
from torch.fx.graph import _ExportCodeGen, _PyTreeCodeGen, _PyTreeInfo
from torch.utils._pytree import TreeSpec
if TYPE_CHECKING:
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.utils._pytree import TreeSpec
log = logging.getLogger(__name__)
@ -448,9 +449,20 @@ def _suggest_or_raise_constraint_violation(
raise constraint_violation_error
@dataclass(frozen=True)
class PyTreeifyOutput:
graph_module: torch.fx.GraphModule
in_spec: TreeSpec
in_shuffle_graph: torch.fx.GraphModule
num_flat_args: int
out_spec: TreeSpec
out_shuffle_graph: torch.fx.GraphModule
root: Optional[torch.nn.Module] = None
def pytreeify(
out: CaptureOutput, mod: Any, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> Any:
) -> PyTreeifyOutput:
"""
Given a dynamo capture output, return a callable graph module that
contain the following information:
@ -468,10 +480,13 @@ def pytreeify(
backend_input = out.backend_input
backend = out.backend_input.graph_module
root = None
if isinstance(mod, torch.nn.Module):
args = (mod,) + args
root = mod
elif inspect.ismethod(mod):
args = (mod.__self__,) + args
root = mod.__self__
flat_real_args, in_spec = pytree.tree_flatten((args, kwargs))
@ -504,15 +519,21 @@ def pytreeify(
backend_input.graph_module = backend
raise RuntimeError
in_shuffle_graph = torch.fx.symbolic_trace(InShuffle())
fake_mode = torch._dynamo.utils.detect_fake_mode(flat_real_args)
if fake_mode and fake_mode.shape_env is None:
fake_mode.shape_env = ShapeEnv()
in_shuffle_graph = make_fx(
InShuffle(), tracing_mode="symbolic", proxy_module_inputs=True
)(*flat_real_args)
output_node = next(iter(reversed(backend_input.graph_module.graph.nodes)))
class OutShuffle(torch.nn.Module):
def __init__(self):
super().__init__()
self.num_inputs = len(flat_real_args)
self.num_outputs = len(
next(iter(reversed(backend_input.graph_module.graph.nodes))).args[0]
)
self.num_outputs = len(output_node.args[0])
self.out_spec: Optional[TreeSpec] = None
def forward(self, *flat_proxy_args):
@ -535,49 +556,101 @@ def pytreeify(
return ret
out_shuffle = OutShuffle()
out_shuffle_graph = torch.fx.symbolic_trace(out_shuffle)
flat_out_shuffle_args = [
*flat_real_args,
*pytree.tree_map_only(
torch.fx.Node,
lambda x: fake_mode.from_tensor(x.meta["example_value"])
if fake_mode
else x.meta["example_value"],
output_node.args[0],
),
]
fake_mode = torch._dynamo.utils.detect_fake_mode(flat_out_shuffle_args)
if fake_mode and fake_mode.shape_env is None:
fake_mode.shape_env = ShapeEnv()
out_shuffle_graph = make_fx(
out_shuffle, tracing_mode="symbolic", proxy_module_inputs=True
)(*flat_out_shuffle_args)
def pytree_call(*args, **kwargs):
import torch.export._unlift
assert out_shuffle.out_spec is not None
return PyTreeifyOutput(
backend_input.graph_module,
in_spec,
in_shuffle_graph,
len(flat_real_args),
out_shuffle.out_spec,
out_shuffle_graph,
root=root, # type: ignore[arg-type]
)
flat_args, in_spec_runtime = pytree.tree_flatten((args, kwargs))
if not torch.export._unlift.eq_spec(in_spec_runtime, in_spec):
raise RuntimeError(
f"Model input mismatch. Expected input spec: {in_spec}. Actual input spec: {in_spec_runtime}"
)
flat_outs = backend_input.graph_module(*in_shuffle_graph(*flat_args))
assert out_shuffle.out_spec is not None
return pytree.tree_unflatten(
out_shuffle_graph(*flat_args, *flat_outs), out_shuffle.out_spec
)
if isinstance(mod, torch.nn.Module):
compiled_mod = copy.copy(mod)
compiled_mod.forward = types.MethodType(pytree_call, compiled_mod)
if not hasattr(compiled_mod, "meta"):
compiled_mod.meta = {} # type: ignore[attr-defined]
if isinstance(compiled_mod.meta, dict) and "fake_mode" not in compiled_mod.meta:
compiled_mod.meta["fake_mode"] = out.backend_input.fake_mode
return compiled_mod
elif inspect.ismethod(mod):
return types.MethodType(pytree_call, mod.__self__)
else:
return pytree_call
def normalize_graph_module(gm):
for node in gm.graph.nodes:
if node.op == "placeholder":
node.meta["val"] = node.meta["example_value"]
def dynamo_graph_capture_for_export(
mod: Callable[..., Any],
constraints: Optional[list[Constraint]] = None,
) -> Callable[..., Any]:
def inner(*args: Any, **kwargs: Any) -> Any:
assert not torch._dynamo.config.install_free_tensors
with (
get_metrics_context(),
dynamo_timed("fullgraph_capture"),
):
out = fullgraph_capture(mod, args, kwargs)
out = fullgraph_capture(
mod,
args,
kwargs,
constraints=constraints,
)
# TODO filter out side effects.
pyt = pytreeify(out, mod, args, kwargs)
return pytreeify(out, mod, args, kwargs)
graph_module = pyt.graph_module
tree_leaf_names = [
graph_module.graph._graph_namespace.create_name(f"_tree_leaf_{i}", None)
for i in range(pyt.num_flat_args)
]
graph_module.graph._codegen = _ExportCodeGen(
_PyTreeInfo(
# TODO we should be able to use the names from dynamo graph directly.
argument_names(inspect.signature(mod), args, kwargs),
pyt.in_spec,
pyt.out_spec,
),
pyt.in_shuffle_graph,
pyt.out_shuffle_graph,
tree_leaf_names,
pyt.root,
) # type: ignore[attr-defined]
normalize_graph_module(graph_module)
if pyt.root is not None:
graph_module._parameters = pyt.root._parameters.copy()
graph_module._buffers = pyt.root._buffers.copy()
assert all(not hasattr(graph_module, m) for m in pyt.root._modules)
graph_module._modules.update(pyt.root._modules)
graph_module._non_persistent_buffers_set = (
pyt.root._non_persistent_buffers_set.copy()
)
graph_module._in_spec = pyt.in_spec
graph_module._out_spec = pyt.out_spec
assert not hasattr(graph_module, "_in_shuffle_graph")
assert not hasattr(graph_module, "_out_shuffle_graph")
graph_module._in_shuffle_graph = pyt.in_shuffle_graph
graph_module._out_shuffle_graph = pyt.out_shuffle_graph
delattr(graph_module, "_param_name_to_source")
graph_module.recompile()
graph_module.meta["module_call_specs"] = (
out.graph_capture_output.output_graph.export_metadata.module_call_spec
)
assert out.backend_input is not None
graph_module.meta["fake_mode"] = out.backend_input.fake_mode # type: ignore[attr-defined]
return graph_module
return inner

View File

@ -2747,12 +2747,14 @@ class DynamoTracerOutput:
error_on_graph_break: bool
is_tracing_resume_prologue: bool
output_graph: Optional[OutputGraph]
closure: Optional[tuple[Any, ...]]
def __init__(
self, tracer: "InstructionTranslatorBase", error: Optional[Any] = None
) -> None:
self.error_on_graph_break = tracer.error_on_graph_break
self.is_tracing_resume_prologue = tracer.is_tracing_resume_prologue
self.closure = tracer.closure
if error:
self.output_graph = None
else:

View File

@ -4120,6 +4120,7 @@ class InstructionTranslatorBase(
self.f_builtins: dict[str, Any] = f_builtins
self.code_options: dict[str, Any] = code_options
self.f_code: types.CodeType = f_code
self.closure = closure
# Execution record for replaying errors
if closure is not None and config.replay_record_enabled:

View File

@ -97,7 +97,7 @@ from torch.fx.experimental.symbolic_shapes import (
GuardOnDataDependentSymNode,
ShapeEnv,
)
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
from torch.fx.graph import _PyTreeInfo
from torch.utils._pytree import TreeSpec
from torch.utils._sympy.value_ranges import ValueRangeError
@ -1537,12 +1537,10 @@ def _strict_export(
orig_arg_names = gm_torch_level.graph._codegen.pytree_info.orig_args # type: ignore[attr-defined]
gm_torch_level.graph._codegen = _PyTreeCodeGen(
_PyTreeInfo(
orig_arg_names,
gm_torch_level._in_spec,
out_spec,
)
gm_torch_level.graph._codegen.pytree_info = _PyTreeInfo(
orig_arg_names,
gm_torch_level._in_spec,
out_spec,
)
gm_torch_level.recompile()

View File

@ -1489,12 +1489,20 @@ def wrap_key(
@functools.wraps(f)
def wrapped(*proxies: _P.args, **_unused: _P.kwargs) -> R:
nonlocal tensors
flat_proxies, _proxies_spec = pytree.tree_flatten(proxies)
assert len(flat_proxies) == len(flat_tensors)
with disable_proxy_modes_tracing() as m:
assert isinstance(m, ProxyTorchDispatchMode)
track_tensor_tree(flat_tensors, flat_proxies, constant=None, tracer=tracer)
if getattr(tracer, "proxy_module_inputs", False):
tensors = [ # type: ignore[assignment, var-annotated]
p if isinstance(t, torch.nn.Module) else t
for t, p in zip(tensors, proxies) # type: ignore[arg-type]
]
def get_tensor_proxy_slot(t: Tensor) -> Union[Tensor, Proxy]:
return get_proxy_slot(t, tracer, t, lambda x: x.proxy) # type: ignore[attr-defined]
@ -2208,6 +2216,7 @@ class _MakefxTracer:
_error_on_data_dependent_ops: bool,
record_stack_traces: bool = False,
parent_tracer: Optional[_MakefxTracer] = None,
proxy_module_inputs: bool = False,
) -> None:
# Configurations that are used to initialize the context managers and their states.
# Should not modify them during tracing.
@ -2240,6 +2249,7 @@ class _MakefxTracer:
)
self.record_stack_traces = record_stack_traces
self.parent_tracer: Optional[_MakefxTracer] = parent_tracer
self.proxy_module_inputs = proxy_module_inputs
def _checkpoint_modes(self) -> list[Any]:
return [
@ -2349,6 +2359,7 @@ class _MakefxTracer:
self.python_dispatcher_mode = enable_python_dispatcher()
self.torch_fn_metadata_mode = TorchFunctionMetadataMode(fx_tracer)
fx_tracer.proxy_module_inputs = self.proxy_module_inputs # type: ignore[union-attr]
@contextmanager
def _init_modes_from_parent(
@ -2551,6 +2562,7 @@ def make_fx(
_allow_fake_constant: bool = False,
_error_on_data_dependent_ops: bool = True,
record_stack_traces: bool = False,
proxy_module_inputs: bool = False,
) -> Callable[..., GraphModule]:
"""
Given a function f, return a new function which when executed with valid
@ -2574,6 +2586,7 @@ def make_fx(
_error_on_data_dependent_ops,
record_stack_traces=record_stack_traces
or config.trace.provenance_tracking_level == 1,
proxy_module_inputs=proxy_module_inputs,
)
@functools.wraps(f)

View File

@ -930,6 +930,42 @@ class _PyTreeCodeGen(CodeGen):
else:
return "\n " + "".join(x + "; " for x in has_annotation) + "\n"
def gen_var_bindings(self, fn_args, free_vars, expanded_def) -> str:
# when kwargs is present, in_spec is tuple(args, kwargs)
has_args_kwargs_tuple = (
self.pytree_info.in_spec.type is tuple
and self.pytree_info.in_spec.num_children == 2
and self.pytree_info.in_spec.children_specs[0].type is tuple
and self.pytree_info.in_spec.children_specs[1].type is dict
)
fn_kwargs = "{}"
fn_signature = f"[{', '.join(fn_args)}], self._in_spec"
if has_args_kwargs_tuple:
count_args = self.pytree_info.in_spec.children_specs[0].num_children
fn_args = self.pytree_info.orig_args[:count_args]
fn_kwargs = (
"{"
+ ", ".join(
f"'{k}':{v}"
for k, v in zip(
self.pytree_info.in_spec.children_specs[1].context,
self.pytree_info.orig_args[count_args:],
)
)
+ "}"
)
fn_signature = f"([{', '.join(fn_args)}], {fn_kwargs}), self._in_spec"
# in Python, `var1: annotation1, var2: annotation2 = function_call()` is invalid.
# we need to split it to two lines:
# one for annotation: `var1: annotation1; var2: annotation2;` (note the semicolon)
# one for code: `var1, var2, = function_call()`
without_annotation = [x.split(":")[0].split("#")[0] for x in free_vars]
bindings = self._format_annotations(free_vars, expanded_def)
bindings += f"""
{", ".join(without_annotation)}, = fx_pytree.tree_flatten_spec({fn_signature})"""
return bindings
def gen_fn_def(
self, free_vars, maybe_return_annotation, *, expanded_def: bool = False
):
@ -962,39 +998,7 @@ class _PyTreeCodeGen(CodeGen):
)
if len(free_vars) > 0: # pytree has placeholders in it
# when kwargs is present, in_spec is tuple(args, kwargs)
has_args_kwargs_tuple = (
self.pytree_info.in_spec.type is tuple
and self.pytree_info.in_spec.num_children == 2
and self.pytree_info.in_spec.children_specs[0].type is tuple
and self.pytree_info.in_spec.children_specs[1].type is dict
)
fn_kwargs = "{}"
fn_signature = f"[{', '.join(fn_args)}], self._in_spec"
if has_args_kwargs_tuple:
count_args = self.pytree_info.in_spec.children_specs[0].num_children
fn_args = self.pytree_info.orig_args[:count_args]
fn_kwargs = (
"{"
+ ", ".join(
f"'{k}':{v}"
for k, v in zip(
self.pytree_info.in_spec.children_specs[1].context,
self.pytree_info.orig_args[count_args:],
)
)
+ "}"
)
fn_signature = f"([{', '.join(fn_args)}], {fn_kwargs}), self._in_spec"
# in Python, `var1: annotation1, var2: annotation2 = function_call()` is invalid.
# we need to split it to two lines:
# one for annotation: `var1: annotation1; var2: annotation2;` (note the semicolon)
# one for code: `var1, var2, = function_call()`
without_annotation = [x.split(":")[0].split("#")[0] for x in free_vars]
fn_definition += self._format_annotations(free_vars, expanded_def)
fn_definition += f"""
{", ".join(without_annotation)}, = fx_pytree.tree_flatten_spec({fn_signature})"""
fn_definition += self.gen_var_bindings(fn_args, free_vars, expanded_def)
return fn_definition
def generate_output(self, output_args, *, descs: Optional[Any] = None):
@ -1014,6 +1018,52 @@ class _PyTreeCodeGen(CodeGen):
return super().generate_output(output_args, descs=descs)
class _ExportCodeGen(_PyTreeCodeGen):
def __init__(
self,
pytree_info: _PyTreeInfo,
in_shuffle_graph: "GraphModule",
out_shuffle_graph: "GraphModule",
tree_leaf_names: list[str],
root: Optional[torch.nn.Module],
):
super().__init__(pytree_info)
self.in_shuffle_graph = in_shuffle_graph
self.out_shuffle_graph = out_shuffle_graph
self.tree_leaf_names = tree_leaf_names
self.root = root
def process_inputs(self, *inputs: Any) -> Any:
flat_args = super().process_inputs(*inputs)
if self.root is not None:
flat_args = (self.root, *flat_args)
self.flat_args = flat_args
return self.in_shuffle_graph(*flat_args)
def process_outputs(self, out: Any) -> Any:
flat_outs = self.out_shuffle_graph(*self.flat_args, *out)
del self.flat_args
ret = super().process_outputs(flat_outs)
return ret
def gen_fn_def(self, *args, **kwargs) -> str:
fn_def = super().gen_fn_def(*args, **kwargs)
return fn_def
def gen_var_bindings(self, fn_args, free_vars, expanded_def) -> str:
without_annotation = [x.split(":")[0].split("#")[0] for x in free_vars]
fn_signature: str = f"{', '.join(fn_args)}"
if self.root is not None:
fn_signature = f"self, {fn_signature}"
return f"""
{", ".join(self.tree_leaf_names)}, = pytree.tree_leaves(({fn_signature},))
{", ".join(without_annotation)}, = self._in_shuffle_graph({", ".join(self.tree_leaf_names)})"""
def generate_output(self, output_args, *args, **kwargs) -> str:
output = f"self._out_shuffle_graph({', '.join(self.tree_leaf_names)}, {', '.join([str(a) for a in output_args])})"
return f"return pytree.tree_unflatten({output}, self._out_spec)"
class _FindNodesLookupTable:
"""
Side table for the graph for the purpose of doing fast queries