# -*- coding: utf-8 -*- #==========================================================# # Create a training data #==========================================================# # Initial settings import os import numpy as np import matplotlib.pyplot as plt import ee import pdb import time import pandas as pd # other modules from osgeo import gdal, ogr, osr import pickle import matplotlib.cm as cm from pylab import ginput # image processing modules import skimage.filters as filters import skimage.exposure as exposure import skimage.transform as transform import sklearn.decomposition as decomposition import skimage.measure as measure import skimage.morphology as morphology from scipy import ndimage import random # machine learning modules from sklearn.model_selection import train_test_split from sklearn.neural_network import MLPClassifier from sklearn.preprocessing import StandardScaler, Normalizer from sklearn.externals import joblib # import own modules import functions.utils as utils import functions.sds as sds # some settings np.seterr(all='ignore') # raise/ignore divisions by 0 and nans plt.rcParams['axes.grid'] = True plt.rcParams['figure.max_open_warning'] = 100 ee.Initialize() # parameters cloud_thresh = 0.3 # threshold for cloud cover plot_bool = False # if you want the plots prob_high = 100 # upper probability to clip and rescale pixel intensity min_contour_points = 100# minimum number of points contained in each water line output_epsg = 28356 # GDA94 / MGA Zone 56 buffer_size = 10 # radius (in pixels) of disk for buffer (pixel classification) min_beach_size = 50 # number of pixels in a beach (pixel classification) # load metadata (timestamps and epsg code) for the collection satname = 'L8' #sitename = 'NARRA_all' sitename = 'NARRA' #sitename = 'OLDBAR' # Load metadata filepath = os.path.join(os.getcwd(), 'data', satname, sitename) # path to images file_path_pan = os.path.join(os.getcwd(), 'data', satname, sitename, 'pan') file_path_ms = os.path.join(os.getcwd(), 'data', satname, sitename, 'ms') file_names_pan = os.listdir(file_path_pan) file_names_ms = os.listdir(file_path_ms) N = len(file_names_pan) # initialise some variables idx_skipped = [] idx_nocloud = [] n_features = 10 train_pos = np.nan*np.ones((1,n_features)) train_neg = np.nan*np.ones((1,n_features)) train_other = np.nan*np.ones((1,n_features)) train_water = np.nan*np.ones((1,n_features)) columns = ('B','G','R','NIR','SWIR','Pan','WI','VI','BR', 'mWI', 'class') #%% date_acquired_ts = [] fig = plt.figure() for i in range(N): plt.close(fig) # read pan image fn_pan = os.path.join(file_path_pan, file_names_pan[i]) data = gdal.Open(fn_pan, gdal.GA_ReadOnly) georef = np.array(data.GetGeoTransform()) bands = [data.GetRasterBand(i + 1).ReadAsArray() for k in range(data.RasterCount)] im_pan = np.stack(bands, 2)[:,:,0] nrow = im_pan.shape[0] ncol = im_pan.shape[1] # read ms image fn_ms = os.path.join(file_path_ms, file_names_ms[i]) data = gdal.Open(fn_ms, gdal.GA_ReadOnly) bands = [data.GetRasterBand(i + 1).ReadAsArray() for k in range(data.RasterCount)] im_ms = np.stack(bands, 2) # cloud mask im_qa = im_ms[:,:,5] cloud_mask = sds.create_cloud_mask(im_qa, satname, plot_bool) cloud_mask = transform.resize(cloud_mask, (im_pan.shape[0], im_pan.shape[1]), order=0, preserve_range=True, mode='constant').astype('bool_') # resize the image using bilinear interpolation (order 1) im_ms = transform.resize(im_ms,(im_pan.shape[0], im_pan.shape[1]), order=1, preserve_range=True, mode='constant') # check if -inf or nan values and add to cloud mask im_inf = np.isin(im_ms[:,:,0], -np.inf) im_nan = np.isnan(im_ms[:,:,0]) cloud_mask = np.logical_or(np.logical_or(cloud_mask, im_inf), im_nan) # skip if cloud cover is more than the threshold cloud_cover = sum(sum(cloud_mask.astype(int)))/(cloud_mask.shape[0]*cloud_mask.shape[1]) if cloud_cover > cloud_thresh: print('skip ' + str(i) + ' - cloudy (' + str(cloud_cover) + ')') idx_skipped.append(i) continue idx_nocloud.append(i) if file_names_pan[i][len(satname)+1+len(sitename)+1:len(satname)+1+len(sitename)+1+10] in date_acquired_ts: idx_skipped.append(i) continue # pansharpen rgb image im_ms_ps = sds.pansharpen(im_ms[:,:,[0,1,2]], im_pan, cloud_mask, plot_bool) # add down-sized bands for NIR and SWIR (since pansharpening is not possible) im_ms_ps = np.append(im_ms_ps, im_ms[:,:,[3,4]], axis=2) # calculate NDWI im_ndwi = sds.nd_index(im_ms_ps[:,:,3], im_ms_ps[:,:,1], cloud_mask, plot_bool) # detect edges wl_pix = sds.find_wl_contours(im_ndwi, cloud_mask, min_contour_points, plot_bool) if not wl_pix: idx_skipped.append(i) continue # classify sand pixels with Kmeans im_sandKmeans = sds.classify_sand_unsupervised(im_ms_ps, im_pan, cloud_mask, wl_pix, buffer_size, min_beach_size, plot_bool) # plot a figure to manually select which images to keep im = np.copy(sds.rescale_image_intensity(im_ms_ps[:,:,[2,1,0]], cloud_mask, 100, False)) im[im_sandKmeans,0] = 0 im[im_sandKmeans,1] = 0 im[im_sandKmeans,2] = 1 # select water pixels on image pt_water = np.array([140, 70]) lenrow = 20 lencol = 20 idx_row = np.linspace(0,lenrow-1,lenrow).astype(int) + int(pt_water[0]) idx_col = np.linspace(0,lencol-1,lencol).astype(int) + int(pt_water[1]) xx, yy = np.meshgrid(idx_row,idx_col, indexing='ij') row_water = xx.reshape(lenrow*lencol) col_water = yy.reshape(lenrow*lencol) im_water = np.zeros((nrow,ncol)).astype(bool) for k in range(len(row_water)): im_water[row_water[k],col_water[k]] = True im[row_water[k],col_water[k],0] = 0 im[row_water[k],col_water[k],1] = 1 im[row_water[k],col_water[k],2] = 1 # select other pixels on image (vegetation + buildings) pt_other = np.array([random.randint(150,235), 7]) lenrow = 40 lencol = 15 idx_row = np.linspace(0,lenrow-1,lenrow).astype(int) + int(pt_other[0]) idx_col = np.linspace(0,lencol-1,lencol).astype(int) + int(pt_other[1]) xx, yy = np.meshgrid(idx_row,idx_col, indexing='ij') row_other = xx.reshape(lenrow*lencol) col_other = yy.reshape(lenrow*lencol) im_other = np.zeros((nrow,ncol)).astype(bool) for k in range(len(row_other)): im_other[row_other[k],col_other[k]] = True im[row_other[k],col_other[k],0] = 1 im[row_other[k],col_other[k],1] = 1 im[row_other[k],col_other[k],2] = 0 # plot image fig = plt.figure() plt.imshow(im) plt.axis('image') plt.title('Sand classification') plt.draw() mng = plt.get_current_fig_manager() mng.window.showMaximized() plt.tight_layout() plt.draw() # click a point # top-left quadrant: sand # bottom-left quadrant: swash # any right quadrant: discard image pt_in = np.array(ginput(n=1, timeout=1000)) if pt_in[0][0] < im_ms_ps.shape[1]/2: # calculate features im_features = np.zeros((im_ms_ps.shape[0], im_ms_ps.shape[1], n_features)) im_features[:,:,[0,1,2,3,4]] = im_ms_ps im_features[:,:,5] = im_pan im_features[:,:,6] = sds.nd_index(im_ms_ps[:,:,3], im_ms_ps[:,:,1], cloud_mask, False) # (NIR-G) im_features[:,:,7] = sds.nd_index(im_ms_ps[:,:,3], im_ms_ps[:,:,2], cloud_mask, False) # (NIR-R) im_features[:,:,8] = sds.nd_index(im_ms_ps[:,:,0], im_ms_ps[:,:,2], cloud_mask, False) # (B-R) im_features[:,:,9] = sds.nd_index(im_ms_ps[:,:,4], im_ms_ps[:,:,1], cloud_mask, False) # (SWIR-G) # win = np.ones((3,3)) # im_features[:,:,9] = ndimage.generic_filter(im_features[:,:,5], np.std, footprint=win) # im_features[:,:,10] = ndimage.generic_filter(im_features[:,:,5], np.max, footprint=win) - ndimage.generic_filter(im_features[:,:,5], np.min, footprint=win) # positive examples # fill training data vec_other = im_features[im_other,:] train_other = np.append(train_other, vec_other, axis=0) vec_water = im_features[im_water,:] train_water = np.append(train_water, vec_water, axis=0) if pt_in[0][1] < im_ms_ps.shape[0]/2: # sand examples vec_pos = im_features[im_sandKmeans,:] train_pos = np.append(train_pos, vec_pos, axis=0) else: # swash examples vec_neg = im_features[im_sandKmeans,:] train_neg = np.append(train_neg, vec_neg, axis=0) else: print('skip ' + str(i)) idx_skipped.append(i) continue date_acquired_ts.append(file_names_pan[i][9:19]) # format data train_pos = train_pos[1:,:] train_neg = train_neg[1:,:] train_water = train_water[1:,:] train_other = train_other[1:,:] # save data #with open(os.path.join(os.getcwd(), 'sand_classification', sitename + '_sand.pkl'), 'wb') as f: # pickle.dump(train_pos, f) #with open(os.path.join(os.getcwd(), 'sand_classification', sitename + '_swash.pkl'), 'wb') as f: # pickle.dump(train_neg, f) #with open(os.path.join(os.getcwd(), 'sand_classification', sitename + '_water.pkl'), 'wb') as f: # pickle.dump(train_water, f) #with open(os.path.join(os.getcwd(), 'sand_classification', sitename + '_other.pkl'), 'wb') as f: # pickle.dump(train_other, f)