311 lines
11 KiB
Python
311 lines
11 KiB
Python
import os
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.optim as optim
|
||
from torch.utils.data import DataLoader
|
||
import numpy as np
|
||
from tqdm import tqdm
|
||
import matplotlib.pyplot as plt
|
||
|
||
from .models import get_model
|
||
from .data import RemoteSensingDataset, get_train_transforms, get_val_transforms
|
||
from .utils.metrics import calculate_metrics
|
||
from .utils.losses import get_loss_function
|
||
|
||
class UAVSegTrainer:
|
||
|
||
def __init__(self,
|
||
data_dir,
|
||
model_type='deeplabv3plus',
|
||
num_classes=6,
|
||
batch_size=4,
|
||
epochs=100,
|
||
learning_rate=0.001,
|
||
tile_size=512,
|
||
stride=None,
|
||
loss_type='ce',
|
||
device=None,
|
||
save_dir='./checkpoints'):
|
||
"""
|
||
Args:
|
||
data_dir (str): 数据目录,应包含'images'和'masks'子目录
|
||
model_type (str): 模型类型,可选 'deeplabv3plus' 或 'unetpp'
|
||
num_classes (int): 类别数量
|
||
batch_size (int): 批次大小
|
||
epochs (int): 训练轮数
|
||
learning_rate (float): 学习率
|
||
tile_size (int): 分块大小
|
||
stride (int): 滑动窗口步长,如果为None则等于tile_size
|
||
loss_type (str): 损失函数类型,可选 'ce'(交叉熵), 'dice', 'focal', 'combined'
|
||
device (str): 设备,如果为None则自动选择
|
||
save_dir (str): 模型保存目录
|
||
"""
|
||
self.data_dir = data_dir
|
||
self.model_type = model_type
|
||
self.num_classes = num_classes
|
||
self.batch_size = batch_size
|
||
self.epochs = epochs
|
||
self.learning_rate = learning_rate
|
||
self.tile_size = tile_size
|
||
self.stride = stride
|
||
self.loss_type = loss_type
|
||
self.save_dir = save_dir
|
||
|
||
# 设置设备
|
||
if device is None:
|
||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||
else:
|
||
self.device = torch.device(device)
|
||
|
||
print(f"使用设备: {self.device}")
|
||
|
||
# 创建保存目录
|
||
os.makedirs(save_dir, exist_ok=True)
|
||
|
||
# 初始化模型
|
||
self.model = get_model(model_type, num_classes)
|
||
self.model.to(self.device)
|
||
|
||
# 初始化损失函数
|
||
self.criterion = get_loss_function(loss_type, num_classes)
|
||
|
||
# 初始化优化器
|
||
self.optimizer = optim.Adam([
|
||
{'params': self.model.get_backbone_params(), 'lr': learning_rate * 0.1},
|
||
{'params': self.model.get_decoder_params(), 'lr': learning_rate}
|
||
])
|
||
|
||
# 学习率调度器
|
||
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
||
self.optimizer, mode='min', factor=0.5, patience=5, verbose=True
|
||
)
|
||
|
||
# 训练历史记录
|
||
self.history = {
|
||
'train_loss': [],
|
||
'val_loss': [],
|
||
'metrics': {}
|
||
}
|
||
|
||
def _get_dataloaders(self):
|
||
"""准备数据加载器"""
|
||
# 数据目录
|
||
train_img_dir = os.path.join(self.data_dir, 'train', 'images')
|
||
train_mask_dir = os.path.join(self.data_dir, 'train', 'masks')
|
||
val_img_dir = os.path.join(self.data_dir, 'val', 'images')
|
||
val_mask_dir = os.path.join(self.data_dir, 'val', 'masks')
|
||
|
||
# 数据增强和预处理
|
||
train_transform = get_train_transforms()
|
||
val_transform = get_val_transforms()
|
||
|
||
# 训练数据集
|
||
train_dataset = RemoteSensingDataset(
|
||
image_dir=train_img_dir,
|
||
mask_dir=train_mask_dir,
|
||
transform=train_transform,
|
||
tile_size=self.tile_size,
|
||
stride=self.stride
|
||
)
|
||
|
||
# 验证数据集
|
||
val_dataset = RemoteSensingDataset(
|
||
image_dir=val_img_dir,
|
||
mask_dir=val_mask_dir,
|
||
transform=val_transform,
|
||
tile_size=self.tile_size,
|
||
stride=self.tile_size # 验证时不重叠
|
||
)
|
||
|
||
# 数据加载器
|
||
train_loader = DataLoader(
|
||
train_dataset,
|
||
batch_size=self.batch_size,
|
||
shuffle=True,
|
||
num_workers=0,
|
||
pin_memory=True
|
||
)
|
||
|
||
val_loader = DataLoader(
|
||
val_dataset,
|
||
batch_size=self.batch_size,
|
||
shuffle=False,
|
||
num_workers=0,
|
||
pin_memory=True
|
||
)
|
||
|
||
return train_loader, val_loader
|
||
|
||
def _train_epoch(self, train_loader):
|
||
"""训练一个轮次"""
|
||
self.model.train()
|
||
epoch_loss = 0
|
||
|
||
with tqdm(train_loader, desc="训练") as pbar:
|
||
for images, masks in pbar:
|
||
images = images.to(self.device)
|
||
masks = masks.to(self.device)
|
||
|
||
# 前向传播
|
||
outputs = self.model(images)
|
||
loss = self.criterion(outputs, masks)
|
||
|
||
# 反向传播和优化
|
||
self.optimizer.zero_grad()
|
||
loss.backward()
|
||
self.optimizer.step()
|
||
|
||
# 更新进度条
|
||
epoch_loss += loss.item()
|
||
pbar.set_postfix({'loss': loss.item()})
|
||
|
||
return epoch_loss / len(train_loader)
|
||
|
||
def _validate(self, val_loader):
|
||
"""验证模型"""
|
||
self.model.eval()
|
||
val_loss = 0
|
||
metrics_sum = None
|
||
samples_count = 0
|
||
|
||
with torch.no_grad():
|
||
with tqdm(val_loader, desc="验证") as pbar:
|
||
for images, masks in pbar:
|
||
images = images.to(self.device)
|
||
masks = masks.to(self.device)
|
||
batch_size = images.size(0)
|
||
|
||
# 前向传播
|
||
outputs = self.model(images)
|
||
loss = self.criterion(outputs, masks)
|
||
|
||
# 计算指标
|
||
batch_metrics = calculate_metrics(outputs, masks, self.num_classes)
|
||
|
||
# 累加损失和指标
|
||
val_loss += loss.item() * batch_size
|
||
if metrics_sum is None:
|
||
metrics_sum = {k: v * batch_size for k, v in batch_metrics.items()}
|
||
else:
|
||
for k, v in batch_metrics.items():
|
||
metrics_sum[k] += v * batch_size
|
||
|
||
samples_count += batch_size
|
||
|
||
# 更新进度条
|
||
pbar.set_postfix({'loss': loss.item()})
|
||
|
||
# 计算平均值
|
||
val_loss /= samples_count
|
||
metrics = {k: v / samples_count for k, v in metrics_sum.items()}
|
||
|
||
return val_loss, metrics
|
||
|
||
def train(self):
|
||
"""训练模型"""
|
||
# 准备数据加载器
|
||
train_loader, val_loader = self._get_dataloaders()
|
||
|
||
# 最佳验证损失
|
||
best_val_loss = float('inf')
|
||
|
||
# 训练循环
|
||
for epoch in range(1, self.epochs + 1):
|
||
print(f"\n轮次 {epoch}/{self.epochs}")
|
||
|
||
# 训练
|
||
train_loss = self._train_epoch(train_loader)
|
||
self.history['train_loss'].append(train_loss)
|
||
|
||
# 验证
|
||
val_loss, metrics = self._validate(val_loader)
|
||
self.history['val_loss'].append(val_loss)
|
||
|
||
# 更新学习率
|
||
self.scheduler.step(val_loss)
|
||
|
||
# 保存指标
|
||
for k, v in metrics.items():
|
||
if k not in self.history['metrics']:
|
||
self.history['metrics'][k] = []
|
||
self.history['metrics'][k].append(v)
|
||
|
||
# 打印指标
|
||
print(f"训练损失: {train_loss:.4f}, 验证损失: {val_loss:.4f}")
|
||
for k, v in metrics.items():
|
||
print(f"{k}: {v:.4f}")
|
||
|
||
# 保存最佳模型
|
||
if val_loss < best_val_loss:
|
||
best_val_loss = val_loss
|
||
self._save_model(f"{self.model_type}_best.pth")
|
||
print(f"保存最佳模型,验证损失: {val_loss:.4f}")
|
||
|
||
# 保存最新模型
|
||
if epoch % 10 == 0:
|
||
self._save_model(f"{self.model_type}_epoch_{epoch}.pth")
|
||
|
||
# 保存最终模型
|
||
self._save_model(f"{self.model_type}_final.pth")
|
||
|
||
# 绘制训练曲线
|
||
self._plot_training_curves()
|
||
|
||
return self.history
|
||
|
||
def _save_model(self, filename):
|
||
"""保存模型"""
|
||
save_path = os.path.join(self.save_dir, filename)
|
||
torch.save({
|
||
'model_state_dict': self.model.state_dict(),
|
||
'optimizer_state_dict': self.optimizer.state_dict(),
|
||
'model_type': self.model_type,
|
||
'num_classes': self.num_classes,
|
||
}, save_path)
|
||
|
||
def _plot_training_curves(self):
|
||
"""绘制训练曲线"""
|
||
# 创建图形
|
||
plt.figure(figsize=(15, 5))
|
||
#设置宋体
|
||
plt.rcParams['font.sans-serif'] = ['SimSun']
|
||
plt.rcParams['axes.unicode_minus'] = False
|
||
# 绘制损失曲线
|
||
plt.subplot(1, 2, 1)
|
||
plt.plot(self.history['train_loss'], label='训练损失')
|
||
plt.plot(self.history['val_loss'], label='验证损失')
|
||
plt.xlabel('轮次')
|
||
plt.ylabel('损失')
|
||
plt.legend()
|
||
plt.title('训练和验证损失')
|
||
|
||
# 绘制指标曲线
|
||
plt.subplot(1, 2, 2)
|
||
for metric_name, metric_values in self.history['metrics'].items():
|
||
plt.plot(metric_values, label=metric_name)
|
||
plt.xlabel('轮次')
|
||
plt.ylabel('指标值')
|
||
plt.legend()
|
||
plt.title('验证指标')
|
||
|
||
# 保存图形
|
||
plt.tight_layout()
|
||
plt.savefig(os.path.join(self.save_dir, f"{self.model_type}_training_curves.png"))
|
||
plt.close()
|
||
|
||
def load_model(self, checkpoint_path):
|
||
"""加载模型"""
|
||
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
||
|
||
# 检查模型类型
|
||
if checkpoint['model_type'] != self.model_type:
|
||
print(f"警告: 加载的模型类型 ({checkpoint['model_type']}) 与当前模型类型 ({self.model_type}) 不匹配")
|
||
|
||
# 加载模型参数
|
||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||
|
||
# 加载优化器参数(如果存在)
|
||
if 'optimizer_state_dict' in checkpoint:
|
||
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||
|
||
print(f"模型已从 {checkpoint_path} 加载") |