# -*- coding: utf-8 -*-
"""
:created on: 2017-01-24
:author: Christian Borger, <christian.borger@student.kit.edu>

This script is an example of how to use the MUSICA MetOp/IASI retrieval
simulator for given model simulations (in this case ECHAM5-wiso).
"""

from netCDF4 import Dataset, num2date
import numpy as np
import os
import pandas as pd
from scipy.interpolate import interp1d
from AVKsimulator import AVKSimulator


def acovdD(heights):
    """
    Calculate a priori covariance matrix for deltaD (see Figure 6)

    Parameters
    ----------
    heights : numpy.array
        height levels in m above mean sea level

    Returns
    -------
    SadD : numpy.ndarray
        a priori covariance matrix, NxN, N=number of levels
    """

    corrlength = 5000.
    corrlengthb1 = 500.
    dim = len(heights)

    X, Y = np.meshgrid(heights, heights)
    SadD = np.zeros((dim, dim))
    adim = 0.1
    SadD[:3, :3] = adim**2. * np.exp(-(X[:3, :3]-Y[:3, :3])**2. / (2*corrlength**2.))
    SadD[3:, 3:] = adim**2. * np.exp(-(X[3:, 3:]-Y[3:, 3:])**2. / (2*corrlength**2.))
    SadD[:3, 3:] = adim**2. * np.exp(-(X[:3, 3:]-Y[:3, 3:])**2. / (2*corrlengthb1**2.))
    SadD[3:, :3] = adim**2. * np.exp(-(X[3:, :3]-Y[3:, :3])**2. / (2*corrlengthb1**2.))

    return SadD


def calcstates(h2o, delD, aprh2o, aprdelD, avk):
    """
    Calculate smoothed {humidity,delD} proxy state for given H2O and
    delD profiles, H2O and delD a priori profiles and averaging kernel.

    Parameters
    ----------
    h2o : numpy.array
        water vapour profile in ln(ppmv)
    delD : numpy.array
        deltaD profile (in unity units, not in permil)
    aprh2o : numpy.array
        water vapour a priori profile in ln(ppmv)
    aprdelD : numpy.array
        deltaD a priori profile (in unity units, not in permil)
    avk : numpy.ndarray
        (simulated) averaging kernel matrix in the {hum,delD} proxy state

    Returns
    -------
    smvap : numpy.array
        calculated smoothed {humidity,delD} proxy state vector; first half
        contains smoothed humidity proxy values and second half smoothed
        delD proxy values
    """

    hump = h2o + np.log(delD + 1) / 2.
    delDp = np.log(delD + 1)
    vap = np.append(hump, delDp)

    aprhump = aprh2o + np.log(aprdelD + 1) / 2.
    aprdelDp = np.log(aprdelD + 1.)
    apriori = np.append(aprhump, aprdelDp)
    mavap = np.ma.MaskedArray(vap, mask=np.isnan(vap))
    mapr = np.ma.MaskedArray(apriori, mask=np.isnan(apriori))
    maavk = np.ma.MaskedArray(avk, mask=np.isnan(avk))
    tmpdiff = mavap - mapr
    smvap = np.dot(maavk, tmpdiff).ravel() + mapr
    smvap = np.array(list(smvap.flat))

    return smvap


