#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Jul 15 12:29:02 2019

@author: zhangshuai
"""

import os
import bz2file
import datetime
import numpy as np
import matplotlib as mpl
from netCDF4 import Dataset
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d

mask2moment={1:'T',2:'Z',3:'V',4:'W',5:'SQI',7:'ZDR',9:'CC',10:'PhiDP',\
             11:'KDP',16:'SNR'}
moment2maskvalue={'T':-200,'Z':-200,'V':-200,'W':-200,'SQI':-200,'ZDR':-200,\
                  'CC':-0.1,'PhiDP':-2,'KDP':-200,'SNR':-200}
elevation={1:0.5,2:1.5,3:2.4,4:3.4,5:4.3,6:5.3,7:6.2,8:7.5,9:8.7,10:10,11:12,\
           12:14,13:16.7,14:19.5,15:90}

class RADAR:
    def __init__(self,version=None):
        self.version=version
        self.time=None #volume scan start time
        self.sweepnum=None
        self.elevation={} #sweep:elevation
        self.azimuth={} #sweep:azimuth
        self.data={} #(sweep,moment):data
        
def _file2buf(filepath):
    path,extension=os.path.splitext(filepath)
    if extension=='.bz2':
        file=bz2file.open(filepath,'rb')
        buf=file.read()
    else:
        file=open(filepath,'rb')
        buf=file.read()
    file.close()
    return buf

def _unpack(buf,dtype,count,offset):
    data=np.frombuffer(buf,dtype,count,offset)
    if count==1:
        return data[0]
    else:
        return data
        
def _interp(azimuth_raw,index_raw,azimuth_interp):
    func=interp1d(azimuth_raw,index_raw,kind='nearest',fill_value='extrapolate')
    index_interp=np.array(func(azimuth_interp),dtype='uint16')
    return index_interp
        
def read(filepath):
    buf=_file2buf(filepath)
    time=_unpack(buf,'uint32',1,332)
    sweepnum=_unpack(buf,'uint32',1,336)
    
    elevation={}
    azimuth={}
    data={}
    for sweep in range(1,sweepnum+1):
        elevation[sweep]=[]
        azimuth[sweep]=[]
        for moment in ('T','Z','V','W','SQI','ZDR','CC','PhiDP','KDP','SNR'):
            data[(sweep,moment)]=[]
    
    raypointer=416+256*sweepnum
    while True:
        raystate=_unpack(buf,'uint32',1,raypointer)
        sweepindex=_unpack(buf,'uint32',1,raypointer+16)
        azimuth_ray=_unpack(buf,'float32',1,raypointer+20)
        elevation_ray=_unpack(buf,'float32',1,raypointer+24)
        raylen=_unpack(buf,'uint32',1,raypointer+36)
        momentnum=_unpack(buf,'uint32',1,raypointer+40)
        azimuth[sweepindex].append(azimuth_ray)
        elevation[sweepindex].append(elevation_ray)
        momentpointer=raypointer+64
        momentnum_over=0
        while True:
            moment=mask2moment[_unpack(buf,'uint32',1,momentpointer)]
            scale=_unpack(buf,'uint32',1,momentpointer+4)
            offset=_unpack(buf,'uint32',1,momentpointer+8)
            binlen=_unpack(buf,'uint16',1,momentpointer+12)
            momentlen=_unpack(buf,'uint32',1,momentpointer+16)
            if binlen==1:
                data_ray=_unpack(buf,'uint8',momentlen,momentpointer+32)
            elif binlen==2:
                data_ray=_unpack(buf,'uint16',momentlen//2,momentpointer+32)
            data_ray=np.array(data_ray,'float32')
            data_ray=(data_ray-offset)/scale
            data[(sweepindex,moment)].append(data_ray)
            momentnum_over+=1
            if momentnum_over==momentnum:
                break
            momentpointer+=(momentlen+32)
        if raystate==4:
            break
        raypointer+=(raylen+64)
        
    for sweep in range(1,sweepnum+1):
        elevation[sweep]=np.array(elevation[sweep],dtype='float32')
        azimuth[sweep]=np.array(azimuth[sweep],dtype='float32')
        for moment in ('T','Z','V','W','SQI','ZDR','CC','PhiDP','KDP','SNR'):
            data[(sweep,moment)]=np.array(data[(sweep,moment)],dtype='float32')
            data[(sweep,moment)][data[(sweep,moment)]==moment2maskvalue[moment]]=-999.
        
        
    radar=RADAR(0) #raw data is 0
    radar.time=time
    radar.sweepnum=sweepnum
    radar.elevation=elevation
    radar.azimuth=azimuth
    radar.data=data
    return radar

def radar2nc(radar,filepath):
    ID='NUIST'
    time=str(datetime.datetime.utcfromtimestamp(radar.time))
    time=time.replace('-','')
    time=time.replace(':','')
    time=time.replace(' ','')
    version=radar.version
    extension='nc'
    filename='{}.{}.{}.{}'.format(ID,time,version,extension)
    
    volumegrp=Dataset(os.path.join(filepath,filename),'w',format='NETCDF4')
    
    volumegrp.version=radar.version
    volumegrp.time=radar.time
    volumegrp.sweepnum=radar.sweepnum
    for sweep in range(1,volumegrp.sweepnum+1):
        sweepgrp=volumegrp.createGroup(str(sweep))
        
        raynum,gatenum=radar.data[(sweep,'Z')].shape
        sweepgrp.createDimension('raydimension',raynum)
        sweepgrp.createDimension('gatedimension',gatenum)
        
        azimuthvariable=sweepgrp.createVariable('azimuth','f4',('raydimension',),zlib=True,complevel=1,fletcher32=True)
        elevationvariable=sweepgrp.createVariable('elevation','f4',('raydimension',),zlib=True,complevel=1,fletcher32=True)
        azimuthvariable[:]=radar.azimuth[sweep]
        elevationvariable[:]=radar.elevation[sweep]
        for moment in ('Z','V','W','ZDR','CC','PhiDP','SQI','SNR'):
            momentvariable=sweepgrp.createVariable(moment,'f4',('raydimension','gatedimension',),zlib=True,complevel=1,fletcher32=True)
            momentvariable[:,:]=radar.data[(sweep,moment)]
    
    volumegrp.close()
    
def nc2radar(filepath):
    volumegrp=Dataset(filepath,'r',format='NETCDF4')
    
    version=volumegrp.version
    time=volumegrp.time
    sweepnum=volumegrp.sweepnum
    
    radar=RADAR(version)
    radar.time=time
    radar.sweepnum=sweepnum
    for sweep in range(1,sweepnum+1):
        sweepgrp=volumegrp.groups[str(sweep)]
        
        radar.elevation[sweep]=np.array(sweepgrp.variables['elevation'][:]) #avoid auto mask
        radar.azimuth[sweep]=np.array(sweepgrp.variables['azimuth'][:]) #avoid auto mask
        for moment in ('Z','V','W','ZDR','CC','PhiDP','SQI','SNR'):
            radar.data[(sweep,moment)]=np.array(sweepgrp.variables[moment][:]) #avoid auto mask
    
    volumegrp.close()
    
    return radar

def visualization(radar,info):
    sweep,moment,norm,cmap,cblabel,resolution,rangecircle,maxrange,dpi,figpath=info
    
    data=radar.data[(sweep,moment)]
    data=np.vstack((data,data[0]))
    data=np.ma.masked_values(data,-999.)
    azimuth=radar.azimuth[sweep]*np.pi/180
    Range=np.arange(1,len(data[0])+1)*resolution
    Range,azimuth=np.meshgrid(Range,azimuth)
    azimuth=np.vstack((azimuth,azimuth[0]))
    Range=np.vstack((Range,Range[0]))
    
    plt.close('all')
    fig,axe=plt.subplots(subplot_kw=dict(polar=True))
    
    norm=mpl.colors.Normalize(vmin=norm[0],vmax=norm[1])
    cmap=plt.get_cmap(cmap)
        
    axe.set_theta_direction(-1)
    axe.set_theta_zero_location('N')
    axe.set_thetagrids(list(range(0,360,30)),visible=False)
    axe.set_rgrids(list(range(rangecircle,maxrange,rangecircle))+[maxrange],visible=True,color='#b0b0b0')
    pcm=axe.pcolormesh(azimuth,Range,data,norm=norm,cmap=cmap)
    cb=fig.colorbar(pcm,ax=axe)
    cb.set_label(cblabel)
    axe.set_rmax(maxrange)
    axe.grid(True)
    axe.set_facecolor('#000000')
                      
    axe.set_title('time: '+str(datetime.datetime.utcfromtimestamp(radar.time))+' [UTC]\n'+\
                  'elevation: '+str(np.around(elevation[sweep],decimals=1))+' [deg]')
    fig.tight_layout()
    if figpath!=None:
        ID='NUIST'
        time=str(datetime.datetime.utcfromtimestamp(radar.time))
        time=time.replace('-','')
        time=time.replace(':','')
        time=time.replace(' ','')
        version=radar.version
        extension='png'
        figname='{}.{}.{}.{}.{}.{}'.format(ID,time,sweep,moment,version,extension)
        plt.savefig(os.path.join(figpath,figname),dpi=dpi)
    else:
        plt.show()
        
    plt.close('all')

def ray_sort(radar,raynum=660):
    sweepnum=radar.sweepnum
    for sweep in range(1,sweepnum+1):
        azimuth=radar.azimuth[sweep]
        index_maxazimuth=np.argmax(azimuth)
        maxazimuth=azimuth[index_maxazimuth]
        azimuth_raw=np.hstack((azimuth,-360+maxazimuth)) #relation between 0 and 360
        index_raw=np.hstack((np.arange(len(radar.azimuth[sweep])),index_maxazimuth))
        azimuth_interp=np.linspace(0,360,raynum+1,dtype='float32')[:-1]
        index_interp=_interp(azimuth_raw,index_raw,azimuth_interp)
        
        for moment in ('Z','V','W','ZDR','CC','PhiDP','SQI','SNR'):
            radar.data[(sweep,moment)]=radar.data[(sweep,moment)][index_interp]
            
        radar.azimuth[sweep]=azimuth_interp
        radar.elevation[sweep]=radar.elevation[sweep][index_interp]
        
    radar.version=1 #after ray sorting is 1
    
if __name__ == '__main__':
    import time
    start=time.time()
    
    '''read from raw data'''
    filepath=r'/Users/zhangshuai/Desktop/data/NUIST/NUIST.20170524.234622.AR2.bz2'
    radar=read(filepath)
    
    '''interactive with nc'''
    filepath=r'/Users/zhangshuai/Desktop'
#    radar2nc(radar,filepath)
    ray_sort(radar)
    radar2nc(radar,filepath)
    
#    filepath=r'/Users/zhangshuai/Desktop/NUIST.20170523104945.1.nc'
#    radar=nc2radar(filepath)
    
    
    '''plot data'''
#    sweep=3
#    moment='Z'
#    norm=(-10,60)
#    cmap='gist_ncar'
#    cblabel='Z [dBZ]'
    
#    sweep=1
#    moment='V'
#    norm=(-15,15)
#    cmap='seismic'
#    cblabel='V [m/s]'
    
#    sweep=1
#    moment='W'
#    norm=(0,10)
#    cmap='gist_ncar'
#    cblabel='W [m/s]'
    
#    sweep=3
#    moment='ZDR'
#    norm=(-2,8)
#    cmap='gist_ncar'
#    cblabel='ZDR [dB]'
    
#    sweep=1
#    moment='PhiDP'
#    norm=(0,360)
#    cmap='gist_ncar'
#    cblabel='PhiDP [deg]'
    
#    sweep=3
#    moment='CC'
#    norm=(.8,1)
#    cmap='gist_ncar'
#    cblabel='CC'
    
#    sweep=1
#    moment='SQI'
#    norm=(0,1)
#    cmap='gist_ncar'
#    cblabel='SQI'
    
#    sweep=3
#    moment='SNR'
#    norm=(0,50)
#    cmap='gist_ncar'
#    cblabel='SNR [dB]'
    
#    resolution=.075
#    rangecircle=50
#    maxrange=150
#    dpi=300
#    figpath=r'/Users/zhangshuai/Desktop'
#    info=sweep,moment,norm,cmap,cblabel,resolution,rangecircle,maxrange,dpi,figpath
#    visualization(radar,info)
    
    end=time.time()
    print(end-start)