mirror of
https://github.com/ZhangXinNan/DL-with-Python-and-PyTorch2.git
synced 2025-10-20 23:34:18 +08:00
550 lines
155 KiB
Plaintext
550 lines
155 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 12.4 微调实例\n",
|
||
"\t微调允许修改预先训练好的网络参数来学习目标任务,所以,微调的训练时间要比特征抽取的训练时间长,但精度更高。微调的大致过程是在预先训练过的网络上添加新的随机初始化层,也会更新预先训练的网络参数,但会使用较小的学习率以防止预先训练好的参数发生较大改变。\n",
|
||
"\t常用的方法是固定底层的参数,调整一些顶层或具体层的参数。这样做的好处是可以减少训练参数的数量,也可以克服过拟合现象的发生。尤其是在目标任务的数据量不足够大时,该方法的实践效果更好。实际上,微调要优于特征提取,因为它能够对迁移过来的预训练网络参数进行优化,使其更加适合新的任务。\n",
|
||
"### 12.4.1 数据预处理\n",
|
||
"\t这里对训练数据添加了几种数据增强方法,如图像裁剪、旋转、颜色改变等方法。测试数据与特征提取的测试数据一样,没有变化。"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import torch\n",
|
||
"from torch import nn\n",
|
||
"import torch.nn.functional as F\n",
|
||
"import torchvision\n",
|
||
"import torchvision.transforms as transforms\n",
|
||
"from torchvision import models\n",
|
||
"from torchvision.datasets import ImageFolder\n",
|
||
"from datetime import datetime"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"trans_train = transforms.Compose(\n",
|
||
" [transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),\n",
|
||
" transforms.RandomRotation(degrees=15),\n",
|
||
" transforms.ColorJitter(),\n",
|
||
" transforms.RandomResizedCrop(224),\n",
|
||
" transforms.RandomHorizontalFlip(), \n",
|
||
" transforms.ToTensor(),\n",
|
||
" transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
|
||
" std=[0.229, 0.224, 0.225])])\n",
|
||
"\n",
|
||
"trans_valid = transforms.Compose(\n",
|
||
" [transforms.Resize(256),\n",
|
||
" transforms.CenterCrop(224),\n",
|
||
" transforms.ToTensor(),\n",
|
||
" transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
|
||
" std=[0.229, 0.224, 0.225])])\n",
|
||
"\n",
|
||
"trainset = torchvision.datasets.CIFAR10(root='../data', train=True,\n",
|
||
" download=False, transform=trans_train)\n",
|
||
"trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,\n",
|
||
" shuffle=True, num_workers=2)\n",
|
||
"\n",
|
||
"testset = torchvision.datasets.CIFAR10(root='../data', train=False,\n",
|
||
" download=False, transform=trans_valid)\n",
|
||
"testloader = torch.utils.data.DataLoader(testset, batch_size=64,\n",
|
||
" shuffle=False, num_workers=2)\n",
|
||
"\n",
|
||
"classes = ('plane', 'car', 'bird', 'cat',\n",
|
||
" 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"image/png": "\n",
|
||
"text/plain": [
|
||
"<Figure size 432x288 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {
|
||
"needs_background": "light"
|
||
},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" car deer ship plane\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"import matplotlib.pyplot as plt\n",
|
||
"import numpy as np\n",
|
||
"%matplotlib inline\n",
|
||
"\n",
|
||
"# 显示图像\n",
|
||
"\n",
|
||
"def imshow(img):\n",
|
||
" img = img / 2 + 0.5 # unnormalize\n",
|
||
" npimg = img.numpy()\n",
|
||
" plt.imshow(np.transpose(npimg, (1, 2, 0)))\n",
|
||
" plt.show()\n",
|
||
"\n",
|
||
"\n",
|
||
"# 随机获取部分训练数据\n",
|
||
"dataiter = iter(trainloader)\n",
|
||
"images, labels = dataiter.next()\n",
|
||
"\n",
|
||
"\n",
|
||
"# 显示图像\n",
|
||
"imshow(torchvision.utils.make_grid(images))\n",
|
||
"# 打印标签\n",
|
||
"print(' '.join('%5s' % classes[labels[j]] for j in range(4)))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"source": [
|
||
"### 12.4.2 加载预训练模型"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 使用预训练的模型\n",
|
||
"net = models.resnet18(pretrained=True)\n",
|
||
"#print(net)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"ResNet(\n",
|
||
" (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
|
||
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||
" (relu): ReLU(inplace=True)\n",
|
||
" (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
|
||
" (layer1): Sequential(\n",
|
||
" (0): BasicBlock(\n",
|
||
" (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||
" (relu): ReLU(inplace=True)\n",
|
||
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||
" )\n",
|
||
" (1): BasicBlock(\n",
|
||
" (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||
" (relu): ReLU(inplace=True)\n",
|
||
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||
" )\n",
|
||
" )\n",
|
||
" (layer2): Sequential(\n",
|
||
" (0): BasicBlock(\n",
|
||
" (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
|
||
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||
" (relu): ReLU(inplace=True)\n",
|
||
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||
" (downsample): Sequential(\n",
|
||
" (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
|
||
" (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||
" )\n",
|
||
" )\n",
|
||
" (1): BasicBlock(\n",
|
||
" (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||
" (relu): ReLU(inplace=True)\n",
|
||
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||
" )\n",
|
||
" )\n",
|
||
" (layer3): Sequential(\n",
|
||
" (0): BasicBlock(\n",
|
||
" (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
|
||
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||
" (relu): ReLU(inplace=True)\n",
|
||
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||
" (downsample): Sequential(\n",
|
||
" (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
|
||
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||
" )\n",
|
||
" )\n",
|
||
" (1): BasicBlock(\n",
|
||
" (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||
" (relu): ReLU(inplace=True)\n",
|
||
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||
" )\n",
|
||
" )\n",
|
||
" (layer4): Sequential(\n",
|
||
" (0): BasicBlock(\n",
|
||
" (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
|
||
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||
" (relu): ReLU(inplace=True)\n",
|
||
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||
" (downsample): Sequential(\n",
|
||
" (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
|
||
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||
" )\n",
|
||
" )\n",
|
||
" (1): BasicBlock(\n",
|
||
" (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||
" (relu): ReLU(inplace=True)\n",
|
||
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||
" )\n",
|
||
" )\n",
|
||
" (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n",
|
||
" (fc): Linear(in_features=512, out_features=1000, bias=True)\n",
|
||
")\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"print(net)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### 12.4.3 修改分类器\n",
|
||
"\t修改最后全连接层,把类别数由原来的1000改为10。"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"ResNet(\n",
|
||
" (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
|
||
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||
" (relu): ReLU(inplace=True)\n",
|
||
" (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
|
||
" (layer1): Sequential(\n",
|
||
" (0): BasicBlock(\n",
|
||
" (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||
" (relu): ReLU(inplace=True)\n",
|
||
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||
" )\n",
|
||
" (1): BasicBlock(\n",
|
||
" (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||
" (relu): ReLU(inplace=True)\n",
|
||
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||
" )\n",
|
||
" )\n",
|
||
" (layer2): Sequential(\n",
|
||
" (0): BasicBlock(\n",
|
||
" (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
|
||
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||
" (relu): ReLU(inplace=True)\n",
|
||
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||
" (downsample): Sequential(\n",
|
||
" (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
|
||
" (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||
" )\n",
|
||
" )\n",
|
||
" (1): BasicBlock(\n",
|
||
" (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||
" (relu): ReLU(inplace=True)\n",
|
||
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||
" )\n",
|
||
" )\n",
|
||
" (layer3): Sequential(\n",
|
||
" (0): BasicBlock(\n",
|
||
" (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
|
||
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||
" (relu): ReLU(inplace=True)\n",
|
||
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||
" (downsample): Sequential(\n",
|
||
" (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
|
||
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||
" )\n",
|
||
" )\n",
|
||
" (1): BasicBlock(\n",
|
||
" (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||
" (relu): ReLU(inplace=True)\n",
|
||
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||
" )\n",
|
||
" )\n",
|
||
" (layer4): Sequential(\n",
|
||
" (0): BasicBlock(\n",
|
||
" (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
|
||
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||
" (relu): ReLU(inplace=True)\n",
|
||
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||
" (downsample): Sequential(\n",
|
||
" (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
|
||
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||
" )\n",
|
||
" )\n",
|
||
" (1): BasicBlock(\n",
|
||
" (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||
" (relu): ReLU(inplace=True)\n",
|
||
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||
" )\n",
|
||
" )\n",
|
||
" (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n",
|
||
" (fc): Linear(in_features=512, out_features=10, bias=True)\n",
|
||
")"
|
||
]
|
||
},
|
||
"execution_count": 6,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"# 将最后的全连接层改成十分类\n",
|
||
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
||
"net.fc = nn.Linear(512, 10)\n",
|
||
"#net = torch.nn.DataParallel(net)\n",
|
||
"net.to(device)\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 7,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"torch.cuda.FloatTensor\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# 打出第一层的权重\n",
|
||
"print(net.conv1.weight.type())"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"source": [
|
||
"### 12.4.4 选择损失函数及优化器\n",
|
||
"\t这里学习率为le-3,使用微调方法训练模型时,一般选择一个稍大一点的学习率,如果选择的学习率太小,效果要差一些。"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 8,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def get_acc(output, label):\n",
|
||
" total = output.shape[0]\n",
|
||
" _, pred_label = output.max(1)\n",
|
||
" num_correct = (pred_label == label).sum().item()\n",
|
||
" return num_correct / total"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 9,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"criterion = nn.CrossEntropyLoss()\n",
|
||
"optimizer = torch.optim.SGD(net.parameters(), lr=1e-3, weight_decay=1e-3,momentum=0.9)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"source": [
|
||
"### 12.4.5 训练及验证模型\n",
|
||
"\t训练及验证模型的代码如下:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 10,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def train(net, train_data, valid_data, num_epochs, optimizer, criterion):\n",
|
||
" \n",
|
||
" prev_time = datetime.now()\n",
|
||
" for epoch in range(num_epochs):\n",
|
||
" train_loss = 0\n",
|
||
" train_acc = 0\n",
|
||
" net = net.train()\n",
|
||
" for im, label in train_data:\n",
|
||
" im = im.to(device) # (bs, 3, h, w)\n",
|
||
" label = label.to(device) # (bs, h, w)\n",
|
||
" # forward\n",
|
||
" output = net(im)\n",
|
||
" loss = criterion(output, label)\n",
|
||
" # backward\n",
|
||
" optimizer.zero_grad()\n",
|
||
" loss.backward()\n",
|
||
" optimizer.step()\n",
|
||
"\n",
|
||
" train_loss += loss.item()\n",
|
||
" train_acc += get_acc(output, label)\n",
|
||
"\n",
|
||
" cur_time = datetime.now()\n",
|
||
" h, remainder = divmod((cur_time - prev_time).seconds, 3600)\n",
|
||
" m, s = divmod(remainder, 60)\n",
|
||
" time_str = \"Time %02d:%02d:%02d\" % (h, m, s)\n",
|
||
" if valid_data is not None:\n",
|
||
" valid_loss = 0\n",
|
||
" valid_acc = 0\n",
|
||
" net = net.eval()\n",
|
||
" for im, label in valid_data:\n",
|
||
" im = im.to(device) # (bs, 3, h, w)\n",
|
||
" label = label.to(device) # (bs, h, w)\n",
|
||
" output = net(im)\n",
|
||
" loss = criterion(output, label)\n",
|
||
" valid_loss += loss.item()\n",
|
||
" valid_acc += get_acc(output, label)\n",
|
||
" epoch_str = (\n",
|
||
" \"Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, \"\n",
|
||
" % (epoch, train_loss / len(train_data),\n",
|
||
" train_acc / len(train_data), valid_loss / len(valid_data),\n",
|
||
" valid_acc / len(valid_data)))\n",
|
||
" else:\n",
|
||
" epoch_str = (\"Epoch %d. Train Loss: %f, Train Acc: %f, \" %\n",
|
||
" (epoch, train_loss / len(train_data),\n",
|
||
" train_acc / len(train_data)))\n",
|
||
" prev_time = cur_time\n",
|
||
" print(epoch_str + time_str)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 11,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Epoch 0. Train Loss: 1.046640, Train Acc: 0.634691, Valid Loss: 0.355521, Valid Acc: 0.880275, Time 00:02:12\n",
|
||
"Epoch 1. Train Loss: 0.689994, Train Acc: 0.758492, Valid Loss: 0.280970, Valid Acc: 0.904757, Time 00:02:25\n",
|
||
"Epoch 2. Train Loss: 0.621406, Train Acc: 0.781630, Valid Loss: 0.248424, Valid Acc: 0.913018, Time 00:02:27\n",
|
||
"Epoch 3. Train Loss: 0.571020, Train Acc: 0.801570, Valid Loss: 0.220384, Valid Acc: 0.925458, Time 00:02:29\n",
|
||
"Epoch 4. Train Loss: 0.540538, Train Acc: 0.812580, Valid Loss: 0.192789, Valid Acc: 0.936007, Time 00:02:29\n",
|
||
"Epoch 5. Train Loss: 0.511390, Train Acc: 0.822530, Valid Loss: 0.182821, Valid Acc: 0.936206, Time 00:02:27\n",
|
||
"Epoch 6. Train Loss: 0.490604, Train Acc: 0.829204, Valid Loss: 0.182214, Valid Acc: 0.936505, Time 00:02:30\n",
|
||
"Epoch 7. Train Loss: 0.480625, Train Acc: 0.831142, Valid Loss: 0.173486, Valid Acc: 0.937799, Time 00:02:30\n",
|
||
"Epoch 8. Train Loss: 0.468967, Train Acc: 0.838175, Valid Loss: 0.167395, Valid Acc: 0.941481, Time 00:02:28\n",
|
||
"Epoch 9. Train Loss: 0.454750, Train Acc: 0.842311, Valid Loss: 0.167120, Valid Acc: 0.943073, Time 00:02:26\n",
|
||
"Epoch 10. Train Loss: 0.444690, Train Acc: 0.845368, Valid Loss: 0.168414, Valid Acc: 0.942576, Time 00:02:26\n",
|
||
"Epoch 11. Train Loss: 0.428961, Train Acc: 0.850384, Valid Loss: 0.159228, Valid Acc: 0.942377, Time 00:02:27\n",
|
||
"Epoch 12. Train Loss: 0.425348, Train Acc: 0.851982, Valid Loss: 0.152707, Valid Acc: 0.946258, Time 00:02:26\n",
|
||
"Epoch 13. Train Loss: 0.414132, Train Acc: 0.855459, Valid Loss: 0.151967, Valid Acc: 0.947253, Time 00:02:26\n",
|
||
"Epoch 14. Train Loss: 0.410402, Train Acc: 0.856298, Valid Loss: 0.152984, Valid Acc: 0.948646, Time 00:02:26\n",
|
||
"Epoch 15. Train Loss: 0.402335, Train Acc: 0.861153, Valid Loss: 0.154637, Valid Acc: 0.947353, Time 00:02:27\n",
|
||
"Epoch 16. Train Loss: 0.391543, Train Acc: 0.863811, Valid Loss: 0.151976, Valid Acc: 0.949045, Time 00:02:26\n",
|
||
"Epoch 17. Train Loss: 0.390811, Train Acc: 0.864390, Valid Loss: 0.157535, Valid Acc: 0.945661, Time 00:02:27\n",
|
||
"Epoch 18. Train Loss: 0.378429, Train Acc: 0.867907, Valid Loss: 0.149089, Valid Acc: 0.951632, Time 00:02:26\n",
|
||
"Epoch 19. Train Loss: 0.374837, Train Acc: 0.869206, Valid Loss: 0.154984, Valid Acc: 0.947253, Time 00:02:26\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"train(net, trainloader, testloader, 20, optimizer, criterion)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"由结果可知,微调的训练时间明显大于特征提取的训练时间,其一个循环需要9分钟左右,但验证准确率高达95%。主要这里只循环了20次,如果增加循环次数,准确率应该还可再提升几个百分点。"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": []
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.7.4"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 2
|
||
}
|