"""
完整参数版 RSMamba 训练脚本
支持与baseline完全一致的参数，用于公平对比实验


使用方法 (对比实验):
cd /home/yuan/data/Kurosiwo_Honduras/KuroSiwoMamba && \
export CUDA_VISIBLE_DEVICES=0 && \
export PYTHONPATH=/home/yuan/data/Kurosiwo_Honduras/KuroSiwoMamba:$PYTHONPATH && \
nohup python -u train_rsmamba_full.py \
    --model_type rsmamba_unet \
    --task segmentation \
    --epochs 100 \
    --batch_size 32 \
    --lr 1e-4 \
    --embed_dim 96 \
    --depth 6 \
    --weight_decay 0.05 \
    --loss_function focal+dice \
    --focal_weight 0.35 \
    --dice_weight 0.65 \
    --focal_gamma 2.5 \
    --class_weights '[0.2, 2.5, 4.0]' \
    --scheduler cosine \
    --warmup_epochs 10 \
    --min_lr 1e-6 \
    --gradient_accumulation_steps 1 \
    --grad_clip 1.0 \
    > checkpoints/UNetRSMamba_FloodFocus/train.log 2>&1 &
"""

import argparse
import json
import os
import sys
import random
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

try:
    import wandb
    WANDB_AVAILABLE = True
except ImportError:
    WANDB_AVAILABLE = False
    print("⚠️ wandb未安装，将禁用W&B日志功能")

sys.path.insert(0, str(Path(__file__).parent))

from models.rs_mamba import RSMambaSegmentation, MultiScaleRSMamba
from models.rs_mamba_unet import UNetRSMamba
from utilities.utilities import prepare_loaders
from torchmetrics import Accuracy, F1Score, Precision, Recall
from torchmetrics.classification import MulticlassJaccardIndex

CLASS_LABELS = {0: "No water", 1: "Permanent Waters", 2: "Floods"}


