You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
geetools_VH/sand_pixtrain_beach.py

249 lines
9.4 KiB
Python

# -*- coding: utf-8 -*-
#==========================================================#
# Create a training data
#==========================================================#
# Initial settings
import os
import numpy as np
import matplotlib.pyplot as plt
import ee
import pdb
import time
import pandas as pd
# other modules
from osgeo import gdal, ogr, osr
import pickle
import matplotlib.cm as cm
from pylab import ginput
# image processing modules
import skimage.filters as filters
import skimage.exposure as exposure
import skimage.transform as transform
import sklearn.decomposition as decomposition
import skimage.measure as measure
import skimage.morphology as morphology
from scipy import ndimage
import random
# machine learning modules
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPClassifier
from sklearn.preprocessing import StandardScaler, Normalizer
from sklearn.externals import joblib
# import own modules
import functions.utils as utils
import functions.sds as sds
# some settings
np.seterr(all='ignore') # raise/ignore divisions by 0 and nans
plt.rcParams['axes.grid'] = True
plt.rcParams['figure.max_open_warning'] = 100
ee.Initialize()
# parameters
cloud_thresh = 0.3 # threshold for cloud cover
plot_bool = False # if you want the plots
prob_high = 100 # upper probability to clip and rescale pixel intensity
min_contour_points = 100# minimum number of points contained in each water line
output_epsg = 28356 # GDA94 / MGA Zone 56
buffer_size = 10 # radius (in pixels) of disk for buffer (pixel classification)
min_beach_size = 50 # number of pixels in a beach (pixel classification)
# load metadata (timestamps and epsg code) for the collection
satname = 'L8'
#sitename = 'NARRA_all'
sitename = 'NARRA'
#sitename = 'OLDBAR'
# Load metadata
filepath = os.path.join(os.getcwd(), 'data', satname, sitename)
# path to images
file_path_pan = os.path.join(os.getcwd(), 'data', satname, sitename, 'pan')
file_path_ms = os.path.join(os.getcwd(), 'data', satname, sitename, 'ms')
file_names_pan = os.listdir(file_path_pan)
file_names_ms = os.listdir(file_path_ms)
N = len(file_names_pan)
# initialise some variables
idx_skipped = []
idx_nocloud = []
n_features = 10
train_pos = np.nan*np.ones((1,n_features))
train_neg = np.nan*np.ones((1,n_features))
train_other = np.nan*np.ones((1,n_features))
train_water = np.nan*np.ones((1,n_features))
columns = ('B','G','R','NIR','SWIR','Pan','WI','VI','BR', 'mWI', 'class')
#%%
date_acquired_ts = []
fig = plt.figure()
for i in range(N):
plt.close(fig)
# read pan image
fn_pan = os.path.join(file_path_pan, file_names_pan[i])
data = gdal.Open(fn_pan, gdal.GA_ReadOnly)
georef = np.array(data.GetGeoTransform())
bands = [data.GetRasterBand(i + 1).ReadAsArray() for i in range(data.RasterCount)]
im_pan = np.stack(bands, 2)[:,:,0]
nrow = im_pan.shape[0]
ncol = im_pan.shape[1]
# read ms image
fn_ms = os.path.join(file_path_ms, file_names_ms[i])
data = gdal.Open(fn_ms, gdal.GA_ReadOnly)
bands = [data.GetRasterBand(i + 1).ReadAsArray() for i in range(data.RasterCount)]
im_ms = np.stack(bands, 2)
# cloud mask
im_qa = im_ms[:,:,5]
cloud_mask = sds.create_cloud_mask(im_qa, satname, plot_bool)
cloud_mask = transform.resize(cloud_mask, (im_pan.shape[0], im_pan.shape[1]),
order=0, preserve_range=True,
mode='constant').astype('bool_')
# resize the image using bilinear interpolation (order 1)
im_ms = transform.resize(im_ms,(im_pan.shape[0], im_pan.shape[1]),
order=1, preserve_range=True, mode='constant')
# check if -inf or nan values and add to cloud mask
im_inf = np.isin(im_ms[:,:,0], -np.inf)
im_nan = np.isnan(im_ms[:,:,0])
cloud_mask = np.logical_or(np.logical_or(cloud_mask, im_inf), im_nan)
# skip if cloud cover is more than the threshold
cloud_cover = sum(sum(cloud_mask.astype(int)))/(cloud_mask.shape[0]*cloud_mask.shape[1])
if cloud_cover > cloud_thresh:
print('skip ' + str(i) + ' - cloudy (' + str(cloud_cover) + ')')
idx_skipped.append(i)
continue
idx_nocloud.append(i)
if file_names_pan[i][len(satname)+1+len(sitename)+1:len(satname)+1+len(sitename)+1+10] in date_acquired_ts:
idx_skipped.append(i)
continue
# pansharpen rgb image
im_ms_ps = sds.pansharpen(im_ms[:,:,[0,1,2]], im_pan, cloud_mask, plot_bool)
# add down-sized bands for NIR and SWIR (since pansharpening is not possible)
im_ms_ps = np.append(im_ms_ps, im_ms[:,:,[3,4]], axis=2)
# calculate NDWI
im_ndwi = sds.nd_index(im_ms_ps[:,:,3], im_ms_ps[:,:,1], cloud_mask, plot_bool)
# detect edges
wl_pix = sds.find_wl_contours(im_ndwi, cloud_mask, min_contour_points, plot_bool)
if not wl_pix:
idx_skipped.append(i)
continue
# classify sand pixels with Kmeans
im_sandKmeans = sds.classify_sand_unsupervised(im_ms_ps, im_pan, cloud_mask, wl_pix, buffer_size, min_beach_size, plot_bool)
# plot a figure to manually select which images to keep
im = np.copy(sds.rescale_image_intensity(im_ms_ps[:,:,[2,1,0]], cloud_mask, 100, False))
im[im_sandKmeans,0] = 0
im[im_sandKmeans,1] = 0
im[im_sandKmeans,2] = 1
# select water pixels on image
pt_water = np.array([140, 70])
lenrow = 20
lencol = 20
idx_row = np.linspace(0,lenrow-1,lenrow).astype(int) + int(pt_water[0])
idx_col = np.linspace(0,lencol-1,lencol).astype(int) + int(pt_water[1])
xx, yy = np.meshgrid(idx_row,idx_col, indexing='ij')
row_water = xx.reshape(lenrow*lencol)
col_water = yy.reshape(lenrow*lencol)
im_water = np.zeros((nrow,ncol)).astype(bool)
for k in range(len(row_water)):
im_water[row_water[k],col_water[k]] = True
im[row_water[k],col_water[k],0] = 0
im[row_water[k],col_water[k],1] = 1
im[row_water[k],col_water[k],2] = 1
# select other pixels on image (vegetation + buildings)
pt_other = np.array([random.randint(150,235), 7])
lenrow = 40
lencol = 15
idx_row = np.linspace(0,lenrow-1,lenrow).astype(int) + int(pt_other[0])
idx_col = np.linspace(0,lencol-1,lencol).astype(int) + int(pt_other[1])
xx, yy = np.meshgrid(idx_row,idx_col, indexing='ij')
row_other = xx.reshape(lenrow*lencol)
col_other = yy.reshape(lenrow*lencol)
im_other = np.zeros((nrow,ncol)).astype(bool)
for k in range(len(row_other)):
im_other[row_other[k],col_other[k]] = True
im[row_other[k],col_other[k],0] = 1
im[row_other[k],col_other[k],1] = 1
im[row_other[k],col_other[k],2] = 0
# plot image
fig = plt.figure()
plt.imshow(im)
plt.axis('image')
plt.title('Sand classification')
plt.draw()
mng = plt.get_current_fig_manager()
mng.window.showMaximized()
plt.tight_layout()
plt.draw()
# click a point
# top-left quadrant: sand
# bottom-left quadrant: swash
# any right quadrant: discard image
pt_in = np.array(ginput(n=1, timeout=1000))
if pt_in[0][0] < im_ms_ps.shape[1]/2:
# calculate features
im_features = np.zeros((im_ms_ps.shape[0], im_ms_ps.shape[1], n_features))
im_features[:,:,[0,1,2,3,4]] = im_ms_ps
im_features[:,:,5] = im_pan
im_features[:,:,6] = sds.nd_index(im_ms_ps[:,:,3], im_ms_ps[:,:,1], cloud_mask, False) # (NIR-G)
im_features[:,:,7] = sds.nd_index(im_ms_ps[:,:,3], im_ms_ps[:,:,2], cloud_mask, False) # (NIR-R)
im_features[:,:,8] = sds.nd_index(im_ms_ps[:,:,0], im_ms_ps[:,:,2], cloud_mask, False) # (B-R)
im_features[:,:,9] = sds.nd_index(im_ms_ps[:,:,4], im_ms_ps[:,:,1], cloud_mask, False) # (SWIR-G)
# win = np.ones((3,3))
# im_features[:,:,9] = ndimage.generic_filter(im_features[:,:,5], np.std, footprint=win)
# im_features[:,:,10] = ndimage.generic_filter(im_features[:,:,5], np.max, footprint=win) - ndimage.generic_filter(im_features[:,:,5], np.min, footprint=win)
# positive examples
# fill training data
vec_other = im_features[im_other,:]
train_other = np.append(train_other, vec_other, axis=0)
vec_water = im_features[im_water,:]
train_water = np.append(train_water, vec_water, axis=0)
if pt_in[0][1] < im_ms_ps.shape[0]/2:
# sand examples
vec_pos = im_features[im_sandKmeans,:]
train_pos = np.append(train_pos, vec_pos, axis=0)
else:
# swash examples
vec_neg = im_features[im_sandKmeans,:]
train_neg = np.append(train_neg, vec_neg, axis=0)
else:
print('skip ' + str(i))
idx_skipped.append(i)
continue
date_acquired_ts.append(file_names_pan[i][9:19])
# format data
train_pos = train_pos[1:,:]
train_neg = train_neg[1:,:]
train_water = train_water[1:,:]
train_other = train_other[1:,:]
# save data
#with open(os.path.join(os.getcwd(), 'sand_classification', sitename + '_sand.pkl'), 'wb') as f:
# pickle.dump(train_pos, f)
#with open(os.path.join(os.getcwd(), 'sand_classification', sitename + '_swash.pkl'), 'wb') as f:
# pickle.dump(train_neg, f)
#with open(os.path.join(os.getcwd(), 'sand_classification', sitename + '_water.pkl'), 'wb') as f:
# pickle.dump(train_water, f)
#with open(os.path.join(os.getcwd(), 'sand_classification', sitename + '_other.pkl'), 'wb') as f:
# pickle.dump(train_other, f)