Revert "[export] Add print_readable to unflattener (#128617)"

This reverts commit 9fee87e4cd9efb55ee5427a8e6b3c57de7c599f9.

Reverted https://github.com/pytorch/pytorch/pull/128617 on behalf of https://github.com/clee2000 due to broke inductor/test_flex_attention https://github.com/pytorch/pytorch/actions/runs/9984688318/job/27595182606 433ef4e444 Not run on PR due to bad TD ([comment](https://github.com/pytorch/pytorch/pull/128617#issuecomment-2236867975))
This commit is contained in:
PyTorch MergeBot
2024-07-18 15:31:51 +00:00
parent 120fdf7ee2
commit d6ae8bbf16
5 changed files with 34 additions and 96 deletions

View File

@ -525,14 +525,14 @@ class GraphModule(torch.nn.Module):
autograd_function_apply: "f32[]" = torch._functorch.autograd_function.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, l_z_, l_weird_b, l_weird_c, args_tensor_mask = [True, False, True]); fwd_body_0 = bwd_body_0 = l_x_ = l_z_ = l_weird_b = l_weird_c = None
return (autograd_function_apply,)
class fwd_body_0(torch.nn.Module):
class GraphModule(torch.nn.Module):
def forward(self, ctx, x: "f32[]", z: "f32[]", l_weird_b: "f32[]", l_weird_c: "f32[]"):
mul: "f32[]" = l_weird_b * l_weird_c
clone: "f32[]" = x.clone(); x = None
mul_1: "f32[]" = mul * clone; mul = clone = None
return (mul_1, [l_weird_b, l_weird_c])
class bwd_body_0(torch.nn.Module):
class GraphModule(torch.nn.Module):
def forward(self, ctx, grad: "f32[]", l_weird_b: "f32[]", l_weird_c: "f32[]"):
_set_grad_enabled = torch._C._set_grad_enabled(False)

View File

@ -332,7 +332,7 @@ class GraphModule(torch.nn.Module):
getitem: "f32[]" = wrap[0]; wrap = None
return (getitem,)
class wrap_body_0(torch.nn.Module):
class GraphModule(torch.nn.Module):
def forward(self, l_d_x_: "f32[]", l_d_y_0_: "f32[]", l_d_y_1_2_: "f32[]"):
sin: "f32[]" = l_d_x_.sin(); l_d_x_ = None
cos: "f32[]" = l_d_y_0_.cos(); l_d_y_0_ = None
@ -370,7 +370,7 @@ class GraphModule(torch.nn.Module):
getitem: "f32[3]" = wrap[0]; wrap = None
return (getitem,)
class wrap_body_0(torch.nn.Module):
class GraphModule(torch.nn.Module):
def forward(self, l_x_: "f32[3, 1]"):
view: "f32[3]" = l_x_.view(3); l_x_ = None
add: "f32[3]" = view + 0.5; view = None
@ -390,7 +390,7 @@ class GraphModule(torch.nn.Module):
getitem: "f32[s0]" = wrap[0]; wrap = None
return (getitem,)
class wrap_body_0(torch.nn.Module):
class GraphModule(torch.nn.Module):
def forward(self, l_x_: "f32[s0, 1]", size: "Sym(s0)"):
view: "f32[s0]" = l_x_.view(size); l_x_ = size = None
add: "f32[s0]" = view + 0.5; view = None
@ -1848,7 +1848,7 @@ class GraphModule(torch.nn.Module):
getitem_1: "f32[3]" = wrap[1]; wrap = None
return (getitem, getitem_1)
class wrap_body_0(torch.nn.Module):
class GraphModule(torch.nn.Module):
def forward(self, l_arg1_0_: "f32[3]", l_arg2_0_: "f32[3]"):
child: "f32[3]" = l_arg1_0_ + 1; l_arg1_0_ = None
@ -2047,7 +2047,7 @@ class GraphModule(torch.nn.Module):
add: "f32[2, 3]" = a + b; a = b = None
return (add,)
class wrap_body_0(torch.nn.Module):
class GraphModule(torch.nn.Module):
def forward(self, l_x_: "f32[2, 3]"):
child: "f32[2, 3]" = l_x_.sin()
child_1: "f32[2, 3]" = l_x_.cos(); l_x_ = None
@ -2082,7 +2082,7 @@ class GraphModule(torch.nn.Module):
getitem: "f32[3]" = wrap[0]; wrap = None
return (getitem,)
class wrap_body_0(torch.nn.Module):
class GraphModule(torch.nn.Module):
def forward(self, l_x_: "f32[3]"):
child: "f32[3]" = -l_x_; l_x_ = None
return (child,)

View File

