#!/usr/bin/env python
# -*- coding: utf-8 -*-

import torch
from torch import nn
from collections import OrderedDict
from bob.ip.binseg.modeling.backbones.vgg import vgg16
from bob.ip.binseg.modeling.make_layers import conv_with_kaiming_uniform, convtrans_with_kaiming_uniform, UpsampleCropBlock

class ConcatFuseBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = conv_with_kaiming_uniform(5,1,1,1,0)
    
    def forward(self,x1,x2,x3,x4,x5):
        x_cat = torch.cat([x1,x2,x3,x4,x5],dim=1)
        x = self.conv(x_cat)
        return x 
            
class HED(nn.Module):
    """
    HED head module
    Attributes
    ----------
        in_channels_list (list[int]): number of channels for each feature map that
        will be fed
        
    """
    def __init__(self, in_channels_list=None):
        super(HED, self).__init__()
        in_conv_1_2_16, in_upsample2, in_upsample_4, in_upsample_8, in_upsample_16 = in_channels_list
        
        self.conv1_2_16 = nn.Conv2d(in_conv_1_2_16,1,3,1,1)
        # Upsample
        self.upsample2 = UpsampleCropBlock(in_upsample2,1,4,2,0)
        self.upsample4 = UpsampleCropBlock(in_upsample_4,1,8,4,0)
        self.upsample8 = UpsampleCropBlock(in_upsample_8,1,16,8,0)
        self.upsample16 = UpsampleCropBlock(in_upsample_16,1,32,16,0)
        # Concat and Fuse
        self.concatfuse = ConcatFuseBlock()

    def forward(self,x):
        """
        Arguments:
            x (list[Tensor]): feature maps for each feature level.
        """
        hw = x[0]
        conv1_2_16 = self.conv1_2_16(x[1])  
        upsample2 = self.upsample2(x[2],hw)
        upsample4 = self.upsample4(x[3],hw)
        upsample8 = self.upsample8(x[4],hw)
        upsample16 = self.upsample16(x[5],hw) 
        concatfuse = self.concatfuse(conv1_2_16,upsample2,upsample4,upsample8,upsample16)
        
        out = [upsample2,upsample4,upsample8,upsample16,concatfuse]
        return out

def build_hed():
    backbone = vgg16(pretrained=False, return_features = [3, 8, 14, 22, 29])
    hed_head = HED([64, 128, 256, 512, 512])

    model = nn.Sequential(OrderedDict([("backbone", backbone), ("head", hed_head)]))
    model.name = "HED"
    return model