"""
UNet-style RSMamba for Segmentation
结合UNet的多尺度架构和RSMamba的门控融合机制

改进点:
1. 多尺度编码器 (4个尺度: 56×56, 28×28, 14×14, 7×7)
2. 跳跃连接 (保留高分辨率细节)
3. 渐进式下采样 (stride=2, 而不是一步patch_size=16)
4. RSMamba门控融合 (保留官方优势)
5. 混合卷积+Mamba (局部+全局)
"""

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple

try:
    from mamba_ssm import Mamba
    HAVE_MAMBA_SSM = True
except ImportError:
    HAVE_MAMBA_SSM = False
    print("⚠️  mamba_ssm库未安装，将使用降级实现")


class PatchMerging(nn.Module):
    """
    Patch Merging层 - 2x下采样
    类似Swin Transformer的下采样方式
    """
    def __init__(self, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)
    
    def forward(self, x, H, W):
        """
        Args:
            x: (B, H*W, C)
            H, W: 当前分辨率
        Returns:
            x: (B, H/2*W/2, 2*C)
        """
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        
        x = x.view(B, H, W, C)
        
        # 2×2窗口内的4个patch拼接
        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C
        
        x = self.norm(x)
        x = self.reduction(x)  # B H/2*W/2 2*C
        
        return x


