"""This module contains all the functions needed for data analysis """ # Initial settings import numpy as np import matplotlib.pyplot as plt import matplotlib.patches as mpatches from matplotlib import gridspec import pdb import ee # other modules from osgeo import gdal, ogr, osr import scipy.interpolate as interpolate import scipy.stats as sstats # image processing modules import skimage.filters as filters import skimage.exposure as exposure import skimage.transform as transform import sklearn.decomposition as decomposition import skimage.measure as measure import skimage.morphology as morphology # machine learning modules from sklearn.cluster import KMeans from sklearn.neural_network import MLPClassifier from sklearn.externals import joblib import time # import own modules import functions.utils as utils def get_tide(dates_sds, dates_tide, tide_level): tide = [] for i in range(len(dates_sds)): dates_diff = np.abs(np.array([ (dates_sds[i] - _).total_seconds() for _ in dates_tide])) if np.min(dates_diff) <= 1800: # half-an-hour idx_closest = np.argmin(dates_diff) tide.append(tide_level[idx_closest]) else: tide.append(np.nan) tide = np.array(tide) return tide def remove_duplicates(output, satname): " removes duplicates from output structure, keep the one with less cloud cover or best georeferencing " dates = output['dates'] dates_str = [_.strftime('%Y%m%d') for _ in dates] dupl = utils.duplicates_dict(dates_str) if dupl: output_nodup = dict([]) idx_remove = [] if satname == 'L8' or satname == 'L5': for k,v in dupl.items(): idx1 = v[0] idx2 = v[1] c1 = output['metadata']['cloud_cover'][idx1] c2 = output['metadata']['cloud_cover'][idx2] g1 = output['metadata']['acc_georef'][idx1] g2 = output['metadata']['acc_georef'][idx2] if c1 < c2 - 0.01: idx_remove.append(idx2) elif g1 < g2 - 0.1: idx_remove.append(idx2) else: idx_remove.append(idx1) else: for k,v in dupl.items(): idx1 = v[0] idx2 = v[1] c1 = output['metadata']['cloud_cover'][idx1] c2 = output['metadata']['cloud_cover'][idx2] if c1 < c2 - 0.01: idx_remove.append(idx2) else: idx_remove.append(idx1) idx_remove = sorted(idx_remove) idx_all = np.linspace(0, len(dates_str)-1, len(dates_str)) idx_keep = list(np.where(~np.isin(idx_all,idx_remove))[0]) output_nodup['dates'] = [output['dates'][k] for k in idx_keep] output_nodup['shorelines'] = [output['shorelines'][k] for k in idx_keep] output_nodup['metadata'] = dict([]) for key in list(output['metadata'].keys()): output_nodup['metadata'][key] = [output['metadata'][key][k] for k in idx_keep] print(satname + ' : ' + str(len(idx_remove)) + ' duplicates') return output_nodup else: print(satname + ' : ' + 'no duplicates') return output def merge(output): " merges data from the different satellites " # stack all list together under one key output_all = {'dates':[], 'shorelines':[], 'metadata':{'filenames':[], 'satname':[], 'cloud_cover':[], 'acc_georef':[]}} for satname in list(output.keys()): output_all['dates'] = output_all['dates'] + output[satname]['dates'] output_all['shorelines'] = output_all['shorelines'] + output[satname]['shorelines'] for key in list(output[satname]['metadata'].keys()): output_all['metadata'][key] = output_all['metadata'][key] + output[satname]['metadata'][key] output_all_sorted = {'dates':[], 'shorelines':[], 'metadata':{'filenames':[], 'satname':[], 'cloud_cover':[], 'acc_georef':[]}} # sort the dates idx_sorted = sorted(range(len(output_all['dates'])), key=output_all['dates'].__getitem__) output_all_sorted['dates'] = [output_all['dates'][i] for i in idx_sorted] output_all_sorted['shorelines'] = [output_all['shorelines'][i] for i in idx_sorted] for key in list(output_all['metadata'].keys()): output_all_sorted['metadata'][key] = [output_all['metadata'][key][i] for i in idx_sorted] return output_all_sorted def create_transects(x0, y0, orientation, chainage_length): " creates shore-normal transects " transects = [] for k in range(len(x0)): # orientation of cross-shore profile phi = (90 - orientation[k])*np.pi/180 # create a vector using the chainage length x = np.linspace(0,chainage_length,chainage_length+1) y = np.zeros(len(x)) coords = np.zeros((len(x),2)) coords[:,0] = x coords[:,1] = y # translate and rotate the vector using the origin and orientation tf = transform.EuclideanTransform(rotation=phi, translation=(x0[k],y0[k])) coords_tf = tf(coords) transects.append(coords_tf) return transects def calculate_chainage(sds, transects, orientation, along_dist): " intersect SDS with transect and compute chainage position " chainage_mtx = np.zeros((len(sds),len(transects),6)) for i in range(len(sds)): sl = sds[i] for j in range(len(transects)): # compute rotation matrix X0 = transects[j][0,0] Y0 = transects[j][0,1] phi = (90 - orientation[j])*np.pi/180 Mrot = np.array([[np.cos(phi), np.sin(phi)],[-np.sin(phi), np.cos(phi)]]) # calculate point to line distance between shoreline points and profile p1 = np.array([X0,Y0]) p2 = transects[j][-1,:] p3 = sl d = np.abs(np.cross(p2-p1,p3-p1)/np.linalg.norm(p2-p1)) idx_close = utils.find_indices(d, lambda e: e <= along_dist) # check if there are SDS points around the profile or not if not idx_close: chainage_mtx[i,j,:] = np.tile(np.nan,(1,6)) else: # change of base to shore-normal coordinate system xy_close = np.array([sl[idx_close,0],sl[idx_close,1]]) - np.tile(np.array([[X0],[Y0]]), (1,len(sl[idx_close]))) xy_rot = np.matmul(Mrot, xy_close) # put nan values if the chainage is negative (MAKE SURE TO PICK ORIGIN CORRECTLY) if np.any(xy_rot[0,:] < 0): xy_rot[0,np.where(xy_rot[0,:] < 0)] = np.nan # compute mean, median max and std of chainage position n_points = len(xy_rot[0,:]) mean_cross = np.nanmean(xy_rot[0,:]) median_cross = np.nanmedian(xy_rot[0,:]) max_cross = np.nanmax(xy_rot[0,:]) min_cross = np.nanmin(xy_rot[0,:]) std_cross = np.nanstd(xy_rot[0,:]) ################################################### if std_cross > 10: # if large std, take the most seaward point mean_cross = max_cross median_cross = max_cross min_cross = max_cross # mean_cross = np.nan # median_cross = np.nan # min_cross = np.nan # store the statistics chainage_mtx[i,j,:] = np.array([mean_cross, median_cross, max_cross, min_cross, n_points, std_cross]) # format into dictionnary chainage = dict([]) chainage['mean'] = chainage_mtx[:,:,0] chainage['median'] = chainage_mtx[:,:,1] chainage['max'] = chainage_mtx[:,:,2] chainage['min'] = chainage_mtx[:,:,3] chainage['npoints'] = chainage_mtx[:,:,4] chainage['std'] = chainage_mtx[:,:,5] return chainage def compare_sds(dates_sds, chain_sds, topo_profiles, mod=0, mindays=5): """ Compare sds with groundtruth data from topographic surveys / argus shorelines KV WRL 2018 Arguments: ----------- dates_sds: list list of dates corresponding to each row in chain_sds chain_sds: np.ndarray array with time series of chainage for each transect (each transect is one column) topo_profiles: dict dict containing the dates and chainage of the groundtruth mod: 0 or 1 0 for linear interpolation between 2 closest surveys, 1 for only nearest neighbour min_days: int minimum number of days for which the data can be compared Returns: ----------- stats: dict contains all the statistics of the comparison """ # create 3 figures fig1 = plt.figure() gs1 = gridspec.GridSpec(chain_sds.shape[1], 1) axfig1 = [] fig2 = plt.figure() gs2 = gridspec.GridSpec(2, chain_sds.shape[1]) axfig2 = [] fig3 = plt.figure() gs3 = gridspec.GridSpec(2,1) axfig3 = [] dates_sds_num = np.array([_.toordinal() for _ in dates_sds]) stats = dict([]) data_fin = dict([]) # for each transect compare and plot the data for i in range(chain_sds.shape[1]): pfname = list(topo_profiles.keys())[i] stats[pfname] = dict([]) data_fin[pfname] = dict([]) dates_sur = topo_profiles[pfname]['dates'] chain_sur = topo_profiles[pfname]['chainage'] # convert to datenum dates_sur_num = np.array([_.toordinal() for _ in dates_sur]) chain_sur_interp = [] diff_days = [] for j, satdate in enumerate(dates_sds_num): temp_diff = satdate - dates_sur_num if mod==0: # select measurement before and after sat image date and interpolate ind_before = np.where(temp_diff == temp_diff[temp_diff > 0][-1])[0] if ind_before == len(temp_diff)-1: chain_sur_interp.append(np.nan) diff_days.append(np.abs(satdate-dates_sur_num[ind_before])[0]) continue ind_after = np.where(temp_diff == temp_diff[temp_diff < 0][0])[0] tempx = np.zeros(2) tempx[0] = dates_sur_num[ind_before] tempx[1] = dates_sur_num[ind_after] tempy = np.zeros(2) tempy[0] = chain_sur[ind_before] tempy[1] = chain_sur[ind_after] diff_days.append(np.abs(np.max([satdate-tempx[0], satdate-tempx[1]]))) # interpolate f = interpolate.interp1d(tempx, tempy) chain_sur_interp.append(f(satdate)) elif mod==1: # select the closest measurement idx_closest = utils.find_indices(np.abs(temp_diff), lambda e: e == np.min(np.abs(temp_diff)))[0] diff_days.append(np.abs(satdate-dates_sur_num[idx_closest])) if diff_days[j] > mindays: chain_sur_interp.append(np.nan) else: chain_sur_interp.append(chain_sur[idx_closest]) chain_sur_interp = np.array(chain_sur_interp) # remove nan values idx_sur_nan = ~np.isnan(chain_sur_interp) idx_sat_nan = ~np.isnan(chain_sds[:,i]) idx_nan = np.logical_and(idx_sur_nan, idx_sat_nan) # groundtruth and sds chain_sur_fin = chain_sur_interp[idx_nan] chain_sds_fin = chain_sds[idx_nan,i] dates_fin = [k for (k, v) in zip(dates_sds, idx_nan) if v] # calculate statistics slope, intercept, rvalue, pvalue, std_err = sstats.linregress(chain_sur_fin, chain_sds_fin) R2 = rvalue**2 correlation = np.corrcoef(chain_sur_fin, chain_sds_fin)[0,1] diff_chain = chain_sur_fin - chain_sds_fin rmse = np.sqrt(np.nanmean((diff_chain)**2)) mean = np.nanmean(diff_chain) std = np.nanstd(diff_chain) q90 = np.percentile(np.abs(diff_chain), 90) # store data stats[pfname]['rmse'] = rmse stats[pfname]['mean'] = mean stats[pfname]['std'] = std stats[pfname]['q90'] = q90 stats[pfname]['diffdays'] = diff_days stats[pfname]['corr'] = correlation stats[pfname]['linfit'] = {'slope':slope, 'intercept':intercept, 'R2':R2, 'pvalue':pvalue} data_fin[pfname]['dates'] = dates_fin data_fin[pfname]['sds'] = chain_sds_fin data_fin[pfname]['survey'] = chain_sur_fin # make time-series plot plt.figure(fig1.number) ax = fig1.add_subplot(gs1[i,0]) axfig1.append(ax) plt.plot(dates_sur, chain_sur, '-', color='C1', markersize=2, label='survey data') # plt.plot(dates_fin, chain_sur_fin, 'o', color=[0.3, 0.3, 0.3], markersize=2, label='survey interp') plt.plot(dates_fin, chain_sds_fin, 'o--', color='C0', markersize=4, alpha=1, label='satellite data') strtitle = '%s (correlation = %.2f)' % (pfname, correlation) plt.title(strtitle, fontweight='bold') plt.xlim([dates_sds[0], dates_sds[-1]]) plt.ylabel('cross-shore position [m]') plt.legend() # make scatter plot plt.figure(fig2.number) fig2.add_subplot(gs2[0,i]) plt.axis('equal') plt.plot(chain_sur_fin, chain_sds_fin, 'ko', markersize=4, markerfacecolor='w', alpha=0.7) xmax = np.max([np.nanmax(chain_sds_fin),np.nanmax(chain_sur_fin)]) xmin = np.min([np.nanmin(chain_sds_fin),np.nanmin(chain_sur_fin)]) ymax = np.max([np.nanmax(chain_sds_fin),np.nanmax(chain_sur_fin)]) ymin = np.min([np.nanmin(chain_sds_fin),np.nanmin(chain_sur_fin)]) plt.plot([xmin, xmax], [ymin, ymax], 'k--') plt.plot([xmin, xmax], [xmin*slope + intercept, xmax*slope + intercept], 'r:') str_corr = ' y = %.2f x + %.2f\n R2 = %.2f\n n = %d' % (slope, intercept, R2, len(diff_chain)) plt.text(xmin, 0.9*ymax, str_corr, bbox=dict(facecolor=[0.7,0.7,0.7], alpha=0.5), horizontalalignment='left') plt.xlabel('chainage survey [m]') plt.ylabel('chainage satellite [m]') plt.title(pfname, fontweight='bold') fig2.add_subplot(gs2[1,i]) binwidth = 3 bins = np.arange(min(diff_chain), max(diff_chain) + binwidth, binwidth) density = plt.hist(diff_chain, bins=bins, density=True, color=[0.8, 0.8, 0.8], edgecolor='k') plt.xlim([-50, 50]) plt.xlabel('error [m]') str_stats = ' rmse = %.1f\n mean = %.1f\n std = %.1f\n q90 = %.1f' % (rmse, mean, std, q90) plt.text(15, np.max(density[0])-0.015, str_stats, bbox=dict(facecolor=[0.8,0.8,0.8], alpha=0.3), horizontalalignment='left', fontsize=10) fig1.set_size_inches(19.2, 9.28) fig1.set_tight_layout(True) fig2.set_size_inches(19.2, 9.28) fig2.set_tight_layout(True) # all transects together chain_sds_all = [] chain_sur_all = [] for i in range(chain_sds.shape[1]): pfname = list(topo_profiles.keys())[i] chain_sds_all = np.append(chain_sds_all,data_fin[pfname]['sds']) chain_sur_all = np.append(chain_sur_all,data_fin[pfname]['survey']) # calculate statistics slope, intercept, rvalue, pvalue, std_err = sstats.linregress(chain_sur_all, chain_sds_all) R2 = rvalue**2 correlation = np.corrcoef(chain_sur_all, chain_sds_all)[0,1] diff_chain_all = chain_sur_all - chain_sds_all rmse = np.sqrt(np.nanmean((diff_chain_all)**2)) mean = np.nanmean(diff_chain_all) std = np.nanstd(diff_chain_all) q90 = np.percentile(np.abs(diff_chain_all), 90) stats['all'] = {'rmse':rmse,'mean':mean,'std':std,'q90':q90, 'corr':correlation, 'linfit':{'slope':slope, 'intercept':intercept, 'R2':R2, 'pvalue':pvalue}} # make plot plt.figure(fig3.number) fig3.add_subplot(gs3[0,0]) plt.axis('equal') plt.plot(chain_sur_all, chain_sds_all, 'ko', markersize=4, markerfacecolor='w', alpha=0.7) xmax = np.max([np.nanmax(chain_sds_all),np.nanmax(chain_sur_all)]) xmin = np.min([np.nanmin(chain_sds_all),np.nanmin(chain_sur_all)]) ymax = np.max([np.nanmax(chain_sds_all),np.nanmax(chain_sur_all)]) ymin = np.min([np.nanmin(chain_sds_all),np.nanmin(chain_sur_all)]) plt.plot([xmin, xmax], [ymin, ymax], 'k--') plt.plot([xmin, xmax], [xmin*slope + intercept, xmax*slope + intercept], 'r:') str_corr = ' y = %.2f x + %.2f\n R2 = %.2f\n n = %d' % (slope, intercept, R2, len(diff_chain_all)) plt.text(xmin, 0.9*ymax, str_corr, bbox=dict(facecolor=[0.7,0.7,0.7], alpha=0.5), horizontalalignment='left') plt.xlabel('chainage survey [m]') plt.ylabel('chainage satellite [m]') plt.title(pfname, fontweight='bold') fig3.add_subplot(gs3[1,0]) binwidth = 3 bins = np.arange(min(diff_chain_all), max(diff_chain_all) + binwidth, binwidth) density = plt.hist(diff_chain_all, bins=bins, density=True, color=[0.8, 0.8, 0.8], edgecolor='k') plt.xlim([-50, 50]) plt.xlabel('error [m]') plt.ylabel('pdf') str_stats = ' rmse = %.1f\n mean = %.1f\n std = %.1f\n q90 = %.1f' % (rmse, mean, std, q90) plt.text(15, np.max(density[0])-0.015, str_stats, bbox=dict(facecolor=[0.8,0.8,0.8], alpha=0.3), horizontalalignment='left', fontsize=10) fig3.set_size_inches(9.2, 9.28) fig3.set_tight_layout(True) # for i in range(len(axfig1)): # axfig1[i].set_ylim([0,150]) # Narrabeen data # axfig1[i].set_ylim([25,110]) # Tairua data return stats