mirror of
https://github.com/ZhangXinNan/DL-with-Python-and-PyTorch2.git
synced 2025-10-21 07:43:46 +08:00
552 lines
103 KiB
Plaintext
552 lines
103 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 7.8 用LSTM预测股票行情\n",
|
||
"这里采用沪深300指数数据,时间跨度为2010-10-10至今,选择每天最高价格。假设当天最高价依赖当天的前n天(如30天)的沪深300的最高价。用LSTM模型来捕捉最高价的时序信息,通过训练模型,使之学会用前n天的最高价,判断当天的最高价(作为训练的标签值)。"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### 7.8.1 导入数据\n",
|
||
"这里使用tushare来下载沪深300指数数据。可以用pip 安装tushare。\n",
|
||
"### <font color = blue>注意:目前下载沪深300指数数据需要注册及需要一定积分,如果你无法下载数据,可以跳过这步,直接使用我们已下载的数据包sh300.csv</font>"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"import tushare as ts "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"cons = ts.get_apis()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 7,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"#获取沪深指数(000300)的信息,包括交易日期(datetime)、开盘价(open)、收盘价(close),\n",
|
||
"#最高价(high)、最低价(low)、成交量(vol)、成交金额(amount)、涨跌幅(p_change)\n",
|
||
"df = ts.bar('000300', conn=cons, asset='INDEX', start_date='2010-01-01', end_date='')"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 8,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"df = df.dropna()\n",
|
||
"df.to_csv('sh300.csv')"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 9,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"Index(['code', 'open', 'close', 'high', 'low', 'vol', 'amount', 'p_change'], dtype='object')"
|
||
]
|
||
},
|
||
"execution_count": 9,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"df.columns"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### 7.8.2 数据概览"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 10,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>open</th>\n",
|
||
" <th>close</th>\n",
|
||
" <th>high</th>\n",
|
||
" <th>low</th>\n",
|
||
" <th>vol</th>\n",
|
||
" <th>amount</th>\n",
|
||
" <th>p_change</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>count</th>\n",
|
||
" <td>2295.000000</td>\n",
|
||
" <td>2295.000000</td>\n",
|
||
" <td>2295.000000</td>\n",
|
||
" <td>2295.000000</td>\n",
|
||
" <td>2.295000e+03</td>\n",
|
||
" <td>2.295000e+03</td>\n",
|
||
" <td>2295.000000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>mean</th>\n",
|
||
" <td>3100.514637</td>\n",
|
||
" <td>3103.181503</td>\n",
|
||
" <td>3128.213684</td>\n",
|
||
" <td>3073.658757</td>\n",
|
||
" <td>1.090221e+06</td>\n",
|
||
" <td>1.296155e+11</td>\n",
|
||
" <td>0.012397</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>std</th>\n",
|
||
" <td>627.888776</td>\n",
|
||
" <td>628.060844</td>\n",
|
||
" <td>634.870454</td>\n",
|
||
" <td>618.306225</td>\n",
|
||
" <td>9.284048e+05</td>\n",
|
||
" <td>1.268450e+11</td>\n",
|
||
" <td>1.483714</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>min</th>\n",
|
||
" <td>2079.870000</td>\n",
|
||
" <td>2086.970000</td>\n",
|
||
" <td>2118.790000</td>\n",
|
||
" <td>2023.170000</td>\n",
|
||
" <td>2.190120e+05</td>\n",
|
||
" <td>2.120044e+10</td>\n",
|
||
" <td>-8.750000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>25%</th>\n",
|
||
" <td>2534.185000</td>\n",
|
||
" <td>2534.185000</td>\n",
|
||
" <td>2558.015000</td>\n",
|
||
" <td>2514.585000</td>\n",
|
||
" <td>5.634255e+05</td>\n",
|
||
" <td>6.092613e+10</td>\n",
|
||
" <td>-0.650000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>50%</th>\n",
|
||
" <td>3160.800000</td>\n",
|
||
" <td>3165.910000</td>\n",
|
||
" <td>3193.820000</td>\n",
|
||
" <td>3134.380000</td>\n",
|
||
" <td>8.055400e+05</td>\n",
|
||
" <td>9.127102e+10</td>\n",
|
||
" <td>0.020000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>75%</th>\n",
|
||
" <td>3484.665000</td>\n",
|
||
" <td>3486.080000</td>\n",
|
||
" <td>3510.940000</td>\n",
|
||
" <td>3461.215000</td>\n",
|
||
" <td>1.202926e+06</td>\n",
|
||
" <td>1.382787e+11</td>\n",
|
||
" <td>0.705000</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>max</th>\n",
|
||
" <td>5379.470000</td>\n",
|
||
" <td>5353.750000</td>\n",
|
||
" <td>5380.430000</td>\n",
|
||
" <td>5283.090000</td>\n",
|
||
" <td>6.864391e+06</td>\n",
|
||
" <td>9.494980e+11</td>\n",
|
||
" <td>6.710000</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" open close high low vol \\\n",
|
||
"count 2295.000000 2295.000000 2295.000000 2295.000000 2.295000e+03 \n",
|
||
"mean 3100.514637 3103.181503 3128.213684 3073.658757 1.090221e+06 \n",
|
||
"std 627.888776 628.060844 634.870454 618.306225 9.284048e+05 \n",
|
||
"min 2079.870000 2086.970000 2118.790000 2023.170000 2.190120e+05 \n",
|
||
"25% 2534.185000 2534.185000 2558.015000 2514.585000 5.634255e+05 \n",
|
||
"50% 3160.800000 3165.910000 3193.820000 3134.380000 8.055400e+05 \n",
|
||
"75% 3484.665000 3486.080000 3510.940000 3461.215000 1.202926e+06 \n",
|
||
"max 5379.470000 5353.750000 5380.430000 5283.090000 6.864391e+06 \n",
|
||
"\n",
|
||
" amount p_change \n",
|
||
"count 2.295000e+03 2295.000000 \n",
|
||
"mean 1.296155e+11 0.012397 \n",
|
||
"std 1.268450e+11 1.483714 \n",
|
||
"min 2.120044e+10 -8.750000 \n",
|
||
"25% 6.092613e+10 -0.650000 \n",
|
||
"50% 9.127102e+10 0.020000 \n",
|
||
"75% 1.382787e+11 0.705000 \n",
|
||
"max 9.494980e+11 6.710000 "
|
||
]
|
||
},
|
||
"execution_count": 10,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"df.describe()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### 7.8.3 预处理数据"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 11,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"import pandas as pd\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"import datetime\n",
|
||
"import torch\n",
|
||
"import torch.nn as nn\n",
|
||
"import numpy as np\n",
|
||
"from torch.utils.data import Dataset, DataLoader\n",
|
||
"import torchvision\n",
|
||
"import torchvision.transforms as transforms\n",
|
||
"\n",
|
||
"%matplotlib inline"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 12,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"n = 30\n",
|
||
"LR = 0.001\n",
|
||
"EPOCH = 200\n",
|
||
"batch_size=20\n",
|
||
"train_end =-600\n",
|
||
"\n",
|
||
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 13,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"#通过一个序列来生成一个31*(count(*)-train_end)矩阵(用于处理时序的数据)\n",
|
||
"#其中最后一列维标签数据。就是把当天的前n天作为参数,当天的数据作为label\n",
|
||
"def generate_data_by_n_days(series, n, index=False):\n",
|
||
" if len(series) <= n:\n",
|
||
" raise Exception(\"The Length of series is %d, while affect by (n=%d).\" % (len(series), n))\n",
|
||
" df = pd.DataFrame()\n",
|
||
" for i in range(n):\n",
|
||
" df['c%d' % i] = series.tolist()[i:-(n - i)] \n",
|
||
" df['y'] = series.tolist()[n:]\n",
|
||
" \n",
|
||
" if index:\n",
|
||
" df.index = series.index[n:]\n",
|
||
" return df\n",
|
||
"\n",
|
||
"#参数n与上相同。train_end表示的是后面多少个数据作为测试集。\n",
|
||
"def readData(column='high', n=30, all_too=True, index=False, train_end=-500):\n",
|
||
" df = pd.read_csv(\"sh300.csv\", index_col=0)\n",
|
||
" #以日期为索引\n",
|
||
" df.index = list(map(lambda x: datetime.datetime.strptime(x, \"%Y-%m-%d\"), df.index))\n",
|
||
" #获取每天的最高价\n",
|
||
" df_column = df[column].copy()\n",
|
||
" #拆分为训练集和测试集\n",
|
||
" df_column_train, df_column_test = df_column[:train_end], df_column[train_end - n:]\n",
|
||
" #生成训练数据\n",
|
||
" df_generate_train = generate_data_by_n_days(df_column_train, n, index=index)\n",
|
||
" if all_too:\n",
|
||
" return df_generate_train, df_column, df.index.tolist()\n",
|
||
" return df_generate_train\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### 7.8.4 定义模型"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 14,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"class RNN(nn.Module):\n",
|
||
" def __init__(self, input_size):\n",
|
||
" super(RNN, self).__init__()\n",
|
||
" self.rnn = nn.LSTM(\n",
|
||
" input_size=input_size,\n",
|
||
" hidden_size=64,\n",
|
||
" num_layers=1,\n",
|
||
" batch_first=True\n",
|
||
" )\n",
|
||
" self.out = nn.Sequential(\n",
|
||
" nn.Linear(64, 1)\n",
|
||
" )\n",
|
||
"\n",
|
||
" def forward(self, x):\n",
|
||
" r_out, (h_n, h_c) = self.rnn(x, None) #None即隐层状态用0初始化\n",
|
||
" out = self.out(r_out)\n",
|
||
" return out\n",
|
||
"\n",
|
||
"\n",
|
||
"class mytrainset(Dataset):\n",
|
||
" def __init__(self, data): \n",
|
||
" self.data, self.label = data[:, :-1].float(), data[:, -1].float()\n",
|
||
" \n",
|
||
" def __getitem__(self, index):\n",
|
||
" return self.data[index], self.label[index]\n",
|
||
"\n",
|
||
" def __len__(self):\n",
|
||
" return len(self.data)\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### 7.8.5 训练模型"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 15,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/png": "\n",
|
||
"text/plain": [
|
||
"<Figure size 432x288 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {
|
||
"needs_background": "light"
|
||
},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"from pandas.plotting import register_matplotlib_converters\n",
|
||
"register_matplotlib_converters()\n",
|
||
"# 获取训练数据、原始数据、索引等信息\n",
|
||
"df, df_all, df_index = readData('high', n=n, train_end=train_end)\n",
|
||
"\n",
|
||
"#可视化原高价数据\n",
|
||
"df_all = np.array(df_all.tolist())\n",
|
||
"plt.plot(df_index, df_all, label='real-data')\n",
|
||
"plt.legend(loc='upper right') \n",
|
||
"\n",
|
||
"\n",
|
||
"#对数据进行预处理,规范化及转换为Tensor\n",
|
||
"df_numpy = np.array(df)\n",
|
||
"\n",
|
||
"df_numpy_mean = np.mean(df_numpy)\n",
|
||
"df_numpy_std = np.std(df_numpy)\n",
|
||
"\n",
|
||
"df_numpy = (df_numpy - df_numpy_mean) / df_numpy_std\n",
|
||
"df_tensor = torch.Tensor(df_numpy)\n",
|
||
"\n",
|
||
"\n",
|
||
"trainset = mytrainset(df_tensor)\n",
|
||
"trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=False)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 16,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"#记录损失值,并用tensorboardx在web上展示\n",
|
||
"from tensorboardX import SummaryWriter\n",
|
||
"writer = SummaryWriter(log_dir='logs')\n",
|
||
"\n",
|
||
"rnn = RNN(n).to(device)\n",
|
||
"optimizer = torch.optim.Adam(rnn.parameters(), lr=LR) \n",
|
||
"loss_func = nn.MSELoss()\n",
|
||
"\n",
|
||
"for step in range(EPOCH):\n",
|
||
" for tx, ty in trainloader:\n",
|
||
" tx=tx.to(device)\n",
|
||
" ty=ty.to(device)\n",
|
||
" #在第1个维度上添加一个维度为1的维度,形状变为[batch,seq_len,input_size]\n",
|
||
" output = rnn(torch.unsqueeze(tx, dim=1)).to(device)\n",
|
||
" loss = loss_func(torch.squeeze(output), ty)\n",
|
||
" optimizer.zero_grad() \n",
|
||
" loss.backward() \n",
|
||
" optimizer.step()\n",
|
||
" writer.add_scalar('sh300_loss', loss, step) "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### 7.8.6 测试模型"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 17,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/png": "\n",
|
||
"text/plain": [
|
||
"<Figure size 432x288 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {
|
||
"needs_background": "light"
|
||
},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"generate_data_train = []\n",
|
||
"generate_data_test = []\n",
|
||
"\n",
|
||
"test_index = len(df_all) + train_end\n",
|
||
"\n",
|
||
"df_all_normal = (df_all - df_numpy_mean) / df_numpy_std\n",
|
||
"df_all_normal_tensor = torch.Tensor(df_all_normal)\n",
|
||
"for i in range(n, len(df_all)):\n",
|
||
" x = df_all_normal_tensor[i - n:i].to(device)\n",
|
||
" #rnn的输入必须是3维,故需添加两个1维的维度,最后成为[1,1,input_size]\n",
|
||
" x = torch.unsqueeze(torch.unsqueeze(x, dim=0), dim=0)\n",
|
||
" \n",
|
||
" y = rnn(x).to(device)\n",
|
||
" if i < test_index:\n",
|
||
" generate_data_train.append(torch.squeeze(y).detach().cpu().numpy() * df_numpy_std + df_numpy_mean)\n",
|
||
" else:\n",
|
||
" generate_data_test.append(torch.squeeze(y).detach().cpu().numpy() * df_numpy_std + df_numpy_mean)\n",
|
||
"plt.plot(df_index[n:train_end], generate_data_train, label='generate_train')\n",
|
||
"plt.plot(df_index[train_end:], generate_data_test, label='generate_test')\n",
|
||
"plt.plot(df_index[train_end:], df_all[train_end:], label='real-data')\n",
|
||
"plt.legend()\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 18,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/png": "\n",
|
||
"text/plain": [
|
||
"<Figure size 432x288 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {
|
||
"needs_background": "light"
|
||
},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"plt.clf()\n",
|
||
"plt.plot(df_index[train_end:-500], df_all[train_end:-500], label='real-data')\n",
|
||
"plt.plot(df_index[train_end:-500], generate_data_test[-600:-500], label='generate_test')\n",
|
||
"plt.legend()\n",
|
||
"plt.show()"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.7.4"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 2
|
||
}
|