Compare commits

...

5 Commits

Author SHA1 Message Date
1d12665c70 Ci with mkdlnn op back 2025-08-22 09:43:55 -07:00
28d71df43f Pass the test with disable higher mode 2025-08-21 16:20:18 -07:00
14577d73cd improving support 2025-08-21 15:44:55 -07:00
afb84da3e3 Fix equality 2025-08-21 11:58:04 -07:00
7316a05024 Support Decomp 2025-08-19 14:51:30 -07:00
4 changed files with 122 additions and 10 deletions

View File

@ -675,11 +675,20 @@ class TestDecomp(TestCase):
)
m = module_cls(*args, **kwargs)
m.to(device).to(dtype)
print(m)
args, kwargs = (
module_input.forward_input.args,
module_input.forward_input.kwargs,
)
for arg in args:
if type(arg) is tuple:
for idx, subarg in enumerate(arg):
print(f"subarg {idx} is {subarg} with shape {subarg.shape}")
else:
print(f"arg is: {arg} and shape: {arg.shape}")
print(kwargs)
non_decomp_out = m(*args, **kwargs)
print(non_decomp_out)
with (
self.DecompCrossRefMode(
self, self.precision, self.rel_tol, dtype, run_all=True
@ -688,10 +697,10 @@ class TestDecomp(TestCase):
):
decomp_out = m(*args, **kwargs)
non_decomp_out = m(*args, **kwargs)
# without this check, incorrect decomps at the python dispatcher level can still pass because
# they're checking aten decomps at the torch_dispatch level
self.assertEqual(decomp_out, non_decomp_out)
print("Passed this first one!")
def test_batch_norm_unflatten_weight_bias(self, device):
# https://github.com/pytorch/pytorch/issues/100970
@ -778,6 +787,7 @@ def forward(self, scores_1, mask_1, value_1):
class DecompCrossRefMode(TorchDispatchMode):
def __init__(self, test_case, saved_precision, saved_rel_tol, dtype, run_all):
self.supports_higher_order_operators = True
self.test_case = test_case
self.saved_precision = saved_precision
self.saved_rel_tol = saved_rel_tol
@ -972,7 +982,11 @@ def forward(self, scores_1, mask_1, value_1):
# for each region
with (
self.DecompCrossRefMode(
self, self.precision, self.rel_tol, dtype, run_all
self,
self.precision,
self.rel_tol,
dtype,
run_all,
) as mode,
enable_python_dispatcher(),
):

40
test_lstm.py Normal file
View File

@ -0,0 +1,40 @@
import torch
@torch._dynamo.config.patch(allow_rnn=True)
def main():
class LSTM(torch.nn.Module):
def __init__(self, h):
super().__init__()
self.lstm = torch.nn.LSTM(2, 3, num_layers=2, bias=False)
def forward(self, x, h0, c0):
out, (_, _) = self.lstm(x, (h0, c0))
return out
torch._dynamo.config.recompile_limit = 1
device = "cuda"
lstm = LSTM(2).to(device)
comp_lstm = torch.compile(lstm)
x = torch.rand(3, 2).to(device)
torch._dynamo.mark_dynamic(x, 0)
h_0 = torch.rand(2, 3).to(device)
c_0 = torch.rand(2, 3).to(device)
base = lstm(x, h_0, c_0)
output = comp_lstm(x, h_0, c_0)
assert torch.all(torch.isclose(base, output, rtol=1e-7, atol=1e-7)), (
f"Failed: base size {base.size()} vs output size {output.size()}; "
+ f"base checksum: {base.sum().item()} vs output sum: {output.sum().item()}"
)
print("Passed, yay!")
x2 = torch.rand(8, 2).to(device) # now try with different seqlen
assert torch.all(
torch.isclose(lstm(x2, h_0, c_0), comp_lstm(x2, h_0, c_0), rtol=1e-7, atol=1e-7)
)
if __name__ == "__main__":
main()

View File

@ -3311,7 +3311,9 @@ def _rnn_helper(
params, hidden, i, bidirectional
)
dropout = dropout if (train and num_layers < i - 1) else 0.0
fwd_inp, fwd_hidden = layer_fn(input, cur_hidden, cur_params, has_biases)
fwd_inp, fwd_hidden = layer_fn(
input.clone(), cur_hidden, cur_params, has_biases
)
final_hiddens.append(fwd_hidden)
if bidirectional:
@ -3478,6 +3480,58 @@ def lstm_cell(inp, hx, cx, hh_weight, hh_bias, hr_weight, chunk_dim):
return hy, cy
def one_layer_while_loop_lstm(inp, hidden, params, has_biases, reverse=False):
"""
1 layer fn for while loop LSTM
"""
from torch._higher_order_ops import while_loop
ih_weight = params[0].clone() # Clone to avoid aliasing
hh_weight = params[1].clone() # Clone to avoid aliasing
ih_bias = params[2] if has_biases else None
hh_bias = params[3] if has_biases else None
hr_weight = (
params[4] if len(params) == 5 else params[2] if len(params) == 3 else None
)
hx = hidden[0].unsqueeze(0)
cx = hidden[1].unsqueeze(0)
precomputed_input = F.linear(inp, ih_weight, ih_bias)
precomputed_input = precomputed_input.flip(0) if reverse else precomputed_input
# while loop rewrite
step_output = torch.empty(
precomputed_input.size(0),
*tuple(hx.shape[1:]),
dtype=hx.dtype,
device=hx.device,
)
def cond_fn(i, out, hx, cx):
return i < precomputed_input.size(0)
def body_fn(idx, out, hx, cx):
i = idx.item()
torch._check_is_size(i)
torch._check_is_size(i, max=precomputed_input.size(0) - 1)
hx, cx = lstm_cell(
precomputed_input.select(0, idx),
hx,
cx,
hh_weight,
hh_bias,
hr_weight,
chunk_dim=2,
)
out = out.clone()
out[i] = hx
return idx + 1, out, hx, cx # this clone to avoid aliasing is annoying
cnt = torch.full((), 0, dtype=torch.int64, device=step_output.device)
_, out, hx, cx = while_loop(cond_fn, body_fn, [cnt, step_output, hx, cx])
if reverse:
out = out.flip(0)
return out, (hx.squeeze(1), cx.squeeze(1))
def one_layer_lstm(inp, hidden, params, has_biases, reverse=False):
ih_weight = params[0]
hh_weight = params[1]
@ -3616,7 +3670,7 @@ def select_one_layer_lstm_function(input, hx, params):
if use_mkldnn(input, hx, params):
return mkldnn_one_layer_lstm
else:
return one_layer_lstm
return one_layer_while_loop_lstm
@register_decomposition(aten.lstm.input)

View File

@ -148,9 +148,11 @@ def _maybe_reenter_make_fx(fn):
fake_mode = detect_fake_mode(args)
if fake_mode is None:
# we creaeta a fake_mode here to make sure we could
# we create a fake_mode here to make sure we could
# trace the graph with data-dependent calls e.g. .item()
return make_fx(fn, tracing_mode="fake")(*args)
return make_fx(
fn, tracing_mode="fake", _allow_non_fake_inputs=True
)(*args)
# Tracing with real if all inputs have been fakfied
return make_fx(fn)(*args)
@ -1136,8 +1138,10 @@ def materialize_as_graph(
@torch._dynamo.disable(recursive=True, reason=None)
def _materialize_as_graph_inner():
with suspend_functionalization(), disable_functional_mode():
with disable_proxy_modes_tracing():
with (
suspend_functionalization(),
torch.utils._python_dispatch._disable_current_modes(),
):
unfunc_t = [_from_fun(arg) for arg in args]
with contextlib.ExitStack() as stack:
stack.enter_context(