You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

624 lines
28 KiB
Python

"""
This module contains functions to label satellite images, use the labels to
train a pixel-wise classifier and evaluate the classifier
Author: Kilian Vos, Water Research Laboratory, University of New South Wales
"""
# load modules
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib.widgets import LassoSelector
from matplotlib import path
import pickle
import pdb
import warnings
warnings.filterwarnings("ignore")
# image processing modules
from skimage.segmentation import flood
from skimage import morphology
from pylab import ginput
from sklearn.metrics import confusion_matrix
np.set_printoptions(precision=2)
# CoastSat modules
from coastsat import SDS_preprocess, SDS_shoreline, SDS_tools
class SelectFromImage(object):
"""
Class used to draw the lassos on the images with two methods:
- onselect: save the pixels inside the selection
- disconnect: stop drawing lassos on the image
"""
# initialize lasso selection class
def __init__(self, ax, implot, color=[1,1,1]):
self.canvas = ax.figure.canvas
self.implot = implot
self.array = implot.get_array()
xv, yv = np.meshgrid(np.arange(self.array.shape[1]),np.arange(self.array.shape[0]))
self.pix = np.vstack( (xv.flatten(), yv.flatten()) ).T
self.ind = []
self.im_bool = np.zeros((self.array.shape[0], self.array.shape[1]))
self.color = color
self.lasso = LassoSelector(ax, onselect=self.onselect)
def onselect(self, verts):
# find pixels contained in the lasso
p = path.Path(verts)
self.ind = p.contains_points(self.pix, radius=1)
# color selected pixels
array_list = []
for k in range(self.array.shape[2]):
array2d = self.array[:,:,k]
lin = np.arange(array2d.size)
new_array2d = array2d.flatten()
new_array2d[lin[self.ind]] = self.color[k]
array_list.append(new_array2d.reshape(array2d.shape))
self.array = np.stack(array_list,axis=2)
self.implot.set_data(self.array)
self.canvas.draw_idle()
# update boolean image with selected pixels
vec_bool = self.im_bool.flatten()
vec_bool[lin[self.ind]] = 1
self.im_bool = vec_bool.reshape(self.im_bool.shape)
def disconnect(self):
self.lasso.disconnect_events()
def label_images(metadata,settings):
"""
Load satellite images and interactively label different classes (hard-coded)
KV WRL 2019
Arguments:
-----------
metadata: dict
contains all the information about the satellite images that were downloaded
settings: dict with the following keys
'cloud_thresh': float
value between 0 and 1 indicating the maximum cloud fraction in
the cropped image that is accepted
'cloud_mask_issue': boolean
True if there is an issue with the cloud mask and sand pixels
are erroneously being masked on the images
'labels': dict
list of label names (key) and label numbers (value) for each class
'flood_fill': boolean
True to use the flood_fill functionality when labelling sand pixels
'tolerance': float
tolerance value for flood fill when labelling the sand pixels
'filepath_train': str
directory in which to save the labelled data
'inputs': dict
input parameters (sitename, filepath, polygon, dates, sat_list)
Returns:
-----------
Stores the labelled data in the specified directory
"""
filepath_train = settings['filepath_train']
# initialize figure
fig,ax = plt.subplots(1,1,figsize=[17,10], tight_layout=True,sharex=True,
sharey=True)
mng = plt.get_current_fig_manager()
mng.window.showMaximized()
# loop through satellites
for satname in metadata.keys():
filepath = SDS_tools.get_filepath(settings['inputs'],satname)
filenames = metadata[satname]['filenames']
# loop through images
for i in range(len(filenames)):
# image filename
fn = SDS_tools.get_filenames(filenames[i],filepath, satname)
# read and preprocess image
im_ms, georef, cloud_mask, im_extra, im_QA, im_nodata = SDS_preprocess.preprocess_single(fn, satname, settings['cloud_mask_issue'])
# calculate cloud cover
cloud_cover = np.divide(sum(sum(cloud_mask.astype(int))),
(cloud_mask.shape[0]*cloud_mask.shape[1]))
# skip image if cloud cover is above threshold
if cloud_cover > settings['cloud_thresh'] or cloud_cover == 1:
continue
# get individual RGB image
im_RGB = SDS_preprocess.rescale_image_intensity(im_ms[:,:,[2,1,0]], cloud_mask, 99.9)
im_NDVI = SDS_tools.nd_index(im_ms[:,:,3], im_ms[:,:,2], cloud_mask)
im_NDWI = SDS_tools.nd_index(im_ms[:,:,3], im_ms[:,:,1], cloud_mask)
# initialise labels
im_viz = im_RGB.copy()
im_labels = np.zeros([im_RGB.shape[0],im_RGB.shape[1]])
# show RGB image
ax.axis('off')
ax.imshow(im_RGB)
implot = ax.imshow(im_viz, alpha=0.6)
filename = filenames[i][:filenames[i].find('.')][:-4]
ax.set_title(filename)
##############################################################
# select image to label
##############################################################
# set a key event to accept/reject the detections (see https://stackoverflow.com/a/15033071)
# this variable needs to be immuatable so we can access it after the keypress event
key_event = {}
def press(event):
# store what key was pressed in the dictionary
key_event['pressed'] = event.key
# let the user press a key, right arrow to keep the image, left arrow to skip it
# to break the loop the user can press 'escape'
while True:
btn_keep = ax.text(1.1, 0.9, 'keep ⇨', size=12, ha="right", va="top",
transform=ax.transAxes,
bbox=dict(boxstyle="square", ec='k',fc='w'))
btn_skip = ax.text(-0.1, 0.9, '⇦ skip', size=12, ha="left", va="top",
transform=ax.transAxes,
bbox=dict(boxstyle="square", ec='k',fc='w'))
btn_esc = ax.text(0.5, 0, '<esc> to quit', size=12, ha="center", va="top",
transform=ax.transAxes,
bbox=dict(boxstyle="square", ec='k',fc='w'))
fig.canvas.draw_idle()
fig.canvas.mpl_connect('key_press_event', press)
plt.waitforbuttonpress()
# after button is pressed, remove the buttons
btn_skip.remove()
btn_keep.remove()
btn_esc.remove()
# keep/skip image according to the pressed key, 'escape' to break the loop
if key_event.get('pressed') == 'right':
skip_image = False
break
elif key_event.get('pressed') == 'left':
skip_image = True
break
elif key_event.get('pressed') == 'escape':
plt.close()
raise StopIteration('User cancelled labelling images')
else:
plt.waitforbuttonpress()
# if user decided to skip show the next image
if skip_image:
ax.clear()
continue
# otherwise label this image
else:
##############################################################
# digitize sandy pixels
##############################################################
ax.set_title('Click on SAND pixels (flood fill activated, tolerance = %.2f)\nwhen finished press <Enter>'%settings['tolerance'])
# create erase button, if you click there it delets the last selection
btn_erase = ax.text(im_ms.shape[1], 0, 'Erase', size=20, ha='right', va='top',
bbox=dict(boxstyle="square", ec='k',fc='w'))
fig.canvas.draw_idle()
color_sand = settings['colors']['sand']
sand_pixels = []
while 1:
seed = ginput(n=1, timeout=0, show_clicks=True)
# if empty break the loop and go to next label
if len(seed) == 0:
break
else:
# round to pixel location
seed = np.round(seed[0]).astype(int)
# if user clicks on erase, delete the last selection
if seed[0] > 0.95*im_ms.shape[1] and seed[1] < 0.05*im_ms.shape[0]:
if len(sand_pixels) > 0:
im_labels[sand_pixels[-1]] = 0
for k in range(im_viz.shape[2]):
im_viz[sand_pixels[-1],k] = im_RGB[sand_pixels[-1],k]
implot.set_data(im_viz)
fig.canvas.draw_idle()
del sand_pixels[-1]
# otherwise label the selected sand pixels
else:
# flood fill the NDVI and the NDWI
fill_NDVI = flood(im_NDVI, (seed[1],seed[0]), tolerance=settings['tolerance'])
fill_NDWI = flood(im_NDWI, (seed[1],seed[0]), tolerance=settings['tolerance'])
# compute the intersection of the two masks
fill_sand = np.logical_and(fill_NDVI, fill_NDWI)
im_labels[fill_sand] = settings['labels']['sand']
sand_pixels.append(fill_sand)
# show the labelled pixels
for k in range(im_viz.shape[2]):
im_viz[im_labels==settings['labels']['sand'],k] = color_sand[k]
implot.set_data(im_viz)
fig.canvas.draw_idle()
##############################################################
# digitize white-water pixels
##############################################################
color_ww = settings['colors']['white-water']
ax.set_title('Click on individual WHITE-WATER pixels (no flood fill)\nwhen finished press <Enter>')
fig.canvas.draw_idle()
ww_pixels = []
while 1:
seed = ginput(n=1, timeout=0, show_clicks=True)
# if empty break the loop and go to next label
if len(seed) == 0:
break
else:
# round to pixel location
seed = np.round(seed[0]).astype(int)
# if user clicks on erase, delete the last labelled pixels
if seed[0] > 0.95*im_ms.shape[1] and seed[1] < 0.05*im_ms.shape[0]:
if len(ww_pixels) > 0:
im_labels[ww_pixels[-1][1],ww_pixels[-1][0]] = 0
for k in range(im_viz.shape[2]):
im_viz[ww_pixels[-1][1],ww_pixels[-1][0],k] = im_RGB[ww_pixels[-1][1],ww_pixels[-1][0],k]
implot.set_data(im_viz)
fig.canvas.draw_idle()
del ww_pixels[-1]
else:
im_labels[seed[1],seed[0]] = settings['labels']['white-water']
for k in range(im_viz.shape[2]):
im_viz[seed[1],seed[0],k] = color_ww[k]
implot.set_data(im_viz)
fig.canvas.draw_idle()
ww_pixels.append(seed)
im_sand_ww = im_viz.copy()
btn_erase.set(text='<Esc> to Erase', fontsize=12)
##############################################################
# digitize water pixels (with lassos)
##############################################################
color_water = settings['colors']['water']
ax.set_title('Click and hold to draw lassos and select WATER pixels\nwhen finished press <Enter>')
fig.canvas.draw_idle()
selector_water = SelectFromImage(ax, implot, color_water)
key_event = {}
while True:
fig.canvas.draw_idle()
fig.canvas.mpl_connect('key_press_event', press)
plt.waitforbuttonpress()
if key_event.get('pressed') == 'enter':
selector_water.disconnect()
break
elif key_event.get('pressed') == 'escape':
selector_water.array = im_sand_ww
implot.set_data(selector_water.array)
fig.canvas.draw_idle()
selector_water.implot = implot
selector_water.im_bool = np.zeros((selector_water.array.shape[0], selector_water.array.shape[1]))
selector_water.ind=[]
# update im_viz and im_labels
im_viz = selector_water.array
selector_water.im_bool = selector_water.im_bool.astype(bool)
im_labels[selector_water.im_bool] = settings['labels']['water']
im_sand_ww_water = im_viz.copy()
##############################################################
# digitize land pixels (with lassos)
##############################################################
color_land = settings['colors']['other land features']
ax.set_title('Click and hold to draw lassos and select OTHER LAND pixels\nwhen finished press <Enter>')
fig.canvas.draw_idle()
selector_land = SelectFromImage(ax, implot, color_land)
key_event = {}
while True:
fig.canvas.draw_idle()
fig.canvas.mpl_connect('key_press_event', press)
plt.waitforbuttonpress()
if key_event.get('pressed') == 'enter':
selector_land.disconnect()
break
elif key_event.get('pressed') == 'escape':
selector_land.array = im_sand_ww_water
implot.set_data(selector_land.array)
fig.canvas.draw_idle()
selector_land.implot = implot
selector_land.im_bool = np.zeros((selector_land.array.shape[0], selector_land.array.shape[1]))
selector_land.ind=[]
# update im_viz and im_labels
im_viz = selector_land.array
selector_land.im_bool = selector_land.im_bool.astype(bool)
im_labels[selector_land.im_bool] = settings['labels']['other land features']
# save labelled image
ax.set_title(filename)
fig.canvas.draw_idle()
fp = os.path.join(filepath_train,settings['inputs']['sitename'])
if not os.path.exists(fp):
os.makedirs(fp)
fig.savefig(os.path.join(fp,filename+'.jpg'), dpi=150)
ax.clear()
# save labels and features
features = dict([])
for key in settings['labels'].keys():
im_bool = im_labels == settings['labels'][key]
features[key] = SDS_shoreline.calculate_features(im_ms, cloud_mask, im_bool)
training_data = {'labels':im_labels, 'features':features, 'label_ids':settings['labels']}
with open(os.path.join(fp, filename + '.pkl'), 'wb') as f:
pickle.dump(training_data,f)
# close figure when finished
plt.close(fig)
def load_labels(train_sites, settings):
"""
Load the labelled data from the different training sites
KV WRL 2019
Arguments:
-----------
train_sites: list of str
sites to be loaded
settings: dict with the following keys
'labels': dict
list of label names (key) and label numbers (value) for each class
'filepath_train': str
directory in which to save the labelled data
Returns:
-----------
features: dict
contains the features for each labelled pixel
"""
filepath_train = settings['filepath_train']
# initialize the features dict
features = dict([])
n_features = 20
first_row = np.nan*np.ones((1,n_features))
for key in settings['labels'].keys():
features[key] = first_row
# loop through each site
for site in train_sites:
sitename = site[:site.find('.')]
filepath = os.path.join(filepath_train,sitename)
if os.path.exists(filepath):
list_files = os.listdir(filepath)
else:
continue
# make a new list with only the .pkl files (no .jpg)
list_files_pkl = []
for file in list_files:
if '.pkl' in file:
list_files_pkl.append(file)
# load and append the training data to the features dict
for file in list_files_pkl:
# read file
with open(os.path.join(filepath, file), 'rb') as f:
labelled_data = pickle.load(f)
for key in labelled_data['features'].keys():
if len(labelled_data['features'][key])>0: # check that is not empty
# append rows
features[key] = np.append(features[key],
labelled_data['features'][key], axis=0)
# remove the first row (initialized with nans) and print how many pixels
print('Number of pixels per class in training data:')
for key in features.keys():
features[key] = features[key][1:,:]
print('%s : %d pixels'%(key,len(features[key])))
return features
def format_training_data(features, classes, labels):
"""
Format the labelled data in an X features matrix and a y labels vector, so
that it can be used for training an ML model.
KV WRL 2019
Arguments:
-----------
features: dict
contains the features for each labelled pixel
classes: list of str
names of the classes
labels: list of int
int value associated with each class (in the same order as classes)
Returns:
-----------
X: np.array
matrix features along the columns and pixels along the rows
y: np.array
vector with the labels corresponding to each row of X
"""
# initialize X and y
X = np.nan*np.ones((1,features[classes[0]].shape[1]))
y = np.nan*np.ones((1,1))
# append row of features to X and corresponding label to y
for i,key in enumerate(classes):
y = np.append(y, labels[i]*np.ones((features[key].shape[0],1)), axis=0)
X = np.append(X, features[key], axis=0)
# remove first row
X = X[1:,:]; y = y[1:]
# replace nans with something close to 0
# training algotihms cannot handle nans
X[np.isnan(X)] = 1e-9
return X, y
def plot_confusion_matrix(y_true,y_pred,classes,normalize=False,cmap=plt.cm.Blues):
"""
Function copied from the scikit-learn examples (https://scikit-learn.org/stable/)
This function plots a confusion matrix.
Normalization can be applied by setting `normalize=True`.
"""
# compute confusion matrix
cm = confusion_matrix(y_true, y_pred)
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print("Normalized confusion matrix")
else:
print('Confusion matrix, without normalization')
# plot confusion matrix
fig, ax = plt.subplots(figsize=(6,6), tight_layout=True)
im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
# ax.figure.colorbar(im, ax=ax)
ax.set(xticks=np.arange(cm.shape[1]),
yticks=np.arange(cm.shape[0]), ylim=[3.5,-0.5],
xticklabels=classes, yticklabels=classes,
ylabel='True label',
xlabel='Predicted label')
# rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
rotation_mode="anchor")
# loop over data dimensions and create text annotations.
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
ax.text(j, i, format(cm[i, j], fmt),
ha="center", va="center",
color="white" if cm[i, j] > thresh else "black",
fontsize=12)
fig.tight_layout()
return ax
def evaluate_classifier(classifier, metadata, settings):
"""
Apply the image classifier to all the images and save the classified images.
KV WRL 2019
Arguments:
-----------
classifier: joblib object
classifier model to be used for image classification
metadata: dict
contains all the information about the satellite images that were downloaded
settings: dict with the following keys
'inputs': dict
input parameters (sitename, filepath, polygon, dates, sat_list)
'cloud_thresh': float
value between 0 and 1 indicating the maximum cloud fraction in
the cropped image that is accepted
'cloud_mask_issue': boolean
True if there is an issue with the cloud mask and sand pixels
are erroneously being masked on the images
'output_epsg': int
output spatial reference system as EPSG code
'buffer_size': int
size of the buffer (m) around the sandy pixels over which the pixels
are considered in the thresholding algorithm
'min_beach_area': int
minimum allowable object area (in metres^2) for the class 'sand',
the area is converted to number of connected pixels
'min_length_sl': int
minimum length (in metres) of shoreline contour to be valid
Returns:
-----------
Saves .jpg images with the output of the classification in the folder ./detection
"""
# create folder called evaluation
fp = os.path.join(os.getcwd(), 'evaluation')
if not os.path.exists(fp):
os.makedirs(fp)
# initialize figure (not interactive)
plt.ioff()
fig,ax = plt.subplots(1,2,figsize=[17,10],sharex=True, sharey=True,
constrained_layout=True)
# create colormap for labels
cmap = cm.get_cmap('tab20c')
colorpalette = cmap(np.arange(0,13,1))
colours = np.zeros((3,4))
colours[0,:] = colorpalette[5]
colours[1,:] = np.array([204/255,1,1,1])
colours[2,:] = np.array([0,91/255,1,1])
# loop through satellites
for satname in metadata.keys():
filepath = SDS_tools.get_filepath(settings['inputs'],satname)
filenames = metadata[satname]['filenames']
# load classifiers and
if satname in ['L5','L7','L8']:
pixel_size = 15
elif satname == 'S2':
pixel_size = 10
# convert settings['min_beach_area'] and settings['buffer_size'] from metres to pixels
buffer_size_pixels = np.ceil(settings['buffer_size']/pixel_size)
min_beach_area_pixels = np.ceil(settings['min_beach_area']/pixel_size**2)
# loop through images
for i in range(len(filenames)):
# image filename
fn = SDS_tools.get_filenames(filenames[i],filepath, satname)
# read and preprocess image
im_ms, georef, cloud_mask, im_extra, im_QA, im_nodata = SDS_preprocess.preprocess_single(fn, satname, settings['cloud_mask_issue'])
image_epsg = metadata[satname]['epsg'][i]
# calculate cloud cover
cloud_cover = np.divide(sum(sum(cloud_mask.astype(int))),
(cloud_mask.shape[0]*cloud_mask.shape[1]))
# skip image if cloud cover is above threshold
if cloud_cover > settings['cloud_thresh']:
continue
# calculate a buffer around the reference shoreline (if any has been digitised)
im_ref_buffer = SDS_shoreline.create_shoreline_buffer(cloud_mask.shape, georef, image_epsg,
pixel_size, settings)
# classify image in 4 classes (sand, whitewater, water, other) with NN classifier
im_classif, im_labels = SDS_shoreline.classify_image_NN(im_ms, im_extra, cloud_mask,
min_beach_area_pixels, classifier)
# there are two options to map the contours:
# if there are pixels in the 'sand' class --> use find_wl_contours2 (enhanced)
# otherwise use find_wl_contours2 (traditional)
try: # use try/except structure for long runs
if sum(sum(im_labels[:,:,0])) < 10 :
# compute MNDWI image (SWIR-G)
im_mndwi = SDS_tools.nd_index(im_ms[:,:,4], im_ms[:,:,1], cloud_mask)
# find water contours on MNDWI grayscale image
contours_mwi = SDS_shoreline.find_wl_contours1(im_mndwi, cloud_mask, im_ref_buffer)
else:
# use classification to refine threshold and extract the sand/water interface
contours_wi, contours_mwi = SDS_shoreline.find_wl_contours2(im_ms, im_labels,
cloud_mask, buffer_size_pixels, im_ref_buffer)
except:
print('Could not map shoreline for this image: ' + filenames[i])
continue
# process the water contours into a shoreline
shoreline = SDS_shoreline.process_shoreline(contours_mwi, cloud_mask, georef, image_epsg, settings)
try:
sl_pix = SDS_tools.convert_world2pix(SDS_tools.convert_epsg(shoreline,
settings['output_epsg'],
image_epsg)[:,[0,1]], georef)
except:
# if try fails, just add nan into the shoreline vector so the next parts can still run
sl_pix = np.array([[np.nan, np.nan],[np.nan, np.nan]])
# make a plot
im_RGB = SDS_preprocess.rescale_image_intensity(im_ms[:,:,[2,1,0]], cloud_mask, 99.9)
# create classified image
im_class = np.copy(im_RGB)
for k in range(0,im_labels.shape[2]):
im_class[im_labels[:,:,k],0] = colours[k,0]
im_class[im_labels[:,:,k],1] = colours[k,1]
im_class[im_labels[:,:,k],2] = colours[k,2]
# show images
ax[0].imshow(im_RGB)
ax[1].imshow(im_RGB)
ax[1].imshow(im_class, alpha=0.5)
ax[0].axis('off')
ax[1].axis('off')
filename = filenames[i][:filenames[i].find('.')][:-4]
ax[0].set_title(filename)
ax[0].plot(sl_pix[:,0], sl_pix[:,1], 'k.', markersize=3)
ax[1].plot(sl_pix[:,0], sl_pix[:,1], 'k.', markersize=3)
# save figure
fig.savefig(os.path.join(fp,settings['inputs']['sitename'] + filename[:19] +'.jpg'), dpi=150)
# clear axes
for cax in fig.axes:
cax.clear()
# close the figure at the end
plt.close()