#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Feb 20 21:54:02 2024

@author: zhangshuai
"""

import pyart
import numpy as np
from joblib import load

import read_sea_flag
import limit_region
import calc_feature

def mainlobe(radar, sea_flag_radar, clf):
    # 限制识别区域
    id_region = limit_region.limit_region(radar, sea_flag_radar)
    total = np.sum(id_region)
    
    if total:
        z = radar.fields['reflectivity']['data']
        w = radar.fields['spectrum_width']['data']
        zdr = radar.fields['differential_reflectivity']['data']
        cc = radar.fields['cross_correlation_ratio']['data']
        feature = np.empty((total, 6))
        pos = np.empty((total, 5)) # 2（经纬度）+ 2（雷达坐标） + 1（识别结果）
        
        index = 0
        for i in range(radar.nrays):
            for j in range(radar.ngates):
                if id_region[i][j]:
                    # 计算特征
                    feature[index, 0] = calc_feature.calc_rgf(z, i, j)
                    feature[index, 1] = calc_feature.calc_rd(z, i, j)
                    feature[index, 2] = calc_feature.calc_swr(w, i, j)
                    feature[index, 3] = calc_feature.calc_w(w, i, j)
                    feature[index, 4] = calc_feature.calc_zdr(zdr, i, j)
                    feature[index, 5] = calc_feature.calc_ccd(cc, i, j)
        
                    # 保存位置，用于后续定位
                    pos[index, 0] = i
                    pos[index, 1] = j
                    
                    index += 1
                    
        # 定位
        lat, lon, alt = radar.get_gate_lat_lon_alt(0)
        for i in range(total):
            j, k = pos[i, 0], pos[i, 1]
            j = int(j)
            k = int(k)
            pos[i][2], pos[i][3] = lat[j][k], lon[j][k]
        
        # 截断+归一化
        index2max = {1:50, 2:40, 3:3, 4:8, 5:1}
        index2min = {1:0, 2:0, 3:0, 4:-8, 5:-.1}
        
        feature[:, 1] = np.clip(feature[:, 1], None, index2max[1])
        feature[:, 2] = np.clip(feature[:, 2], None, index2max[2])
        feature[:, 3] = np.clip(feature[:, 3], None, index2max[3])
        feature[:, 5] = np.clip(feature[:, 5], index2max[5], None)
        
        for i in range(1, 6):
            feature[:, i] = (feature[:, i] - index2min[i]) / (index2max[i] - index2min[i])
            
        # 预测
        pos[:, -1] = clf.predict(feature)
        pos = pos[pos[:, -1] == 1,:4]
    else:
        pos = []
        
    return pos

def sidelobes(radar, pos, v_thr=1.5, snr_thr=10):
    snr_model = np.load(r'snr_model.npy')
    sens = np.load(r'sensitivity_lfm.npy')
    sidelobe_flag = np.zeros(radar.fields['reflectivity']['data'].shape, dtype='bool')
    
    if len(pos):
        ant_radius = snr_model.shape[0] // 2
        rg_radius = snr_model.shape[1] // 2
        
        snr = np.ma.copy(radar.fields['reflectivity']['data'])
        snr = snr - sens + 1.5
        v = np.ma.copy(radar.fields['velocity']['data'])
        snr[snr.mask] = np.nan
        v[v.mask] = np.nan
        
        for i in range(len(pos)):
            x, y = pos[i][0], pos[i][1]
            x = int(x)
            y = int(y)
            up = x - ant_radius
            down = x + ant_radius + 1
            left = y - rg_radius
            right = y + rg_radius + 1
            
            if left < 0: # 离雷达太近就放弃了
                continue
            
            potential_area = (snr[x][y] + snr_model) >= 1.5
            v_criterion = (np.ma.fabs(v[up:down,left:right] - v[x][y]) <= v_thr) | np.isnan(v[up:down,left:right]) # 速度覆盖面积较小，因此为nan也滤除
            snr_criterion = snr[up:down,left:right] <= (snr[x][y] + snr_model + snr_thr)
            
            sidelobe_flag[up:down,left:right][potential_area & v_criterion & snr_criterion] = 1
            
    return sidelobe_flag
    
if __name__ == '__main__':
    read_path = r'KU3200501190540.RAW4F83'
    radar = pyart.io.read_sigmet(read_path)
    sea_flag_cartesian, x, y = read_sea_flag.read()
    sea_flag_radar = read_sea_flag.cartesian2radar(radar, sea_flag_cartesian, x, y)
    clf = load(r'rf.joblib')
    pos = mainlobe(radar, sea_flag_radar, clf)
    
    sidelobe_flag = sidelobes(radar, pos)
    radar.fields['reflectivity']['data'].mask[sidelobe_flag==1] = 1