mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
support condition branch in ao debug handler (#141516)
This diff introduced the supportive of condition statement into ao debug handler generation. Most of code borrowed from ExecuTorch to avoid circle dependency issue. Differential Revision: [D66270691](https://our.internmc.facebook.com/intern/diff/D66270691/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/141516 Approved by: https://github.com/jerryzh168
This commit is contained in:
committed by
PyTorch MergeBot
parent
75530885ba
commit
ff059587c6
@ -4,6 +4,8 @@ r"""Importing this file includes common utility methods and base clases for
|
||||
checking quantization api and properties of resulting modules.
|
||||
"""
|
||||
|
||||
from functorch.experimental import control_flow
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@ -2677,6 +2679,44 @@ class SparseNNModel(nn.Module):
|
||||
return out
|
||||
|
||||
class TestHelperModules:
|
||||
class ControlFlow(torch.nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
xs: torch.Tensor,
|
||||
pred1: torch.Tensor,
|
||||
pred2: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
|
||||
def true_nested(y: torch.Tensor) -> torch.Tensor:
|
||||
y = y + y
|
||||
y = torch.mm(y, y)
|
||||
return y
|
||||
|
||||
def false_nested(y: torch.Tensor) -> torch.Tensor:
|
||||
return torch.mm(y, y)
|
||||
|
||||
def true_fn(x: torch.Tensor, pred2: torch.Tensor) -> torch.Tensor:
|
||||
z = control_flow.cond(pred2, true_nested, false_nested, [x])
|
||||
return x + z
|
||||
|
||||
def false_fn(x: torch.Tensor, _) -> torch.Tensor:
|
||||
return x.cos()
|
||||
|
||||
def map_fn(
|
||||
x: torch.Tensor, pred1: torch.Tensor, pred2: torch.Tensor, y: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
x = x.cos()
|
||||
y = control_flow.cond(pred1, true_fn, false_fn, [y, pred2])
|
||||
x = x + y
|
||||
return x.sin()
|
||||
|
||||
y = torch.mm(y, y)
|
||||
return control_flow.map(map_fn, xs, pred1, pred2, y)
|
||||
|
||||
def example_inputs(self):
|
||||
return (torch.ones(2, 2), torch.tensor([False]), torch.tensor([False]), torch.ones(2, 2),)
|
||||
|
||||
class Conv2dPropAnnotaton(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
Reference in New Issue
Block a user