You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
geetools_VH/sand_create_train.py

246 lines
9.3 KiB
Python

# -*- coding: utf-8 -*-
#==========================================================#
# Create a training data
#==========================================================#
# Initial settings
import os
import numpy as np
import matplotlib.pyplot as plt
import ee
import pdb
import time
import pandas as pd
# other modules
from osgeo import gdal, ogr, osr
import pickle
import matplotlib.cm as cm
from pylab import ginput
# image processing modules
import skimage.filters as filters
import skimage.exposure as exposure
import skimage.transform as transform
import sklearn.decomposition as decomposition
import skimage.measure as measure
import skimage.morphology as morphology
from scipy import ndimage
# machine learning modules
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPClassifier
from sklearn.preprocessing import StandardScaler, Normalizer
from sklearn.externals import joblib
# import own modules
import functions.utils as utils
import functions.sds as sds
# some settings
np.seterr(all='ignore') # raise/ignore divisions by 0 and nans
plt.rcParams['axes.grid'] = True
plt.rcParams['figure.max_open_warning'] = 100
ee.Initialize()
# parameters
cloud_thresh = 0.5 # threshold for cloud cover
plot_bool = False # if you want the plots
prob_high = 99.9 # upper probability to clip and rescale pixel intensity
min_contour_points = 100# minimum number of points contained in each water line
output_epsg = 28356 # GDA94 / MGA Zone 56
buffer_size = 10 # radius (in pixels) of disk for buffer (pixel classification)
min_beach_size = 50 # number of pixels in a beach (pixel classification)
# load metadata (timestamps and epsg code) for the collection
satname = 'L8'
sitename = 'NARRA_all'
#sitename = 'NARRA'
#sitename = 'OLDBAR'
# Load metadata
filepath = os.path.join(os.getcwd(), 'data', satname, sitename)
# path to images
file_path_pan = os.path.join(os.getcwd(), 'data', satname, sitename, 'pan')
file_path_ms = os.path.join(os.getcwd(), 'data', satname, sitename, 'ms')
file_names_pan = os.listdir(file_path_pan)
file_names_ms = os.listdir(file_path_ms)
N = len(file_names_pan)
# initialise some variables
idx_skipped = []
idx_nocloud = []
n_features = 10
train_pos = np.nan*np.ones((1,n_features))
train_neg = np.nan*np.ones((1,n_features))
columns = ('B','G','R','NIR','SWIR','Pan','WI','VI','BR', 'mWI', 'SAND')
#%%
for i in range(N):
# read pan image
fn_pan = os.path.join(file_path_pan, file_names_pan[i])
data = gdal.Open(fn_pan, gdal.GA_ReadOnly)
georef = np.array(data.GetGeoTransform())
bands = [data.GetRasterBand(i + 1).ReadAsArray() for i in range(data.RasterCount)]
im_pan = np.stack(bands, 2)[:,:,0]
# read ms image
fn_ms = os.path.join(file_path_ms, file_names_ms[i])
data = gdal.Open(fn_ms, gdal.GA_ReadOnly)
bands = [data.GetRasterBand(i + 1).ReadAsArray() for i in range(data.RasterCount)]
im_ms = np.stack(bands, 2)
# cloud mask
im_qa = im_ms[:,:,5]
cloud_mask = sds.create_cloud_mask(im_qa, satname, plot_bool)
cloud_mask = transform.resize(cloud_mask, (im_pan.shape[0], im_pan.shape[1]),
order=0, preserve_range=True,
mode='constant').astype('bool_')
# resize the image using bilinear interpolation (order 1)
im_ms = transform.resize(im_ms,(im_pan.shape[0], im_pan.shape[1]),
order=1, preserve_range=True, mode='constant')
# check if -inf or nan values and add to cloud mask
im_inf = np.isin(im_ms[:,:,0], -np.inf)
im_nan = np.isnan(im_ms[:,:,0])
cloud_mask = np.logical_or(np.logical_or(cloud_mask, im_inf), im_nan)
# skip if cloud cover is more than the threshold
cloud_cover = sum(sum(cloud_mask.astype(int)))/(cloud_mask.shape[0]*cloud_mask.shape[1])
if cloud_cover > cloud_thresh:
print('skip ' + str(i) + ' - cloudy (' + str(cloud_cover) + ')')
idx_skipped.append(i)
continue
idx_nocloud.append(i)
# rescale intensities
im_ms = sds.rescale_image_intensity(im_ms, cloud_mask, prob_high, plot_bool)
im_pan = sds.rescale_image_intensity(im_pan, cloud_mask, prob_high, plot_bool)
# pansharpen rgb image
im_ms_ps = sds.pansharpen(im_ms[:,:,[0,1,2]], im_pan, cloud_mask, plot_bool)
nrow = im_ms_ps.shape[0]
ncol = im_ms_ps.shape[1]
# add down-sized bands for NIR and SWIR (since pansharpening is not possible)
im_ms_ps = np.append(im_ms_ps, im_ms[:,:,[3,4]], axis=2)
# calculate NDWI
im_ndwi = sds.nd_index(im_ms_ps[:,:,3], im_ms_ps[:,:,1], cloud_mask, plot_bool)
# detect edges
wl_pix = sds.find_wl_contours(im_ndwi, cloud_mask, min_contour_points, plot_bool)
# classify sand pixels with Kmeans
im_sand = sds.classify_sand_unsupervised(im_ms_ps, im_pan, cloud_mask, wl_pix, buffer_size, min_beach_size, plot_bool)
# plot a figure to manually select which images to keep
im = np.copy(im_ms_ps)
im[im_sand,0] = 0
im[im_sand,1] = 0
im[im_sand,2] = 1
plt.figure()
plt.imshow(im[:,:,[2,1,0]])
plt.axis('image')
plt.title('Sand classification')
plt.show()
mng = plt.get_current_fig_manager()
mng.window.showMaximized()
plt.tight_layout()
plt.draw()
# click a point
# top-left quadrant: keep classif as pos and click somewhere for neg
# bottom-left: keep classif as neg
# any right quadrant: discard image
pt_in = np.array(ginput(n=1, timeout=1000))
if pt_in[0][0] < im_ms_ps.shape[1]/2:
im_features = np.zeros((im_ms_ps.shape[0], im_ms_ps.shape[1], n_features))
im_features[:,:,[0,1,2,3,4]] = im_ms_ps
im_features[:,:,5] = im_pan
im_features[:,:,6] = sds.nd_index(im_ms_ps[:,:,3], im_ms_ps[:,:,1], cloud_mask, False) # (NIR-G)
im_features[:,:,7] = sds.nd_index(im_ms_ps[:,:,3], im_ms_ps[:,:,2], cloud_mask, False) # (NIR-R)
im_features[:,:,8] = sds.nd_index(im_ms_ps[:,:,0], im_ms_ps[:,:,2], cloud_mask, False) # (B-R)
im_features[:,:,9] = sds.nd_index(im_ms_ps[:,:,4], im_ms_ps[:,:,1], cloud_mask, False) # (SWIR-G)
# win = np.ones((3,3))
# im_features[:,:,9] = ndimage.generic_filter(im_features[:,:,5], np.std, footprint=win)
# im_features[:,:,10] = ndimage.generic_filter(im_features[:,:,5], np.max, footprint=win) - ndimage.generic_filter(im_features[:,:,5], np.min, footprint=win)
if pt_in[0][1] < im_ms_ps.shape[0]/2:
# positive examples
vec_pos = im_features[im_sand,:]
train_pos = np.append(train_pos, vec_pos, axis=0)
# click where negative examples are
pt_neg = np.round(np.array(ginput(n=1, timeout=1000))[0])
radius = int(round(np.sqrt(sum(sum(im_sand)))))
idx_rows = np.linspace(0,radius-1,radius).astype(int) + int(pt_neg[1])
idx_cols = np.linspace(0,radius-1,radius).astype(int) + int(pt_neg[0])
xx, yy = np.meshgrid(idx_rows,idx_cols, indexing='ij')
row_neg = xx.reshape(radius*radius)
col_neg = yy.reshape(radius*radius)
im_nosand = np.zeros((nrow,ncol)).astype(bool)
for i in range(len(row_neg)):
im_nosand[row_neg[i],col_neg[i]] = True
im_ms_ps[row_neg[i],col_neg[i],0] = 1
im_ms_ps[row_neg[i],col_neg[i],1] = 1
im_ms_ps[row_neg[i],col_neg[i],2] = 0
plt.imshow(im_ms_ps[:,:,[2,1,0]])
plt.draw()
# negative examples
vec_neg = im_features[im_nosand,:]
train_neg = np.append(train_neg, vec_neg, axis=0)
else:
# negative examples
vec_neg = im_features[im_sand,:]
train_neg = np.append(train_neg, vec_neg, axis=0)
else:
print('skip ' + str(i))
idx_skipped.append(i)
# format data
train_pos = train_pos[1:,:]
train_neg = train_neg[1:,:]
n_pos = len(train_pos)
n_neg = len(train_neg)
training_data = np.zeros((n_pos+n_neg, n_features+1))
training_data[:n_pos,:n_features] = train_pos
training_data[n_pos:n_pos+n_neg,:n_features] = train_neg
training_data[:n_pos,n_features] = np.ones((n_pos))
df_train = pd.DataFrame(training_data, columns=columns)
df_train.dropna(axis=0, how='any', inplace=True)
sand_train = np.array(df_train)
# save data
#with open(os.path.join(filepath, sitename + '_sand_idxskip' + '.pkl'), 'wb') as f:
# pickle.dump(idx_skipped, f)
#with open(os.path.join(filepath, sitename + '_sand_train' + '.pkl'), 'wb') as f:
# pickle.dump(sand_train, f)
#df_train.to_csv('training_data.csv')
#%% Train neural network on data
# load training data
with open(os.path.join(filepath, sitename + '_sand_train' + '.pkl'), 'rb') as f:
sand_train = pickle.load(f)
n_features = sand_train.shape[1] - 1
X = sand_train[:,0:n_features]
y = sand_train[:,n_features]
# divide in train and test
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=0)
#scaler = StandardScaler()
#scaler.fit(X_train)
#X_train = scaler.transform(X_train)
#X_test = scaler.transform(X_test)
# run NN on train dat and evaluate on test data
clf = MLPClassifier()
clf.fit(X_train,y_train)
clf.score(X_test,y_test)
# save NN model
joblib.dump(clf, os.path.join(os.getcwd(), 'sand_classification', 'NN_small.pkl'))