def main(srcfn, outfn):

    defheights = np.array([0.00, 0.39, 0.83, 1.30, 1.82, 2.37, 2.95,
                           3.57, 4.22, 4.90, 5.62, 6.38, 7.18, 8.01,
                           8.88, 9.78, 10.92, 12.00, 13.66, 15.96,
                           18.31, 22.10, 26.22, 30.69, 36.34, 42.37,
                           48.55, 55.60]) * 1000

    print 'Read ECHAM data:', srcfn
    with Dataset(srcfn, 'r') as srcnc:
        alt = srcnc.variables['height'][:]
        lat = srcnc.variables['lat'][:]
        lon = srcnc.variables['lon'][:]
        qvmr = srcnc.variables['WVMR'][:]
        wisodD = srcnc.variables['deltaD'][:]
        time = srcnc.variables['datatime']
        dates = num2date(time[:], time.units)
        pres = srcnc.variables['pressure'][:]
        temp = srcnc.variables['temperature'][:]

        stateh2o = np.log(qvmr)
        emissivity = srcnc.variables['emissivity'][:]
        sktemp = srcnc.variables['skin_temp'][:]
        cldflag = srcnc.variables['cloudflag'][:]

    # At first, read a priori profile data
    print 'Load a priori profiles'
    aprdf = pd.read_csv('apri.dat', sep=' ', index_col=0)

    # load class of AVK simulator
    # path to regularisation folder
    invdir = os.path.join('.', 'regularisation')
    avksim = AVKSimulator(invdir)
    sza = 25.0

    with open(outfn, 'w') as f:
        header = ['lon', 'lat', 'GNDalt', 'SKINtemp', 'cld', 'DOFS_T1',
                  'DOFS_T2', 'altbest', 'Sensbest', 'MH2Obest', 'MdelDbest',
                  'smMH2Obest', 'smMdelDbest', 'alt5km', 'Sens5km', 'MH2O5km',
                  'MdelD5km', 'smMH2O5km', 'smMdelD5km', ' ']

        f.write(';'.join(s for s in header))
        f.write('\n')
        for i, dt in enumerate(dates):
            print dt

            for k, dlon in enumerate(lon):
                print k

                for j, dlat in enumerate(lat):

                    try:
                        cld = cldflag[i, j, k]
                        # check for clouds
                        # if cld == 1:
                        #     continue
                        # else:
                        #     pass

                        # read emissivity
                        emi = emissivity[i, j, k]

                        # get ground altitude
                        srfalt = alt[i, :, j, k].min()

                        # if ground altitude is smaller than -30m, skip this
                        # ones
                        if srfalt < -30:
                            continue
                        else:
                            pass

                        # if ground altitude is between -30 and 0m, we do not
                        # want to insert an extra retrieval level, so
                        # we perform this step
                        delta_alt = 30
                        if defheights.min() > srfalt:
                            altv = defheights[defheights > srfalt+delta_alt]
                        else:
                            altv = defheights[defheights > srfalt]

                        # round ground altitude to 10m precision
                        # if ground altitude is negative, set it to 0m
                        if srfalt < 0:
                            newalt = np.append(0, altv)
                        else:
                            srfalt = int(srfalt / 10) * 10
                            newalt = np.append(srfalt, altv)

                        # check if resulting height profile has not more
                        # levels than the retrieval
                        if len(newalt) > 28:
                            continue

                        # interpolate profiles to retrieval levels
                        newtsrf = sktemp[i, j, k]
                        newh2o = interp1d(alt[i, :, j, k],
                                          stateh2o[i, :, j, k],
                                          fill_value='extrapolate')(newalt)
                        newdD = interp1d(alt[i, :, j, k],
                                         wisodD[i, :, j, k],
                                         fill_value='extrapolate')(newalt)
                        newpres = np.exp((interp1d(alt[i, :, j, k],
                                                   np.log(pres[i, :, j, k]),
                                                   fill_value='extrapolate')(newalt)))
                        newtemp = interp1d(alt[i, :, j, k],
                                           temp[i, :, j, k],
                                           fill_value='extrapolate')(newalt)

                        # run retrieval simulator
                        [kern_T1,
                         kern_T2] = avksim.run(newpres, newtemp, newalt,
                                               newh2o, newtsrf, emi, sza)

                        odim = len(newalt)
                        dofst1 = np.trace(kern_T1[:odim, :odim])
                        dofst2 = np.trace(kern_T2[:odim, :odim])

                        # get/calculate a priori profiles
                        aprh2odf = aprdf['H2O[ppmv]']
                        aprdelDdf = aprdf['delD[unity]']
                        aprh2o = np.interp(newalt,
                                           aprdf.index.values,
                                           np.log(aprh2odf.values))
                        aprdelD = np.interp(newalt,
                                            aprdf.index.values,
                                            aprdelDdf.values)

                        # calculate {H2O,delD} states
                        smoothedstate = calcstates(newh2o, newdD, aprh2o,
                                                   aprdelD, kern_T2)
                        # type 2 H2O & type 2 dD
                        h2o_T2 = smoothedstate[:odim] - (smoothedstate[odim:] / 2.)
                        dD_T2 = np.exp(smoothedstate[odim:]) - 1

                        # calculate smoothing error
                        SadD = np.mat(acovdD(newalt))
                        tmpmat = np.mat(np.eye(odim) - kern_T2[odim:odim*2, odim:odim*2])
                        smerr = np.sqrt(np.diag(tmpmat * SadD * tmpmat.T))
                        smidx = smerr.argmin()
                        hidx = 9-(28-odim)

                        # write results to output file
                        data = [dlon, dlat, newalt.min(), newtsrf, cld, dofst1, dofst2,
                                newalt[smidx], smerr.min(), np.exp(newh2o[smidx]), newdD[smidx], np.exp(h2o_T2[smidx]), dD_T2[smidx],
                                newalt[hidx], smerr[hidx], np.exp(newh2o[hidx]), newdD[hidx], np.exp(h2o_T2[hidx]), dD_T2[hidx]]
                        np.savetxt(f, np.array(data), fmt='%11.6e',
                                   newline=';')
                        f.write('\n')

                    except Exception, e:
                        # print e
                        continue


if __name__ == "__main__":

    srcfn = os.path.join('.', 'example_ECHAM5wiso_20140212hh00.nc')
    outfn = os.path.join('.', 'example_output.dat')
    main(srcfn, outfn)
