mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[export] Add print_readable to unflattener (#128617)"
This reverts commit 9fee87e4cd9efb55ee5427a8e6b3c57de7c599f9.
Reverted https://github.com/pytorch/pytorch/pull/128617 on behalf of https://github.com/clee2000 due to broke inductor/test_flex_attention https://github.com/pytorch/pytorch/actions/runs/9984688318/job/27595182606 433ef4e444
Not run on PR due to bad TD ([comment](https://github.com/pytorch/pytorch/pull/128617#issuecomment-2236867975))
This commit is contained in:
@ -525,14 +525,14 @@ class GraphModule(torch.nn.Module):
|
||||
autograd_function_apply: "f32[]" = torch._functorch.autograd_function.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, l_z_, l_weird_b, l_weird_c, args_tensor_mask = [True, False, True]); fwd_body_0 = bwd_body_0 = l_x_ = l_z_ = l_weird_b = l_weird_c = None
|
||||
return (autograd_function_apply,)
|
||||
|
||||
class fwd_body_0(torch.nn.Module):
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, ctx, x: "f32[]", z: "f32[]", l_weird_b: "f32[]", l_weird_c: "f32[]"):
|
||||
mul: "f32[]" = l_weird_b * l_weird_c
|
||||
clone: "f32[]" = x.clone(); x = None
|
||||
mul_1: "f32[]" = mul * clone; mul = clone = None
|
||||
return (mul_1, [l_weird_b, l_weird_c])
|
||||
|
||||
class bwd_body_0(torch.nn.Module):
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, ctx, grad: "f32[]", l_weird_b: "f32[]", l_weird_c: "f32[]"):
|
||||
_set_grad_enabled = torch._C._set_grad_enabled(False)
|
||||
|
||||
|
@ -332,7 +332,7 @@ class GraphModule(torch.nn.Module):
|
||||
getitem: "f32[]" = wrap[0]; wrap = None
|
||||
return (getitem,)
|
||||
|
||||
class wrap_body_0(torch.nn.Module):
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, l_d_x_: "f32[]", l_d_y_0_: "f32[]", l_d_y_1_2_: "f32[]"):
|
||||
sin: "f32[]" = l_d_x_.sin(); l_d_x_ = None
|
||||
cos: "f32[]" = l_d_y_0_.cos(); l_d_y_0_ = None
|
||||
@ -370,7 +370,7 @@ class GraphModule(torch.nn.Module):
|
||||
getitem: "f32[3]" = wrap[0]; wrap = None
|
||||
return (getitem,)
|
||||
|
||||
class wrap_body_0(torch.nn.Module):
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[3, 1]"):
|
||||
view: "f32[3]" = l_x_.view(3); l_x_ = None
|
||||
add: "f32[3]" = view + 0.5; view = None
|
||||
@ -390,7 +390,7 @@ class GraphModule(torch.nn.Module):
|
||||
getitem: "f32[s0]" = wrap[0]; wrap = None
|
||||
return (getitem,)
|
||||
|
||||
class wrap_body_0(torch.nn.Module):
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[s0, 1]", size: "Sym(s0)"):
|
||||
view: "f32[s0]" = l_x_.view(size); l_x_ = size = None
|
||||
add: "f32[s0]" = view + 0.5; view = None
|
||||
@ -1848,7 +1848,7 @@ class GraphModule(torch.nn.Module):
|
||||
getitem_1: "f32[3]" = wrap[1]; wrap = None
|
||||
return (getitem, getitem_1)
|
||||
|
||||
class wrap_body_0(torch.nn.Module):
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, l_arg1_0_: "f32[3]", l_arg2_0_: "f32[3]"):
|
||||
child: "f32[3]" = l_arg1_0_ + 1; l_arg1_0_ = None
|
||||
|
||||
@ -2047,7 +2047,7 @@ class GraphModule(torch.nn.Module):
|
||||
add: "f32[2, 3]" = a + b; a = b = None
|
||||
return (add,)
|
||||
|
||||
class wrap_body_0(torch.nn.Module):
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[2, 3]"):
|
||||
child: "f32[2, 3]" = l_x_.sin()
|
||||
child_1: "f32[2, 3]" = l_x_.cos(); l_x_ = None
|
||||
@ -2082,7 +2082,7 @@ class GraphModule(torch.nn.Module):
|
||||
getitem: "f32[3]" = wrap[0]; wrap = None
|
||||
return (getitem,)
|
||||
|
||||
class wrap_body_0(torch.nn.Module):
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[3]"):
|
||||
child: "f32[3]" = -l_x_; l_x_ = None
|
||||
return (child,)
|
||||
|
@ -1093,7 +1093,7 @@ class GraphModule(torch.nn.Module):
|
||||
getitem: "f32[3, 4]" = wrap[0]; wrap = None
|
||||
return (getitem,)
|
||||
|
||||
class wrap_body_0(torch.nn.Module):
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[3, 4]"):
|
||||
add_: "f32[3, 4]" = l_x_.add_(1.0); l_x_ = None
|
||||
return (add_,)
|
||||
@ -1117,7 +1117,7 @@ class GraphModule(torch.nn.Module):
|
||||
getitem = wrap[0]; wrap = None
|
||||
return (getitem,)
|
||||
|
||||
class wrap_body_0(torch.nn.Module):
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, l_x_):
|
||||
add_ = l_x_.add_(1.0); l_x_ = None
|
||||
return (add_,)
|
||||
@ -1147,7 +1147,7 @@ class GraphModule(torch.nn.Module):
|
||||
getitem = wrap[0]; wrap = None
|
||||
return (getitem,)
|
||||
|
||||
class wrap_body_0(torch.nn.Module):
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, l_x_):
|
||||
add_ = l_x_.add_(1.0); l_x_ = None
|
||||
return (add_,)
|
||||
|
@ -21,7 +21,6 @@ from torch.export.exported_program import (
|
||||
TensorArgument,
|
||||
)
|
||||
from torch.fx._symbolic_trace import is_fx_tracing
|
||||
from torch.fx.graph_module import _print_readable
|
||||
from torch.utils._pytree import GetAttrKey, SequenceKey
|
||||
|
||||
from ._remove_effect_tokens_pass import _remove_effect_tokens
|
||||
@ -134,22 +133,6 @@ class InterpreterModule(torch.nn.Module):
|
||||
if node.op == "placeholder":
|
||||
self.arg_names.append(node.target)
|
||||
|
||||
def print_readable(
|
||||
self,
|
||||
print_output=True,
|
||||
include_stride=False,
|
||||
include_device=False,
|
||||
colored=False,
|
||||
):
|
||||
return _print_readable(
|
||||
self,
|
||||
"InterpreterModule",
|
||||
print_output,
|
||||
include_stride,
|
||||
include_device,
|
||||
colored,
|
||||
)
|
||||
|
||||
|
||||
class FlatArgsAdapter(abc.ABC):
|
||||
"""
|
||||
@ -482,22 +465,6 @@ class UnflattenedModule(torch.nn.Module):
|
||||
)
|
||||
return pytree.tree_unflatten(tree_out, signature.out_spec)
|
||||
|
||||
def print_readable(
|
||||
self,
|
||||
print_output=True,
|
||||
include_stride=False,
|
||||
include_device=False,
|
||||
colored=False,
|
||||
):
|
||||
return _print_readable(
|
||||
self,
|
||||
"UnflattenedModule",
|
||||
print_output,
|
||||
include_stride,
|
||||
include_device,
|
||||
colored,
|
||||
)
|
||||
|
||||
|
||||
def unflatten(
|
||||
module: ExportedProgram, flat_args_adapter: Optional[FlatArgsAdapter] = None
|
||||
|
@ -257,51 +257,6 @@ def _assign_attr(from_obj: Any, to_module: torch.nn.Module, target: str):
|
||||
setattr(to_module, field, from_obj)
|
||||
|
||||
|
||||
def _print_readable(
|
||||
module,
|
||||
module_name,
|
||||
print_output=True,
|
||||
include_stride=False,
|
||||
include_device=False,
|
||||
colored=False,
|
||||
):
|
||||
graph = module.graph
|
||||
assert graph is not None and isinstance(graph, torch.fx.Graph), "print_readable must be used on a module with a graph"
|
||||
|
||||
verbose_python_code = graph.python_code(
|
||||
root_module="self",
|
||||
verbose=True,
|
||||
include_stride=include_stride,
|
||||
include_device=include_device,
|
||||
colored=colored,
|
||||
)
|
||||
module_code = verbose_python_code.src
|
||||
module_code = module_code.lstrip("\n")
|
||||
module_code = f"class {module_name}(torch.nn.Module):\n" + module_code
|
||||
module_code = _addindent(module_code, 4)
|
||||
|
||||
submodule_code_list = [""]
|
||||
for submodule_name, submodule in module.named_children():
|
||||
if hasattr(submodule, "graph"):
|
||||
submodule_code_list.append(
|
||||
_print_readable(
|
||||
submodule,
|
||||
submodule_name,
|
||||
print_output=False,
|
||||
include_stride=include_stride,
|
||||
include_device=include_device,
|
||||
colored=colored,
|
||||
)
|
||||
)
|
||||
submodule_code = "\n".join(submodule_code_list)
|
||||
submodule_code = _addindent(submodule_code, 4)
|
||||
|
||||
output = module_code + submodule_code
|
||||
if print_output:
|
||||
print(module_code + submodule_code)
|
||||
return output
|
||||
|
||||
|
||||
class _WrappedCall:
|
||||
def __init__(self, cls, cls_call):
|
||||
self.cls = cls
|
||||
@ -870,14 +825,30 @@ class {module_name}(torch.nn.Module):
|
||||
"""
|
||||
Return the Python code generated for current GraphModule and its children GraphModules
|
||||
"""
|
||||
return _print_readable(
|
||||
self,
|
||||
self._get_name(),
|
||||
print_output,
|
||||
include_stride,
|
||||
include_device,
|
||||
colored,
|
||||
verbose_python_code = self._graph.python_code(
|
||||
root_module="self", verbose=True, include_stride=include_stride, include_device=include_device, colored=colored
|
||||
)
|
||||
module_code = verbose_python_code.src
|
||||
module_code = module_code.lstrip("\n")
|
||||
module_code = f"class {self._get_name()}(torch.nn.Module):\n" + module_code
|
||||
module_code = _addindent(module_code, 4)
|
||||
|
||||
submodule_code_list = [""]
|
||||
for submodule in self.children():
|
||||
if isinstance(submodule, GraphModule):
|
||||
submodule_code_list.append(submodule.print_readable(
|
||||
print_output=False,
|
||||
include_stride=include_stride,
|
||||
include_device=include_device,
|
||||
colored=colored
|
||||
))
|
||||
submodule_code = "\n".join(submodule_code_list)
|
||||
submodule_code = _addindent(submodule_code, 4)
|
||||
|
||||
output = module_code + submodule_code
|
||||
if print_output:
|
||||
print(module_code + submodule_code)
|
||||
return output
|
||||
|
||||
def __str__(self) -> str:
|
||||
orig_str = super().__str__()
|
||||
|
Reference in New Issue
Block a user