You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
446 lines
18 KiB
Python
446 lines
18 KiB
Python
"""This module contains all the functions needed for data analysis """
|
|
|
|
# Initial settings
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.patches as mpatches
|
|
from matplotlib import gridspec
|
|
import pdb
|
|
import ee
|
|
|
|
# other modules
|
|
from osgeo import gdal, ogr, osr
|
|
import scipy.interpolate as interpolate
|
|
import scipy.stats as sstats
|
|
|
|
# image processing modules
|
|
import skimage.filters as filters
|
|
import skimage.exposure as exposure
|
|
import skimage.transform as transform
|
|
import sklearn.decomposition as decomposition
|
|
import skimage.measure as measure
|
|
import skimage.morphology as morphology
|
|
|
|
# machine learning modules
|
|
from sklearn.cluster import KMeans
|
|
from sklearn.neural_network import MLPClassifier
|
|
from sklearn.externals import joblib
|
|
|
|
import time
|
|
|
|
# import own modules
|
|
import functions.utils as utils
|
|
|
|
def get_tide(dates_sds, dates_tide, tide_level):
|
|
|
|
tide = []
|
|
for i in range(len(dates_sds)):
|
|
dates_diff = np.abs(np.array([ (dates_sds[i] - _).total_seconds() for _ in dates_tide]))
|
|
if np.min(dates_diff) <= 1800: # half-an-hour
|
|
idx_closest = np.argmin(dates_diff)
|
|
tide.append(tide_level[idx_closest])
|
|
else:
|
|
tide.append(np.nan)
|
|
tide = np.array(tide)
|
|
|
|
return tide
|
|
|
|
def remove_duplicates(output, satname):
|
|
" removes duplicates from output structure, keep the one with less cloud cover or best georeferencing "
|
|
dates = output['dates']
|
|
dates_str = [_.strftime('%Y%m%d') for _ in dates]
|
|
dupl = utils.duplicates_dict(dates_str)
|
|
if dupl:
|
|
output_nodup = dict([])
|
|
idx_remove = []
|
|
if satname == 'L8' or satname == 'L5':
|
|
for k,v in dupl.items():
|
|
|
|
idx1 = v[0]
|
|
idx2 = v[1]
|
|
|
|
c1 = output['metadata']['cloud_cover'][idx1]
|
|
c2 = output['metadata']['cloud_cover'][idx2]
|
|
g1 = output['metadata']['acc_georef'][idx1]
|
|
g2 = output['metadata']['acc_georef'][idx2]
|
|
|
|
if c1 < c2 - 0.01:
|
|
idx_remove.append(idx2)
|
|
elif g1 < g2 - 0.1:
|
|
idx_remove.append(idx2)
|
|
else:
|
|
idx_remove.append(idx1)
|
|
|
|
else:
|
|
for k,v in dupl.items():
|
|
|
|
idx1 = v[0]
|
|
idx2 = v[1]
|
|
|
|
c1 = output['metadata']['cloud_cover'][idx1]
|
|
c2 = output['metadata']['cloud_cover'][idx2]
|
|
|
|
if c1 < c2 - 0.01:
|
|
idx_remove.append(idx2)
|
|
else:
|
|
idx_remove.append(idx1)
|
|
|
|
idx_remove = sorted(idx_remove)
|
|
idx_all = np.linspace(0, len(dates_str)-1, len(dates_str))
|
|
idx_keep = list(np.where(~np.isin(idx_all,idx_remove))[0])
|
|
|
|
output_nodup['dates'] = [output['dates'][k] for k in idx_keep]
|
|
output_nodup['shorelines'] = [output['shorelines'][k] for k in idx_keep]
|
|
output_nodup['metadata'] = dict([])
|
|
for key in list(output['metadata'].keys()):
|
|
output_nodup['metadata'][key] = [output['metadata'][key][k] for k in idx_keep]
|
|
print(satname + ' : ' + str(len(idx_remove)) + ' duplicates')
|
|
return output_nodup
|
|
|
|
else:
|
|
print(satname + ' : ' + 'no duplicates')
|
|
return output
|
|
|
|
|
|
def merge(output):
|
|
" merges data from the different satellites "
|
|
|
|
# stack all list together under one key
|
|
output_all = {'dates':[], 'shorelines':[],
|
|
'metadata':{'filenames':[], 'satname':[], 'cloud_cover':[], 'acc_georef':[]}}
|
|
for satname in list(output.keys()):
|
|
output_all['dates'] = output_all['dates'] + output[satname]['dates']
|
|
output_all['shorelines'] = output_all['shorelines'] + output[satname]['shorelines']
|
|
for key in list(output[satname]['metadata'].keys()):
|
|
output_all['metadata'][key] = output_all['metadata'][key] + output[satname]['metadata'][key]
|
|
|
|
output_all_sorted = {'dates':[], 'shorelines':[],
|
|
'metadata':{'filenames':[], 'satname':[], 'cloud_cover':[], 'acc_georef':[]}}
|
|
# sort the dates
|
|
idx_sorted = sorted(range(len(output_all['dates'])), key=output_all['dates'].__getitem__)
|
|
output_all_sorted['dates'] = [output_all['dates'][i] for i in idx_sorted]
|
|
output_all_sorted['shorelines'] = [output_all['shorelines'][i] for i in idx_sorted]
|
|
for key in list(output_all['metadata'].keys()):
|
|
output_all_sorted['metadata'][key] = [output_all['metadata'][key][i] for i in idx_sorted]
|
|
|
|
return output_all_sorted
|
|
|
|
def create_transects(x0, y0, orientation, chainage_length):
|
|
" creates shore-normal transects "
|
|
|
|
transects = []
|
|
|
|
for k in range(len(x0)):
|
|
|
|
# orientation of cross-shore profile
|
|
phi = (90 - orientation[k])*np.pi/180
|
|
|
|
# create a vector using the chainage length
|
|
x = np.linspace(0,chainage_length,chainage_length+1)
|
|
y = np.zeros(len(x))
|
|
coords = np.zeros((len(x),2))
|
|
coords[:,0] = x
|
|
coords[:,1] = y
|
|
|
|
# translate and rotate the vector using the origin and orientation
|
|
tf = transform.EuclideanTransform(rotation=phi, translation=(x0[k],y0[k]))
|
|
coords_tf = tf(coords)
|
|
|
|
transects.append(coords_tf)
|
|
|
|
return transects
|
|
|
|
def calculate_chainage(sds, transects, orientation, along_dist):
|
|
" intersect SDS with transect and compute chainage position "
|
|
|
|
chainage_mtx = np.zeros((len(sds),len(transects),6))
|
|
|
|
for i in range(len(sds)):
|
|
|
|
sl = sds[i]
|
|
|
|
for j in range(len(transects)):
|
|
|
|
# compute rotation matrix
|
|
X0 = transects[j][0,0]
|
|
Y0 = transects[j][0,1]
|
|
phi = (90 - orientation[j])*np.pi/180
|
|
Mrot = np.array([[np.cos(phi), np.sin(phi)],[-np.sin(phi), np.cos(phi)]])
|
|
|
|
# calculate point to line distance between shoreline points and profile
|
|
p1 = np.array([X0,Y0])
|
|
p2 = transects[j][-1,:]
|
|
p3 = sl
|
|
d = np.abs(np.cross(p2-p1,p3-p1)/np.linalg.norm(p2-p1))
|
|
idx_close = utils.find_indices(d, lambda e: e <= along_dist)
|
|
|
|
# check if there are SDS points around the profile or not
|
|
if not idx_close:
|
|
chainage_mtx[i,j,:] = np.tile(np.nan,(1,6))
|
|
|
|
else:
|
|
# change of base to shore-normal coordinate system
|
|
xy_close = np.array([sl[idx_close,0],sl[idx_close,1]]) - np.tile(np.array([[X0],[Y0]]), (1,len(sl[idx_close])))
|
|
xy_rot = np.matmul(Mrot, xy_close)
|
|
|
|
# put nan values if the chainage is negative (MAKE SURE TO PICK ORIGIN CORRECTLY)
|
|
if np.any(xy_rot[0,:] < 0):
|
|
xy_rot[0,np.where(xy_rot[0,:] < 0)] = np.nan
|
|
|
|
# compute mean, median max and std of chainage position
|
|
n_points = len(xy_rot[0,:])
|
|
mean_cross = np.nanmean(xy_rot[0,:])
|
|
median_cross = np.nanmedian(xy_rot[0,:])
|
|
max_cross = np.nanmax(xy_rot[0,:])
|
|
min_cross = np.nanmin(xy_rot[0,:])
|
|
std_cross = np.nanstd(xy_rot[0,:])
|
|
###################################################
|
|
if std_cross > 10: # if large std, take the most seaward point
|
|
mean_cross = max_cross
|
|
median_cross = max_cross
|
|
min_cross = max_cross
|
|
# mean_cross = np.nan
|
|
# median_cross = np.nan
|
|
# min_cross = np.nan
|
|
|
|
# store the statistics
|
|
chainage_mtx[i,j,:] = np.array([mean_cross, median_cross, max_cross,
|
|
min_cross, n_points, std_cross])
|
|
|
|
# format into dictionnary
|
|
chainage = dict([])
|
|
chainage['mean'] = chainage_mtx[:,:,0]
|
|
chainage['median'] = chainage_mtx[:,:,1]
|
|
chainage['max'] = chainage_mtx[:,:,2]
|
|
chainage['min'] = chainage_mtx[:,:,3]
|
|
chainage['npoints'] = chainage_mtx[:,:,4]
|
|
chainage['std'] = chainage_mtx[:,:,5]
|
|
|
|
return chainage
|
|
|
|
def compare_sds(dates_sds, chain_sds, topo_profiles, mod=0, mindays=5):
|
|
"""
|
|
Compare sds with groundtruth data from topographic surveys / argus shorelines
|
|
|
|
KV WRL 2018
|
|
|
|
Arguments:
|
|
-----------
|
|
dates_sds: list
|
|
list of dates corresponding to each row in chain_sds
|
|
chain_sds: np.ndarray
|
|
array with time series of chainage for each transect (each transect is one column)
|
|
topo_profiles: dict
|
|
dict containing the dates and chainage of the groundtruth
|
|
mod: 0 or 1
|
|
0 for linear interpolation between 2 closest surveys, 1 for only nearest neighbour
|
|
min_days: int
|
|
minimum number of days for which the data can be compared
|
|
|
|
Returns: -----------
|
|
stats: dict
|
|
contains all the statistics of the comparison
|
|
|
|
"""
|
|
|
|
# create 3 figures
|
|
fig1 = plt.figure()
|
|
gs1 = gridspec.GridSpec(chain_sds.shape[1], 1)
|
|
axfig1 = []
|
|
fig2 = plt.figure()
|
|
gs2 = gridspec.GridSpec(2, chain_sds.shape[1])
|
|
axfig2 = []
|
|
fig3 = plt.figure()
|
|
gs3 = gridspec.GridSpec(2,1)
|
|
axfig3 = []
|
|
|
|
dates_sds_num = np.array([_.toordinal() for _ in dates_sds])
|
|
stats = dict([])
|
|
data_fin = dict([])
|
|
|
|
# for each transect compare and plot the data
|
|
for i in range(chain_sds.shape[1]):
|
|
|
|
pfname = list(topo_profiles.keys())[i]
|
|
stats[pfname] = dict([])
|
|
data_fin[pfname] = dict([])
|
|
|
|
dates_sur = topo_profiles[pfname]['dates']
|
|
chain_sur = topo_profiles[pfname]['chainage']
|
|
|
|
# convert to datenum
|
|
dates_sur_num = np.array([_.toordinal() for _ in dates_sur])
|
|
|
|
chain_sur_interp = []
|
|
diff_days = []
|
|
|
|
for j, satdate in enumerate(dates_sds_num):
|
|
|
|
temp_diff = satdate - dates_sur_num
|
|
|
|
if mod==0:
|
|
# select measurement before and after sat image date and interpolate
|
|
|
|
ind_before = np.where(temp_diff == temp_diff[temp_diff > 0][-1])[0]
|
|
if ind_before == len(temp_diff)-1:
|
|
chain_sur_interp.append(np.nan)
|
|
diff_days.append(np.abs(satdate-dates_sur_num[ind_before])[0])
|
|
continue
|
|
ind_after = np.where(temp_diff == temp_diff[temp_diff < 0][0])[0]
|
|
tempx = np.zeros(2)
|
|
tempx[0] = dates_sur_num[ind_before]
|
|
tempx[1] = dates_sur_num[ind_after]
|
|
tempy = np.zeros(2)
|
|
tempy[0] = chain_sur[ind_before]
|
|
tempy[1] = chain_sur[ind_after]
|
|
diff_days.append(np.abs(np.max([satdate-tempx[0], satdate-tempx[1]])))
|
|
# interpolate
|
|
f = interpolate.interp1d(tempx, tempy)
|
|
chain_sur_interp.append(f(satdate))
|
|
|
|
elif mod==1:
|
|
# select the closest measurement
|
|
|
|
idx_closest = utils.find_indices(np.abs(temp_diff), lambda e: e == np.min(np.abs(temp_diff)))[0]
|
|
diff_days.append(np.abs(satdate-dates_sur_num[idx_closest]))
|
|
if diff_days[j] > mindays:
|
|
chain_sur_interp.append(np.nan)
|
|
else:
|
|
chain_sur_interp.append(chain_sur[idx_closest])
|
|
|
|
chain_sur_interp = np.array(chain_sur_interp)
|
|
|
|
# remove nan values
|
|
idx_sur_nan = ~np.isnan(chain_sur_interp)
|
|
idx_sat_nan = ~np.isnan(chain_sds[:,i])
|
|
idx_nan = np.logical_and(idx_sur_nan, idx_sat_nan)
|
|
|
|
# groundtruth and sds
|
|
chain_sur_fin = chain_sur_interp[idx_nan]
|
|
chain_sds_fin = chain_sds[idx_nan,i]
|
|
dates_fin = [k for (k, v) in zip(dates_sds, idx_nan) if v]
|
|
|
|
# calculate statistics
|
|
slope, intercept, rvalue, pvalue, std_err = sstats.linregress(chain_sur_fin, chain_sds_fin)
|
|
R2 = rvalue**2
|
|
correlation = np.corrcoef(chain_sur_fin, chain_sds_fin)[0,1]
|
|
diff_chain = chain_sur_fin - chain_sds_fin
|
|
|
|
rmse = np.sqrt(np.nanmean((diff_chain)**2))
|
|
mean = np.nanmean(diff_chain)
|
|
std = np.nanstd(diff_chain)
|
|
q90 = np.percentile(np.abs(diff_chain), 90)
|
|
|
|
# store data
|
|
stats[pfname]['rmse'] = rmse
|
|
stats[pfname]['mean'] = mean
|
|
stats[pfname]['std'] = std
|
|
stats[pfname]['q90'] = q90
|
|
stats[pfname]['diffdays'] = diff_days
|
|
stats[pfname]['corr'] = correlation
|
|
stats[pfname]['linfit'] = {'slope':slope, 'intercept':intercept, 'R2':R2, 'pvalue':pvalue}
|
|
|
|
data_fin[pfname]['dates'] = dates_fin
|
|
data_fin[pfname]['sds'] = chain_sds_fin
|
|
data_fin[pfname]['survey'] = chain_sur_fin
|
|
|
|
# make time-series plot
|
|
plt.figure(fig1.number)
|
|
ax = fig1.add_subplot(gs1[i,0])
|
|
axfig1.append(ax)
|
|
plt.plot(dates_sur, chain_sur, '-', color='C1', markersize=2, label='survey data')
|
|
# plt.plot(dates_fin, chain_sur_fin, 'o', color=[0.3, 0.3, 0.3], markersize=2, label='survey interp')
|
|
plt.plot(dates_fin, chain_sds_fin, 'o--', color='C0', markersize=4, alpha=1, label='satellite data')
|
|
strtitle = '%s (correlation = %.2f)' % (pfname, correlation)
|
|
plt.title(strtitle, fontweight='bold')
|
|
plt.xlim([dates_sds[0], dates_sds[-1]])
|
|
plt.ylabel('cross-shore position [m]')
|
|
plt.legend()
|
|
|
|
# make scatter plot
|
|
plt.figure(fig2.number)
|
|
fig2.add_subplot(gs2[0,i])
|
|
plt.axis('equal')
|
|
plt.plot(chain_sur_fin, chain_sds_fin, 'ko', markersize=4, markerfacecolor='w', alpha=0.7)
|
|
xmax = np.max([np.nanmax(chain_sds_fin),np.nanmax(chain_sur_fin)])
|
|
xmin = np.min([np.nanmin(chain_sds_fin),np.nanmin(chain_sur_fin)])
|
|
ymax = np.max([np.nanmax(chain_sds_fin),np.nanmax(chain_sur_fin)])
|
|
ymin = np.min([np.nanmin(chain_sds_fin),np.nanmin(chain_sur_fin)])
|
|
plt.plot([xmin, xmax], [ymin, ymax], 'k--')
|
|
plt.plot([xmin, xmax], [xmin*slope + intercept, xmax*slope + intercept], 'r:')
|
|
str_corr = ' y = %.2f x + %.2f\n R2 = %.2f\n n = %d' % (slope, intercept, R2, len(diff_chain))
|
|
plt.text(xmin, 0.9*ymax, str_corr, bbox=dict(facecolor=[0.7,0.7,0.7], alpha=0.5), horizontalalignment='left')
|
|
plt.xlabel('chainage survey [m]')
|
|
plt.ylabel('chainage satellite [m]')
|
|
plt.title(pfname, fontweight='bold')
|
|
|
|
fig2.add_subplot(gs2[1,i])
|
|
binwidth = 3
|
|
bins = np.arange(min(diff_chain), max(diff_chain) + binwidth, binwidth)
|
|
density = plt.hist(diff_chain, bins=bins, density=True, color=[0.8, 0.8, 0.8], edgecolor='k')
|
|
plt.xlim([-50, 50])
|
|
plt.xlabel('error [m]')
|
|
str_stats = ' rmse = %.1f\n mean = %.1f\n std = %.1f\n q90 = %.1f' % (rmse, mean, std, q90)
|
|
plt.text(15, np.max(density[0])-0.015, str_stats, bbox=dict(facecolor=[0.8,0.8,0.8], alpha=0.3), horizontalalignment='left', fontsize=10)
|
|
|
|
fig1.set_size_inches(19.2, 9.28)
|
|
fig1.set_tight_layout(True)
|
|
fig2.set_size_inches(19.2, 9.28)
|
|
fig2.set_tight_layout(True)
|
|
|
|
# all transects together
|
|
chain_sds_all = []
|
|
chain_sur_all = []
|
|
for i in range(chain_sds.shape[1]):
|
|
pfname = list(topo_profiles.keys())[i]
|
|
chain_sds_all = np.append(chain_sds_all,data_fin[pfname]['sds'])
|
|
chain_sur_all = np.append(chain_sur_all,data_fin[pfname]['survey'])
|
|
|
|
# calculate statistics
|
|
slope, intercept, rvalue, pvalue, std_err = sstats.linregress(chain_sur_all, chain_sds_all)
|
|
R2 = rvalue**2
|
|
correlation = np.corrcoef(chain_sur_all, chain_sds_all)[0,1]
|
|
diff_chain_all = chain_sur_all - chain_sds_all
|
|
|
|
rmse = np.sqrt(np.nanmean((diff_chain_all)**2))
|
|
mean = np.nanmean(diff_chain_all)
|
|
std = np.nanstd(diff_chain_all)
|
|
q90 = np.percentile(np.abs(diff_chain_all), 90)
|
|
|
|
stats['all'] = {'rmse':rmse,'mean':mean,'std':std,'q90':q90, 'corr':correlation,
|
|
'linfit':{'slope':slope, 'intercept':intercept, 'R2':R2, 'pvalue':pvalue}}
|
|
|
|
# make plot
|
|
plt.figure(fig3.number)
|
|
fig3.add_subplot(gs3[0,0])
|
|
plt.axis('equal')
|
|
plt.plot(chain_sur_all, chain_sds_all, 'ko', markersize=4, markerfacecolor='w', alpha=0.7)
|
|
xmax = np.max([np.nanmax(chain_sds_all),np.nanmax(chain_sur_all)])
|
|
xmin = np.min([np.nanmin(chain_sds_all),np.nanmin(chain_sur_all)])
|
|
ymax = np.max([np.nanmax(chain_sds_all),np.nanmax(chain_sur_all)])
|
|
ymin = np.min([np.nanmin(chain_sds_all),np.nanmin(chain_sur_all)])
|
|
plt.plot([xmin, xmax], [ymin, ymax], 'k--')
|
|
plt.plot([xmin, xmax], [xmin*slope + intercept, xmax*slope + intercept], 'r:')
|
|
str_corr = ' y = %.2f x + %.2f\n R2 = %.2f\n n = %d' % (slope, intercept, R2, len(diff_chain_all))
|
|
plt.text(xmin, 0.9*ymax, str_corr, bbox=dict(facecolor=[0.7,0.7,0.7], alpha=0.5), horizontalalignment='left')
|
|
plt.xlabel('chainage survey [m]')
|
|
plt.ylabel('chainage satellite [m]')
|
|
plt.title(pfname, fontweight='bold')
|
|
|
|
fig3.add_subplot(gs3[1,0])
|
|
binwidth = 3
|
|
bins = np.arange(min(diff_chain_all), max(diff_chain_all) + binwidth, binwidth)
|
|
density = plt.hist(diff_chain_all, bins=bins, density=True, color=[0.8, 0.8, 0.8], edgecolor='k')
|
|
plt.xlim([-50, 50])
|
|
plt.xlabel('error [m]')
|
|
plt.ylabel('pdf')
|
|
str_stats = ' rmse = %.1f\n mean = %.1f\n std = %.1f\n q90 = %.1f' % (rmse, mean, std, q90)
|
|
plt.text(15, np.max(density[0])-0.015, str_stats, bbox=dict(facecolor=[0.8,0.8,0.8], alpha=0.3), horizontalalignment='left', fontsize=10)
|
|
fig3.set_size_inches(9.2, 9.28)
|
|
fig3.set_tight_layout(True)
|
|
|
|
# for i in range(len(axfig1)):
|
|
# axfig1[i].set_ylim([0,150]) # Narrabeen data
|
|
# axfig1[i].set_ylim([25,110]) # Tairua data
|
|
|
|
return stats |