# -*- coding: utf-8 -*-
"""
:author: Christian Borger <christian.borger@student.kit.edu>
:date: 2017-01-24
:copyright: Karlsruhe Institute of Technology,
            The Institute of Meteorology and Climate
            Research (IMK) - Atmospheric Trace Gases and
            Remote Sensing (ASF)

Class simulating averaging kernels for given temperature, pressure and
humidity profiles.
This script is written and has been tested in Python 2.7, however, there
should not be much effort for making it executable in Python 3.5.
"""

import numpy as np
import os
import scipy.constants as sc


class AVKSimulatorError(Exception):
    """
    Base class for errors during averaging kernel simulation
    """


class AVKSimulator(object):
    """
    Base class for simulations

    Parameters
    ----------
    invdir : str
        Path to folder containing regularisation matrices
    """

    def __init__(self, invdir):

        # wanumbers for planck functions
        self.wn_h2o = 1250
        self.wn_hdo = 1250

        self.x_strength = np.arange(1, 19+1)
        self.nstr = len(self.x_strength)

        # directory for regularisation
        self.invdir = invdir

        # read regularisation matrices
        dh = 10
        self.invdict = {}
        for i in xrange(0, 6000+dh, dh):
            invfol = 'inv_{:04d}'.format(i)
            corfn = os.path.join(self.invdir, invfol, 'correl.inp')
            correldata = np.loadtxt(corfn, skiprows=4)

            kovarfn = os.path.join(self.invdir, invfol, 'kovar.inp')
            kovardata = np.loadtxt(kovarfn, skiprows=2)
            self.invdict[i] = {'correl': correldata, 'kovar': kovardata}

    def planck(self, nu, T):
        """
        Calculate blackbody radiance for given temperature and wavenumber

        Parameters
        ----------
        nu : float, numpy.array
            wavenumber in cm-1
        T: float
            temperature in K

        Returns
        -------
        L : float, numpy.array
            blackbody radiance in mW m-2 sr-1 (cm-1)-1
        """

        c1 = 1.191042 * 10**-5
        c2 = 1.4387752

        L = (c1*nu**3.) / (np.exp(c2*nu/T)-1)

        return L

    def calcrad(self, tsrf, temp, dim):
        """
        Calculate blackbody radiation for surface skin temperature and
        air temperature with and without perturbations

        Parameters
        ----------
        tsrf : float
            surface skin temperature in K
        temp : numpy.array
            temperature profile in K
        dim : int
            number of vertical levels

        Returns
        -------
        raddict : dict of dict
            blackbody radiations corresponding to 'H2O' and 'HDO'
        """

        bbradh2o_srf = self.planck(self.wn_h2o, tsrf)
        bbradhdo_srf = self.planck(self.wn_hdo, tsrf)
        bbradh2o_srf_T = self.planck(self.wn_h2o, tsrf+1)
        bbradhdo_srf_T = self.planck(self.wn_hdo, tsrf+1)

        tmatrix = np.empty((dim, dim))
        for i in xrange(dim):
            tmatrix[i, :] = temp
        tmatrix = tmatrix + np.eye(dim)

        bbradh2o_atm = self.planck(self.wn_h2o, temp)
        bbradhdo_atm = self.planck(self.wn_hdo, temp)
        bbradh2o_atm_T = self.planck(self.wn_h2o, tmatrix)
        bbradhdo_atm_T = self.planck(self.wn_hdo, tmatrix)

        raddict = {'H2O': {'srf': bbradh2o_srf, 'srf_T': bbradh2o_srf_T,
                           'atm': bbradh2o_atm, 'atm_T': bbradh2o_atm_T},
                   'HDO': {'srf': bbradhdo_srf, 'srf_T': bbradhdo_srf_T,
                           'atm': bbradhdo_atm, 'atm_T': bbradhdo_atm_T}
                   }

        return raddict

    def partialcol(self, pres, temp, alt, state_wv, dim):
        """
        Calculate partial columns of H2O and HDO with and without
        perturbations

        Parameters
        ----------
        pres : numpy.array
            pressure profile in hPa
        temp : numpy.array
            temperature profile in K
        alt : numpy.array
            altitudes of profile levels in m
        state_wv : numpy.array
            dict with volume mixing ratio of H2O and HDO in ln(ppmv)
        dim : int
            number of vertical levels

        Returns
        -------
        pcdict : dict of dict
            dict containing partial columns of 'H2O' and 'HDO' including
            perturbated columns
        """

        pcdict = {}

        for gas, state in state_wv.iteritems():

            # partial columns
            state_pkm = np.empty((dim, dim))
            for i in xrange(dim):
                state_pkm[i, :] = state
            state_pkm = state_pkm + np.eye(dim)

            # calculate principal partial column
            pc = pres * np.exp(state) * 28.96 \
                / (28.96*10**6. + 18*np.exp(state)) \
                / (temp*sc.k*10**4) * 10**6

            # calculate partial column with deviations
            pc_pkm = pres * np.exp(state_pkm) * 28.96 \
                / (28.96*10**6. + 18*np.exp(state_pkm)) \
                / (temp*sc.k*10**4) * 10**6

            pcstate = 0.5 * (pc[1:] + pc[:-1]) * np.diff(alt)
            pcstate_pkm = 0.5 * (pc_pkm[:, 1:] + pc_pkm[:, :-1]) * np.diff(alt)

            pcdict[gas] = {'normal': pcstate, 'pkm': pcstate_pkm}

        return pcdict

    def calcjacobian(self, strength, sza, pcdict, raddict, emi):
        """
        Calculate jacobians for averaging kernel

        Parameters
        ----------
        strength : int / float
            strength of absorption crosssections (exponent)
        sza : float
            satellite zenith angle in degrees
        pcdict : dict of dict
            dict containing partial columns of 'H2O' and 'HDO' including
            perturbated columns
        raddict : dict of dict
            blackbody radiations corresponding to 'H2O' and 'HDO'
        emi : float
            surface emissivity (unitless)

        Returns
        -------
        jacdict : dict of dict
            dict containing jacobians for 'H2O' and 'HDO' depending on
            surface skin temperature, temperature profile and
            partial columns
        """

        pch2o = pcdict['H2O']['normal']
        pchdo = pcdict['HDO']['normal']
        pch2o_pkm = pcdict['H2O']['pkm']
        pchdo_pkm = pcdict['HDO']['pkm']

        bbradh2o_srf = raddict['H2O']['srf']
        bbradh2o_srf_T = raddict['H2O']['srf_T']
        bbradh2o_atm = raddict['H2O']['atm']
        bbradh2o_atm_T = raddict['H2O']['atm_T']

        bbradhdo_srf = raddict['HDO']['srf']
        bbradhdo_srf_T = raddict['HDO']['srf_T']
        bbradhdo_atm = raddict['HDO']['atm']
        bbradhdo_atm_T = raddict['HDO']['atm_T']

        mradh2o_atm = 0.5 * (bbradh2o_atm[:-1] + bbradh2o_atm[1:])
        mradhdo_atm = 0.5 * (bbradhdo_atm[:-1] + bbradhdo_atm[1:])
        mradh2o_atm_T = 0.5 * (bbradh2o_atm_T[:, 1:] + bbradh2o_atm_T[:, :-1])
        mradhdo_atm_T = 0.5 * (bbradhdo_atm_T[:, 1:] + bbradhdo_atm_T[:, :-1])

        fac = 1./np.cos(np.deg2rad(sza))
        pot = 2.**(strength-1)
        xsecdict = {'H2O': 1e-31, 'H2Os1': 2e-30, 'H2Os2': 9e-29, 'HDO': 1e-31}

        jacdict = {}
        for xsecvar, xs in xsecdict.items():

            xsec = xs * pot * fac
            if 'H2O' in xsecvar:
                pc = pch2o
                pc_pkm = pch2o_pkm
                bbrad_srf = bbradh2o_srf
                bbrad_srf_T = bbradh2o_srf_T
                mradatm = mradh2o_atm
                mradatm_T = mradh2o_atm_T
            else:
                pc = pchdo
                pc_pkm = pchdo_pkm
                bbrad_srf = bbradhdo_srf
                bbrad_srf_T = bbradhdo_srf_T
                mradatm = mradhdo_atm
                mradatm_T = mradhdo_atm_T

            # calculate optical thickness
            dkappa = pc * xsec
            dkappa_pkm = pc_pkm * xsec
            tau = dkappa[::-1].cumsum()[::-1]
            tau_pkm = dkappa_pkm[:, ::-1].cumsum(axis=1)[:, ::-1]

            # atmospheric radiations
            SueG1 = emi * bbrad_srf * np.exp(-tau[0])
            SueG1_T = emi * bbrad_srf_T * np.exp(-tau[0])
            SueG1_pkm = emi * bbrad_srf * np.exp(-tau_pkm[:, 0])
            SueG2 = np.sum(dkappa * mradatm * np.exp(-tau))
            SueG2_pkm = (dkappa_pkm * mradatm * np.exp(-tau_pkm)).sum(axis=1)
            SueG2_Tatm = (dkappa * mradatm_T * np.exp(-tau)).sum(axis=1)

            # spectral signatures
            SueGtot = SueG1 + SueG2
            SueGtot_T = SueG1_T + SueG2

            # Jacobians
            sig_T = SueGtot_T - SueGtot

            # spectral signatures
            SueGtot_pkm = SueG1_pkm + SueG2_pkm
            SueGtot_Tatm = SueG1 + SueG2_Tatm

            # jacobians
            sig_pkm = SueGtot_pkm - SueGtot
            sig_Tatm = SueGtot_Tatm - SueGtot

            jacdict[xsecvar] = {'T': sig_T,
                                'pkm': sig_pkm,
                                'Tatm': sig_Tatm}

        return jacdict

    def createjac(self, raddict, pcdict, sza, emi, dim):
        """
        Create jacobian matrix from jacobians for averaging kernel

        Parameters
        ----------
        raddict : dict of dict
            blackbody radiations corresponding to 'H2O' and 'HDO'
        pcdict : dict of dict
            dict containing partial columns of 'H2O' and 'HDO' including
            perturbated columns
        sza : float
            satellite zenith angle in degrees
        emi : float
            surface emissivity (unitless)
        dim : int
            number of vertical levels of profiles

        Returns
        -------
        sig_comp : numpy.ndarray
            jacobian matrix with size (4*nstrength, 3*dim+1)
        """

        # calculate jacobians fo temperature and humidity
        sig_comp = np.zeros((4*self.nstr, 3*dim+1))

        if sza is None:
            sza = 25.

        for ss, strength in enumerate(self.x_strength):

            jacdict = self.calcjacobian(strength, sza, pcdict, raddict, emi)

            # creation of Jacobian Matrix
            # humidity
            sig_comp[ss, :dim] = jacdict['H2O']['pkm']
            sig_comp[ss+self.nstr, :dim] = jacdict['H2Os1']['pkm']
            sig_comp[ss+2*self.nstr, :dim] = jacdict['H2Os2']['pkm']
            sig_comp[ss+3*self.nstr, dim:2*dim] = jacdict['HDO']['pkm']

            # atm temperature
            sig_comp[ss, 2*dim:3*dim] = jacdict['H2O']['Tatm']
            sig_comp[ss+self.nstr, 2*dim:3*dim] = jacdict['H2Os1']['Tatm']
            sig_comp[ss+2*self.nstr, 2*dim:3*dim] = jacdict['H2Os2']['Tatm']
            sig_comp[ss+3*self.nstr, 2*dim:3*dim] = jacdict['HDO']['Tatm']

            # srf temperature
            sig_comp[ss, 3*dim] = jacdict['H2O']['T']
            sig_comp[ss+self.nstr, 3*dim] = jacdict['H2Os1']['T']
            sig_comp[ss+2*self.nstr, 3*dim] = jacdict['H2Os2']['T']
            sig_comp[ss+3*self.nstr, 3*dim] = jacdict['HDO']['T']

        return sig_comp

    def setupreg(self, alt, dim):
        """
        Set up regularisation matrix for averaging kernel

        Parameters
        ----------
        alt : numpy.array
            altitudes of profile levels in m
        dim : int
            number of vertical levels

        Returns
        -------
        Sam1 : numpy.ndarray
            inv(Sa) for atmospheric humidity
        Sam1_T : numpy.ndarray
            inv(Sa) for atmospheric temperature
        """

        correl = np.zeros((2*dim, 2*dim))
        minalt = (int(alt.min()) / 10) * 10
        if (minalt < 0) and (minalt >= -30):
            minalt = 0
        elif (minalt >= 0):
            pass
        else:
            raise AVKSimulatorError("Surface altitude too low (<-30m)")

        correl[:, :] = self.invdict[minalt]['correl']

        kovar = np.zeros((2*dim, 2*dim))
        kovardata = self.invdict[minalt]['kovar']
        kovar[:dim, :dim] = kovardata
        kovar[dim:2*dim, dim:2*dim] = kovardata

        Sam1 = kovar + correl  # inv(Sa) for atmospheric humidity

        # Calculation of inv(Sa) for atmospheric temperatures
        Sam1_T = np.zeros((dim, dim))
        abs_T = np.zeros((dim, dim))
        Oper_T = np.zeros((dim, dim))

        std_T = np.ones(dim) * (1./0.25)
        abs_T = abs_T + np.eye(dim) * (1./(0.25**2.))
        abs_T[0, 0] = 1./(1.**2)

        Tcorrkm = np.ones(dim) * 5000
        Tcorrkm[0] = 10

        tempvec = 1./np.sqrt(std_T[:-1]*std_T[1:]) * \
            (1-np.exp(-np.diff(alt)**2. / (2.*Tcorrkm[:-1]*Tcorrkm[1:])))

        for k, val in enumerate(tempvec):
            Oper_T[k, k] = val
            Oper_T[k, k+1] = -val

        smo_T = np.mat(Oper_T).T * np.mat(Oper_T)
        Sam1_T = abs_T + smo_T  # inv(Sa) for atmospheric temperature

        return Sam1, Sam1_T

    def makeTraf(self, dim):
        """
        Prepare matrix for transformation to {humdity, deltaD} states

        Parameters
        ----------
        dim : int
            dimension / number of levels

        Returns
        -------
        Traf : numpy.matrix
            transformation matrix
        invTraf : numpy.matrix
            inverse of transformation matrix
        """

        Traf = np.zeros((2*dim, 2*dim))
        A = 0.5
        B = 0.5
        C = -1.
        D = 1
        Traf[:dim, :dim] = np.eye(dim) * A
        Traf[:dim, dim:2*dim] = np.eye(dim) * B
        Traf[dim:2*dim, :dim] = np.eye(dim) * C
        Traf[dim:2*dim, dim:2*dim] = np.eye(dim) * D
        Traf = np.mat(Traf)

        # calculate inverse matrix
        invTraf = np.mat(np.zeros((2*dim, 2*dim)))
        S = D - C
        newA = 1./A + 1./A * B * 1./S * C * 1./A
        newB = -1./A * B * 1./S
        newC = -1./S * C * 1./A
        newD = 1./S
        invTraf[:dim, :dim] = np.eye(dim)*newA
        invTraf[:dim, dim:2*dim] = np.eye(dim)*newB
        invTraf[dim:2*dim, :dim] = np.eye(dim)*newC
        invTraf[dim:2*dim, dim:2*dim] = np.eye(dim)*newD

        return Traf, invTraf

    def createkernel(self, Sam1, Sam1_T, sig_comp, dim):
        """
        Create kernel using jacobian matrix and regularisation matrix

        Parameters
        ----------
        Sam1 : numpy.ndarray
            inv(Sa) for atmospheric humidity
        Sam1_T : numpy.ndarray
            inv(Sa) for atmospheric temperature
        sig_comp : numpy.ndarray
            jacobian matrix with size (4*nstrength, 3*dim+1)
        dim : int
            number of levels of profiles

        Returns
        -------
        kern_T1 : numpy.ndarray
            simulated AVK for MUSICA product type 1
        kern_T2 : numpy.ndarray
            simulated AVK for MUSICA product type 2
        """

        sig_comp = np.mat(sig_comp)

        # setup of inv(Sa)
        Sam1_comp = np.zeros((3*dim+1, 3*dim+1))
        Sam1_comp[:2*dim, :2*dim] = Sam1

        # ref constraint for atm temperature
        Sam1_comp[2*dim:3*dim, 2*dim:3*dim] = Sam1_T
        # ref: no constraint for surface temperature
        Sam1_comp[3*dim, 3*dim] = 1e-12

        # setup of inv(Se)
        noise_value = 0.5

        # Calculation according to Eq. (4)
        # A = sig_comp.T * np.mat(noise) * sig_comp + np.mat(Sam1_comp)
        kern_mat2 = sig_comp.T * sig_comp * 1./noise_value**2.
        A = kern_mat2 + np.mat(Sam1_comp)
        kern_mat1 = np.linalg.pinv(A, rcond=1e-12)
        kern_theo_all = kern_mat1 * kern_mat2

        # Transformation to {H2O,delD} proxy state -> Type 1 avk
        Traf, invTraf = self.makeTraf(dim)
        kern_T1 = Traf * kern_theo_all[:2*dim, :2*dim] * invTraf

        # A posteriori correction -> Type 2 avk
        # Matrix C, Eq. (14) from Schneider et al., 2012)
        CORR_retr = np.zeros((2*dim, 2*dim))
        CORR_retr[:dim, :dim] = kern_T1[dim:2*dim, dim:2*dim]
        CORR_retr[dim:2*dim, :dim] = -kern_T1[dim:2*dim, :dim]
        CORR_retr[dim:2*dim, dim:2*dim] = np.eye(dim)

        kern_T2 = np.mat(CORR_retr) * np.mat(kern_T1)

        return kern_T1, kern_T2

    def run(self, pres, temp, alt, state_h2o, tsrf, emi, sza=None):
        """
        Run all necessary steps for calculating averaging kernel
        for given data

        Parameters
        ----------
        pres : numpy.array
            pressure profile in hPa
        temp : numpy.array
            temperature profile in K
        alt : numpy.array
            altitude corresponding to profiles in m
        state_h2o : numpy.array
            profile of water vapor volume mixing ratio in log(ppmv)
        tsrf : float
            skin surface temperature in K
        emi : float
            surface emissivity between [0; 1]
        sza : float
            satellite zenith angle in degrees; if not given 25 degrees
            is assumed

        Returns
        -------
        kern_T1 : numpy.matrix
            transformation matrix for MUSICA product type 1
        kern_T2 : numpy.matrix
            transformation matrix for MUSICA product type 2
        """

        dim = len(alt)
        state_wv = {'H2O': state_h2o, 'HDO': state_h2o}
        pcdict = self.partialcol(pres, temp, alt, state_wv, dim)
        raddict = self.calcrad(tsrf, temp, dim)
        sig_comp = self.createjac(raddict, pcdict, sza, emi, dim)
        Sam1, Sam1_T = self.setupreg(alt, dim)
        kern_T1, kern_T2 = self.createkernel(Sam1, Sam1_T, sig_comp, dim)

        return kern_T1, kern_T2


