import os import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader import numpy as np from tqdm import tqdm import matplotlib.pyplot as plt from .models import get_model from .data import RemoteSensingDataset, get_train_transforms, get_val_transforms from .utils.metrics import calculate_metrics from .utils.losses import get_loss_function class UAVSegTrainer: def __init__(self, data_dir, model_type='deeplabv3plus', num_classes=6, batch_size=4, epochs=100, learning_rate=0.001, tile_size=512, stride=None, loss_type='ce', device=None, save_dir='./checkpoints'): """ Args: data_dir (str): 数据目录,应包含'images'和'masks'子目录 model_type (str): 模型类型,可选 'deeplabv3plus' 或 'unetpp' num_classes (int): 类别数量 batch_size (int): 批次大小 epochs (int): 训练轮数 learning_rate (float): 学习率 tile_size (int): 分块大小 stride (int): 滑动窗口步长,如果为None则等于tile_size loss_type (str): 损失函数类型,可选 'ce'(交叉熵), 'dice', 'focal', 'combined' device (str): 设备,如果为None则自动选择 save_dir (str): 模型保存目录 """ self.data_dir = data_dir self.model_type = model_type self.num_classes = num_classes self.batch_size = batch_size self.epochs = epochs self.learning_rate = learning_rate self.tile_size = tile_size self.stride = stride self.loss_type = loss_type self.save_dir = save_dir # 设置设备 if device is None: self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') else: self.device = torch.device(device) print(f"使用设备: {self.device}") # 创建保存目录 os.makedirs(save_dir, exist_ok=True) # 初始化模型 self.model = get_model(model_type, num_classes) self.model.to(self.device) # 初始化损失函数 self.criterion = get_loss_function(loss_type, num_classes) # 初始化优化器 self.optimizer = optim.Adam([ {'params': self.model.get_backbone_params(), 'lr': learning_rate * 0.1}, {'params': self.model.get_decoder_params(), 'lr': learning_rate} ]) # 学习率调度器 self.scheduler = optim.lr_scheduler.ReduceLROnPlateau( self.optimizer, mode='min', factor=0.5, patience=5, verbose=True ) # 训练历史记录 self.history = { 'train_loss': [], 'val_loss': [], 'metrics': {} } def _get_dataloaders(self): """准备数据加载器""" # 数据目录 train_img_dir = os.path.join(self.data_dir, 'train', 'images') train_mask_dir = os.path.join(self.data_dir, 'train', 'masks') val_img_dir = os.path.join(self.data_dir, 'val', 'images') val_mask_dir = os.path.join(self.data_dir, 'val', 'masks') # 数据增强和预处理 train_transform = get_train_transforms() val_transform = get_val_transforms() # 训练数据集 train_dataset = RemoteSensingDataset( image_dir=train_img_dir, mask_dir=train_mask_dir, transform=train_transform, tile_size=self.tile_size, stride=self.stride ) # 验证数据集 val_dataset = RemoteSensingDataset( image_dir=val_img_dir, mask_dir=val_mask_dir, transform=val_transform, tile_size=self.tile_size, stride=self.tile_size # 验证时不重叠 ) # 数据加载器 train_loader = DataLoader( train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=0, pin_memory=True ) val_loader = DataLoader( val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=0, pin_memory=True ) return train_loader, val_loader def _train_epoch(self, train_loader): """训练一个轮次""" self.model.train() epoch_loss = 0 with tqdm(train_loader, desc="训练") as pbar: for images, masks in pbar: images = images.to(self.device) masks = masks.to(self.device) # 前向传播 outputs = self.model(images) loss = self.criterion(outputs, masks) # 反向传播和优化 self.optimizer.zero_grad() loss.backward() self.optimizer.step() # 更新进度条 epoch_loss += loss.item() pbar.set_postfix({'loss': loss.item()}) return epoch_loss / len(train_loader) def _validate(self, val_loader): """验证模型""" self.model.eval() val_loss = 0 metrics_sum = None samples_count = 0 with torch.no_grad(): with tqdm(val_loader, desc="验证") as pbar: for images, masks in pbar: images = images.to(self.device) masks = masks.to(self.device) batch_size = images.size(0) # 前向传播 outputs = self.model(images) loss = self.criterion(outputs, masks) # 计算指标 batch_metrics = calculate_metrics(outputs, masks, self.num_classes) # 累加损失和指标 val_loss += loss.item() * batch_size if metrics_sum is None: metrics_sum = {k: v * batch_size for k, v in batch_metrics.items()} else: for k, v in batch_metrics.items(): metrics_sum[k] += v * batch_size samples_count += batch_size # 更新进度条 pbar.set_postfix({'loss': loss.item()}) # 计算平均值 val_loss /= samples_count metrics = {k: v / samples_count for k, v in metrics_sum.items()} return val_loss, metrics def train(self): """训练模型""" # 准备数据加载器 train_loader, val_loader = self._get_dataloaders() # 最佳验证损失 best_val_loss = float('inf') # 训练循环 for epoch in range(1, self.epochs + 1): print(f"\n轮次 {epoch}/{self.epochs}") # 训练 train_loss = self._train_epoch(train_loader) self.history['train_loss'].append(train_loss) # 验证 val_loss, metrics = self._validate(val_loader) self.history['val_loss'].append(val_loss) # 更新学习率 self.scheduler.step(val_loss) # 保存指标 for k, v in metrics.items(): if k not in self.history['metrics']: self.history['metrics'][k] = [] self.history['metrics'][k].append(v) # 打印指标 print(f"训练损失: {train_loss:.4f}, 验证损失: {val_loss:.4f}") for k, v in metrics.items(): print(f"{k}: {v:.4f}") # 保存最佳模型 if val_loss < best_val_loss: best_val_loss = val_loss self._save_model(f"{self.model_type}_best.pth") print(f"保存最佳模型,验证损失: {val_loss:.4f}") # 保存最新模型 if epoch % 10 == 0: self._save_model(f"{self.model_type}_epoch_{epoch}.pth") # 保存最终模型 self._save_model(f"{self.model_type}_final.pth") # 绘制训练曲线 self._plot_training_curves() return self.history def _save_model(self, filename): """保存模型""" save_path = os.path.join(self.save_dir, filename) torch.save({ 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'model_type': self.model_type, 'num_classes': self.num_classes, }, save_path) def _plot_training_curves(self): """绘制训练曲线""" # 创建图形 plt.figure(figsize=(15, 5)) #设置宋体 plt.rcParams['font.sans-serif'] = ['SimSun'] plt.rcParams['axes.unicode_minus'] = False # 绘制损失曲线 plt.subplot(1, 2, 1) plt.plot(self.history['train_loss'], label='训练损失') plt.plot(self.history['val_loss'], label='验证损失') plt.xlabel('轮次') plt.ylabel('损失') plt.legend() plt.title('训练和验证损失') # 绘制指标曲线 plt.subplot(1, 2, 2) for metric_name, metric_values in self.history['metrics'].items(): plt.plot(metric_values, label=metric_name) plt.xlabel('轮次') plt.ylabel('指标值') plt.legend() plt.title('验证指标') # 保存图形 plt.tight_layout() plt.savefig(os.path.join(self.save_dir, f"{self.model_type}_training_curves.png")) plt.close() def load_model(self, checkpoint_path): """加载模型""" checkpoint = torch.load(checkpoint_path, map_location=self.device) # 检查模型类型 if checkpoint['model_type'] != self.model_type: print(f"警告: 加载的模型类型 ({checkpoint['model_type']}) 与当前模型类型 ({self.model_type}) 不匹配") # 加载模型参数 self.model.load_state_dict(checkpoint['model_state_dict']) # 加载优化器参数(如果存在) if 'optimizer_state_dict' in checkpoint: self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) print(f"模型已从 {checkpoint_path} 加载")