311 lines
11 KiB
Python
311 lines
11 KiB
Python
|
import os
|
|||
|
import torch
|
|||
|
import torch.nn as nn
|
|||
|
import torch.optim as optim
|
|||
|
from torch.utils.data import DataLoader
|
|||
|
import numpy as np
|
|||
|
from tqdm import tqdm
|
|||
|
import matplotlib.pyplot as plt
|
|||
|
|
|||
|
from .models import get_model
|
|||
|
from .data import RemoteSensingDataset, get_train_transforms, get_val_transforms
|
|||
|
from .utils.metrics import calculate_metrics
|
|||
|
from .utils.losses import get_loss_function
|
|||
|
|
|||
|
class UAVSegTrainer:
|
|||
|
|
|||
|
def __init__(self,
|
|||
|
data_dir,
|
|||
|
model_type='deeplabv3plus',
|
|||
|
num_classes=6,
|
|||
|
batch_size=4,
|
|||
|
epochs=100,
|
|||
|
learning_rate=0.001,
|
|||
|
tile_size=512,
|
|||
|
stride=None,
|
|||
|
loss_type='ce',
|
|||
|
device=None,
|
|||
|
save_dir='./checkpoints'):
|
|||
|
"""
|
|||
|
Args:
|
|||
|
data_dir (str): 数据目录,应包含'images'和'masks'子目录
|
|||
|
model_type (str): 模型类型,可选 'deeplabv3plus' 或 'unetpp'
|
|||
|
num_classes (int): 类别数量
|
|||
|
batch_size (int): 批次大小
|
|||
|
epochs (int): 训练轮数
|
|||
|
learning_rate (float): 学习率
|
|||
|
tile_size (int): 分块大小
|
|||
|
stride (int): 滑动窗口步长,如果为None则等于tile_size
|
|||
|
loss_type (str): 损失函数类型,可选 'ce'(交叉熵), 'dice', 'focal', 'combined'
|
|||
|
device (str): 设备,如果为None则自动选择
|
|||
|
save_dir (str): 模型保存目录
|
|||
|
"""
|
|||
|
self.data_dir = data_dir
|
|||
|
self.model_type = model_type
|
|||
|
self.num_classes = num_classes
|
|||
|
self.batch_size = batch_size
|
|||
|
self.epochs = epochs
|
|||
|
self.learning_rate = learning_rate
|
|||
|
self.tile_size = tile_size
|
|||
|
self.stride = stride
|
|||
|
self.loss_type = loss_type
|
|||
|
self.save_dir = save_dir
|
|||
|
|
|||
|
# 设置设备
|
|||
|
if device is None:
|
|||
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|||
|
else:
|
|||
|
self.device = torch.device(device)
|
|||
|
|
|||
|
print(f"使用设备: {self.device}")
|
|||
|
|
|||
|
# 创建保存目录
|
|||
|
os.makedirs(save_dir, exist_ok=True)
|
|||
|
|
|||
|
# 初始化模型
|
|||
|
self.model = get_model(model_type, num_classes)
|
|||
|
self.model.to(self.device)
|
|||
|
|
|||
|
# 初始化损失函数
|
|||
|
self.criterion = get_loss_function(loss_type, num_classes)
|
|||
|
|
|||
|
# 初始化优化器
|
|||
|
self.optimizer = optim.Adam([
|
|||
|
{'params': self.model.get_backbone_params(), 'lr': learning_rate * 0.1},
|
|||
|
{'params': self.model.get_decoder_params(), 'lr': learning_rate}
|
|||
|
])
|
|||
|
|
|||
|
# 学习率调度器
|
|||
|
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
|||
|
self.optimizer, mode='min', factor=0.5, patience=5, verbose=True
|
|||
|
)
|
|||
|
|
|||
|
# 训练历史记录
|
|||
|
self.history = {
|
|||
|
'train_loss': [],
|
|||
|
'val_loss': [],
|
|||
|
'metrics': {}
|
|||
|
}
|
|||
|
|
|||
|
def _get_dataloaders(self):
|
|||
|
"""准备数据加载器"""
|
|||
|
# 数据目录
|
|||
|
train_img_dir = os.path.join(self.data_dir, 'train', 'images')
|
|||
|
train_mask_dir = os.path.join(self.data_dir, 'train', 'masks')
|
|||
|
val_img_dir = os.path.join(self.data_dir, 'val', 'images')
|
|||
|
val_mask_dir = os.path.join(self.data_dir, 'val', 'masks')
|
|||
|
|
|||
|
# 数据增强和预处理
|
|||
|
train_transform = get_train_transforms()
|
|||
|
val_transform = get_val_transforms()
|
|||
|
|
|||
|
# 训练数据集
|
|||
|
train_dataset = RemoteSensingDataset(
|
|||
|
image_dir=train_img_dir,
|
|||
|
mask_dir=train_mask_dir,
|
|||
|
transform=train_transform,
|
|||
|
tile_size=self.tile_size,
|
|||
|
stride=self.stride
|
|||
|
)
|
|||
|
|
|||
|
# 验证数据集
|
|||
|
val_dataset = RemoteSensingDataset(
|
|||
|
image_dir=val_img_dir,
|
|||
|
mask_dir=val_mask_dir,
|
|||
|
transform=val_transform,
|
|||
|
tile_size=self.tile_size,
|
|||
|
stride=self.tile_size # 验证时不重叠
|
|||
|
)
|
|||
|
|
|||
|
# 数据加载器
|
|||
|
train_loader = DataLoader(
|
|||
|
train_dataset,
|
|||
|
batch_size=self.batch_size,
|
|||
|
shuffle=True,
|
|||
|
num_workers=0,
|
|||
|
pin_memory=True
|
|||
|
)
|
|||
|
|
|||
|
val_loader = DataLoader(
|
|||
|
val_dataset,
|
|||
|
batch_size=self.batch_size,
|
|||
|
shuffle=False,
|
|||
|
num_workers=0,
|
|||
|
pin_memory=True
|
|||
|
)
|
|||
|
|
|||
|
return train_loader, val_loader
|
|||
|
|
|||
|
def _train_epoch(self, train_loader):
|
|||
|
"""训练一个轮次"""
|
|||
|
self.model.train()
|
|||
|
epoch_loss = 0
|
|||
|
|
|||
|
with tqdm(train_loader, desc="训练") as pbar:
|
|||
|
for images, masks in pbar:
|
|||
|
images = images.to(self.device)
|
|||
|
masks = masks.to(self.device)
|
|||
|
|
|||
|
# 前向传播
|
|||
|
outputs = self.model(images)
|
|||
|
loss = self.criterion(outputs, masks)
|
|||
|
|
|||
|
# 反向传播和优化
|
|||
|
self.optimizer.zero_grad()
|
|||
|
loss.backward()
|
|||
|
self.optimizer.step()
|
|||
|
|
|||
|
# 更新进度条
|
|||
|
epoch_loss += loss.item()
|
|||
|
pbar.set_postfix({'loss': loss.item()})
|
|||
|
|
|||
|
return epoch_loss / len(train_loader)
|
|||
|
|
|||
|
def _validate(self, val_loader):
|
|||
|
"""验证模型"""
|
|||
|
self.model.eval()
|
|||
|
val_loss = 0
|
|||
|
metrics_sum = None
|
|||
|
samples_count = 0
|
|||
|
|
|||
|
with torch.no_grad():
|
|||
|
with tqdm(val_loader, desc="验证") as pbar:
|
|||
|
for images, masks in pbar:
|
|||
|
images = images.to(self.device)
|
|||
|
masks = masks.to(self.device)
|
|||
|
batch_size = images.size(0)
|
|||
|
|
|||
|
# 前向传播
|
|||
|
outputs = self.model(images)
|
|||
|
loss = self.criterion(outputs, masks)
|
|||
|
|
|||
|
# 计算指标
|
|||
|
batch_metrics = calculate_metrics(outputs, masks, self.num_classes)
|
|||
|
|
|||
|
# 累加损失和指标
|
|||
|
val_loss += loss.item() * batch_size
|
|||
|
if metrics_sum is None:
|
|||
|
metrics_sum = {k: v * batch_size for k, v in batch_metrics.items()}
|
|||
|
else:
|
|||
|
for k, v in batch_metrics.items():
|
|||
|
metrics_sum[k] += v * batch_size
|
|||
|
|
|||
|
samples_count += batch_size
|
|||
|
|
|||
|
# 更新进度条
|
|||
|
pbar.set_postfix({'loss': loss.item()})
|
|||
|
|
|||
|
# 计算平均值
|
|||
|
val_loss /= samples_count
|
|||
|
metrics = {k: v / samples_count for k, v in metrics_sum.items()}
|
|||
|
|
|||
|
return val_loss, metrics
|
|||
|
|
|||
|
def train(self):
|
|||
|
"""训练模型"""
|
|||
|
# 准备数据加载器
|
|||
|
train_loader, val_loader = self._get_dataloaders()
|
|||
|
|
|||
|
# 最佳验证损失
|
|||
|
best_val_loss = float('inf')
|
|||
|
|
|||
|
# 训练循环
|
|||
|
for epoch in range(1, self.epochs + 1):
|
|||
|
print(f"\n轮次 {epoch}/{self.epochs}")
|
|||
|
|
|||
|
# 训练
|
|||
|
train_loss = self._train_epoch(train_loader)
|
|||
|
self.history['train_loss'].append(train_loss)
|
|||
|
|
|||
|
# 验证
|
|||
|
val_loss, metrics = self._validate(val_loader)
|
|||
|
self.history['val_loss'].append(val_loss)
|
|||
|
|
|||
|
# 更新学习率
|
|||
|
self.scheduler.step(val_loss)
|
|||
|
|
|||
|
# 保存指标
|
|||
|
for k, v in metrics.items():
|
|||
|
if k not in self.history['metrics']:
|
|||
|
self.history['metrics'][k] = []
|
|||
|
self.history['metrics'][k].append(v)
|
|||
|
|
|||
|
# 打印指标
|
|||
|
print(f"训练损失: {train_loss:.4f}, 验证损失: {val_loss:.4f}")
|
|||
|
for k, v in metrics.items():
|
|||
|
print(f"{k}: {v:.4f}")
|
|||
|
|
|||
|
# 保存最佳模型
|
|||
|
if val_loss < best_val_loss:
|
|||
|
best_val_loss = val_loss
|
|||
|
self._save_model(f"{self.model_type}_best.pth")
|
|||
|
print(f"保存最佳模型,验证损失: {val_loss:.4f}")
|
|||
|
|
|||
|
# 保存最新模型
|
|||
|
if epoch % 10 == 0:
|
|||
|
self._save_model(f"{self.model_type}_epoch_{epoch}.pth")
|
|||
|
|
|||
|
# 保存最终模型
|
|||
|
self._save_model(f"{self.model_type}_final.pth")
|
|||
|
|
|||
|
# 绘制训练曲线
|
|||
|
self._plot_training_curves()
|
|||
|
|
|||
|
return self.history
|
|||
|
|
|||
|
def _save_model(self, filename):
|
|||
|
"""保存模型"""
|
|||
|
save_path = os.path.join(self.save_dir, filename)
|
|||
|
torch.save({
|
|||
|
'model_state_dict': self.model.state_dict(),
|
|||
|
'optimizer_state_dict': self.optimizer.state_dict(),
|
|||
|
'model_type': self.model_type,
|
|||
|
'num_classes': self.num_classes,
|
|||
|
}, save_path)
|
|||
|
|
|||
|
def _plot_training_curves(self):
|
|||
|
"""绘制训练曲线"""
|
|||
|
# 创建图形
|
|||
|
plt.figure(figsize=(15, 5))
|
|||
|
#设置宋体
|
|||
|
plt.rcParams['font.sans-serif'] = ['SimSun']
|
|||
|
plt.rcParams['axes.unicode_minus'] = False
|
|||
|
# 绘制损失曲线
|
|||
|
plt.subplot(1, 2, 1)
|
|||
|
plt.plot(self.history['train_loss'], label='训练损失')
|
|||
|
plt.plot(self.history['val_loss'], label='验证损失')
|
|||
|
plt.xlabel('轮次')
|
|||
|
plt.ylabel('损失')
|
|||
|
plt.legend()
|
|||
|
plt.title('训练和验证损失')
|
|||
|
|
|||
|
# 绘制指标曲线
|
|||
|
plt.subplot(1, 2, 2)
|
|||
|
for metric_name, metric_values in self.history['metrics'].items():
|
|||
|
plt.plot(metric_values, label=metric_name)
|
|||
|
plt.xlabel('轮次')
|
|||
|
plt.ylabel('指标值')
|
|||
|
plt.legend()
|
|||
|
plt.title('验证指标')
|
|||
|
|
|||
|
# 保存图形
|
|||
|
plt.tight_layout()
|
|||
|
plt.savefig(os.path.join(self.save_dir, f"{self.model_type}_training_curves.png"))
|
|||
|
plt.close()
|
|||
|
|
|||
|
def load_model(self, checkpoint_path):
|
|||
|
"""加载模型"""
|
|||
|
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
|||
|
|
|||
|
# 检查模型类型
|
|||
|
if checkpoint['model_type'] != self.model_type:
|
|||
|
print(f"警告: 加载的模型类型 ({checkpoint['model_type']}) 与当前模型类型 ({self.model_type}) 不匹配")
|
|||
|
|
|||
|
# 加载模型参数
|
|||
|
self.model.load_state_dict(checkpoint['model_state_dict'])
|
|||
|
|
|||
|
# 加载优化器参数(如果存在)
|
|||
|
if 'optimizer_state_dict' in checkpoint:
|
|||
|
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|||
|
|
|||
|
print(f"模型已从 {checkpoint_path} 加载")
|