if __name__ == "__main__":

    temp = np.array([298.4, 296.9, 295.2, 293.3, 291.0, 288.3, 285.2, 281.7,
                     277.9, 273.7, 269.2, 264.1, 258.4, 252.2, 245.4, 237.9,
                     228.5, 220.6, 210.5, 200.7, 202.2, 214.7, 223.2, 229.4,
                     237.0, 249.2, 259.7, 259.6])

    alt = np.array([0.00, 0.39, 0.83, 1.30, 1.82, 2.37, 2.95, 3.57, 4.22, 4.90,
                    5.62, 6.38, 7.18, 8.01, 8.88, 9.78, 10.92, 12.00, 13.66,
                    15.96, 18.31, 22.10, 26.22, 30.69, 36.34, 42.37, 48.55,
                    55.60]) * 1000

    H2O = np.array([1.068e+04, 9.366e+03, 8.076e+03, 6.899e+03, 5.747e+03,
                    4.718e+03, 3.839e+03, 3.077e+03, 2.380e+03, 1.780e+03,
                    1.279e+03, 8.746e+02, 5.418e+02, 2.941e+02, 1.339e+02,
                    6.172e+01, 3.503e+01, 2.057e+01, 9.152e+00, 4.268e+00,
                    4.020e+00, 4.588e+00, 4.957e+00, 5.323e+00, 5.725e+00,
                    6.004e+00, 6.182e+00, 5.995e+00])
    dim = 28
    sza = 25.
    tsrf = 300.
    pres = np.exp(np.linspace(9.2, 3.9, dim)) / 10.
    emi = 0.95
    state_h2o = np.log(H2O)

    invdir = os.path.join('.', 'regularisation')
    AVK = AVKSimulator(invdir)
    kern_T1, kern_T2 = AVK.run(pres, temp, alt, state_h2o, tsrf, emi, sza)
    print kern_T1
    print kern_T2
