2025-07-10 09:41:26 +08:00

311 lines
11 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from .models import get_model
from .data import RemoteSensingDataset, get_train_transforms, get_val_transforms
from .utils.metrics import calculate_metrics
from .utils.losses import get_loss_function
class UAVSegTrainer:
def __init__(self,
data_dir,
model_type='deeplabv3plus',
num_classes=6,
batch_size=4,
epochs=100,
learning_rate=0.001,
tile_size=512,
stride=None,
loss_type='ce',
device=None,
save_dir='./checkpoints'):
"""
Args:
data_dir (str): 数据目录,应包含'images''masks'子目录
model_type (str): 模型类型,可选 'deeplabv3plus''unetpp'
num_classes (int): 类别数量
batch_size (int): 批次大小
epochs (int): 训练轮数
learning_rate (float): 学习率
tile_size (int): 分块大小
stride (int): 滑动窗口步长如果为None则等于tile_size
loss_type (str): 损失函数类型,可选 'ce'(交叉熵), 'dice', 'focal', 'combined'
device (str): 设备如果为None则自动选择
save_dir (str): 模型保存目录
"""
self.data_dir = data_dir
self.model_type = model_type
self.num_classes = num_classes
self.batch_size = batch_size
self.epochs = epochs
self.learning_rate = learning_rate
self.tile_size = tile_size
self.stride = stride
self.loss_type = loss_type
self.save_dir = save_dir
# 设置设备
if device is None:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
self.device = torch.device(device)
print(f"使用设备: {self.device}")
# 创建保存目录
os.makedirs(save_dir, exist_ok=True)
# 初始化模型
self.model = get_model(model_type, num_classes)
self.model.to(self.device)
# 初始化损失函数
self.criterion = get_loss_function(loss_type, num_classes)
# 初始化优化器
self.optimizer = optim.Adam([
{'params': self.model.get_backbone_params(), 'lr': learning_rate * 0.1},
{'params': self.model.get_decoder_params(), 'lr': learning_rate}
])
# 学习率调度器
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer, mode='min', factor=0.5, patience=5, verbose=True
)
# 训练历史记录
self.history = {
'train_loss': [],
'val_loss': [],
'metrics': {}
}
def _get_dataloaders(self):
"""准备数据加载器"""
# 数据目录
train_img_dir = os.path.join(self.data_dir, 'train', 'images')
train_mask_dir = os.path.join(self.data_dir, 'train', 'masks')
val_img_dir = os.path.join(self.data_dir, 'val', 'images')
val_mask_dir = os.path.join(self.data_dir, 'val', 'masks')
# 数据增强和预处理
train_transform = get_train_transforms()
val_transform = get_val_transforms()
# 训练数据集
train_dataset = RemoteSensingDataset(
image_dir=train_img_dir,
mask_dir=train_mask_dir,
transform=train_transform,
tile_size=self.tile_size,
stride=self.stride
)
# 验证数据集
val_dataset = RemoteSensingDataset(
image_dir=val_img_dir,
mask_dir=val_mask_dir,
transform=val_transform,
tile_size=self.tile_size,
stride=self.tile_size # 验证时不重叠
)
# 数据加载器
train_loader = DataLoader(
train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=0,
pin_memory=True
)
val_loader = DataLoader(
val_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=0,
pin_memory=True
)
return train_loader, val_loader
def _train_epoch(self, train_loader):
"""训练一个轮次"""
self.model.train()
epoch_loss = 0
with tqdm(train_loader, desc="训练") as pbar:
for images, masks in pbar:
images = images.to(self.device)
masks = masks.to(self.device)
# 前向传播
outputs = self.model(images)
loss = self.criterion(outputs, masks)
# 反向传播和优化
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# 更新进度条
epoch_loss += loss.item()
pbar.set_postfix({'loss': loss.item()})
return epoch_loss / len(train_loader)
def _validate(self, val_loader):
"""验证模型"""
self.model.eval()
val_loss = 0
metrics_sum = None
samples_count = 0
with torch.no_grad():
with tqdm(val_loader, desc="验证") as pbar:
for images, masks in pbar:
images = images.to(self.device)
masks = masks.to(self.device)
batch_size = images.size(0)
# 前向传播
outputs = self.model(images)
loss = self.criterion(outputs, masks)
# 计算指标
batch_metrics = calculate_metrics(outputs, masks, self.num_classes)
# 累加损失和指标
val_loss += loss.item() * batch_size
if metrics_sum is None:
metrics_sum = {k: v * batch_size for k, v in batch_metrics.items()}
else:
for k, v in batch_metrics.items():
metrics_sum[k] += v * batch_size
samples_count += batch_size
# 更新进度条
pbar.set_postfix({'loss': loss.item()})
# 计算平均值
val_loss /= samples_count
metrics = {k: v / samples_count for k, v in metrics_sum.items()}
return val_loss, metrics
def train(self):
"""训练模型"""
# 准备数据加载器
train_loader, val_loader = self._get_dataloaders()
# 最佳验证损失
best_val_loss = float('inf')
# 训练循环
for epoch in range(1, self.epochs + 1):
print(f"\n轮次 {epoch}/{self.epochs}")
# 训练
train_loss = self._train_epoch(train_loader)
self.history['train_loss'].append(train_loss)
# 验证
val_loss, metrics = self._validate(val_loader)
self.history['val_loss'].append(val_loss)
# 更新学习率
self.scheduler.step(val_loss)
# 保存指标
for k, v in metrics.items():
if k not in self.history['metrics']:
self.history['metrics'][k] = []
self.history['metrics'][k].append(v)
# 打印指标
print(f"训练损失: {train_loss:.4f}, 验证损失: {val_loss:.4f}")
for k, v in metrics.items():
print(f"{k}: {v:.4f}")
# 保存最佳模型
if val_loss < best_val_loss:
best_val_loss = val_loss
self._save_model(f"{self.model_type}_best.pth")
print(f"保存最佳模型,验证损失: {val_loss:.4f}")
# 保存最新模型
if epoch % 10 == 0:
self._save_model(f"{self.model_type}_epoch_{epoch}.pth")
# 保存最终模型
self._save_model(f"{self.model_type}_final.pth")
# 绘制训练曲线
self._plot_training_curves()
return self.history
def _save_model(self, filename):
"""保存模型"""
save_path = os.path.join(self.save_dir, filename)
torch.save({
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'model_type': self.model_type,
'num_classes': self.num_classes,
}, save_path)
def _plot_training_curves(self):
"""绘制训练曲线"""
# 创建图形
plt.figure(figsize=(15, 5))
#设置宋体
plt.rcParams['font.sans-serif'] = ['SimSun']
plt.rcParams['axes.unicode_minus'] = False
# 绘制损失曲线
plt.subplot(1, 2, 1)
plt.plot(self.history['train_loss'], label='训练损失')
plt.plot(self.history['val_loss'], label='验证损失')
plt.xlabel('轮次')
plt.ylabel('损失')
plt.legend()
plt.title('训练和验证损失')
# 绘制指标曲线
plt.subplot(1, 2, 2)
for metric_name, metric_values in self.history['metrics'].items():
plt.plot(metric_values, label=metric_name)
plt.xlabel('轮次')
plt.ylabel('指标值')
plt.legend()
plt.title('验证指标')
# 保存图形
plt.tight_layout()
plt.savefig(os.path.join(self.save_dir, f"{self.model_type}_training_curves.png"))
plt.close()
def load_model(self, checkpoint_path):
"""加载模型"""
checkpoint = torch.load(checkpoint_path, map_location=self.device)
# 检查模型类型
if checkpoint['model_type'] != self.model_type:
print(f"警告: 加载的模型类型 ({checkpoint['model_type']}) 与当前模型类型 ({self.model_type}) 不匹配")
# 加载模型参数
self.model.load_state_dict(checkpoint['model_state_dict'])
# 加载优化器参数(如果存在)
if 'optimizer_state_dict' in checkpoint:
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
print(f"模型已从 {checkpoint_path} 加载")