support condition branch in ao debug handler (#141516)

This diff introduced the supportive of condition statement into ao debug handler generation.

Most of code borrowed from ExecuTorch to avoid circle dependency issue.

Differential Revision: [D66270691](https://our.internmc.facebook.com/intern/diff/D66270691/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141516
Approved by: https://github.com/jerryzh168
This commit is contained in:
gasoonjia
2024-12-09 20:44:43 -08:00
committed by PyTorch MergeBot
parent 75530885ba
commit ff059587c6
4 changed files with 139 additions and 26 deletions

View File

@ -4,6 +4,8 @@ r"""Importing this file includes common utility methods and base clases for
checking quantization api and properties of resulting modules.
"""
from functorch.experimental import control_flow
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -2677,6 +2679,44 @@ class SparseNNModel(nn.Module):
return out
class TestHelperModules:
class ControlFlow(torch.nn.Module):
def forward(
self,
xs: torch.Tensor,
pred1: torch.Tensor,
pred2: torch.Tensor,
y: torch.Tensor,
) -> torch.Tensor:
def true_nested(y: torch.Tensor) -> torch.Tensor:
y = y + y
y = torch.mm(y, y)
return y
def false_nested(y: torch.Tensor) -> torch.Tensor:
return torch.mm(y, y)
def true_fn(x: torch.Tensor, pred2: torch.Tensor) -> torch.Tensor:
z = control_flow.cond(pred2, true_nested, false_nested, [x])
return x + z
def false_fn(x: torch.Tensor, _) -> torch.Tensor:
return x.cos()
def map_fn(
x: torch.Tensor, pred1: torch.Tensor, pred2: torch.Tensor, y: torch.Tensor
) -> torch.Tensor:
x = x.cos()
y = control_flow.cond(pred1, true_fn, false_fn, [y, pred2])
x = x + y
return x.sin()
y = torch.mm(y, y)
return control_flow.map(map_fn, xs, pred1, pred2, y)
def example_inputs(self):
return (torch.ones(2, 2), torch.tensor([False]), torch.tensor([False]), torch.ones(2, 2),)
class Conv2dPropAnnotaton(torch.nn.Module):
def __init__(self) -> None:
super().__init__()