You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
geetools_VH/sand_runNN.py

174 lines
6.7 KiB
Python

# -*- coding: utf-8 -*-
#==========================================================#
# Run Neural Network on image to extract sandy pixels
#==========================================================#
# Initial settings
import os
import numpy as np
import matplotlib.pyplot as plt
import ee
import pdb
import time
import pandas as pd
# other modules
from osgeo import gdal, ogr, osr
import pickle
import matplotlib.cm as cm
from pylab import ginput
# image processing modules
import skimage.filters as filters
import skimage.exposure as exposure
import skimage.transform as transform
import sklearn.decomposition as decomposition
import skimage.measure as measure
import skimage.morphology as morphology
from scipy import ndimage
# machine learning modules
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPClassifier
from sklearn.preprocessing import StandardScaler, Normalizer
from sklearn.externals import joblib
# import own modules
import functions.utils as utils
import functions.sds as sds
# some settings
np.seterr(all='ignore') # raise/ignore divisions by 0 and nans
plt.rcParams['axes.grid'] = True
plt.rcParams['figure.max_open_warning'] = 100
ee.Initialize()
# parameters
cloud_thresh = 0.3 # threshold for cloud cover
plot_bool = False # if you want the plots
prob_high = 100 # upper probability to clip and rescale pixel intensity
min_contour_points = 100# minimum number of points contained in each water line
output_epsg = 28356 # GDA94 / MGA Zone 56
buffer_size = 10 # radius (in pixels) of disk for buffer (pixel classification)
min_beach_size = 50 # number of pixels in a beach (pixel classification)
# load metadata (timestamps and epsg code) for the collection
satname = 'L8'
#sitename = 'NARRA_all'
#sitename = 'NARRA'
#sitename = 'OLDBAR'
sitename = 'OLDBAR_inlet'
# Load metadata
filepath = os.path.join(os.getcwd(), 'data', satname, sitename)
# path to images
file_path_pan = os.path.join(os.getcwd(), 'data', satname, sitename, 'pan')
file_path_ms = os.path.join(os.getcwd(), 'data', satname, sitename, 'ms')
file_names_pan = os.listdir(file_path_pan)
file_names_ms = os.listdir(file_path_ms)
N = len(file_names_pan)
# initialise some variables
idx_skipped = []
idx_nocloud = []
n_features = 10
train_pos = np.nan*np.ones((1,n_features))
train_neg = np.nan*np.ones((1,n_features))
columns = ('B','G','R','NIR','SWIR','Pan','WI','VI','BR', 'mWI', 'class')
clf = joblib.load(os.path.join(os.getcwd(), 'sand_classification', 'NN1.pkl'))
#%%
for i in range(N):
# read pan image
fn_pan = os.path.join(file_path_pan, file_names_pan[i])
data = gdal.Open(fn_pan, gdal.GA_ReadOnly)
georef = np.array(data.GetGeoTransform())
bands = [data.GetRasterBand(i + 1).ReadAsArray() for i in range(data.RasterCount)]
im_pan = np.stack(bands, 2)[:,:,0]
nrow = im_pan.shape[0]
ncol = im_pan.shape[1]
# read ms image
fn_ms = os.path.join(file_path_ms, file_names_ms[i])
data = gdal.Open(fn_ms, gdal.GA_ReadOnly)
bands = [data.GetRasterBand(i + 1).ReadAsArray() for i in range(data.RasterCount)]
im_ms = np.stack(bands, 2)
# cloud mask
im_qa = im_ms[:,:,5]
cloud_mask = sds.create_cloud_mask(im_qa, satname, plot_bool)
cloud_mask = transform.resize(cloud_mask, (im_pan.shape[0], im_pan.shape[1]),
order=0, preserve_range=True,
mode='constant').astype('bool_')
# resize the image using bilinear interpolation (order 1)
im_ms = transform.resize(im_ms,(im_pan.shape[0], im_pan.shape[1]),
order=1, preserve_range=True, mode='constant')
# check if -inf or nan values and add to cloud mask
im_inf = np.isin(im_ms[:,:,0], -np.inf)
im_nan = np.isnan(im_ms[:,:,0])
cloud_mask = np.logical_or(np.logical_or(cloud_mask, im_inf), im_nan)
# skip if cloud cover is more than the threshold
cloud_cover = sum(sum(cloud_mask.astype(int)))/(cloud_mask.shape[0]*cloud_mask.shape[1])
if cloud_cover > cloud_thresh:
print('skip ' + str(i) + ' - cloudy (' + str(np.round(cloud_cover*100).astype(int)) + '%)')
idx_skipped.append(i)
continue
idx_nocloud.append(i)
# pansharpen rgb image
im_ms_ps = sds.pansharpen(im_ms[:,:,[0,1,2]], im_pan, cloud_mask, plot_bool)
# add down-sized bands for NIR and SWIR (since pansharpening is not possible)
im_ms_ps = np.append(im_ms_ps, im_ms[:,:,[3,4]], axis=2)
# # calculate NDWI
# im_ndwi = sds.nd_index(im_ms_ps[:,:,3], im_ms_ps[:,:,1], cloud_mask, plot_bool)
# # detect edges
# wl_pix = sds.find_wl_contours(im_ndwi, cloud_mask, min_contour_points, plot_bool)
# # classify sand pixels with Kmeans
# im_sand = sds.classify_sand_unsupervised(im_ms_ps, im_pan, cloud_mask, wl_pix, buffer_size, min_beach_size, plot_bool)
# calculate features
im_features = np.zeros((im_ms_ps.shape[0], im_ms_ps.shape[1], n_features))
im_features[:,:,[0,1,2,3,4]] = im_ms_ps
im_features[:,:,5] = im_pan
im_features[:,:,6] = sds.nd_index(im_ms_ps[:,:,3], im_ms_ps[:,:,1], cloud_mask, False) # (NIR-G)
im_features[:,:,7] = sds.nd_index(im_ms_ps[:,:,3], im_ms_ps[:,:,2], cloud_mask, False) # ND(NIR-R)
im_features[:,:,8] = sds.nd_index(im_ms_ps[:,:,0], im_ms_ps[:,:,2], cloud_mask, False) # ND(B-R)
im_features[:,:,9] = sds.nd_index(im_ms_ps[:,:,4], im_ms_ps[:,:,1], cloud_mask, False) # ND(SWIR-G)
# remove NaNs and clouds
vec = im_features.reshape((im_ms_ps.shape[0] * im_ms_ps.shape[1], n_features))
vec_cloud = cloud_mask.reshape(cloud_mask.shape[0]*cloud_mask.shape[1])
vec_nan = np.any(np.isnan(vec), axis=1)
vec_mask = np.logical_or(vec_cloud, vec_nan)
vec = vec[~vec_mask, :]
# predict with NN
y = clf.predict(vec)
# recompose image
vec_new = np.zeros((cloud_mask.shape[0]*cloud_mask.shape[1]))
vec_new[~vec_mask] = y
im_classif = vec_new.reshape((im_ms_ps.shape[0], im_ms_ps.shape[1]))
# im_classif = morphology.remove_small_objects(im_classif, min_size=min_beach_size, connectivity=2)
# plot NN labels
im_display = sds.rescale_image_intensity(im_ms_ps[:,:,[2,1,0]], cloud_mask, 100, False)
im = np.copy(im_display)
colours = np.array([[1,0,0],[1,1,0],[0,1,1],[0,0,1]])
for k in range(4):
im[im_classif == k,0] = colours[k,0]
im[im_classif == k,1] = colours[k,1]
im[im_classif == k,2] = colours[k,2]
plt.figure()
ax1 = plt.subplot(121)
plt.imshow(im_display)
plt.axis('off')
plt.title('Image')
ax2 = plt.subplot(122, sharex=ax1, sharey=ax1)
plt.imshow(im)
plt.axis('off')
plt.title('NN')
mng = plt.get_current_fig_manager()
mng.window.showMaximized()
plt.tight_layout()
plt.draw()