# -*- coding: utf-8 -*- #==========================================================# # Run Neural Network on image to extract sandy pixels #==========================================================# # Initial settings import os import numpy as np import matplotlib.pyplot as plt import ee import pdb import time import pandas as pd # other modules from osgeo import gdal, ogr, osr import pickle import matplotlib.cm as cm from pylab import ginput # image processing modules import skimage.filters as filters import skimage.exposure as exposure import skimage.transform as transform import sklearn.decomposition as decomposition import skimage.measure as measure import skimage.morphology as morphology from scipy import ndimage # machine learning modules from sklearn.model_selection import train_test_split from sklearn.neural_network import MLPClassifier from sklearn.preprocessing import StandardScaler, Normalizer from sklearn.externals import joblib # import own modules import functions.utils as utils import functions.sds as sds # some settings np.seterr(all='ignore') # raise/ignore divisions by 0 and nans plt.rcParams['axes.grid'] = True plt.rcParams['figure.max_open_warning'] = 100 ee.Initialize() # parameters cloud_thresh = 0.3 # threshold for cloud cover plot_bool = False # if you want the plots prob_high = 100 # upper probability to clip and rescale pixel intensity min_contour_points = 100# minimum number of points contained in each water line output_epsg = 28356 # GDA94 / MGA Zone 56 buffer_size = 10 # radius (in pixels) of disk for buffer (pixel classification) min_beach_size = 50 # number of pixels in a beach (pixel classification) # load metadata (timestamps and epsg code) for the collection satname = 'L8' #sitename = 'NARRA_all' #sitename = 'NARRA' #sitename = 'OLDBAR' sitename = 'OLDBAR_inlet' # Load metadata filepath = os.path.join(os.getcwd(), 'data', satname, sitename) # path to images file_path_pan = os.path.join(os.getcwd(), 'data', satname, sitename, 'pan') file_path_ms = os.path.join(os.getcwd(), 'data', satname, sitename, 'ms') file_names_pan = os.listdir(file_path_pan) file_names_ms = os.listdir(file_path_ms) N = len(file_names_pan) # initialise some variables idx_skipped = [] idx_nocloud = [] n_features = 10 train_pos = np.nan*np.ones((1,n_features)) train_neg = np.nan*np.ones((1,n_features)) columns = ('B','G','R','NIR','SWIR','Pan','WI','VI','BR', 'mWI', 'class') clf = joblib.load(os.path.join(os.getcwd(), 'sand_classification', 'NN1.pkl')) #%% for i in range(N): # read pan image fn_pan = os.path.join(file_path_pan, file_names_pan[i]) data = gdal.Open(fn_pan, gdal.GA_ReadOnly) georef = np.array(data.GetGeoTransform()) bands = [data.GetRasterBand(i + 1).ReadAsArray() for i in range(data.RasterCount)] im_pan = np.stack(bands, 2)[:,:,0] nrow = im_pan.shape[0] ncol = im_pan.shape[1] # read ms image fn_ms = os.path.join(file_path_ms, file_names_ms[i]) data = gdal.Open(fn_ms, gdal.GA_ReadOnly) bands = [data.GetRasterBand(i + 1).ReadAsArray() for i in range(data.RasterCount)] im_ms = np.stack(bands, 2) # cloud mask im_qa = im_ms[:,:,5] cloud_mask = sds.create_cloud_mask(im_qa, satname, plot_bool) cloud_mask = transform.resize(cloud_mask, (im_pan.shape[0], im_pan.shape[1]), order=0, preserve_range=True, mode='constant').astype('bool_') # resize the image using bilinear interpolation (order 1) im_ms = transform.resize(im_ms,(im_pan.shape[0], im_pan.shape[1]), order=1, preserve_range=True, mode='constant') # check if -inf or nan values and add to cloud mask im_inf = np.isin(im_ms[:,:,0], -np.inf) im_nan = np.isnan(im_ms[:,:,0]) cloud_mask = np.logical_or(np.logical_or(cloud_mask, im_inf), im_nan) # skip if cloud cover is more than the threshold cloud_cover = sum(sum(cloud_mask.astype(int)))/(cloud_mask.shape[0]*cloud_mask.shape[1]) if cloud_cover > cloud_thresh: print('skip ' + str(i) + ' - cloudy (' + str(np.round(cloud_cover*100).astype(int)) + '%)') idx_skipped.append(i) continue idx_nocloud.append(i) # pansharpen rgb image im_ms_ps = sds.pansharpen(im_ms[:,:,[0,1,2]], im_pan, cloud_mask, plot_bool) # add down-sized bands for NIR and SWIR (since pansharpening is not possible) im_ms_ps = np.append(im_ms_ps, im_ms[:,:,[3,4]], axis=2) # # calculate NDWI # im_ndwi = sds.nd_index(im_ms_ps[:,:,3], im_ms_ps[:,:,1], cloud_mask, plot_bool) # # detect edges # wl_pix = sds.find_wl_contours(im_ndwi, cloud_mask, min_contour_points, plot_bool) # # classify sand pixels with Kmeans # im_sand = sds.classify_sand_unsupervised(im_ms_ps, im_pan, cloud_mask, wl_pix, buffer_size, min_beach_size, plot_bool) # calculate features im_features = np.zeros((im_ms_ps.shape[0], im_ms_ps.shape[1], n_features)) im_features[:,:,[0,1,2,3,4]] = im_ms_ps im_features[:,:,5] = im_pan im_features[:,:,6] = sds.nd_index(im_ms_ps[:,:,3], im_ms_ps[:,:,1], cloud_mask, False) # (NIR-G) im_features[:,:,7] = sds.nd_index(im_ms_ps[:,:,3], im_ms_ps[:,:,2], cloud_mask, False) # ND(NIR-R) im_features[:,:,8] = sds.nd_index(im_ms_ps[:,:,0], im_ms_ps[:,:,2], cloud_mask, False) # ND(B-R) im_features[:,:,9] = sds.nd_index(im_ms_ps[:,:,4], im_ms_ps[:,:,1], cloud_mask, False) # ND(SWIR-G) # remove NaNs and clouds vec = im_features.reshape((im_ms_ps.shape[0] * im_ms_ps.shape[1], n_features)) vec_cloud = cloud_mask.reshape(cloud_mask.shape[0]*cloud_mask.shape[1]) vec_nan = np.any(np.isnan(vec), axis=1) vec_mask = np.logical_or(vec_cloud, vec_nan) vec = vec[~vec_mask, :] # predict with NN y = clf.predict(vec) # recompose image vec_new = np.zeros((cloud_mask.shape[0]*cloud_mask.shape[1])) vec_new[~vec_mask] = y im_classif = vec_new.reshape((im_ms_ps.shape[0], im_ms_ps.shape[1])) # im_classif = morphology.remove_small_objects(im_classif, min_size=min_beach_size, connectivity=2) # plot NN labels im_display = sds.rescale_image_intensity(im_ms_ps[:,:,[2,1,0]], cloud_mask, 100, False) im = np.copy(im_display) colours = np.array([[1,0,0],[1,1,0],[0,1,1],[0,0,1]]) for k in range(4): im[im_classif == k,0] = colours[k,0] im[im_classif == k,1] = colours[k,1] im[im_classif == k,2] = colours[k,2] plt.figure() ax1 = plt.subplot(121) plt.imshow(im_display) plt.axis('off') plt.title('Image') ax2 = plt.subplot(122, sharex=ax1, sharey=ax1) plt.imshow(im) plt.axis('off') plt.title('NN') mng = plt.get_current_fig_manager() mng.window.showMaximized() plt.tight_layout() plt.draw()