class EMA:
    """指数移动平均"""
    def __init__(self, model, decay=0.999):
        self.model = model
        self.decay = decay
        self.shadow = {}
        self.backup = {}
        self.register()
    
    def register(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()
    
    def update(self, model):
        for name, param in model.named_parameters():
            if param.requires_grad:
                assert name in self.shadow
                new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
                self.shadow[name] = new_average.clone()
    
    def apply_shadow(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                assert name in self.shadow
                self.backup[name] = param.data
                param.data = self.shadow[name]
    
    def restore(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                assert name in self.backup
                param.data = self.backup[name]
        self.backup = {}


class EarlyStopping:
    """Early Stopping"""
    def __init__(self, patience=10, min_delta=0.0, mode='max'):
        self.patience = patience
        self.min_delta = min_delta
        self.mode = mode
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        
    def __call__(self, score):
        if self.best_score is None:
            self.best_score = score
            return False
        
        if self.mode == 'max':
            if score < self.best_score + self.min_delta:
                self.counter += 1
            else:
                self.best_score = score
                self.counter = 0
        else:
            if score > self.best_score - self.min_delta:
                self.counter += 1
            else:
                self.best_score = score
                self.counter = 0
        
        if self.counter >= self.patience:
            self.early_stop = True
            return True
        return False


def parse_args():
    parser = argparse.ArgumentParser(description='完整参数RSMamba训练 - 对比实验')
    
    # 模型选择
    parser.add_argument('--model_type', type=str, default='rsmamba_unet',
                        choices=['rsmamba', 'rsmamba_multiscale', 'rsmamba_unet'],
                        help='模型类型')
    parser.add_argument('--task', type=str, default='segmentation',
                        choices=['segmentation', 'change_detection'],
                        help='任务类型')
    
    # 基础训练参数
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--epochs', type=int, default=80)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--weight_decay', type=float, default=0.08)
    parser.add_argument('--seed', type=int, default=42)
    
    # 模型参数
    parser.add_argument('--img_size', type=int, default=224)
    parser.add_argument('--embed_dim', type=int, default=96,
                        help='基础嵌入维度')
    parser.add_argument('--depth', type=int, default=6,
                        help='基础深度')
    parser.add_argument('--embed_dims', type=str, default='[96,192,384,768]',
                        help='多尺度维度 (rsmamba_unet)')
    parser.add_argument('--depths', type=str, default='[2,2,6,2]',
                        help='多尺度深度 (rsmamba_unet)')
    
    # 损失函数参数
    parser.add_argument('--loss_function', type=str, default='focal+dice',
                        choices=['ce', 'focal', 'dice', 'focal+dice', 'ce+dice'])
    parser.add_argument('--focal_weight', type=float, default=0.4)
    parser.add_argument('--dice_weight', type=float, default=0.6)
    parser.add_argument('--focal_gamma', type=float, default=2.5)
    parser.add_argument('--class_weights', type=str, default='[0.2, 2.5, 3.5]')
    
    # 学习率调度器
    parser.add_argument('--scheduler', type=str, default='cosine',
                        choices=['onecycle', 'cosine', 'plateau', 'step'])
    parser.add_argument('--warmup_epochs', type=int, default=5)
    parser.add_argument('--min_lr', type=float, default=1e-6)
    parser.add_argument('--patience', type=int, default=10)
    
    # 高级训练选项
    parser.add_argument('--gradient_accumulation_steps', type=int, default=1)
    parser.add_argument('--grad_clip', type=float, default=1.0)
    parser.add_argument('--early_stopping_patience', type=int, default=20)
    parser.add_argument('--ema_decay', type=float, default=0.999)
    parser.add_argument('--use_tta', action='store_true')
    
    # 输出
    parser.add_argument('--output_dir', type=str, default='checkpoints/UNetRSMamba_Compare')
    parser.add_argument('--checkpoint', type=str, default=None)
    parser.add_argument('--experiment_name', type=str, default=None)
    
    # W&B 日志
    parser.add_argument('--use_wandb', action='store_true',
                        help='使用 Weights & Biases 记录训练过程')
    parser.add_argument('--wandb_project', type=str, default='KuroSiwo-RSMamba-Flood',
                        help='W&B 项目名称')
    
    # 测试模式
    parser.add_argument('--test_only', action='store_true',
                        help='仅进行测试评估（需要提供checkpoint）')
    
    return parser.parse_args()


def set_seed(seed):
    """设置随机种子"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


def create_model(configs):
    """创建模型"""
    model_type = configs['model_type']
    embed_dim = configs['embed_dim']
    depth = configs['depth']
    
    print(f"\n{'='*70}")
    print(f"创建模型: {model_type}")
    print(f"{'='*70}")
    print(f"  基础参数: embed_dim={embed_dim}, depth={depth}")
    
    if model_type == 'rsmamba':
        print("  ⚠️  原始单尺度RSMamba")
        model = RSMambaSegmentation(
            img_size=configs['img_size'],
            patch_size=16,
            in_channels=configs['num_channels'],
            embed_dim=embed_dim * 8,  # 保持相对比例
            depth=depth * 2,
            num_classes=configs['num_classes'],
            d_state=16,
            d_conv=4,
            expand=2,
        )
    
    elif model_type == 'rsmamba_multiscale':
        print("  ✨ 多尺度RSMamba")
        # 根据embed_dim缩放
        scale = embed_dim / 96.0
        base_dims = [int(192*scale), int(384*scale), int(768*scale)]
        print(f"  多尺度维度: {base_dims}")
        
        model = MultiScaleRSMamba(
            img_size=configs['img_size'],
            in_channels=configs['num_channels'],
            num_classes=configs['num_classes'],
            embed_dims=base_dims,
            depths=[max(1, depth//3), max(2, depth//2), depth],
            d_state=16,
            d_conv=4,
            expand=2,
            drop_path_rate=0.1,
        )
    
    elif model_type == 'rsmamba_unet':
        print("  🔥 UNet风格RSMamba")
        
        # 智能缩放：根据embed_dim和depth调整
        if configs.get('auto_scale', True):
            # 根据embed_dim计算多尺度维度
            scale = embed_dim / 96.0
            base_dims = [
                int(96 * scale),
                int(192 * scale),
                int(384 * scale),
                int(768 * scale)
            ]
            # 根据depth调整每层深度
            total_depth = depth * 2  # 总深度
            base_depths = [
                max(1, total_depth // 8),
                max(1, total_depth // 8),
                max(2, total_depth // 2),
                max(1, total_depth // 8)
            ]
        else:
            # 使用用户指定的维度
            base_dims = configs['embed_dims']
            base_depths = configs['depths']
        
        print(f"  多尺度维度: {base_dims}")
        print(f"  多尺度深度: {base_depths}")
        
        model = UNetRSMamba(
            img_size=configs['img_size'],
            in_channels=configs['num_channels'],
            num_classes=configs['num_classes'],
            embed_dims=base_dims,
            depths=base_depths,
            d_state=16,
            drop_path_rate=0.1,
        )
    
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"\n参数统计:")
    print(f"  总参数: {total_params/1e6:.2f}M")
    print(f"  可训练: {trainable_params/1e6:.2f}M")
    print(f"{'='*70}\n")
    
    return model


def create_loss_function(configs, device):
    """创建损失函数"""
    loss_type = configs['loss_function']
    class_weights = torch.tensor(configs['class_weights']).to(device)
    
    print(f"\n{'='*60}")
    print(f"损失函数: {loss_type}")
    print(f"类别权重: {configs['class_weights']}")
    print(f"{'='*60}\n")
    
    if loss_type == 'focal+dice':
        from utilities.focal_loss import FocalDiceLoss
        criterion = FocalDiceLoss(
            focal_weight=configs['focal_weight'],
            dice_weight=configs['dice_weight'],
            focal_alpha=class_weights.cpu().tolist(),
            focal_gamma=configs['focal_gamma'],
            ignore_index=3
        ).to(device)
    elif loss_type == 'focal':
        from utilities.focal_loss import FocalLoss
        criterion = FocalLoss(
            alpha=class_weights.cpu().tolist(),
            gamma=configs['focal_gamma'],
            reduction='mean',
            ignore_index=3
        ).to(device)
    elif loss_type == 'ce+dice':
        from utilities.bce_and_dice import BCEandDiceLoss
        criterion = BCEandDiceLoss(
            weights=class_weights.cpu().tolist(),
            ignore_index=3,
            use_softmax=True
        ).to(device)
    elif loss_type == 'dice':
        from utilities.dice import DiceLoss
        criterion = DiceLoss(ignore_index=3).to(device)
    else:
        criterion = nn.CrossEntropyLoss(weight=class_weights, ignore_index=3)
    
    return criterion


def unpack_batch(batch):
    """解包batch"""
    batch_size = len(batch)
    if batch_size == 6:
        image, mask, pre_event, pre_event_2, clz, activation = batch
    elif batch_size == 7:
        image, mask, pre_event, pre_event_2, dem, clz, activation = batch
    elif batch_size == 12:
        _, _, image, mask, _, _, pre_event, _, _, pre_event_2, clz, activation = batch
    elif batch_size == 13:
        _, _, image, mask, _, _, pre_event, _, _, pre_event_2, dem, clz, activation = batch
    else:
        raise ValueError(f"Unexpected batch size: {batch_size}")
    
    return image, mask, pre_event, pre_event_2


def train_epoch(model, train_loader, criterion, optimizer, scheduler, configs, epoch, ema_model=None, metrics=None):
    """训练一个epoch"""
    model.train()
    train_loss = 0.0
    device = configs['device']
    grad_accum_steps = configs['gradient_accumulation_steps']
    grad_clip = configs['grad_clip']
    
    # 解包metrics
    if metrics is not None:
        accuracy, fscore, precision, recall, iou = metrics
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{configs['epochs']}")
    for batch_idx, batch in enumerate(pbar):
        if batch_idx % grad_accum_steps == 0:
            optimizer.zero_grad()
        
        try:
            image, mask, pre_event, pre_event_2 = unpack_batch(batch)
        except ValueError:
            continue
        
        image = image.to(device)
        mask = mask.to(device)
        pre_event = pre_event.to(device)
        pre_event_2 = pre_event_2.to(device)
        x = torch.cat([pre_event_2, pre_event, image], dim=1)
        
        # 前向传播
        output = model(x)
        predictions = output.argmax(1)
        loss = criterion(output, mask) / grad_accum_steps
        
        # 反向传播
        loss.backward()
        
        if (batch_idx + 1) % grad_accum_steps == 0:
            if grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            optimizer.step()
            if scheduler is not None and configs['scheduler'] == 'onecycle':
                scheduler.step()
            if ema_model is not None:
                ema_model.update(model)
        
        train_loss += loss.item() * image.size(0) * grad_accum_steps
        
        # 计算实时mIoU（用于进度条显示）
        if metrics is not None:
            valid_mask = (mask != 3)
            if valid_mask.sum() > 0:
                valid_predictions = predictions[valid_mask]
                valid_targets = mask[valid_mask]
                acc = accuracy(valid_predictions, valid_targets)
                score = fscore(valid_predictions, valid_targets)
                ious = iou(valid_predictions, valid_targets)
                mean_iou = ious.mean()
            else:
                mean_iou = torch.tensor(0.0, device=device)
            
            # 更新进度条（包含mIoU）
            pbar.set_postfix({
                'loss': f'{loss.item()*grad_accum_steps:.4f}',
                'mIoU': f'{mean_iou.item()*100:.2f}%',
                'lr': f'{optimizer.param_groups[0]["lr"]:.6f}'
            })
            
            # W&B 日志 (每10个batch记录一次)
            if configs.get('use_wandb', False) and batch_idx % 10 == 0:
                wandb.log({
                    'train/loss': loss.item() * grad_accum_steps,
                    'train/mIoU': mean_iou.item() * 100,
                    'train/accuracy': acc.mean().item() * 100,
                    'train/f1': score.mean().item() * 100,
                    'train/lr': optimizer.param_groups[0]['lr'],
                    'epoch': epoch,
                })
        else:
            # 没有metrics时的简单进度条
            pbar.set_postfix({
                'loss': f'{loss.item()*grad_accum_steps:.4f}',
                'lr': f'{optimizer.param_groups[0]["lr"]:.6f}'
            })
    
    return train_loss / len(train_loader.dataset)


def validate(model, val_loader, criterion, configs, metrics, use_tta=False):
    """验证"""
    model.eval()
    val_loss = 0.0
    device = configs['device']
    accuracy, fscore, precision, recall, iou = metrics
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="验证"):
            try:
                image, mask, pre_event, pre_event_2 = unpack_batch(batch)
            except ValueError:
                continue
            
            image = image.to(device)
            mask = mask.to(device)
            pre_event = pre_event.to(device)
            pre_event_2 = pre_event_2.to(device)
            x = torch.cat([pre_event_2, pre_event, image], dim=1)
            
            if use_tta:
                # Test-Time Augmentation
                outputs = []
                outputs.append(model(x))
                outputs.append(torch.flip(model(torch.flip(x, dims=[3])), dims=[3]))
                outputs.append(torch.flip(model(torch.flip(x, dims=[2])), dims=[2]))
                output = torch.stack(outputs).mean(dim=0)
            else:
                output = model(x)
            
            loss = criterion(output, mask)
            predictions = output.argmax(1)
            
            val_loss += loss.item() * image.size(0)
            
            valid_mask = (mask != 3)
            if valid_mask.sum() > 0:
                valid_predictions = predictions[valid_mask]
                valid_targets = mask[valid_mask]
                accuracy(valid_predictions, valid_targets)
                fscore(valid_predictions, valid_targets)
                precision(valid_predictions, valid_targets)
                recall(valid_predictions, valid_targets)
                iou(valid_predictions, valid_targets)
    
    avg_loss = val_loss / len(val_loader.dataset)
    acc = accuracy.compute()
    score = fscore.compute()
    prec = precision.compute()
    rec = recall.compute()
    ious = iou.compute()
    mean_iou = ious.mean()
    
    accuracy.reset()
    fscore.reset()
    precision.reset()
    recall.reset()
    iou.reset()
    
    print(f"\n{'='*60}")
    print(f"验证结果 - Epoch {configs.get('current_epoch', 0)+1}")
    print(f"{'='*60}")
    print(f"损失: {avg_loss:.4f}")
    print(f"平均 IoU: {mean_iou.item()*100:.2f}%")
    print(f"\n各类别指标:")
    for i in range(len(CLASS_LABELS)):
        print(f"  {CLASS_LABELS[i]}:")
        print(f"    准确率:  {acc[i].item()*100:.2f}%")
        print(f"    F1分数:  {score[i].item()*100:.2f}%")
        print(f"    精确率:  {prec[i].item()*100:.2f}%")
        print(f"    召回率:  {rec[i].item()*100:.2f}%")
        print(f"    IoU:     {ious[i].item()*100:.2f}%")
    print(f"{'='*60}\n")
    
    # W&B 日志
    if configs.get('use_wandb', False):
        log_dict = {
            'val/loss': avg_loss,
            'val/mIoU': mean_iou.item() * 100,
            'val/accuracy': acc.mean().item() * 100,
            'val/f1': score.mean().item() * 100,
            'val/precision': prec.mean().item() * 100,
            'val/recall': rec.mean().item() * 100,
            'epoch': configs.get('current_epoch', 0),
        }
        # 每个类别的详细指标
        for i in range(len(CLASS_LABELS)):
            label = CLASS_LABELS[i]
            log_dict[f'val/accuracy_{label}'] = acc[i].item() * 100
            log_dict[f'val/f1_{label}'] = score[i].item() * 100
            log_dict[f'val/precision_{label}'] = prec[i].item() * 100
            log_dict[f'val/recall_{label}'] = rec[i].item() * 100
            log_dict[f'val/iou_{label}'] = ious[i].item() * 100
        wandb.log(log_dict)
    
    return avg_loss, mean_iou.item()


def main():
    args = parse_args()
    set_seed(args.seed)
    
    # 检查mamba_ssm库
    try:
        from mamba_ssm import Mamba
        print("\n✅ mamba_ssm库已安装，将使用官方实现")
    except ImportError:
        print("\n" + "="*70)
        print("⚠️  警告: mamba_ssm库未安装!")
        print("   当前使用降级实现，性能可能不稳定")
        print("   强烈建议安装官方库:")
        print("   pip install mamba-ssm==1.2.0")
        print("="*70 + "\n")
        import time
        time.sleep(3)  # 让用户看到警告
    
    # 配置
    configs = {
        'model_type': args.model_type,
        'task': args.task,
        'img_size': args.img_size,
        'embed_dim': args.embed_dim,
        'depth': args.depth,
        'batch_size': args.batch_size,
        'epochs': args.epochs,
        'learning_rate': args.lr,
        'weight_decay': args.weight_decay,
        'num_classes': 3,
        'num_channels': 6,
        'device': 'cuda' if torch.cuda.is_available() else 'cpu',
        'seed': args.seed,
        'auto_scale': True,  # 自动缩放维度
        
        # 损失函数
        'loss_function': args.loss_function,
        'focal_weight': args.focal_weight,
        'dice_weight': args.dice_weight,
        'focal_gamma': args.focal_gamma,
        'class_weights': eval(args.class_weights),
        
        # 调度器
        'scheduler': args.scheduler,
        'warmup_epochs': args.warmup_epochs,
        'min_lr': args.min_lr,
        'patience': args.patience,
        
        # 训练选项
        'gradient_accumulation_steps': args.gradient_accumulation_steps,
        'grad_clip': args.grad_clip,
        'early_stopping_patience': args.early_stopping_patience,
        'ema_decay': args.ema_decay,
        'use_tta': args.use_tta,
        
        # W&B 日志
        'use_wandb': args.use_wandb and WANDB_AVAILABLE,
        'wandb_project': args.wandb_project,
        'experiment_name': args.experiment_name,
        
        # 数据配置
        'root_path': str(Path(__file__).parent),
        'train_pickle': 'pickle/grid_dict_full.pkl',
        'test_pickle': 'pickle/grid_with_water_dict.pkl',
        'train_acts': [118, 427, 324, 411],
        'val_acts': [279, 417, 445],
        'test_acts': [497, 421, 502],
        'track': 'RandomEvents',
        'num_workers': 4,
        'inputs': ['pre_event_1', 'pre_event_2', 'post_event'],
        'channels': ['vv', 'vh'],
        'water_percentage': '[0,100]',
        'clamp_input': 0.15,
        'scale_input': 'normalize',
        'data_mean': [0.0953, 0.0264],
        'data_std': [0.0427, 0.0215],
        'dem': False,
        'slope': False,
        'dem_mean': 93.4313,
        'dem_std': 1410.8382,
        'slope_mean': 2.1277,
        'slope_std': 67.5048,
        'reverse_scaling': False,
        'uint8': False,
        'oversampling': False,
        'evaluate_water': True,
        'data_augmentations': True,
        'augmentation_config': 'configs/augmentations/augmentation_simple.json',
    }
    
    # 解析多尺度参数
    try:
        configs['embed_dims'] = eval(args.embed_dims)
        configs['depths'] = eval(args.depths)
    except:
        configs['embed_dims'] = [96, 192, 384, 768]
        configs['depths'] = [2, 2, 6, 2]
    
    # 加载augmentation配置
    if configs.get('data_augmentations', False):
        aug_config_path = os.path.join(
            configs['root_path'], 
            configs.get('augmentation_config', 'configs/augmentations/augmentation_simple.json')
        )
        
        if os.path.exists(aug_config_path):
            try:
                with open(aug_config_path, 'r') as f:
                    aug_config = json.load(f)
                    configs['augmentations'] = aug_config.get('augmentations', {})
                    print(f"✓ 成功加载数据增强配置: {aug_config_path}")
                    print(f"  包含 {len(configs['augmentations'])} 个增强类型")
            except Exception as e:
                print(f"⚠️  加载增强配置失败: {e}")
                configs['augmentations'] = {}
        else:
            print(f"⚠️  增强配置文件不存在: {aug_config_path}")
            configs['augmentations'] = {}
    else:
        print("数据增强已禁用")
        configs['augmentations'] = {}
    
    # 创建输出目录
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    configs['checkpoint_path'] = str(output_dir)
    
    with open(output_dir / 'config.json', 'w', encoding='utf-8') as f:
        json.dump(configs, f, indent=2, ensure_ascii=False)
    
    print(f"\n{'='*70}")
    print(f"完整参数 RSMamba 训练 - 对比实验")
    print(f"{'='*70}")
    print(f"模型: {configs['model_type']}")
    print(f"基础参数: embed_dim={configs['embed_dim']}, depth={configs['depth']}")
    print(f"批次大小: {configs['batch_size']}")
    print(f"梯度累积: {configs['gradient_accumulation_steps']} (有效batch={configs['batch_size']*configs['gradient_accumulation_steps']})")
    print(f"训练轮数: {configs['epochs']}")
    print(f"学习率: {configs['learning_rate']}")
    print(f"权重衰减: {configs['weight_decay']}")
    print(f"损失函数: {configs['loss_function']}")
    print(f"调度器: {configs['scheduler']}")
    print(f"EMA: {configs['ema_decay']}")
    print(f"TTA: {configs['use_tta']}")
    print(f"Early Stop: {configs['early_stopping_patience']}")
    print(f"种子: {configs['seed']}")
    print(f"输出: {output_dir}")
    print(f"{'='*70}\n")
    
    # 准备数据
    print("准备数据加载器...")
    train_loader, val_loader, test_loader = prepare_loaders(configs)
    print(f"✓ 训练: {len(train_loader.dataset)}")
    print(f"✓ 验证: {len(val_loader.dataset)}")
    print(f"✓ 测试: {len(test_loader.dataset)}")
    
    # 初始化 W&B
    if configs['use_wandb']:
        experiment_name = configs['experiment_name'] or f"RSMamba_{configs['model_type']}_lr{configs['learning_rate']}_bs{configs['batch_size']}"
        wandb.init(
            project=configs['wandb_project'],
            config=configs,
            name=experiment_name
        )
        print(f"\n✓ W&B 已初始化: {configs['wandb_project']}/{experiment_name}\n")
    
    # 创建模型
    model = create_model(configs).to(configs['device'])
    
    # 创建损失函数
    criterion = create_loss_function(configs, configs['device'])
    
    # 优化器
    optimizer = optim.AdamW(
        model.parameters(),
        lr=configs['learning_rate'],
        weight_decay=configs['weight_decay']
    )
    
    # 学习率调度器
    if configs['scheduler'] == 'onecycle':
        total_steps = len(train_loader) * configs['epochs'] // configs['gradient_accumulation_steps']
        scheduler = optim.lr_scheduler.OneCycleLR(
            optimizer, max_lr=configs['learning_rate'],
            total_steps=total_steps, pct_start=0.1
        )
    elif configs['scheduler'] == 'cosine':
        scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer, T_0=configs['epochs']-configs['warmup_epochs'],
            eta_min=configs['min_lr']
        )
    else:
        scheduler = None
    
    # 初始化指标
    def init_metrics(device):
        return (
            Accuracy(task="multiclass", num_classes=3, average=None).to(device),
            F1Score(task="multiclass", num_classes=3, average=None).to(device),
            Precision(task="multiclass", num_classes=3, average=None).to(device),
            Recall(task="multiclass", num_classes=3, average=None).to(device),
            MulticlassJaccardIndex(num_classes=3, average=None).to(device),
        )
    
    train_metrics = init_metrics(configs['device'])
    val_metrics = init_metrics(configs['device'])
    
    # EMA
    ema_model = EMA(model, decay=configs['ema_decay']) if configs['ema_decay'] > 0 else None
    
    # Early Stopping
    early_stopping = EarlyStopping(
        patience=configs['early_stopping_patience'], mode='max'
    ) if configs['early_stopping_patience'] > 0 else None
    
    # 恢复训练
    start_epoch = 0
    best_miou = 0.0
    if args.checkpoint:
        print(f"\n从checkpoint恢复: {args.checkpoint}")
        checkpoint = torch.load(args.checkpoint, map_location=configs['device'])
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint.get('epoch', 0) + 1
        best_miou = checkpoint.get('best_miou', 0.0)
        print(f"✓ 恢复到epoch {start_epoch}, 最佳mIoU: {best_miou*100:.2f}%\n")
    
    # 仅测试模式
    if args.test_only:
        print(f"\n{'='*70}")
        print(f"测试模式 - 评估模型性能")
        print(f"{'='*70}\n")
        
        # 必须提供checkpoint
        if not args.checkpoint:
            # 尝试从output_dir加载best_model.pt
            best_model_path = output_dir / 'best_model.pt'
            if best_model_path.exists():
                print(f"未指定checkpoint，使用默认路径: {best_model_path}")
                checkpoint = torch.load(best_model_path, map_location=configs['device'])
                model.load_state_dict(checkpoint['model_state_dict'])
                best_miou = checkpoint.get('best_miou', 0.0)
                print(f"✓ 加载最佳模型 (训练mIoU: {best_miou*100:.2f}%)\n")
            else:
                raise ValueError("测试模式需要提供 --checkpoint 或确保 best_model.pt 存在")
        
        # 在测试集上评估
        print("在测试集上评估...")
        test_loss, test_miou = validate(
            model, test_loader, criterion, configs, val_metrics,
            use_tta=configs['use_tta']
        )
        
        print(f"\n{'='*70}")
        print(f"测试结果")
        print(f"{'='*70}")
        print(f"测试集 Loss: {test_loss:.4f}")
        print(f"测试集 mIoU: {test_miou*100:.2f}%")
        if args.use_tta:
            print(f"(使用了测试时增强 TTA)")
        print(f"{'='*70}\n")
        
        print("✓ 测试完成")
        return
    
    # 训练循环
    print(f"\n{'='*70}")
    print(f"开始训练")
    print(f"{'='*70}\n")
    
    for epoch in range(start_epoch, configs['epochs']):
        # 设置当前epoch（用于日志）
        configs['current_epoch'] = epoch
        
        # 训练（传入metrics以计算实时mIoU）
        train_loss = train_epoch(
            model, train_loader, criterion, optimizer, scheduler,
            configs, epoch, ema_model, metrics=train_metrics
        )
        
        # 使用EMA模型验证
        if ema_model is not None:
            ema_model.apply_shadow()
        
        val_loss, val_miou = validate(
            model, val_loader, criterion, configs, val_metrics,
            use_tta=False  # 训练时不用TTA
        )
        
        if ema_model is not None:
            ema_model.restore()
        
        # 打印Epoch总结
        print(f"\nEpoch {epoch+1}/{configs['epochs']}: 训练损失={train_loss:.4f}, 验证mIoU={val_miou*100:.2f}%")
        
        # 更新调度器
        if scheduler is not None and configs['scheduler'] != 'onecycle':
            if configs['scheduler'] == 'plateau':
                scheduler.step(val_miou)
            else:
                scheduler.step()
        
        print(f"Epoch {epoch+1}: 训练损失={train_loss:.4f}, 验证mIoU={val_miou*100:.2f}%")
        
        # 保存最佳模型
        if val_miou > best_miou:
            best_miou = val_miou
            if ema_model is not None:
                ema_model.apply_shadow()
            checkpoint_dict = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_miou': best_miou,
                'config': configs
            }
            torch.save(checkpoint_dict, output_dir / 'best_model.pt')
            if ema_model is not None:
                ema_model.restore()
            print(f"  ✓ 保存最佳模型! mIoU: {best_miou*100:.2f}%\n")
        
        # Early Stopping
        if early_stopping is not None and early_stopping(val_miou):
            print(f"\n⚠️  Early Stopping触发! 最佳mIoU: {best_miou*100:.2f}%")
            break
        
        # 定期保存
        if (epoch + 1) % 10 == 0:
            checkpoint_dict = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_miou': best_miou,
                'config': configs
            }
            torch.save(checkpoint_dict, output_dir / f'checkpoint_epoch_{epoch+1}.pt')
    
    # 最终测试 (使用TTA)
    print(f"\n{'='*70}")
    print(f"训练完成! 在测试集上评估 (使用TTA)...")
    print(f"{'='*70}\n")
    
    checkpoint = torch.load(output_dir / 'best_model.pt', map_location=configs['device'])
    model.load_state_dict(checkpoint['model_state_dict'])
    
    test_loss, test_miou = validate(
        model, test_loader, criterion, configs, val_metrics,
        use_tta=configs['use_tta']
    )
    
    print(f"\n{'='*70}")
    print(f"最终结果")
    print(f"{'='*70}")
    print(f"最佳验证 mIoU: {best_miou*100:.2f}%")
    print(f"测试集 mIoU: {test_miou*100:.2f}%")
    print(f"{'='*70}\n")
    
    # W&B 最终结果
    if configs.get('use_wandb', False):
        wandb.log({
            'final/best_val_miou': best_miou * 100,
            'final/test_miou': test_miou * 100,
        })
        wandb.finish()
        print("✓ W&B 日志已完成")
    
    print(f"✓ 训练完成! 模型保存在: {output_dir}")


if __name__ == "__main__":
    main()

