import torch
import torch.nn as nn

from . import denoiserutils

class UNet(nn.Module):
    def __init__(
        self, 
        input_channels, 
        output_channels, 
        encoder_features=[[64, 64], [128], [256], [512], [512]], 
        bottleneck_features=[512], 
        decoder_features=[[512, 512], [256, 256], [128, 128], [64, 64], [64, 64]], 
        initializer=nn.init.xavier_uniform_
        ):
            super(UNet, self).__init__()
            self.output_channels     = output_channels
            self.input_channels      = input_channels
            self.encoder_features    = encoder_features
            self.bottleneck_features = bottleneck_features
            self.decoder_features    = decoder_features
            self.dropout = False
            self.residual_skip = True
            self.initNetwork(initializer)


    def initNetwork(self, initializer):
        # Utility function that creates a convolution "block" from a list of features, with one convolutional layer per feature count in the list
        def make_conv_block(in_features, features):
            layers = []
            prev_features = in_features
            for f in features:
                layers = layers + ([nn.Conv2d(prev_features, f, 3, padding=1), nn.ReLU(inplace=True)] if self.dropout is None else [nn.Conv2d(prev_features, f, 3, padding=1), nn.Dropout2d(p=self.dropout, inplace=True), nn.ReLU(inplace=True)])
                prev_features = f
            return layers

        prev_features = self.input_channels

        # Create encoder
        enc = []
        for enc_f in self.encoder_features:
            enc = enc + [nn.Sequential(*make_conv_block(prev_features, enc_f), nn.MaxPool2d(2)).cuda()]
            prev_features = enc_f[-1]
        self.enc = nn.ModuleList(enc)

        # Create bottleneck
        self.bottleneck = nn.Sequential(*make_conv_block(prev_features, self.bottleneck_features)).cuda()
        prev_features = self.bottleneck_features[-1]

        # Create decoder
        if self.residual_skip:
            dec = []
            for idx, dec_f in enumerate(self.decoder_features[:-1]):
                assert prev_features >= self.encoder_features[len(self.decoder_features) - idx - 2][-1]
                dec = dec + [nn.Sequential(*make_conv_block(prev_features, dec_f)).cuda()]
                prev_features = dec_f[-1]
            assert prev_features >= self.input_channels
            dec = dec + [nn.Sequential(*make_conv_block(prev_features, self.decoder_features[-1])).cuda()]
            self.dec = nn.ModuleList(dec)
        else:
            dec = []
            for idx, dec_f in enumerate(self.decoder_features[:-1]):
                skip_features = prev_features + self.encoder_features[len(self.decoder_features) - idx - 2][-1]
                dec = dec + [nn.Sequential(*make_conv_block(skip_features, dec_f)).cuda()]
                prev_features = dec_f[-1]
            dec = dec + [nn.Sequential(*make_conv_block(prev_features + self.input_channels, self.decoder_features[-1])).cuda()]
            self.dec = nn.ModuleList(dec)

        # Add final "mixing" step to create RGB output
        self.final = nn.Conv2d(self.decoder_features[-1][-1], self.output_channels, 3, padding=1)

        # initialize weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                initializer(m.weight.data)
                if m.bias is not None:
                    m.bias.data.zero_()

    def forward(self, prev):
        decoder_mips = []

        # Run encoder
        enc_vars = [prev]
        for block in self.enc:
            prev = block(prev)
            enc_vars = enc_vars + [prev]

        # Run bottleneck
        prev = self.bottleneck(prev)
        
        # Run decoder
        for idx, block in enumerate(self.dec):
            prev = nn.functional.interpolate(prev, scale_factor=2, mode='nearest', align_corners=None) # Upscale result from previous step
            if self.residual_skip:
                concat = prev
                concat[:, 0:enc_vars[len(self.dec) - idx - 1].shape[1], :, :] += enc_vars[len(self.dec) - idx - 1] # Residual skip connection
            else:
                concat = torch.cat((prev, enc_vars[len(self.dec) - idx - 1]), dim=1)                   # Concatenate skip connection
            prev = block(concat)
            decoder_mips = [prev] + decoder_mips                                                       # Populate output list in reverse order (highest res layer first)

        # Run final composition
        color = self.final(decoder_mips[0])

        # Return output color & all decoder levels
        return denoiserutils.object_from_dict({
            'color' : color,
            'decoder_mips' : decoder_mips
        })
