# -*- coding: utf-8 -*-

import numpy as np
from metpy import constants as const
from metpy import calc as metcalc
from metpy.units import units
from scipy.integrate import simps
from scipy.interpolate import interp1d

"""
Functions used for the calculations in Beirle et al., AMT, 2021.

Equation number refer to Beirle et al., 2021.
Physical units are fully accounted for by using the Metpy units module
(which is using the pint module).
All pyhsical quantities have to be defined with unit, e.g.
T = 300 * units.K
See the Metpy/pint documentations and the sample call in __main__
"""

# Constants

M_AIR = const.dry_air_molecular_weight  # just for brevity

# Volume mixing ratio of O2 in dry air (https://doi.org/10.1029/2004JD005595)
O2_VMR = 0.209392 * units.dimensionless

# combination of constants in main equations
C = (O2_VMR**2 / (const.R*const.g*M_AIR)).to_base_units()

# Default lapse rate
GAMMA0 = -6.5 * units.K/units.km

# Parameters for Eq. 14
a = 1.7743433442596852
b = 0.11820810212834893


def number_density(T, p, vmr=1):
    """
    Number density according to the ideal gas law.

    vmr should be 1 for air, O2_VMR for O2, or the appropriate value for H2O
    """
    return vmr * p / (const.R * T)


def temperature(T0, z0, z, Gamma):
    """
    Temperature at z for given T0 at z0 assuming a constant lapse rate Gamma.
    """
    return T0 + (z-z0)*Gamma


def barometric_pressure(p0, T0, z0, z, Gamma=0, molar_mass=M_AIR):
    """
    Pressure at z for given p0 & T0 at z0 assuming a constant lapse rate Gamma.
    """
    # formulas taken from https://en.wikipedia.org/wiki/Barometric_formula
    if Gamma == 0:  # constant T, simple formula
        arg = -const.g * molar_mass * (z-z0) / (const.R * T0)
        return p0 * np.exp(arg)
    else:  # T changing with altitude, extended formula
        base = 1 + Gamma/T0*(z-z0)
        exponent = -const.g * molar_mass / (const.R * Gamma)
        return p0 * base**exponent


def hydrostatic_column(p, vmr=1):
    """
    Atmospheric column above the given pressure level
    """
    VCD = vmr/const.g/M_AIR * p
    return VCD.to_base_units()


def ratio_h(Gamma):
    """
    Ratio of effective heights of O2 and O4 for constant lapse rate Gamma
    """
    return 2 + const.R/const.g/const.dry_air_molecular_weight * Gamma


def O4_par_Gamma(p, T, Gamma):
    """
    O4 VCD parameterized from surface T0 & p0 and lapse rate Gamma
    """
    return C/ratio_h(Gamma) * p**2/T


def O4_par_RH(p, T, RH, pol=(b, a)):
    """
    O4 VCD parameterized from surface T0, p0 and RH0
    """
    return C/(pol[1]+pol[0]*RH) * p**2/T


class WaterVapor(object):
    """
    Various humidity quantities for given temperature and pressure.
    """

    def __init__(self, q, T, p):
        """
        Init if specific humidity is given.
        """
        self.T = T
        self.p = p
        self.q = q

        self.w = metcalc.mixing_ratio_from_specific_humidity(q)
        self.T_virtual = metcalc.virtual_temperature(T, self.w)
        self.RH = metcalc.relative_humidity_from_specific_humidity(p, T, q)
        self.partial_pressure = metcalc.vapor_pressure(p, self.w)
        self.rho_dry = const.dry_air_molecular_weight*p / (const.R*T)
        self.rho_air = self.rho_dry * T/self.T_virtual
        self.number_density = q * self.rho_air / const.water_molecular_weight


    @classmethod
    def from_RH(cls, RH, T, p):
        """
        Init if relative humidity is given.
        """
        saturation_mixing_ratio = metcalc.saturation_mixing_ratio(p, T)
        q = 1 / (1+1/(RH*saturation_mixing_ratio))
        return cls(q, T, p)


def calc_column(number_density, z):
    """
    Calculate the vertical column density from vertical profile of the number
    density using Simpsons rule.
    """
    return ( simps(number_density.m, z.m) *
             (number_density.u * z.u) ).to_base_units()


def calculate_columns(z, T, p, RH=None):
    """
    Calculate vertical column densities of O2, O4 and H2O for vertical profiles
    of T, p, and RH
    """
    if RH is None:
        RH = np.zeros_like(T) * units.dimensionless
    # water vapor
    wv = WaterVapor.from_RH(RH, T, p)
    TCWV = calc_column(wv.number_density, z)
    # O2
    p_dry = p - wv.partial_pressure
    o2_nd = number_density(T, p_dry, O2_VMR)
    V_O2 = calc_column(o2_nd, z)  # integrating the [O2] profile
    V_O2_above = hydrostatic_column(p_dry[-1], O2_VMR)  # missing part above
    V_O2 += V_O2_above
    # O4
    V_O4 = calc_column(o2_nd**2, z)
    V_O4_above = 1/2 * o2_nd[-1] * V_O2_above
    V_O4 += V_O4_above
    return {"V_O2": V_O2, "V_O4": V_O4, "TCWV": TCWV, "O2_nd_0": o2_nd[0]}


if __name__ == "__main__":

    # sample calls
    T0 = 280 * units.K
    p0 = 1000 * units.hPa
    RH0 = 50 * units.percent
    O4VCD = O4_par_RH(p0, T0, RH0)
    print(O4VCD.to(units.molec**2 / units.cm**5))

    RH0 = 0 * units.percent
    O4VCD = O4_par_RH(p0, T0, RH0)
    print(O4VCD.to(units.molec**2 / units.cm**5))


