炼丹日记-1
使用 CNN 进行 0 ~ 9 的数字识别,数据集来源于谷歌 Kaptcha 生成的纯数字验证码,并复制两份,分别添加 15° 与 -15 ° 的旋转角度。
最终效果,在验证集 (0°|15°|-15°, 240 张图片)上达到了 76.25% 的准确率,在部分验证集 (0°,80张图片)上达到了 100% 的准确率。
使用验证码来测试,几十次内也是完全正确,可以认为也接近 100% 准确率。
缺点:训练的模型在数据集上出现了过拟合,对于宋体等标准数字字体,识别效果较差。一部分原因是 Kaptcha 的数字是有点糊而且分辨度不如标准字体,当然更大的原因应该是模型自身的原因。
不过如果仅用来识别验证码倒是够了。
数据集
依赖
import cv2
import matplotlib.pyplot as plt
import numpy as np
import requests
from time import sleep
import base64
from io import BytesIO
from PIL import Image
函数
二值化
# 根据阈值将灰度图进行二值化,并裁剪图片边缘边框
# 185是这里写死的阈值,如果需要泛用性可以将其转移到函数参数
def threshold(img_gray):
_, img_b = cv2.threshold(img_gray, 185, 255, cv2.THRESH_BINARY)
contours, _ = cv2.findContours(img_b, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if contours:
# 找到最大轮廓
largest_contour = max(contours, key=cv2.contourArea)
x, y, w, h = cv2.boundingRect(largest_contour)
# 裁剪图像
img_b = img_b[y:y+h, x:x+w]
return img_b
寻找字符末尾
# 假定为白底黑字,如果某列白色占比大于 95 % 视为空
# 从 start 开始,找到第一列为“空”视为结束位置
# 可以优化为自动判别白底黑字或黑底白字
def find_end(start_, width, white, white_max):
end_ = start_ + 1
for m in range(start_ + 1, width - 1):
if white[m] > 0.95 * white_max:
end_ = m
break
return end_
裁剪
def clip(img_b):
white = []
black = []
height = img_b.shape[0]
width = img_b.shape[1]
white_max = 0
black_max = 0
# 竖着裁剪,去掉左右两侧空白
for i in range(width):
s = 0
t = 0
for j in range(height):
if img_b[j][i] == 255:
s += 1
if img_b[j][i] == 0:
t += 1
white_max = max(white_max, s)
black_max = max(black_max, t)
white.append(s)
black.append(t)
n = 1
start = 1
end = 2
bottom = 0
top = 0
cj = []
while n < width - 2:
n += 1
# 横着裁剪,去掉上下两侧空白
if black[n] > 0.05 * black_max:
start = n
end = find_end(start, width, white, white_max)
n = end
if end - start > 5:
for j in range(0, height):
if np.any(img_b[j, start:end] == 0):
bottom = j
break
for j in range(bottom, height):
if np.all(img_b[j, start:end] == 255):
top = j
break
cj.append(img_b[bottom:top, start:end])
return cj
填充图片
# 填充图片使其成为 28 * 28 的尺寸
# 左右同步、上下同步,最终字符居中
def fill28(img_ori):
if img_ori.shape[0] > 28 or img_ori.shape[1] > 28:
raise Exception(f"图片尺寸过大: {img_ori.shape}")
top = (28 - img_ori.shape[0]) // 2
bottom = 28 - img_ori.shape[0] - top
left = (28 - img_ori.shape[1]) // 2
right = 28 - img_ori.shape[1] - left
centered_image = cv2.copyMakeBorder(img_ori, top, bottom, left, right, cv2.BORDER_CONSTANT, value=255)
return centered_image
缩放
# 缩放图片,将大于 28 * 28 的图片按比例缩小
def resize_and_pad(image, target_size=(28, 28)):
# 缩放图像
scale_factor = min(target_size[0] / image.shape[0], target_size[1] / image.shape[1])
new_size = (int(image.shape[1] * scale_factor), int(image.shape[0] * scale_factor))
resized_image = cv2.resize(image, new_size)
# 创建一个全白的目标图像
padded_image = np.ones(target_size) * 255 # 填充为255(白色)
# 计算填充位置
top = (target_size[0] - new_size[1]) // 2
left = (target_size[1] - new_size[0]) // 2
# 将缩放后的图像放到新图像中
padded_image[top:top + new_size[1], left:left + new_size[0]] = resized_image
return padded_image
准备
offset = 0 # 图片序号
for i in range(100):
sleep(1)
# 调用 Java 验证码接口得到 Base64
res = requests.get('https://xxx/api/kaptCha')
image_data = base64.b64decode(res.json()['result']['image'].replace('data:image/png;base64,', ''))
# 将 Base64 转换为 ndarray
image = Image.open(BytesIO(image_data))
image_np = np.array(image)
# 转换为灰度图
image_gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)
# 二值化
image_cliped = threshold(image_gray)
# 裁剪
cj = clip(image_cliped)
for m in cj:
offset += 1
# 填充并保存
plt.imsave('./img/' + str(offset) + '.png', fill28(m), cmap='gray')
这里没有用到缩放,因为从接口得到的验证码尺寸经过裁剪后都是小于 28 * 28 的,如果是其他数据来源,就要添加对应逻辑。
图片保存到本地后,手动打标签,文件命名格式为 offset_label.png
训练
依赖
import os
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image
import shutil
import random
from io import BytesIO
import requests
import base64
import numpy as np
import cv2
import matplotlib.pyplot as plt
数据集准备
src_dir = 'img'
train_dir = 'train'
val_dir = 'val'
os.makedirs(train_dir, exist_ok=True)
os.makedirs(val_dir, exist_ok=True)
# 获取 img 目录下所有 .png 文件
all_images = [f for f in os.listdir(src_dir) if f.endswith('.png')]
# 打乱顺序,确保随机性
random.shuffle(all_images)
# 计算 80% 和 20% 的分割点
split_idx = int(0.8 * len(all_images))
# 分割成训练集和验证集
train_images = all_images[:split_idx]
val_images = all_images[split_idx:]
# 移动文件到训练集目录
for img in train_images:
shutil.move(os.path.join(src_dir, img), os.path.join(train_dir, img))
# 移动文件到验证集目录
for img in val_images:
shutil.move(os.path.join(src_dir, img), os.path.join(val_dir, img))
print(f'Total images: {len(all_images)}')
print(f'Training set: {len(train_images)} images')
print(f'Validation set: {len(val_images)} images')
数据集读取
# 数据处理
transform = transforms.Compose([
transforms.Resize((28, 28)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
class CustomDataset(Dataset):
def __init__(self, image_dir, transform=None):
self.image_dir = image_dir
self.transform = transform
self.image_files = [f for f in os.listdir(image_dir) if f.endswith('.png')]
self.classes = list(set(f.split('_')[1].split('.')[0] for f in self.image_files)) # 提取类别并去重
self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)} # 类别到索引的映射
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
image_file = self.image_files[idx]
img_path = os.path.join(self.image_dir, image_file)
image = Image.open(img_path).convert('L') # 灰度图
label = self.class_to_idx[image_file.split('_')[1].split('.')[0]] # 提取类别并转换为索引
if self.transform:
image = self.transform(image)
return image, label
# 加载数据集
train_dataset = CustomDataset(image_dir='train', transform=transform)
val_dataset = CustomDataset(image_dir='val', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
模型定义
class CNNModel(nn.Module):
def __init__(self):
super(CNNModel, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.fc1 = nn.Linear(64 * 7 * 7, 128)
self.fc2 = nn.Linear(128, len(train_dataset.classes)) # 分类数
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2)
x = x.view(-1, 64 * 7 * 7) # 展平
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
训练
# 初始化模型、损失函数和优化器
device = torch.device('cuda')
model = CNNModel().to(device)
criterion = nn.CrossEntropyLoss()
epochs = 0
# weight_decay 用来减轻过拟合
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
# 训练循环
num_epochs = 10
epochs += num_epochs
for epoch in range(num_epochs):
model.train()
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.8f}')
# 验证
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in val_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Accuracy of the model on the validation set: {100 * correct / total:.2f}%')
保存
checkpoint = {
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epoch': epochs,
'loss': loss,
}
torch.save(checkpoint, 'checkpoints/model_num_checkpoints_1.pth')
训练历程
在添加旋转前,仅使用 0° 数据集很轻松达到了验证集 100% 的准确度,但发现没有提取到有效特征。
例如数字 7 只提取到了斜着的一竖,而没注意到上面还有一横,当传入的图片是斜着的 1 时,会被误识别为 7。虽说最后添加旋转也没解决这个问题。
因此添加了两种角度的旋转,重新训练
最初使用 lr=0.001
学习率,5 个 epoch 就可以收敛,但准确率只有 7%,判断是初始学习率太低无法跳出局部最优点。
重新开始训练,分别使用 0.002
、0.005
、0.01
训练 20 个 epoch,前两者得到同样结果,快速收敛,但准确率依旧为 7%,而 0.01 的学习率则始终在 0.02 的 loss 徘徊无法收敛,但准确率为 21%
最终的配方是:
lr=0.02, 20 epochs
lr=0.018, 20 epochs * 2, 此时准确率为 75%
lr=0.016, 20 epochs, 准确率为 76.25%
测试
调用接口得到验证码图片,使用数据集准备过程中的二值化、裁剪等步骤,得到单独的各个数字进行识别
# 索引与 label 的映射关系
key_to_class = {v: k for k, v in train_dataset.class_to_idx.items()}
# 接口
res = requests.get('https://xxx/api/kaptCha')
image_data = base64.b64decode(res.json()['result']['image'].replace('data:image/png;base64,', ''))
# 转图片
image = Image.open(BytesIO(image_data))
image_np = np.array(image)
# 灰度图
image_gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)
# 二值化
image_cliped = threshold(image_gray)
# 裁剪
cj = clip(image_cliped)
# 预测
answer = ''
for d in cj:
image = transform(Image.fromarray(fill28(d)).convert('L')).to(device)
outputs = model(image)
_, predicted = torch.max(outputs.data, 1)
answer += key_to_class[predicted.item()]
print(answer)
plt.imshow(image_np)