mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49980 From ``` ./python/libcst/libcst codemod remove_unused_imports.RemoveUnusedImportsWithGlean --no-format caffe2/ ``` Test Plan: Standard sandcastle tests Reviewed By: xush6528 Differential Revision: D25727359 fbshipit-source-id: c4f60005b10546423dc093d31d46deb418352286
73 lines
2.1 KiB
Python
73 lines
2.1 KiB
Python
#!/usr/bin/env python3
|
|
|
|
import hypothesis.strategies as st
|
|
import numpy as np
|
|
import torch
|
|
from caffe2.python import core
|
|
from caffe2.python.test_util import TestCase
|
|
from hypothesis import given, settings
|
|
from torch import nn
|
|
|
|
|
|
class TestC2LSTM(TestCase):
|
|
@given(
|
|
bsz=st.integers(1, 5),
|
|
seq_lens=st.integers(1, 6),
|
|
emb_lens=st.integers(5, 10),
|
|
hidden_size=st.integers(3, 7),
|
|
num_layers=st.integers(1, 4),
|
|
has_biases=st.booleans(),
|
|
is_bidirectional=st.booleans(),
|
|
batch_first=st.booleans(),
|
|
)
|
|
@settings(deadline=10000)
|
|
def test_c2_lstm(
|
|
self,
|
|
bsz,
|
|
seq_lens,
|
|
emb_lens,
|
|
hidden_size,
|
|
num_layers,
|
|
has_biases,
|
|
is_bidirectional,
|
|
batch_first,
|
|
):
|
|
net = core.Net("test_net")
|
|
num_directions = 2 if is_bidirectional else 1
|
|
py_lstm = nn.LSTM(
|
|
emb_lens,
|
|
hidden_size,
|
|
batch_first=batch_first,
|
|
bidirectional=is_bidirectional,
|
|
bias=has_biases,
|
|
num_layers=num_layers,
|
|
)
|
|
|
|
hx = np.zeros((num_layers * num_directions, bsz, hidden_size), dtype=np.float32)
|
|
|
|
if batch_first:
|
|
inputs = np.random.randn(bsz, seq_lens, emb_lens).astype(np.float32)
|
|
else:
|
|
inputs = np.random.randn(seq_lens, bsz, emb_lens).astype(np.float32)
|
|
|
|
py_results = py_lstm(torch.from_numpy(inputs))
|
|
lstm_in = [
|
|
torch.from_numpy(inputs),
|
|
torch.from_numpy(hx),
|
|
torch.from_numpy(hx),
|
|
] + [param.detach() for param in py_lstm._flat_weights]
|
|
|
|
c2_results = torch.ops._caffe2.InferenceLSTM(
|
|
lstm_in, num_layers, has_biases, batch_first, is_bidirectional
|
|
)
|
|
|
|
np.testing.assert_array_almost_equal(
|
|
py_results[0].detach().numpy(), c2_results[0].detach().numpy()
|
|
)
|
|
np.testing.assert_array_almost_equal(
|
|
py_results[1][0].detach().numpy(), c2_results[1].detach().numpy()
|
|
)
|
|
np.testing.assert_array_almost_equal(
|
|
py_results[1][1].detach().numpy(), c2_results[2].detach().numpy()
|
|
)
|