利用Resnet18训练Flower102数据集 - 大白社区
- 发帖时间:
- 2023-04-29 08:29:44
摘要:利用Resnet18训练Flower102数据集,多少有点大材小用了,效果还是蛮不错的,准确率高于95%。
本着学习完进行记录,方便后续复习,同时也希望分享出来给大家在学习过程中提供一些思路,特此按照逻辑路线梳理代码过程,本人也是初学者,希望大家本着学习原则互相探讨,有问题可以指出。本次内容只涉及到代码训练以及预测部份,数据集网上很多可以去官网下载,大概300多MB的样子。
导入库
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os
读入标签
import scipy.io
# 读取 .mat 文件
mat_file = scipy.io.loadmat('path.../imagelabels.mat')
# 查看 .mat 文件中的变量
print(mat_file.keys())
print(mat_file)
labels = mat_file['labels']
结果:dict_keys(['__header__', '__version__', '__globals__', 'labels']) {'__header__': b'MATLAB 5.0 MAT-file, Platform: GLNX86, Created on: Thu Feb 19 15:43:33 2009', '__version__': '1.0', '__globals__': [], 'labels': array([[77, 77, 77, ..., 62, 62, 62]], dtype=uint8)}可知'labels'属性包含所有的数据标签结果
uq = np.unique(labels)
labels = labels.reshape(-1)
labels = torch.from_numpy(labels)
uq,uq.shape,labels.shape
观察标签类别结果:一共102个类别,8189个样本。
(array([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26,
27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39,
40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52,
53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65,
66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78,
79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91,
92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102], dtype=uint8),
(102,),
torch.Size([8189]))
下面导入图片数据集
#1.首先数据增强
data_transforms = {
'train':tv.transforms.Compose([tv.transforms.RandomRotation(45),#随机旋转从-45~+45
tv.transforms.CenterCrop(448),#从中心还开始裁剪
tv.transforms.RandomHorizontalFlip(p=0.5),#随机水平翻转
tv.transforms.RandomVerticalFlip(p=0.5),#随机数值翻转
#tv.transforms.ColorJitter(brightness = 0.2,contrast = 0.1,saturation = 0.1,hue = 0.1),#亮度对比度饱和度色相
tv.transforms.ToTensor(),
tv.transforms.Resize(224,antialias=True),
tv.transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])]),
'valid':tv.transforms.Compose([tv.transforms.Resize(256),
tv.transforms.CenterCrop(224),
tv.transforms.ToTensor(),
tv.transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
}
#2.训练集图片读取
from torchvision.datasets import ImageFolder
import os
from PIL import Image
from torch.utils.data import Dataset
class ImageFolderDataset(Dataset):
def __init__(self, root, transform=None):
self.root = root
self.transform = transform
self.img_paths = [os.path.join(root, filename) for filename in os.listdir(root) if filename.endswith('.jpg') or filename.endswith('.jpeg') or filename.endswith('.png')]
def __len__(self):
return len(self.img_paths)
def __getitem__(self, index):
img_path = self.img_paths[index]
img = Image.open(img_path).convert('RGB')
if self.transform is not None:
img = self.transform(img)
return img
#3.建立数据集,img_dir是数据集的位置,自己指定
dataset = ImageFolderDataset(img_dir, transform=data_transforms['train'])
#4.方便训练,转为独热编码
#把索引转换为独热编码
def to_onehot(label):
len1 = len(label)
labels = np.zeros((len1,102))
labels[np.arange(len1), label-1] = 1
return torch.from_numpy(labels)
我们查看下其中的一张图片,图片大小缩放到了224x224
plt.imshow(dataset[251].permute(1,2,0).numpy())
之后我们把标签和数据集zip到一起,构建一个新的数据集,方便后续创建dataloader
#把标签弄到dataset中
# 将图片和标签组合成数据集
data = [(img, label) for img, label in zip(dataset, labels)]
dataset = torch.utils.data.TensorDataset(torch.stack([d[0] for d in data]), torch.tensor([d[1] for d in data]))
我们如果不想用那个1020个测试集的样本用于测试,我们可以直接对这8000多个训练样本划分训练集和测试集,batch别选太大,太大显卡容易out of memory。
#划分batch
import torch.utils.data as data
batch_size = 64
train_size = int(len(dataset) * 0.95) # 假设将 95% 的数据用于训练,5%用于测试
test_size = len(dataset) - train_size
train_set, test_set = data.random_split(dataset, [train_size, test_size])
train_loaders = torch.utils.data.DataLoader(train_set,batch_size = batch_size,shuffle = True)
test_loaders = torch.utils.data.DataLoader(test_set,batch_size = batch_size,shuffle = True)
来看一下划分完之后的结果:
train_loaders,test_loaders,len(train_set),len(test_set)
7779个训练样本,好像够了,样本太少,模型收敛速度会很慢。
(<torch.utils.data.dataloader.DataLoader at 0x1d74d0ac880>,
<torch.utils.data.dataloader.DataLoader at 0x1d74d093490>,
7779,
410)
看一下每一个batch吧:
for batch_idx, (data, target) in enumerate(train_loaders):
print(f"Batch {batch_idx}:")
print("Data: ", data.shape)
print("Target: ", target)
Batch 0:
Data: torch.Size([64, 3, 224, 224])
Target: tensor([ 18, 17, 27, 21, 40, 35, 31, 77, 82, 83, 11, 98, 69, 102,
9, 75, 95, 91, 46, 18, 92, 58, 89, 21, 74, 46, 83, 95,
11, 69, 83, 48, 12, 71, 2, 12, 55, 4, 63, 42, 75, 80,
91, 19, 66, 45, 55, 74, 46, 19, 11, 72, 91, 46, 18, 65,
37, 84, 99, 3, 29, 88, 51, 24], dtype=torch.uint8)
Batch 1:
Data: torch.Size([64, 3, 224, 224])
Target: tensor([49, 71, 73, 86, 73, 11, 50, 89, 83, 3, 52, 85, 73, 45, 1, 51, 81, 46,
94, 52, 91, 6, 12, 23, 8, 80, 70, 81, 56, 92, 37, 75, 85, 11, 89, 26,
37, 58, 23, 53, 5, 54, 48, 30, 61, 72, 33, 58, 82, 2, 2, 25, 41, 4,
2, 56, 79, 51, 43, 51, 28, 42, 56, 81], dtype=torch.uint8)
...
...
...
按照18的结构编写一个resnet18(也不是很难,按照结构图一点一点写就可以)或者直接导入pytorch有的固定的resnet18模型,模型我自己写了一个,很长,要定义resnet block,在指定不同block的输入通道数,输出通道数,卷积核的大小,步数,别忘了残差网络最后全连接的输出是102。
#测试一下网络
x = torch.rand(1,3,224,224)
for layer in resnet18:
x = layer(x)
print(layer.__class__.__name__,'output shape:\t',x.shape)
Sequential output shape: torch.Size([1, 64, 56, 56])
Sequential output shape: torch.Size([1, 64, 56, 56])
Sequential output shape: torch.Size([1, 128, 28, 28])
Sequential output shape: torch.Size([1, 256, 14, 14])
Sequential output shape: torch.Size([1, 512, 7, 7])
AdaptiveAvgPool2d output shape: torch.Size([1, 512, 1, 1])
Flatten output shape: torch.Size([1, 512])
Linear output shape: torch.Size([1, 102])
定义自己的cuda:
#传入cuda
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
进行训练:
#训练模型
epoches = 30
lr = 0.01
resnet18 = resnet18.to(device)
#损失函数
criterion = nn.CrossEntropyLoss()
#优化器
optimizer = torch.optim.SGD(resnet18.parameters(),lr=lr,weight_decay=5e-4,momentum=0.9, nesterov=True)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,2,0.85)
for epoch in range(epoches):
for batch_idx,data in enumerate(train_loaders):
img,label = data
img = img.to(device)
label = to_onehot(label).to(device)
resnet18.train()
pred = resnet18(img)
optimizer.zero_grad()
loss = criterion(pred,label)
loss.backward(retain_graph=True)
optimizer.step()
scheduler.step()
resnet18.eval()
if epoch%2 == 0:
print('epoch:',epoch,'loss:',loss.item())
epoch: 0 loss: 2.2539112329483033
epoch: 2 loss: 1.1994134024849958
epoch: 4 loss: 0.9377995965636468
epoch: 6 loss: 1.4816364160273223
epoch: 8 loss: 0.9987844148883596
epoch: 10 loss: 0.26387754171072236
epoch: 12 loss: 0.4316956601895591
epoch: 14 loss: 0.10870306638867727
epoch: 16 loss: 0.05006525443708856
epoch: 18 loss: 0.07943305657577834
epoch: 20 loss: 0.20609229835182694
epoch: 22 loss: 0.04521459138923092
epoch: 24 loss: 0.09232093384032071
epoch: 26 loss: 0.03266934813613521
epoch: 28 loss: 0.034536676809407904
可以看出迭代到20几轮,损失就很小了。
看下测试的准确率:
def accuracy(preds,labels):
# 将预测值和标签值转化为类别索引
preds_index = torch.argmax(preds, dim=1)+1
#labels_index = torch.argmax(labels, dim=1)
# 将类别索引转化为NumPy数组
preds_np = preds_index.numpy()
labels_np = labels.numpy()
# 计算准确率
correct = np.sum(preds_np == labels_np)
accuracy0 = correct / len(labels_np)
return accuracy0
for batch_idx,data in enumerate(test_loaders):
img,label = data
print(img.shape)
img = img.to(device)
label = label.to(device)
plt.imshow(img[0].cpu().permute(1,2,0).numpy())
pred = net(img)
print(torch.argmax(pred.cpu(), dim=1)+1)
print(label.cpu())
accuracy0 = accuracy(pred.cpu(),label.cpu())
print(accuracy0)
torch.Size([64, 3, 224, 224])
tensor([ 43, 43, 76, 36, 61, 78, 11, 12, 9, 31, 95, 81, 75, 50,
11, 22, 64, 95, 23, 37, 22, 51, 88, 37, 60, 51, 99, 94,
12, 82, 58, 73, 47, 56, 100, 73, 30, 34, 59, 93, 50, 66,
65, 47, 18, 81, 37, 72, 70, 81, 28, 60, 45, 80, 90, 71,
51, 86, 3, 51, 77, 81, 86, 65])
tensor([ 43, 43, 76, 36, 61, 78, 11, 12, 9, 30, 95, 81, 75, 50,
11, 22, 64, 95, 23, 37, 22, 51, 88, 37, 60, 51, 99, 94,
12, 82, 58, 73, 47, 56, 100, 73, 30, 34, 59, 93, 50, 66,
65, 47, 18, 81, 37, 72, 70, 81, 28, 60, 45, 80, 90, 71,
51, 86, 3, 51, 77, 81, 86, 65], dtype=torch.uint8)
0.984375
torch.Size([64, 3, 224, 224])
tensor([96, 80, 42, 36, 43, 46, 40, 22, 18, 78, 98, 61, 77, 38, 78, 81, 33, 91,
49, 81, 84, 18, 8, 85, 85, 51, 5, 77, 48, 58, 92, 58, 61, 74, 37, 78,
65, 30, 16, 12, 82, 2, 51, 37, 21, 81, 43, 27, 14, 78, 91, 83, 89, 88,
23, 3, 31, 20, 39, 70, 56, 51, 83, 60])
tensor([96, 80, 42, 36, 43, 46, 40, 22, 18, 78, 98, 61, 77, 38, 78, 81, 33, 91,
49, 81, 84, 18, 8, 85, 85, 51, 5, 77, 48, 58, 92, 58, 61, 74, 37, 78,
65, 30, 16, 12, 82, 2, 51, 37, 21, 81, 43, 27, 14, 78, 91, 83, 89, 88,
23, 3, 31, 20, 39, 70, 56, 51, 83, 60], dtype=torch.uint8)
1.0
torch.Size([64, 3, 224, 224])
tensor([93, 4, 76, 65, 6, 73, 5, 91, 44, 56, 16, 42, 58, 42, 43, 86, 43, 78,
94, 84, 53, 11, 57, 23, 48, 6, 95, 23, 77, 80, 90, 24, 75, 90, 8, 72,
72, 86, 77, 78, 8, 69, 55, 27, 39, 85, 10, 2, 29, 8, 44, 44, 33, 52,
46, 89, 75, 93, 33, 23, 95, 82, 23, 90])
tensor([93, 4, 76, 65, 6, 73, 5, 91, 44, 56, 16, 42, 58, 42, 43, 86, 43, 78,
94, 84, 53, 11, 57, 23, 48, 6, 95, 23, 77, 80, 90, 24, 75, 90, 8, 72,
72, 86, 77, 78, 8, 69, 55, 27, 39, 85, 10, 2, 29, 8, 44, 44, 33, 52,
46, 89, 75, 93, 33, 23, 95, 82, 23, 90], dtype=torch.uint8)
1.0
torch.Size([64, 3, 224, 224])
tensor([ 73, 80, 41, 48, 43, 94, 53, 63, 58, 43, 24, 3, 95, 43,
73, 80, 8, 89, 73, 51, 46, 91, 56, 37, 56, 31, 37, 40,
98, 58, 82, 87, 89, 77, 85, 72, 43, 78, 70, 99, 71, 101,
79, 60, 78, 25, 35, 70, 73, 90, 3, 47, 98, 82, 77, 43,
50, 73, 102, 93, 83, 89, 41, 43])
tensor([ 73, 80, 41, 48, 43, 94, 53, 63, 58, 43, 24, 3, 95, 43,
73, 80, 8, 89, 73, 51, 46, 91, 56, 37, 56, 31, 37, 40,
98, 58, 19, 87, 89, 77, 85, 72, 43, 78, 70, 99, 71, 44,
79, 60, 78, 25, 35, 70, 73, 90, 3, 47, 98, 82, 77, 43,
50, 73, 102, 93, 83, 89, 41, 43], dtype=torch.uint8)
0.96875
torch.Size([64, 3, 224, 224])
tensor([ 46, 97, 55, 75, 82, 27, 83, 27, 12, 39, 94, 48, 82, 74,
12, 71, 88, 80, 95, 89, 52, 59, 91, 78, 76, 102, 72, 51,
31, 22, 75, 15, 1, 98, 102, 77, 45, 77, 42, 2, 38, 4,
90, 20, 76, 15, 91, 85, 60, 83, 99, 81, 46, 51, 6, 73,
10, 73, 51, 46, 90, 53, 41, 94])
tensor([ 46, 97, 55, 75, 82, 27, 83, 27, 12, 39, 94, 48, 82, 74,
12, 71, 88, 80, 95, 89, 52, 59, 91, 78, 76, 102, 72, 51,
31, 22, 75, 15, 1, 98, 102, 77, 45, 77, 42, 2, 38, 4,
90, 20, 76, 15, 91, 85, 60, 83, 99, 81, 46, 51, 6, 73,
10, 73, 51, 46, 90, 53, 41, 94], dtype=torch.uint8)
1.0
torch.Size([64, 3, 224, 224])
tensor([ 81, 73, 72, 70, 77, 78, 50, 59, 11, 1, 76, 60, 66, 29,
79, 17, 101, 90, 98, 102, 50, 97, 83, 58, 88, 59, 19, 87,
33, 58, 44, 47, 47, 74, 51, 39, 36, 95, 68, 4, 58, 73,
41, 16, 65, 93, 71, 92, 56, 24, 31, 98, 75, 37, 50, 14,
96, 55, 74, 24, 46, 37, 2, 87])
tensor([ 81, 73, 72, 70, 77, 78, 50, 59, 11, 1, 76, 60, 66, 29,
79, 17, 101, 90, 98, 102, 50, 97, 83, 58, 88, 59, 19, 87,
33, 58, 44, 47, 47, 74, 51, 39, 36, 95, 68, 4, 58, 73,
41, 16, 65, 93, 71, 92, 56, 24, 31, 98, 75, 37, 50, 14,
96, 55, 74, 24, 46, 37, 2, 87], dtype=torch.uint8)
1.0
torch.Size([26, 3, 224, 224])
tensor([89, 52, 96, 56, 73, 42, 26, 90, 83, 17, 74, 77, 74, 94, 30, 51, 93, 72,
68, 51, 40, 66, 73, 94, 90, 17])
tensor([89, 52, 96, 56, 73, 42, 26, 90, 83, 17, 74, 77, 74, 94, 30, 51, 93, 72,
68, 51, 40, 66, 73, 94, 90, 17], dtype=torch.uint8)
1.0
可以观察到,测试集每一个batch的准确率差不多96%~100%左右,效果还是非常不错的。
残差网络学习到的权重是啥,我们可视化第一个权重结果:
fig, axs = plt.subplots(8, 8) # 一行两列的子图
for i in range(len(weight0)):
axs[i//8,i%8].imshow(weight0[i].transpose(1,2,0))
plt.show()
其实来讲,很抽象,网络学习到的权重只用于优化他的损失,而不具有一般的可解释性,俗称“炼丹”。