@ -1093,7 +1093,7 @@ class GraphModule(torch.nn.Module):
getitem: "f32[3, 4]" = wrap[0]; wrap = None
return (getitem,)
class wrap_body_0(torch.nn.Module):
class GraphModule(torch.nn.Module):
def forward(self, l_x_: "f32[3, 4]"):
add_: "f32[3, 4]" = l_x_.add_(1.0); l_x_ = None
return (add_,)
@ -1117,7 +1117,7 @@ class GraphModule(torch.nn.Module):
getitem = wrap[0]; wrap = None
return (getitem,)
class wrap_body_0(torch.nn.Module):
class GraphModule(torch.nn.Module):
def forward(self, l_x_):
add_ = l_x_.add_(1.0); l_x_ = None
return (add_,)
@ -1147,7 +1147,7 @@ class GraphModule(torch.nn.Module):
getitem = wrap[0]; wrap = None
return (getitem,)
class wrap_body_0(torch.nn.Module):
class GraphModule(torch.nn.Module):
def forward(self, l_x_):
add_ = l_x_.add_(1.0); l_x_ = None
return (add_,)

View File

@ -21,7 +21,6 @@ from torch.export.exported_program import (
TensorArgument,
)
from torch.fx._symbolic_trace import is_fx_tracing
from torch.fx.graph_module import _print_readable
from torch.utils._pytree import GetAttrKey, SequenceKey
from ._remove_effect_tokens_pass import _remove_effect_tokens
@ -134,22 +133,6 @@ class InterpreterModule(torch.nn.Module):
if node.op == "placeholder":
self.arg_names.append(node.target)
def print_readable(
self,
print_output=True,
include_stride=False,
include_device=False,
colored=False,
):
return _print_readable(
self,
"InterpreterModule",
print_output,
include_stride,
include_device,
colored,
)
class FlatArgsAdapter(abc.ABC):
"""
@ -482,22 +465,6 @@ class UnflattenedModule(torch.nn.Module):
)
return pytree.tree_unflatten(tree_out, signature.out_spec)
def print_readable(
self,
print_output=True,
include_stride=False,
include_device=False,
colored=False,
):
return _print_readable(
self,
"UnflattenedModule",
print_output,
include_stride,
include_device,
colored,
)
def unflatten(
module: ExportedProgram, flat_args_adapter: Optional[FlatArgsAdapter] = None

View File

@ -257,51 +257,6 @@ def _assign_attr(from_obj: Any, to_module: torch.nn.Module, target: str):
setattr(to_module, field, from_obj)
def _print_readable(
module,
module_name,
print_output=True,
include_stride=False,
include_device=False,
colored=False,
):
graph = module.graph
assert graph is not None and isinstance(graph, torch.fx.Graph), "print_readable must be used on a module with a graph"
verbose_python_code = graph.python_code(
root_module="self",
verbose=True,
include_stride=include_stride,
include_device=include_device,
colored=colored,
)
module_code = verbose_python_code.src
module_code = module_code.lstrip("\n")
module_code = f"class {module_name}(torch.nn.Module):\n" + module_code
module_code = _addindent(module_code, 4)
submodule_code_list = [""]
for submodule_name, submodule in module.named_children():
if hasattr(submodule, "graph"):
submodule_code_list.append(
_print_readable(
submodule,
submodule_name,
print_output=False,
include_stride=include_stride,
include_device=include_device,
colored=colored,
)
)
submodule_code = "\n".join(submodule_code_list)
submodule_code = _addindent(submodule_code, 4)
output = module_code + submodule_code
if print_output:
print(module_code + submodule_code)
return output
class _WrappedCall:
def __init__(self, cls, cls_call):
self.cls = cls
@ -870,14 +825,30 @@ class {module_name}(torch.nn.Module):
"""
Return the Python code generated for current GraphModule and its children GraphModules
"""
return _print_readable(
self,
self._get_name(),
print_output,
include_stride,
include_device,
colored,
verbose_python_code = self._graph.python_code(
root_module="self", verbose=True, include_stride=include_stride, include_device=include_device, colored=colored
)
module_code = verbose_python_code.src
module_code = module_code.lstrip("\n")
module_code = f"class {self._get_name()}(torch.nn.Module):\n" + module_code
module_code = _addindent(module_code, 4)
submodule_code_list = [""]
for submodule in self.children():
if isinstance(submodule, GraphModule):
submodule_code_list.append(submodule.print_readable(
print_output=False,
include_stride=include_stride,
include_device=include_device,
colored=colored
))
submodule_code = "\n".join(submodule_code_list)
submodule_code = _addindent(submodule_code, 4)
output = module_code + submodule_code
if print_output:
print(module_code + submodule_code)
return output
def __str__(self) -> str:
orig_str = super().__str__()