forked from kilianv/CoastSat_WRL
Supprimer 'functions/data_analysis.py'
parent
c32c0da22b
commit
f9e8468015
@ -1,446 +0,0 @@
|
||||
"""This module contains all the functions needed for data analysis """
|
||||
|
||||
# Initial settings
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as mpatches
|
||||
from matplotlib import gridspec
|
||||
import pdb
|
||||
import ee
|
||||
|
||||
# other modules
|
||||
from osgeo import gdal, ogr, osr
|
||||
import scipy.interpolate as interpolate
|
||||
import scipy.stats as sstats
|
||||
|
||||
# image processing modules
|
||||
import skimage.filters as filters
|
||||
import skimage.exposure as exposure
|
||||
import skimage.transform as transform
|
||||
import sklearn.decomposition as decomposition
|
||||
import skimage.measure as measure
|
||||
import skimage.morphology as morphology
|
||||
|
||||
# machine learning modules
|
||||
from sklearn.cluster import KMeans
|
||||
from sklearn.neural_network import MLPClassifier
|
||||
from sklearn.externals import joblib
|
||||
|
||||
import time
|
||||
|
||||
# import own modules
|
||||
import functions.utils as utils
|
||||
|
||||
def get_tide(dates_sds, dates_tide, tide_level):
|
||||
|
||||
tide = []
|
||||
for i in range(len(dates_sds)):
|
||||
dates_diff = np.abs(np.array([ (dates_sds[i] - _).total_seconds() for _ in dates_tide]))
|
||||
if np.min(dates_diff) <= 1800: # half-an-hour
|
||||
idx_closest = np.argmin(dates_diff)
|
||||
tide.append(tide_level[idx_closest])
|
||||
else:
|
||||
tide.append(np.nan)
|
||||
tide = np.array(tide)
|
||||
|
||||
return tide
|
||||
|
||||
def remove_duplicates(output, satname):
|
||||
" removes duplicates from output structure, keep the one with less cloud cover or best georeferencing "
|
||||
dates = output['dates']
|
||||
dates_str = [_.strftime('%Y%m%d') for _ in dates]
|
||||
dupl = utils.duplicates_dict(dates_str)
|
||||
if dupl:
|
||||
output_nodup = dict([])
|
||||
idx_remove = []
|
||||
if satname == 'L8' or satname == 'L5':
|
||||
for k,v in dupl.items():
|
||||
|
||||
idx1 = v[0]
|
||||
idx2 = v[1]
|
||||
|
||||
c1 = output['metadata']['cloud_cover'][idx1]
|
||||
c2 = output['metadata']['cloud_cover'][idx2]
|
||||
g1 = output['metadata']['acc_georef'][idx1]
|
||||
g2 = output['metadata']['acc_georef'][idx2]
|
||||
|
||||
if c1 < c2 - 0.01:
|
||||
idx_remove.append(idx2)
|
||||
elif g1 < g2 - 0.1:
|
||||
idx_remove.append(idx2)
|
||||
else:
|
||||
idx_remove.append(idx1)
|
||||
|
||||
else:
|
||||
for k,v in dupl.items():
|
||||
|
||||
idx1 = v[0]
|
||||
idx2 = v[1]
|
||||
|
||||
c1 = output['metadata']['cloud_cover'][idx1]
|
||||
c2 = output['metadata']['cloud_cover'][idx2]
|
||||
|
||||
if c1 < c2 - 0.01:
|
||||
idx_remove.append(idx2)
|
||||
else:
|
||||
idx_remove.append(idx1)
|
||||
|
||||
idx_remove = sorted(idx_remove)
|
||||
idx_all = np.linspace(0, len(dates_str)-1, len(dates_str))
|
||||
idx_keep = list(np.where(~np.isin(idx_all,idx_remove))[0])
|
||||
|
||||
output_nodup['dates'] = [output['dates'][k] for k in idx_keep]
|
||||
output_nodup['shorelines'] = [output['shorelines'][k] for k in idx_keep]
|
||||
output_nodup['metadata'] = dict([])
|
||||
for key in list(output['metadata'].keys()):
|
||||
output_nodup['metadata'][key] = [output['metadata'][key][k] for k in idx_keep]
|
||||
print(satname + ' : ' + str(len(idx_remove)) + ' duplicates')
|
||||
return output_nodup
|
||||
|
||||
else:
|
||||
print(satname + ' : ' + 'no duplicates')
|
||||
return output
|
||||
|
||||
|
||||
def merge(output):
|
||||
" merges data from the different satellites "
|
||||
|
||||
# stack all list together under one key
|
||||
output_all = {'dates':[], 'shorelines':[],
|
||||
'metadata':{'filenames':[], 'satname':[], 'cloud_cover':[], 'acc_georef':[]}}
|
||||
for satname in list(output.keys()):
|
||||
output_all['dates'] = output_all['dates'] + output[satname]['dates']
|
||||
output_all['shorelines'] = output_all['shorelines'] + output[satname]['shorelines']
|
||||
for key in list(output[satname]['metadata'].keys()):
|
||||
output_all['metadata'][key] = output_all['metadata'][key] + output[satname]['metadata'][key]
|
||||
|
||||
output_all_sorted = {'dates':[], 'shorelines':[],
|
||||
'metadata':{'filenames':[], 'satname':[], 'cloud_cover':[], 'acc_georef':[]}}
|
||||
# sort the dates
|
||||
idx_sorted = sorted(range(len(output_all['dates'])), key=output_all['dates'].__getitem__)
|
||||
output_all_sorted['dates'] = [output_all['dates'][i] for i in idx_sorted]
|
||||
output_all_sorted['shorelines'] = [output_all['shorelines'][i] for i in idx_sorted]
|
||||
for key in list(output_all['metadata'].keys()):
|
||||
output_all_sorted['metadata'][key] = [output_all['metadata'][key][i] for i in idx_sorted]
|
||||
|
||||
return output_all_sorted
|
||||
|
||||
def create_transects(x0, y0, orientation, chainage_length):
|
||||
" creates shore-normal transects "
|
||||
|
||||
transects = []
|
||||
|
||||
for k in range(len(x0)):
|
||||
|
||||
# orientation of cross-shore profile
|
||||
phi = (90 - orientation[k])*np.pi/180
|
||||
|
||||
# create a vector using the chainage length
|
||||
x = np.linspace(0,chainage_length,chainage_length+1)
|
||||
y = np.zeros(len(x))
|
||||
coords = np.zeros((len(x),2))
|
||||
coords[:,0] = x
|
||||
coords[:,1] = y
|
||||
|
||||
# translate and rotate the vector using the origin and orientation
|
||||
tf = transform.EuclideanTransform(rotation=phi, translation=(x0[k],y0[k]))
|
||||
coords_tf = tf(coords)
|
||||
|
||||
transects.append(coords_tf)
|
||||
|
||||
return transects
|
||||
|
||||
def calculate_chainage(sds, transects, orientation, along_dist):
|
||||
" intersect SDS with transect and compute chainage position "
|
||||
|
||||
chainage_mtx = np.zeros((len(sds),len(transects),6))
|
||||
|
||||
for i in range(len(sds)):
|
||||
|
||||
sl = sds[i]
|
||||
|
||||
for j in range(len(transects)):
|
||||
|
||||
# compute rotation matrix
|
||||
X0 = transects[j][0,0]
|
||||
Y0 = transects[j][0,1]
|
||||
phi = (90 - orientation[j])*np.pi/180
|
||||
Mrot = np.array([[np.cos(phi), np.sin(phi)],[-np.sin(phi), np.cos(phi)]])
|
||||
|
||||
# calculate point to line distance between shoreline points and profile
|
||||
p1 = np.array([X0,Y0])
|
||||
p2 = transects[j][-1,:]
|
||||
p3 = sl
|
||||
d = np.abs(np.cross(p2-p1,p3-p1)/np.linalg.norm(p2-p1))
|
||||
idx_close = utils.find_indices(d, lambda e: e <= along_dist)
|
||||
|
||||
# check if there are SDS points around the profile or not
|
||||
if not idx_close:
|
||||
chainage_mtx[i,j,:] = np.tile(np.nan,(1,6))
|
||||
|
||||
else:
|
||||
# change of base to shore-normal coordinate system
|
||||
xy_close = np.array([sl[idx_close,0],sl[idx_close,1]]) - np.tile(np.array([[X0],[Y0]]), (1,len(sl[idx_close])))
|
||||
xy_rot = np.matmul(Mrot, xy_close)
|
||||
|
||||
# put nan values if the chainage is negative (MAKE SURE TO PICK ORIGIN CORRECTLY)
|
||||
if np.any(xy_rot[0,:] < 0):
|
||||
xy_rot[0,np.where(xy_rot[0,:] < 0)] = np.nan
|
||||
|
||||
# compute mean, median max and std of chainage position
|
||||
n_points = len(xy_rot[0,:])
|
||||
mean_cross = np.nanmean(xy_rot[0,:])
|
||||
median_cross = np.nanmedian(xy_rot[0,:])
|
||||
max_cross = np.nanmax(xy_rot[0,:])
|
||||
min_cross = np.nanmin(xy_rot[0,:])
|
||||
std_cross = np.nanstd(xy_rot[0,:])
|
||||
###################################################
|
||||
if std_cross > 10: # if large std, take the most seaward point
|
||||
mean_cross = max_cross
|
||||
median_cross = max_cross
|
||||
min_cross = max_cross
|
||||
# mean_cross = np.nan
|
||||
# median_cross = np.nan
|
||||
# min_cross = np.nan
|
||||
|
||||
# store the statistics
|
||||
chainage_mtx[i,j,:] = np.array([mean_cross, median_cross, max_cross,
|
||||
min_cross, n_points, std_cross])
|
||||
|
||||
# format into dictionnary
|
||||
chainage = dict([])
|
||||
chainage['mean'] = chainage_mtx[:,:,0]
|
||||
chainage['median'] = chainage_mtx[:,:,1]
|
||||
chainage['max'] = chainage_mtx[:,:,2]
|
||||
chainage['min'] = chainage_mtx[:,:,3]
|
||||
chainage['npoints'] = chainage_mtx[:,:,4]
|
||||
chainage['std'] = chainage_mtx[:,:,5]
|
||||
|
||||
return chainage
|
||||
|
||||
def compare_sds(dates_sds, chain_sds, topo_profiles, mod=0, mindays=5):
|
||||
"""
|
||||
Compare sds with groundtruth data from topographic surveys / argus shorelines
|
||||
|
||||
KV WRL 2018
|
||||
|
||||
Arguments:
|
||||
-----------
|
||||
dates_sds: list
|
||||
list of dates corresponding to each row in chain_sds
|
||||
chain_sds: np.ndarray
|
||||
array with time series of chainage for each transect (each transect is one column)
|
||||
topo_profiles: dict
|
||||
dict containing the dates and chainage of the groundtruth
|
||||
mod: 0 or 1
|
||||
0 for linear interpolation between 2 closest surveys, 1 for only nearest neighbour
|
||||
min_days: int
|
||||
minimum number of days for which the data can be compared
|
||||
|
||||
Returns: -----------
|
||||
stats: dict
|
||||
contains all the statistics of the comparison
|
||||
|
||||
"""
|
||||
|
||||
# create 3 figures
|
||||
fig1 = plt.figure()
|
||||
gs1 = gridspec.GridSpec(chain_sds.shape[1], 1)
|
||||
axfig1 = []
|
||||
fig2 = plt.figure()
|
||||
gs2 = gridspec.GridSpec(2, chain_sds.shape[1])
|
||||
axfig2 = []
|
||||
fig3 = plt.figure()
|
||||
gs3 = gridspec.GridSpec(2,1)
|
||||
axfig3 = []
|
||||
|
||||
dates_sds_num = np.array([_.toordinal() for _ in dates_sds])
|
||||
stats = dict([])
|
||||
data_fin = dict([])
|
||||
|
||||
# for each transect compare and plot the data
|
||||
for i in range(chain_sds.shape[1]):
|
||||
|
||||
pfname = list(topo_profiles.keys())[i]
|
||||
stats[pfname] = dict([])
|
||||
data_fin[pfname] = dict([])
|
||||
|
||||
dates_sur = topo_profiles[pfname]['dates']
|
||||
chain_sur = topo_profiles[pfname]['chainage']
|
||||
|
||||
# convert to datenum
|
||||
dates_sur_num = np.array([_.toordinal() for _ in dates_sur])
|
||||
|
||||
chain_sur_interp = []
|
||||
diff_days = []
|
||||
|
||||
for j, satdate in enumerate(dates_sds_num):
|
||||
|
||||
temp_diff = satdate - dates_sur_num
|
||||
|
||||
if mod==0:
|
||||
# select measurement before and after sat image date and interpolate
|
||||
|
||||
ind_before = np.where(temp_diff == temp_diff[temp_diff > 0][-1])[0]
|
||||
if ind_before == len(temp_diff)-1:
|
||||
chain_sur_interp.append(np.nan)
|
||||
diff_days.append(np.abs(satdate-dates_sur_num[ind_before])[0])
|
||||
continue
|
||||
ind_after = np.where(temp_diff == temp_diff[temp_diff < 0][0])[0]
|
||||
tempx = np.zeros(2)
|
||||
tempx[0] = dates_sur_num[ind_before]
|
||||
tempx[1] = dates_sur_num[ind_after]
|
||||
tempy = np.zeros(2)
|
||||
tempy[0] = chain_sur[ind_before]
|
||||
tempy[1] = chain_sur[ind_after]
|
||||
diff_days.append(np.abs(np.max([satdate-tempx[0], satdate-tempx[1]])))
|
||||
# interpolate
|
||||
f = interpolate.interp1d(tempx, tempy)
|
||||
chain_sur_interp.append(f(satdate))
|
||||
|
||||
elif mod==1:
|
||||
# select the closest measurement
|
||||
|
||||
idx_closest = utils.find_indices(np.abs(temp_diff), lambda e: e == np.min(np.abs(temp_diff)))[0]
|
||||
diff_days.append(np.abs(satdate-dates_sur_num[idx_closest]))
|
||||
if diff_days[j] > mindays:
|
||||
chain_sur_interp.append(np.nan)
|
||||
else:
|
||||
chain_sur_interp.append(chain_sur[idx_closest])
|
||||
|
||||
chain_sur_interp = np.array(chain_sur_interp)
|
||||
|
||||
# remove nan values
|
||||
idx_sur_nan = ~np.isnan(chain_sur_interp)
|
||||
idx_sat_nan = ~np.isnan(chain_sds[:,i])
|
||||
idx_nan = np.logical_and(idx_sur_nan, idx_sat_nan)
|
||||
|
||||
# groundtruth and sds
|
||||
chain_sur_fin = chain_sur_interp[idx_nan]
|
||||
chain_sds_fin = chain_sds[idx_nan,i]
|
||||
dates_fin = [k for (k, v) in zip(dates_sds, idx_nan) if v]
|
||||
|
||||
# calculate statistics
|
||||
slope, intercept, rvalue, pvalue, std_err = sstats.linregress(chain_sur_fin, chain_sds_fin)
|
||||
R2 = rvalue**2
|
||||
correlation = np.corrcoef(chain_sur_fin, chain_sds_fin)[0,1]
|
||||
diff_chain = chain_sur_fin - chain_sds_fin
|
||||
|
||||
rmse = np.sqrt(np.nanmean((diff_chain)**2))
|
||||
mean = np.nanmean(diff_chain)
|
||||
std = np.nanstd(diff_chain)
|
||||
q90 = np.percentile(np.abs(diff_chain), 90)
|
||||
|
||||
# store data
|
||||
stats[pfname]['rmse'] = rmse
|
||||
stats[pfname]['mean'] = mean
|
||||
stats[pfname]['std'] = std
|
||||
stats[pfname]['q90'] = q90
|
||||
stats[pfname]['diffdays'] = diff_days
|
||||
stats[pfname]['corr'] = correlation
|
||||
stats[pfname]['linfit'] = {'slope':slope, 'intercept':intercept, 'R2':R2, 'pvalue':pvalue}
|
||||
|
||||
data_fin[pfname]['dates'] = dates_fin
|
||||
data_fin[pfname]['sds'] = chain_sds_fin
|
||||
data_fin[pfname]['survey'] = chain_sur_fin
|
||||
|
||||
# make time-series plot
|
||||
plt.figure(fig1.number)
|
||||
ax = fig1.add_subplot(gs1[i,0])
|
||||
axfig1.append(ax)
|
||||
plt.plot(dates_sur, chain_sur, '-', color='C1', markersize=2, label='survey data')
|
||||
# plt.plot(dates_fin, chain_sur_fin, 'o', color=[0.3, 0.3, 0.3], markersize=2, label='survey interp')
|
||||
plt.plot(dates_fin, chain_sds_fin, 'o--', color='C0', markersize=4, alpha=1, label='satellite data')
|
||||
strtitle = '%s (correlation = %.2f)' % (pfname, correlation)
|
||||
plt.title(strtitle, fontweight='bold')
|
||||
plt.xlim([dates_sds[0], dates_sds[-1]])
|
||||
plt.ylabel('cross-shore position [m]')
|
||||
plt.legend()
|
||||
|
||||
# make scatter plot
|
||||
plt.figure(fig2.number)
|
||||
fig2.add_subplot(gs2[0,i])
|
||||
plt.axis('equal')
|
||||
plt.plot(chain_sur_fin, chain_sds_fin, 'ko', markersize=4, markerfacecolor='w', alpha=0.7)
|
||||
xmax = np.max([np.nanmax(chain_sds_fin),np.nanmax(chain_sur_fin)])
|
||||
xmin = np.min([np.nanmin(chain_sds_fin),np.nanmin(chain_sur_fin)])
|
||||
ymax = np.max([np.nanmax(chain_sds_fin),np.nanmax(chain_sur_fin)])
|
||||
ymin = np.min([np.nanmin(chain_sds_fin),np.nanmin(chain_sur_fin)])
|
||||
plt.plot([xmin, xmax], [ymin, ymax], 'k--')
|
||||
plt.plot([xmin, xmax], [xmin*slope + intercept, xmax*slope + intercept], 'r:')
|
||||
str_corr = ' y = %.2f x + %.2f\n R2 = %.2f\n n = %d' % (slope, intercept, R2, len(diff_chain))
|
||||
plt.text(xmin, 0.9*ymax, str_corr, bbox=dict(facecolor=[0.7,0.7,0.7], alpha=0.5), horizontalalignment='left')
|
||||
plt.xlabel('chainage survey [m]')
|
||||
plt.ylabel('chainage satellite [m]')
|
||||
plt.title(pfname, fontweight='bold')
|
||||
|
||||
fig2.add_subplot(gs2[1,i])
|
||||
binwidth = 3
|
||||
bins = np.arange(min(diff_chain), max(diff_chain) + binwidth, binwidth)
|
||||
density = plt.hist(diff_chain, bins=bins, density=True, color=[0.8, 0.8, 0.8], edgecolor='k')
|
||||
plt.xlim([-50, 50])
|
||||
plt.xlabel('error [m]')
|
||||
str_stats = ' rmse = %.1f\n mean = %.1f\n std = %.1f\n q90 = %.1f' % (rmse, mean, std, q90)
|
||||
plt.text(15, np.max(density[0])-0.015, str_stats, bbox=dict(facecolor=[0.8,0.8,0.8], alpha=0.3), horizontalalignment='left', fontsize=10)
|
||||
|
||||
fig1.set_size_inches(19.2, 9.28)
|
||||
fig1.set_tight_layout(True)
|
||||
fig2.set_size_inches(19.2, 9.28)
|
||||
fig2.set_tight_layout(True)
|
||||
|
||||
# all transects together
|
||||
chain_sds_all = []
|
||||
chain_sur_all = []
|
||||
for i in range(chain_sds.shape[1]):
|
||||
pfname = list(topo_profiles.keys())[i]
|
||||
chain_sds_all = np.append(chain_sds_all,data_fin[pfname]['sds'])
|
||||
chain_sur_all = np.append(chain_sur_all,data_fin[pfname]['survey'])
|
||||
|
||||
# calculate statistics
|
||||
slope, intercept, rvalue, pvalue, std_err = sstats.linregress(chain_sur_all, chain_sds_all)
|
||||
R2 = rvalue**2
|
||||
correlation = np.corrcoef(chain_sur_all, chain_sds_all)[0,1]
|
||||
diff_chain_all = chain_sur_all - chain_sds_all
|
||||
|
||||
rmse = np.sqrt(np.nanmean((diff_chain_all)**2))
|
||||
mean = np.nanmean(diff_chain_all)
|
||||
std = np.nanstd(diff_chain_all)
|
||||
q90 = np.percentile(np.abs(diff_chain_all), 90)
|
||||
|
||||
stats['all'] = {'rmse':rmse,'mean':mean,'std':std,'q90':q90, 'corr':correlation,
|
||||
'linfit':{'slope':slope, 'intercept':intercept, 'R2':R2, 'pvalue':pvalue}}
|
||||
|
||||
# make plot
|
||||
plt.figure(fig3.number)
|
||||
fig3.add_subplot(gs3[0,0])
|
||||
plt.axis('equal')
|
||||
plt.plot(chain_sur_all, chain_sds_all, 'ko', markersize=4, markerfacecolor='w', alpha=0.7)
|
||||
xmax = np.max([np.nanmax(chain_sds_all),np.nanmax(chain_sur_all)])
|
||||
xmin = np.min([np.nanmin(chain_sds_all),np.nanmin(chain_sur_all)])
|
||||
ymax = np.max([np.nanmax(chain_sds_all),np.nanmax(chain_sur_all)])
|
||||
ymin = np.min([np.nanmin(chain_sds_all),np.nanmin(chain_sur_all)])
|
||||
plt.plot([xmin, xmax], [ymin, ymax], 'k--')
|
||||
plt.plot([xmin, xmax], [xmin*slope + intercept, xmax*slope + intercept], 'r:')
|
||||
str_corr = ' y = %.2f x + %.2f\n R2 = %.2f\n n = %d' % (slope, intercept, R2, len(diff_chain_all))
|
||||
plt.text(xmin, 0.9*ymax, str_corr, bbox=dict(facecolor=[0.7,0.7,0.7], alpha=0.5), horizontalalignment='left')
|
||||
plt.xlabel('chainage survey [m]')
|
||||
plt.ylabel('chainage satellite [m]')
|
||||
plt.title(pfname, fontweight='bold')
|
||||
|
||||
fig3.add_subplot(gs3[1,0])
|
||||
binwidth = 3
|
||||
bins = np.arange(min(diff_chain_all), max(diff_chain_all) + binwidth, binwidth)
|
||||
density = plt.hist(diff_chain_all, bins=bins, density=True, color=[0.8, 0.8, 0.8], edgecolor='k')
|
||||
plt.xlim([-50, 50])
|
||||
plt.xlabel('error [m]')
|
||||
plt.ylabel('pdf')
|
||||
str_stats = ' rmse = %.1f\n mean = %.1f\n std = %.1f\n q90 = %.1f' % (rmse, mean, std, q90)
|
||||
plt.text(15, np.max(density[0])-0.015, str_stats, bbox=dict(facecolor=[0.8,0.8,0.8], alpha=0.3), horizontalalignment='left', fontsize=10)
|
||||
fig3.set_size_inches(9.2, 9.28)
|
||||
fig3.set_tight_layout(True)
|
||||
|
||||
# for i in range(len(axfig1)):
|
||||
# axfig1[i].set_ylim([0,150]) # Narrabeen data
|
||||
# axfig1[i].set_ylim([25,110]) # Tairua data
|
||||
|
||||
return stats
|
Loading…
Reference in New Issue