# -*- coding: utf-8 -*- #==========================================================# # Run Neural Network on image to extract sandy pixels #==========================================================# # Initial settings import os import numpy as np import matplotlib.pyplot as plt import matplotlib.patches as mpatches import matplotlib.lines as mlines from matplotlib import gridspec from datetime import datetime, timedelta import pytz import ee import pdb import time import pandas as pd # other modules from osgeo import gdal, ogr, osr import pickle import matplotlib.cm as cm from pylab import ginput # image processing modules import skimage.filters as filters import skimage.exposure as exposure import skimage.transform as transform import sklearn.decomposition as decomposition import skimage.measure as measure import skimage.morphology as morphology from scipy import ndimage import imageio # machine learning modules from sklearn.model_selection import train_test_split from sklearn.neural_network import MLPClassifier from sklearn.preprocessing import StandardScaler, Normalizer from sklearn.externals import joblib # import own modules import functions.utils as utils import functions.sds as sds # some settings np.seterr(all='ignore') # raise/ignore divisions by 0 and nans plt.rcParams['axes.grid'] = True plt.rcParams['figure.max_open_warning'] = 100 ee.Initialize() # parameters cloud_thresh = 0.2 # threshold for cloud cover plot_bool = False # if you want the plots prob_high = 100 # upper probability to clip and rescale pixel intensity min_contour_points = 100# minimum number of points contained in each water line output_epsg = 28356 # GDA94 / MGA Zone 56 buffer_size = 10 # radius (in pixels) of disk for buffer (pixel classification) min_beach_size = 10 # number of pixels in a beach (pixel classification) # load metadata (timestamps and epsg code) for the collection satname = 'L8' #sitename = 'NARRA_all' #sitename = 'NARRA' #sitename = 'OLDBAR' #sitename = 'OLDBAR_inlet' #sitename = 'SANDMOTOR' #sitename = 'TAIRUA' #sitename = 'DUCK' #sitename = 'BROULEE' sitename = 'MURI' # Load metadata filepath = os.path.join(os.getcwd(), 'data', satname, sitename) with open(os.path.join(filepath, sitename + '_timestamps' + '.pkl'), 'rb') as f: timestamps = pickle.load(f) timestamps_sorted = sorted(timestamps) daysall = (datetime(2019,1,1,tzinfo=pytz.utc) - datetime(2013,1,1,tzinfo=pytz.utc)).total_seconds() # path to images file_path_pan = os.path.join(os.getcwd(), 'data', satname, sitename, 'pan') file_path_ms = os.path.join(os.getcwd(), 'data', satname, sitename, 'ms') file_names_pan = os.listdir(file_path_pan) file_names_ms = os.listdir(file_path_ms) N = len(file_names_pan) # initialise some variables idx_skipped = [] idx_nocloud = [] n_features = 10 train_pos = np.nan*np.ones((1,n_features)) train_neg = np.nan*np.ones((1,n_features)) columns = ('B','G','R','NIR','SWIR','Pan','WI','VI','BR', 'mWI', 'class') #%% for i in range(N): # read pan image fn_pan = os.path.join(file_path_pan, file_names_pan[i]) data = gdal.Open(fn_pan, gdal.GA_ReadOnly) georef = np.array(data.GetGeoTransform()) bands = [data.GetRasterBand(i + 1).ReadAsArray() for k in range(data.RasterCount)] im_pan = np.stack(bands, 2)[:,:,0] nrow = im_pan.shape[0] ncol = im_pan.shape[1] # read ms image fn_ms = os.path.join(file_path_ms, file_names_ms[i]) data = gdal.Open(fn_ms, gdal.GA_ReadOnly) bands = [data.GetRasterBand(i + 1).ReadAsArray() for k in range(data.RasterCount)] im_ms = np.stack(bands, 2) # cloud mask im_qa = im_ms[:,:,5] cloud_mask = sds.create_cloud_mask(im_qa, satname, plot_bool) cloud_mask = transform.resize(cloud_mask, (im_pan.shape[0], im_pan.shape[1]), order=0, preserve_range=True, mode='constant').astype('bool_') # resize the image using bilinear interpolation (order 1) im_ms = transform.resize(im_ms,(im_pan.shape[0], im_pan.shape[1]), order=1, preserve_range=True, mode='constant') # check if -inf or nan values and add to cloud mask im_inf = np.isin(im_ms[:,:,0], -np.inf) im_nan = np.isnan(im_ms[:,:,0]) cloud_mask = np.logical_or(np.logical_or(cloud_mask, im_inf), im_nan) # skip if cloud cover is more than the threshold cloud_cover = sum(sum(cloud_mask.astype(int)))/(cloud_mask.shape[0]*cloud_mask.shape[1]) if cloud_cover > cloud_thresh: print('skip ' + str(i) + ' - cloudy (' + str(np.round(cloud_cover*100).astype(int)) + '%)') idx_skipped.append(i) continue idx_nocloud.append(i) # pansharpen rgb image im_ms_ps = sds.pansharpen(im_ms[:,:,[0,1,2]], im_pan, cloud_mask, plot_bool) # add down-sized bands for NIR and SWIR (since pansharpening is not possible) im_ms_ps = np.append(im_ms_ps, im_ms[:,:,[3,4]], axis=2) im_classif, im_labels = sds.classify_image_NN(im_ms_ps, im_pan, cloud_mask, min_beach_size, plot_bool) # if there are no sand pixels, skip the image (maybe later change the detection method with old method) if sum(sum(im_labels[:,:,0])) == 0 : print('skip ' + str(i) + ' - no sand') idx_skipped.append(i) continue contours_wi, contours_mwi = sds.find_wl_contours2(im_ms_ps, im_labels, cloud_mask, buffer_size, False) im_display = sds.rescale_image_intensity(im_ms_ps[:,:,[2,1,0]], cloud_mask, 100, False) im = np.copy(im_display) # define colours for plot colours = np.array([[1,128/255,0/255],[204/255,1,1],[0,0,204/255]]) for k in range(0,im_labels.shape[2]): im[im_labels[:,:,k],0] = colours[k,0] im[im_labels[:,:,k],1] = colours[k,1] im[im_labels[:,:,k],2] = colours[k,2] # fig = plt.figure() # plt.suptitle(date_im, fontsize=17, fontweight='bold') # ax1 = plt.subplot(121) # plt.imshow(im_display) # plt.axis('off') # ax2 = plt.subplot(122, sharex=ax1, sharey=ax1) # plt.imshow(im) # plt.axis('off') # plt.gcf().set_size_inches(17.99,7.55) # plt.tight_layout() # orange_patch = mpatches.Patch(color=[1,128/255,0/255], label='sand') # white_patch = mpatches.Patch(color=[204/255,1,1], label='swash/whitewater') # blue_patch = mpatches.Patch(color=[0,0,204/255], label='water') # plt.legend(handles=[orange_patch,white_patch,blue_patch], bbox_to_anchor=(0.95, 0.2)) # plt.draw() date_im = timestamps_sorted[i].strftime('%d %b %Y') daysnow = (timestamps_sorted[i] - datetime(2013,1,1,tzinfo=pytz.utc)).total_seconds() fig = plt.figure() gs = gridspec.GridSpec(2, 2, height_ratios=[1, 20]) ax1 = fig.add_subplot(gs[0,:]) plt.plot(0,0,'ko',daysall,0,'ko') plt.plot([0,daysall],[0,0],'k-') plt.plot(daysnow,0,'ro') plt.text(0,0.05,'2013') plt.text(daysall,0.05,'2019') plt.plot((datetime(2014,1,1,tzinfo=pytz.utc)- datetime(2013,1,1,tzinfo=pytz.utc)).total_seconds(),0,'ko',markersize=3) plt.plot((datetime(2015,1,1,tzinfo=pytz.utc)- datetime(2013,1,1,tzinfo=pytz.utc)).total_seconds(),0,'ko',markersize=3) plt.plot((datetime(2016,1,1,tzinfo=pytz.utc)- datetime(2013,1,1,tzinfo=pytz.utc)).total_seconds(),0,'ko',markersize=3) plt.plot((datetime(2017,1,1,tzinfo=pytz.utc)- datetime(2013,1,1,tzinfo=pytz.utc)).total_seconds(),0,'ko',markersize=3) plt.plot((datetime(2018,1,1,tzinfo=pytz.utc)- datetime(2013,1,1,tzinfo=pytz.utc)).total_seconds(),0,'ko',markersize=3) plt.text((datetime(2014,1,1,tzinfo=pytz.utc)- datetime(2013,1,1,tzinfo=pytz.utc)).total_seconds(),0.05,'2014') plt.text((datetime(2015,1,1,tzinfo=pytz.utc)- datetime(2013,1,1,tzinfo=pytz.utc)).total_seconds(),0.05,'2015') plt.text((datetime(2016,1,1,tzinfo=pytz.utc)- datetime(2013,1,1,tzinfo=pytz.utc)).total_seconds(),0.05,'2016') plt.text((datetime(2017,1,1,tzinfo=pytz.utc)- datetime(2013,1,1,tzinfo=pytz.utc)).total_seconds(),0.05,'2017') plt.text((datetime(2018,1,1,tzinfo=pytz.utc)- datetime(2013,1,1,tzinfo=pytz.utc)).total_seconds(),0.05,'2018') plt.axis('off') ax2 = fig.add_subplot(gs[1,0]) plt.imshow(im_display) plt.axis('off') plt.title(date_im, fontsize=17, fontweight='bold') ax3 = fig.add_subplot(gs[1,1]) plt.imshow(im) for l,contour in enumerate(contours_mwi): plt.plot(contour[:, 1], contour[:, 0], linewidth=2, color='k', linestyle='--') plt.axis('off') orange_patch = mpatches.Patch(color=[1,128/255,0/255], label='sand') white_patch = mpatches.Patch(color=[204/255,1,1], label='swash/whitewater') blue_patch = mpatches.Patch(color=[0,0,204/255], label='water') black_line = mlines.Line2D([],[],color='k',linestyle='-', label='shoreline') plt.legend(handles=[orange_patch,white_patch,blue_patch, black_line], bbox_to_anchor=(0.95, 0.2)) plt.gcf().set_size_inches(17.99,7.55) plt.gcf().set_tight_layout(True) plt.draw() plt.savefig(os.path.join(filepath,'plots_classif', file_names_pan[i][len(satname)+1+len(sitename)+1:len(satname)+1+len(sitename)+1+10] + '.jpg'), dpi = 300) plt.close() # create gif images = [] filenames = os.listdir(os.path.join(filepath, 'plots_classif')) with imageio.get_writer(sitename + '.gif', mode='I', duration=0.4) as writer: for filename in filenames: image = imageio.imread(os.path.join(filepath,'plots_classif',filename)) writer.append_data(image)