import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from pyswarm import pso
# 定义简单的 CNN 模型
class SimpleCNN(nn.Module):
def __init__(self, num_classes=10):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
self.relu1 = nn.ReLU()
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.relu2 = nn.ReLU()
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(64 * 8 * 8, 512)
self.relu3 = nn.ReLU()
self.fc2 = nn.Linear(512, num_classes)
def forward(self, x):
x = self.conv1(x)
x = self.relu1(x)
x = self.pool1(x)
x = self.conv2(x)
x = self.relu2(x)
x = self.pool2(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu3(x)
x = self.fc2(x)
return x
# 定义目标函数
def objective_function(params):
# 将 PSO 的参数映射到 CNN 模型的相关参数
learning_rate = params[0]
weight_decay = params[1]
# 构建并初始化 CNN 模型
cnn_model = SimpleCNN(num_classes=10)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(cnn_model.parameters(), lr=learning_rate, weight_decay=weight_decay)
# 训练 CNN 模型
num_epochs = 5
for epoch in range(num_epochs):
# 训练过程略(根据实际情况填写)
# 在验证集上评估性能
validation_accuracy = evaluate_cnn_model(cnn_model, validation_data_loader) # 自行定义验证集评估函数
# 返回验证准确率,PSO 将尝试最小化这个值
return -validation_accuracy
# 定义参数的搜索范围
lb = [1e-5, 1e-6] # 学习率和权重衰减的下限
ub = [1e-2, 1e-3] # 学习率和权重衰减的上限
# 使用 PSO 进行参数优化
best_params, _ = pso(objective_function, lb, ub, swarmsize=10, maxiter=10)
# 输出最优参数
print("最优参数:", best_params)