#!/usr/bin/env python
# -*- coding: utf-8 -*-

import torch
from torch import nn
from collections import OrderedDict
from bob.ip.binseg.modeling.backbones.vgg import vgg16
from bob.ip.binseg.modeling.make_layers import conv_with_kaiming_uniform,convtrans_with_kaiming_uniform, UpsampleCropBlock

class ConcatFuseBlock(nn.Module):
    """ 
    Takes in four feature maps with 16 channels each, concatenates them 
    and applies a 1x1 convolution with 1 output channel. 
    """
    def __init__(self):
        super().__init__()
        self.conv = conv_with_kaiming_uniform(4*16,1,1,1,0)
    
    def forward(self,x1,x2,x3,x4):
        x_cat = torch.cat([x1,x2,x3,x4],dim=1)
        x = self.conv(x_cat)
        return x 
            
class DRIU(nn.Module):
    """
    DRIU head module
    
    Parameters
    ----------
    in_channels_list : list
                        number of channels for each feature map that is returned from backbone
    """
    def __init__(self, in_channels_list=None):
        super(DRIU, self).__init__()
        in_conv_1_2_16, in_upsample2, in_upsample_4, in_upsample_8 = in_channels_list

        self.conv1_2_16 = nn.Conv2d(in_conv_1_2_16, 16, 3, 1, 1)
        # Upsample layers
        self.upsample2 = UpsampleCropBlock(in_upsample2, 16, 4, 2, 0)
        self.upsample4 = UpsampleCropBlock(in_upsample_4, 16, 8, 4, 0)
        self.upsample8 = UpsampleCropBlock(in_upsample_8, 16, 16, 8, 0)
        
        # Concat and Fuse
        self.concatfuse = ConcatFuseBlock()

    def forward(self,x):
        """
        Parameters
        ----------
        x : list
                list of tensors as returned from the backbone network.
                First element: height and width of input image. 
                Remaining elements: feature maps for each feature level.
        """
        hw = x[0]
        conv1_2_16 = self.conv1_2_16(x[1])  # conv1_2_16   
        upsample2 = self.upsample2(x[2], hw) # side-multi2-up
        upsample4 = self.upsample4(x[3], hw) # side-multi3-up
        upsample8 = self.upsample8(x[4], hw) # side-multi4-up
        out = self.concatfuse(conv1_2_16, upsample2, upsample4, upsample8)
        return out

def build_driu():
    """ 
    Adds backbone and head together

    Returns
    -------
    model : :py:class:torch.nn.Module
    """
    backbone = vgg16(pretrained=False, return_features = [3, 8, 14, 22])
    driu_head = DRIU([64, 128, 256, 512])

    model = nn.Sequential(OrderedDict([("backbone", backbone), ("head", driu_head)]))
    model.name = "DRIU"
    return model