class PatchExpanding(nn.Module):
    """
    Patch Expanding层 - 2x上采样
    """
    def __init__(self, dim, dim_scale=2, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.expand = nn.Linear(dim, 2 * dim, bias=False) if dim_scale == 2 else nn.Identity()
        self.norm = norm_layer(dim // dim_scale)
    
    def forward(self, x, H, W):
        """
        Args:
            x: (B, H*W, C)
        Returns:
            x: (B, 4*H*W, C/2)
        """
        B, L, C = x.shape
        assert L == H * W
        
        x = self.expand(x)
        x = x.view(B, H, W, 2 * C)
        
        # 重排为2x上采样
        x = x.view(B, H, W, 2, 2, C // 2)
        x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
        x = x.view(B, H * 2, W * 2, C // 2)
        x = x.view(B, -1, C // 2)
        x = self.norm(x)
        
        return x


class RSMambaLayer(nn.Module):
    """
    RS-Mamba Layer with Gate Fusion
    """
    def __init__(self, dim, d_state=16, d_conv=4, expand=2, drop_path=0.):
        super().__init__()
        self.dim = dim
        
        # Mamba mixer
        if HAVE_MAMBA_SSM:
            self.mamba = Mamba(
                d_model=dim,
                d_state=d_state,
                d_conv=d_conv,
                expand=expand,
            )
        else:
            # 简化实现
            self.mamba = SimpleMambaMixer(dim, d_state, d_conv, dim * expand)
        
        self.norm = nn.LayerNorm(dim)
        
        # 门控层 (3路径: forward, reverse, shuffle)
        self.gate_layer = nn.Sequential(
            nn.Linear(3 * dim, 3, bias=False),
            nn.Softmax(dim=-1)
        )
        
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
    
    def forward(self, x):
        """
        Args:
            x: (B, L, D)
        Returns:
            x: (B, L, D)
        """
        residual = x
        B, L, D = x.shape
        
        # Normalize
        x = self.norm(x)
        
        # 3个路径
        x_forward = x
        x_reverse = torch.flip(x, [1])
        rand_idx = torch.randperm(L, device=x.device)
        x_shuffle = x[:, rand_idx]
        
        # Batch处理
        x_all = torch.cat([x_forward, x_reverse, x_shuffle], dim=0)  # [3B, L, D]
        
        # Mamba处理
        if HAVE_MAMBA_SSM:
            x_all = self.mamba(x_all)
        else:
            x_all = self.mamba(x_all)[0]
        
        # 分离3个路径
        forward_out, reverse_out, shuffle_out = torch.chunk(x_all, 3, dim=0)
        
        # 恢复顺序
        reverse_out = torch.flip(reverse_out, [1])
        shuffle_out = shuffle_out[:, torch.argsort(rand_idx)]
        
        # 门控融合
        mean_forward = forward_out.mean(dim=1)
        mean_reverse = reverse_out.mean(dim=1)
        mean_shuffle = shuffle_out.mean(dim=1)
        
        gate = self.gate_layer(torch.cat([mean_forward, mean_reverse, mean_shuffle], dim=-1))
        gate = gate.unsqueeze(1)  # [B, 1, 3]
        
        # 加权融合
        output = (gate[:, :, 0:1] * forward_out + 
                 gate[:, :, 1:2] * reverse_out + 
                 gate[:, :, 2:3] * shuffle_out)
        
        # 残差连接
        output = residual + self.drop_path(output)
        
        return output


class SimpleMambaMixer(nn.Module):
    """简化的Mamba实现"""
    def __init__(self, dim, d_state=16, d_conv=4, d_inner=None):
        super().__init__()
        if d_inner is None:
            d_inner = dim * 2
        
        self.dim = dim
        self.d_inner = d_inner
        self.in_proj = nn.Linear(dim, d_inner * 2, bias=False)
        self.conv1d = nn.Conv1d(d_inner, d_inner, d_conv, padding=d_conv-1, groups=d_inner, bias=True)
        self.out_proj = nn.Linear(d_inner, dim, bias=False)
        self.act = nn.SiLU()
    
    def forward(self, x):
        B, L, D = x.shape
        
        # 投影
        xz = self.in_proj(x)  # [B, L, 2*d_inner]
        x, z = xz.chunk(2, dim=-1)  # each [B, L, d_inner]
        
        # 1D卷积
        x = x.transpose(1, 2)  # [B, d_inner, L]
        x = self.conv1d(x)  # [B, d_inner, L+padding]
        x = x[:, :, :L]  # 截断到原始长度 [B, d_inner, L]
        x = x.transpose(1, 2)  # [B, L, d_inner]
        
        # 激活和门控
        x = self.act(x)
        z = self.act(z)
        x = x * z
        
        # 输出投影
        x = self.out_proj(x)
        return (x,)


class DropPath(nn.Module):
    """Drop paths (Stochastic Depth)"""
    def __init__(self, drop_prob=0.):
        super().__init__()
        self.drop_prob = drop_prob
    
    def forward(self, x):
        if self.drop_prob == 0. or not self.training:
            return x
        keep_prob = 1 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
        random_tensor.floor_()
        return x.div(keep_prob) * random_tensor


class ConvRSMambaBlock(nn.Module):
    """
    混合块: 卷积 + RSMamba
    - 卷积处理局部2D特征
    - RSMamba处理全局上下文
    """
    def __init__(self, dim, d_state=16, drop_path=0.):
        super().__init__()
        self.dim = dim
        
        # 2D卷积 (局部特征)
        self.conv = nn.Sequential(
            nn.Conv2d(dim, dim, 3, padding=1, groups=dim),  # Depthwise
            nn.BatchNorm2d(dim),
            nn.GELU(),
            nn.Conv2d(dim, dim, 1),  # Pointwise
            nn.BatchNorm2d(dim),
        )
        
        # RSMamba (全局上下文)
        self.mamba_block = RSMambaLayer(dim, d_state=d_state, drop_path=drop_path)
        
        # FFN
        self.norm = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(dim * 4, dim),
            nn.Dropout(0.1),
        )
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
    
    def forward(self, x, H, W):
        """
        Args:
            x: (B, H*W, C)
            H, W: 空间分辨率
        Returns:
            x: (B, H*W, C)
        """
        B, L, C = x.shape
        assert L == H * W
        
        # 卷积分支 (2D)
        x_2d = x.transpose(1, 2).reshape(B, C, H, W)
        conv_out = self.conv(x_2d)
        conv_out = conv_out.flatten(2).transpose(1, 2)  # [B, H*W, C]
        
        # Mamba分支 (1D序列)
        mamba_out = self.mamba_block(x)
        
        # 融合两个分支
        x = x + conv_out + mamba_out
        
        # FFN
        x = x + self.drop_path(self.mlp(self.norm(x)))
        
        return x


class UNetRSMambaEncoder(nn.Module):
    """
    UNet-style多尺度编码器 with RSMamba
    """
    def __init__(
        self,
        img_size=224,
        in_chans=3,
        embed_dims=[96, 192, 384, 768],
        depths=[2, 2, 6, 2],
        d_state=16,
        drop_path_rate=0.1,
    ):
        super().__init__()
        self.num_layers = len(embed_dims)
        self.embed_dims = embed_dims
        
        # Patch Embedding (初始下采样 stride=4)
        self.patch_embed = nn.Sequential(
            nn.Conv2d(in_chans, embed_dims[0], kernel_size=4, stride=4),
            nn.BatchNorm2d(embed_dims[0]),
        )
        
        # 计算每层的分辨率
        self.resolutions = [img_size // (4 * (2 ** i)) for i in range(self.num_layers)]
        
        # Stochastic depth
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
        
        # 多尺度编码器
        self.layers = nn.ModuleList()
        self.downsamples = nn.ModuleList()
        
        for i_layer in range(self.num_layers):
            # 每层的ConvRSMambaBlock
            layer = nn.ModuleList([
                ConvRSMambaBlock(
                    dim=embed_dims[i_layer],
                    d_state=d_state,
                    drop_path=dpr[sum(depths[:i_layer]) + j]
                )
                for j in range(depths[i_layer])
            ])
            self.layers.append(layer)
            
            # 下采样 (除了最后一层)
            if i_layer < self.num_layers - 1:
                downsample = PatchMerging(embed_dims[i_layer])
                self.downsamples.append(downsample)
            else:
                self.downsamples.append(nn.Identity())
    
    def forward(self, x):
        """
        Args:
            x: (B, C, H, W)
        Returns:
            encoder_features: list of (B, H_i*W_i, C_i) for i in [0,1,2,3]
        """
        B = x.shape[0]
        
        # Patch embedding: [B, C, 224, 224] → [B, 96, 56, 56]
        x = self.patch_embed(x)
        
        # 转为序列
        B, C, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)  # [B, H*W, C]
        
        # 保存每层的特征 (用于跳跃连接)
        encoder_features = []
        
        for i_layer in range(self.num_layers):
            H, W = self.resolutions[i_layer], self.resolutions[i_layer]
            
            # 通过每层的blocks
            for block in self.layers[i_layer]:
                x = block(x, H, W)
            
            # 保存当前层特征
            encoder_features.append((x, H, W))
            
            # 下采样
            if i_layer < self.num_layers - 1:
                x = self.downsamples[i_layer](x, H, W)
        
        return encoder_features


class UNetRSMambaDecoder(nn.Module):
    """
    UNet-style解码器 with 跳跃连接
    """
    def __init__(
        self,
        embed_dims=[768, 384, 192, 96],
        depths=[2, 2, 2, 2],
        d_state=16,
        drop_path_rate=0.1,
    ):
        super().__init__()
        self.num_layers = len(embed_dims)
        self.embed_dims = embed_dims
        
        # Stochastic depth
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
        
        # 解码器: 从深到浅
        self.decoder_stages = nn.ModuleList()
        
        for i in range(self.num_layers - 1):
            # i=0: 768 → 384, 需要concat 384 (encoder的384层)
            # i=1: 384 → 192, 需要concat 192
            # i=2: 192 → 96, 需要concat 96
            
            current_dim = embed_dims[i]      # 当前层维度 (深层)
            next_dim = embed_dims[i + 1]     # 下一层维度 (浅层)
            encoder_dim = next_dim            # 编码器对应层的维度
            
            stage = nn.ModuleDict({
                # 上采样: current_dim → current_dim//2
                'upsample': PatchExpanding(current_dim),
                
                # 跳跃连接融合: (current_dim//2 + encoder_dim) → next_dim
                'skip_conv': nn.Linear(current_dim // 2 + encoder_dim, next_dim),
                
                # Decoder blocks
                'blocks': nn.ModuleList([
                    ConvRSMambaBlock(
                        dim=next_dim,
                        d_state=d_state,
                        drop_path=dpr[sum(depths[:i+1]) + j]
                    )
                    for j in range(depths[i + 1])
                ])
            })
            
            self.decoder_stages.append(stage)
        
    def forward(self, encoder_features):
        """
        Args:
            encoder_features: list of (x, H, W) from encoder
                              从浅到深: [(f0,56,56), (f1,28,28), (f2,14,14), (f3,7,7)]
        Returns:
            x: (B, H*W, C) 最终解码特征
            H, W: 最终分辨率 (应该是56×56)
        """
        # 从最深层开始: [B, 49, 768]
        x, H, W = encoder_features[-1]
        
        # 逐层解码: 7→14→28→56
        for i, stage in enumerate(self.decoder_stages):
            # 上采样: 768→384, 384→192, 192→96
            x = stage['upsample'](x, H, W)  # 维度减半，分辨率翻倍
            H, W = H * 2, W * 2
            
            # 获取对应的编码器特征
            enc_feat, enc_H, enc_W = encoder_features[-(i + 2)]
            assert H == enc_H and W == enc_W, f"分辨率不匹配: {H}x{W} vs {enc_H}x{enc_W}"
            
            # Concat + 降维
            x = torch.cat([x, enc_feat], dim=-1)
            x = stage['skip_conv'](x)
            
            # 通过decoder blocks
            for block in stage['blocks']:
                x = block(x, H, W)
        
        # 解码器输出应该是56×56×96
        # 不再额外上采样，在主模型中处理
        return x, H, W


class UNetRSMamba(nn.Module):
    """
    UNet-style RSMamba for Segmentation
    
    结合:
    - UNet的多尺度编码器和跳跃连接
    - RSMamba的门控融合和全局建模
    - 卷积的局部特征提取
    """
    def __init__(
        self,
        img_size=224,
        in_channels=3,
        num_classes=3,
        embed_dims=[96, 192, 384, 768],
        depths=[2, 2, 6, 2],
        d_state=16,
        drop_path_rate=0.1,
    ):
        super().__init__()
        
        print("\n" + "="*60)
        print("🚀 UNet-RSMamba 初始化")
        print(f"   输入: {in_channels}通道, {img_size}×{img_size}")
        print(f"   输出: {num_classes}类")
        print(f"   多尺度: {embed_dims}")
        print(f"   深度: {depths}")
        print(f"   特性: 多尺度 + 跳跃连接 + 门控RSMamba")
        print("="*60 + "\n")
        
        self.num_classes = num_classes
        
        # 编码器
        self.encoder = UNetRSMambaEncoder(
            img_size=img_size,
            in_chans=in_channels,
            embed_dims=embed_dims,
            depths=depths,
            d_state=d_state,
            drop_path_rate=drop_path_rate,
        )
        
        # 解码器
        self.decoder = UNetRSMambaDecoder(
            embed_dims=embed_dims[::-1],  # 反转
            depths=depths[::-1],
            d_state=d_state,
            drop_path_rate=drop_path_rate,
        )
        
        # 最终分类头: 从56×56上采样到224×224 (4倍)
        self.final_expand = nn.Sequential(
            nn.ConvTranspose2d(embed_dims[0], embed_dims[0] // 2, 
                              kernel_size=2, stride=2),  # 56→112
            nn.BatchNorm2d(embed_dims[0] // 2),
            nn.GELU(),
            nn.ConvTranspose2d(embed_dims[0] // 2, embed_dims[0] // 4, 
                              kernel_size=2, stride=2),  # 112→224
            nn.BatchNorm2d(embed_dims[0] // 4),
            nn.GELU(),
            nn.Conv2d(embed_dims[0] // 4, num_classes, kernel_size=1),
        )
        
        self.apply(self._init_weights)
    
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        """
        Args:
            x: (B, C, H, W)
        Returns:
            output: (B, num_classes, H, W)
        """
        B, C, H_in, W_in = x.shape
        
        # 编码
        encoder_features = self.encoder(x)
        
        # 解码
        x, H, W = self.decoder(encoder_features)
        
        # Reshape to 2D
        x = x.transpose(1, 2).reshape(B, -1, H, W)
        
        # 最终上采样到原始分辨率
        output = self.final_expand(x)
        
        return output


# 兼容性别名
UNetRSMambaSegmentation = UNetRSMamba


if __name__ == '__main__':
    # 测试
    model = UNetRSMamba(
        img_size=224,
        in_channels=2,  # VV + VH
        num_classes=3,
        embed_dims=[96, 192, 384, 768],
        depths=[2, 2, 6, 2],
    )
    
    x = torch.randn(2, 2, 224, 224)
    y = model(x)
    print(f"\n输入: {x.shape}")
    print(f"输出: {y.shape}")
    
    # 统计参数
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"\n总参数: {total_params / 1e6:.2f}M")
    print(f"可训练参数: {trainable_params / 1e6:.2f}M")

