You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

894 lines
38 KiB
Python

"""
This module contains all the functions needed to preprocess the satellite images
before the shorelines can be extracted. This includes creating a cloud mask and
pansharpening/downsampling the multispectral bands.
Author: Kilian Vos, Water Research Laboratory, University of New South Wales
"""
# load modules
import os
import numpy as np
import matplotlib.pyplot as plt
import pdb
# image processing modules
import skimage.transform as transform
import skimage.morphology as morphology
import sklearn.decomposition as decomposition
import skimage.exposure as exposure
# other modules
from osgeo import gdal
from pylab import ginput
import pickle
import geopandas as gpd
from shapely import geometry
# CoastSat modules
from coastsat import SDS_tools
np.seterr(all='ignore') # raise/ignore divisions by 0 and nans
def create_cloud_mask(im_QA, satname, cloud_mask_issue):
"""
Creates a cloud mask using the information contained in the QA band.
KV WRL 2018
Arguments:
-----------
im_QA: np.array
Image containing the QA band
satname: string
short name for the satellite: ```'L5', 'L7', 'L8' or 'S2'```
cloud_mask_issue: boolean
True if there is an issue with the cloud mask and sand pixels are being
erroneously masked on the images
Returns:
-----------
cloud_mask : np.array
boolean array with True if a pixel is cloudy and False otherwise
"""
# convert QA bits (the bits allocated to cloud cover vary depending on the satellite mission)
if satname == 'L8':
cloud_values = [2800, 2804, 2808, 2812, 6896, 6900, 6904, 6908]
elif satname == 'L7' or satname == 'L5' or satname == 'L4':
cloud_values = [752, 756, 760, 764]
elif satname == 'S2':
cloud_values = [1024, 2048] # 1024 = dense cloud, 2048 = cirrus clouds
# find which pixels have bits corresponding to cloud values
cloud_mask = np.isin(im_QA, cloud_values)
# remove cloud pixels that form very thin features. These are beach or swash pixels that are
# erroneously identified as clouds by the CFMASK algorithm applied to the images by the USGS.
if sum(sum(cloud_mask)) > 0 and sum(sum(~cloud_mask)) > 0:
morphology.remove_small_objects(cloud_mask, min_size=10, connectivity=1, in_place=True)
if cloud_mask_issue:
elem = morphology.square(3) # use a square of width 3 pixels
cloud_mask = morphology.binary_opening(cloud_mask,elem) # perform image opening
# remove objects with less than 25 connected pixels
morphology.remove_small_objects(cloud_mask, min_size=25, connectivity=1, in_place=True)
return cloud_mask
def hist_match(source, template):
"""
Adjust the pixel values of a grayscale image such that its histogram matches
that of a target image.
Arguments:
-----------
source: np.array
Image to transform; the histogram is computed over the flattened
array
template: np.array
Template image; can have different dimensions to source
Returns:
-----------
matched: np.array
The transformed output image
"""
oldshape = source.shape
source = source.ravel()
template = template.ravel()
# get the set of unique pixel values and their corresponding indices and
# counts
s_values, bin_idx, s_counts = np.unique(source, return_inverse=True,
return_counts=True)
t_values, t_counts = np.unique(template, return_counts=True)
# take the cumsum of the counts and normalize by the number of pixels to
# get the empirical cumulative distribution functions for the source and
# template images (maps pixel value --> quantile)
s_quantiles = np.cumsum(s_counts).astype(np.float64)
s_quantiles /= s_quantiles[-1]
t_quantiles = np.cumsum(t_counts).astype(np.float64)
t_quantiles /= t_quantiles[-1]
# interpolate linearly to find the pixel values in the template image
# that correspond most closely to the quantiles in the source image
interp_t_values = np.interp(s_quantiles, t_quantiles, t_values)
return interp_t_values[bin_idx].reshape(oldshape)
def pansharpen(im_ms, im_pan, cloud_mask):
"""
Pansharpens a multispectral image, using the panchromatic band and a cloud mask.
A PCA is applied to the image, then the 1st PC is replaced, after histogram
matching with the panchromatic band. Note that it is essential to match the
histrograms of the 1st PC and the panchromatic band before replacing and
inverting the PCA.
KV WRL 2018
Arguments:
-----------
im_ms: np.array
Multispectral image to pansharpen (3D)
im_pan: np.array
Panchromatic band (2D)
cloud_mask: np.array
2D cloud mask with True where cloud pixels are
Returns:
-----------
im_ms_ps: np.ndarray
Pansharpened multispectral image (3D)
"""
# reshape image into vector and apply cloud mask
vec = im_ms.reshape(im_ms.shape[0] * im_ms.shape[1], im_ms.shape[2])
vec_mask = cloud_mask.reshape(im_ms.shape[0] * im_ms.shape[1])
vec = vec[~vec_mask, :]
# apply PCA to multispectral bands
pca = decomposition.PCA()
vec_pcs = pca.fit_transform(vec)
# replace 1st PC with pan band (after matching histograms)
vec_pan = im_pan.reshape(im_pan.shape[0] * im_pan.shape[1])
vec_pan = vec_pan[~vec_mask]
vec_pcs[:,0] = hist_match(vec_pan, vec_pcs[:,0])
vec_ms_ps = pca.inverse_transform(vec_pcs)
# reshape vector into image
vec_ms_ps_full = np.ones((len(vec_mask), im_ms.shape[2])) * np.nan
vec_ms_ps_full[~vec_mask,:] = vec_ms_ps
im_ms_ps = vec_ms_ps_full.reshape(im_ms.shape[0], im_ms.shape[1], im_ms.shape[2])
return im_ms_ps
def rescale_image_intensity(im, cloud_mask, prob_high):
"""
Rescales the intensity of an image (multispectral or single band) by applying
a cloud mask and clipping the prob_high upper percentile. This functions allows
to stretch the contrast of an image, only for visualisation purposes.
KV WRL 2018
Arguments:
-----------
im: np.array
Image to rescale, can be 3D (multispectral) or 2D (single band)
cloud_mask: np.array
2D cloud mask with True where cloud pixels are
prob_high: float
probability of exceedence used to calculate the upper percentile
Returns:
-----------
im_adj: np.array
rescaled image
"""
# lower percentile is set to 0
prc_low = 0
# reshape the 2D cloud mask into a 1D vector
vec_mask = cloud_mask.reshape(im.shape[0] * im.shape[1])
# if image contains several bands, stretch the contrast for each band
if len(im.shape) > 2:
# reshape into a vector
vec = im.reshape(im.shape[0] * im.shape[1], im.shape[2])
# initiliase with NaN values
vec_adj = np.ones((len(vec_mask), im.shape[2])) * np.nan
# loop through the bands
for i in range(im.shape[2]):
# find the higher percentile (based on prob)
prc_high = np.percentile(vec[~vec_mask, i], prob_high)
# clip the image around the 2 percentiles and rescale the contrast
vec_rescaled = exposure.rescale_intensity(vec[~vec_mask, i],
in_range=(prc_low, prc_high))
vec_adj[~vec_mask,i] = vec_rescaled
# reshape into image
im_adj = vec_adj.reshape(im.shape[0], im.shape[1], im.shape[2])
# if image only has 1 bands (grayscale image)
else:
vec = im.reshape(im.shape[0] * im.shape[1])
vec_adj = np.ones(len(vec_mask)) * np.nan
prc_high = np.percentile(vec[~vec_mask], prob_high)
vec_rescaled = exposure.rescale_intensity(vec[~vec_mask], in_range=(prc_low, prc_high))
vec_adj[~vec_mask] = vec_rescaled
im_adj = vec_adj.reshape(im.shape[0], im.shape[1])
return im_adj
def preprocess_single(fn, satname, cloud_mask_issue):
"""
Reads the image and outputs the pansharpened/down-sampled multispectral bands,
the georeferencing vector of the image (coordinates of the upper left pixel),
the cloud mask, the QA band and a no_data image.
For Landsat 7-8 it also outputs the panchromatic band and for Sentinel-2 it
also outputs the 20m SWIR band.
KV WRL 2018
Arguments:
-----------
fn: str or list of str
filename of the .TIF file containing the image. For L7, L8 and S2 this
is a list of filenames, one filename for each band at different
resolution (30m and 15m for Landsat 7-8, 10m, 20m, 60m for Sentinel-2)
satname: str
name of the satellite mission (e.g., 'L5')
cloud_mask_issue: boolean
True if there is an issue with the cloud mask and sand pixels are being masked on the images
Returns:
-----------
im_ms: np.array
3D array containing the pansharpened/down-sampled bands (B,G,R,NIR,SWIR1)
georef: np.array
vector of 6 elements [Xtr, Xscale, Xshear, Ytr, Yshear, Yscale] defining the
coordinates of the top-left pixel of the image
cloud_mask: np.array
2D cloud mask with True where cloud pixels are
im_extra : np.array
2D array containing the 20m resolution SWIR band for Sentinel-2 and the 15m resolution
panchromatic band for Landsat 7 and Landsat 8. This field is empty for Landsat 5.
im_QA: np.array
2D array containing the QA band, from which the cloud_mask can be computed.
im_nodata: np.array
2D array with True where no data values (-inf) are located
"""
#=============================================================================================#
# L5 images
#=============================================================================================#
if satname == 'L5':
# read all bands
data = gdal.Open(fn, gdal.GA_ReadOnly)
georef = np.array(data.GetGeoTransform())
bands = [data.GetRasterBand(k + 1).ReadAsArray() for k in range(data.RasterCount)]
im_ms = np.stack(bands, 2)
# down-sample to 15 m (half of the original pixel size)
nrows = im_ms.shape[0]*2
ncols = im_ms.shape[1]*2
# create cloud mask
im_QA = im_ms[:,:,5]
im_ms = im_ms[:,:,:-1]
cloud_mask = create_cloud_mask(im_QA, satname, cloud_mask_issue)
# resize the image using bilinear interpolation (order 1)
im_ms = transform.resize(im_ms,(nrows, ncols), order=1, preserve_range=True,
mode='constant')
# resize the image using nearest neighbour interpolation (order 0)
cloud_mask = transform.resize(cloud_mask, (nrows, ncols), order=0, preserve_range=True,
mode='constant').astype('bool_')
# adjust georeferencing vector to the new image size
# scale becomes 15m and the origin is adjusted to the center of new top left pixel
georef[1] = 15
georef[5] = -15
georef[0] = georef[0] + 7.5
georef[3] = georef[3] - 7.5
# check if -inf or nan values on any band and add to cloud mask
im_nodata = np.zeros(cloud_mask.shape).astype(bool)
for k in range(im_ms.shape[2]):
im_inf = np.isin(im_ms[:,:,k], -np.inf)
im_nan = np.isnan(im_ms[:,:,k])
cloud_mask = np.logical_or(np.logical_or(cloud_mask, im_inf), im_nan)
im_nodata = np.logical_or(np.logical_or(im_nodata, im_inf), im_nan)
# check if there are pixels with 0 intensity in the Green, NIR and SWIR bands and add those
# to the cloud mask as otherwise they will cause errors when calculating the NDWI and MNDWI
im_zeros = np.ones(cloud_mask.shape).astype(bool)
for k in [1,3,4]: # loop through the Green, NIR and SWIR bands
im_zeros = np.logical_and(np.isin(im_ms[:,:,k],0), im_zeros)
# update cloud mask and nodata
cloud_mask = np.logical_or(im_zeros, cloud_mask)
im_nodata = np.logical_or(im_zeros, im_nodata)
# no extra image for Landsat 5 (they are all 30 m bands)
im_extra = []
#=============================================================================================#
# L7 images
#=============================================================================================#
elif satname == 'L7':
# read pan image
fn_pan = fn[0]
data = gdal.Open(fn_pan, gdal.GA_ReadOnly)
georef = np.array(data.GetGeoTransform())
bands = [data.GetRasterBand(k + 1).ReadAsArray() for k in range(data.RasterCount)]
im_pan = np.stack(bands, 2)[:,:,0]
# size of pan image
nrows = im_pan.shape[0]
ncols = im_pan.shape[1]
# read ms image
fn_ms = fn[1]
data = gdal.Open(fn_ms, gdal.GA_ReadOnly)
bands = [data.GetRasterBand(k + 1).ReadAsArray() for k in range(data.RasterCount)]
im_ms = np.stack(bands, 2)
# create cloud mask
im_QA = im_ms[:,:,5]
cloud_mask = create_cloud_mask(im_QA, satname, cloud_mask_issue)
# resize the image using bilinear interpolation (order 1)
im_ms = im_ms[:,:,:5]
im_ms = transform.resize(im_ms,(nrows, ncols), order=1, preserve_range=True,
mode='constant')
# resize the image using nearest neighbour interpolation (order 0)
cloud_mask = transform.resize(cloud_mask, (nrows, ncols), order=0, preserve_range=True,
mode='constant').astype('bool_')
# check if -inf or nan values on any band and eventually add those pixels to cloud mask
im_nodata = np.zeros(cloud_mask.shape).astype(bool)
for k in range(im_ms.shape[2]):
im_inf = np.isin(im_ms[:,:,k], -np.inf)
im_nan = np.isnan(im_ms[:,:,k])
cloud_mask = np.logical_or(np.logical_or(cloud_mask, im_inf), im_nan)
im_nodata = np.logical_or(np.logical_or(im_nodata, im_inf), im_nan)
# check if there are pixels with 0 intensity in the Green, NIR and SWIR bands and add those
# to the cloud mask as otherwise they will cause errors when calculating the NDWI and MNDWI
im_zeros = np.ones(cloud_mask.shape).astype(bool)
for k in [1,3,4]: # loop through the Green, NIR and SWIR bands
im_zeros = np.logical_and(np.isin(im_ms[:,:,k],0), im_zeros)
# update cloud mask and nodata
cloud_mask = np.logical_or(im_zeros, cloud_mask)
im_nodata = np.logical_or(im_zeros, im_nodata)
# pansharpen Green, Red, NIR (where there is overlapping with pan band in L7)
try:
im_ms_ps = pansharpen(im_ms[:,:,[1,2,3]], im_pan, cloud_mask)
except: # if pansharpening fails, keep downsampled bands (for long runs)
im_ms_ps = im_ms[:,:,[1,2,3]]
# add downsampled Blue and SWIR1 bands
im_ms_ps = np.append(im_ms[:,:,[0]], im_ms_ps, axis=2)
im_ms_ps = np.append(im_ms_ps, im_ms[:,:,[4]], axis=2)
im_ms = im_ms_ps.copy()
# the extra image is the 15m panchromatic band
im_extra = im_pan
#=============================================================================================#
# L8 images
#=============================================================================================#
elif satname == 'L8':
# read pan image
fn_pan = fn[0]
data = gdal.Open(fn_pan, gdal.GA_ReadOnly)
georef = np.array(data.GetGeoTransform())
bands = [data.GetRasterBand(k + 1).ReadAsArray() for k in range(data.RasterCount)]
im_pan = np.stack(bands, 2)[:,:,0]
# size of pan image
nrows = im_pan.shape[0]
ncols = im_pan.shape[1]
# read ms image
fn_ms = fn[1]
data = gdal.Open(fn_ms, gdal.GA_ReadOnly)
bands = [data.GetRasterBand(k + 1).ReadAsArray() for k in range(data.RasterCount)]
im_ms = np.stack(bands, 2)
# create cloud mask
im_QA = im_ms[:,:,5]
cloud_mask = create_cloud_mask(im_QA, satname, cloud_mask_issue)
# resize the image using bilinear interpolation (order 1)
im_ms = im_ms[:,:,:5]
im_ms = transform.resize(im_ms,(nrows, ncols), order=1, preserve_range=True,
mode='constant')
# resize the image using nearest neighbour interpolation (order 0)
cloud_mask = transform.resize(cloud_mask, (nrows, ncols), order=0, preserve_range=True,
mode='constant').astype('bool_')
# check if -inf or nan values on any band and eventually add those pixels to cloud mask
im_nodata = np.zeros(cloud_mask.shape).astype(bool)
for k in range(im_ms.shape[2]):
im_inf = np.isin(im_ms[:,:,k], -np.inf)
im_nan = np.isnan(im_ms[:,:,k])
cloud_mask = np.logical_or(np.logical_or(cloud_mask, im_inf), im_nan)
im_nodata = np.logical_or(np.logical_or(im_nodata, im_inf), im_nan)
# check if there are pixels with 0 intensity in the Green, NIR and SWIR bands and add those
# to the cloud mask as otherwise they will cause errors when calculating the NDWI and MNDWI
im_zeros = np.ones(cloud_mask.shape).astype(bool)
for k in [1,3,4]: # loop through the Green, NIR and SWIR bands
im_zeros = np.logical_and(np.isin(im_ms[:,:,k],0), im_zeros)
# update cloud mask and nodata
cloud_mask = np.logical_or(im_zeros, cloud_mask)
im_nodata = np.logical_or(im_zeros, im_nodata)
# pansharpen Blue, Green, Red (where there is overlapping with pan band in L8)
try:
im_ms_ps = pansharpen(im_ms[:,:,[0,1,2]], im_pan, cloud_mask)
except: # if pansharpening fails, keep downsampled bands (for long runs)
im_ms_ps = im_ms[:,:,[0,1,2]]
# add downsampled NIR and SWIR1 bands
im_ms_ps = np.append(im_ms_ps, im_ms[:,:,[3,4]], axis=2)
im_ms = im_ms_ps.copy()
# the extra image is the 15m panchromatic band
im_extra = im_pan
#=============================================================================================#
# S2 images
#=============================================================================================#
if satname == 'S2':
# read 10m bands (R,G,B,NIR)
fn10 = fn[0]
data = gdal.Open(fn10, gdal.GA_ReadOnly)
georef = np.array(data.GetGeoTransform())
bands = [data.GetRasterBand(k + 1).ReadAsArray() for k in range(data.RasterCount)]
im10 = np.stack(bands, 2)
im10 = im10/10000 # TOA scaled to 10000
# if image contains only zeros (can happen with S2), skip the image
if sum(sum(sum(im10))) < 1:
im_ms = []
georef = []
# skip the image by giving it a full cloud_mask
cloud_mask = np.ones((im10.shape[0],im10.shape[1])).astype('bool')
return im_ms, georef, cloud_mask, [], [], []
# size of 10m bands
nrows = im10.shape[0]
ncols = im10.shape[1]
# read 20m band (SWIR1)
fn20 = fn[1]
data = gdal.Open(fn20, gdal.GA_ReadOnly)
bands = [data.GetRasterBand(k + 1).ReadAsArray() for k in range(data.RasterCount)]
im20 = np.stack(bands, 2)
im20 = im20[:,:,0]
im20 = im20/10000 # TOA scaled to 10000
# resize the image using bilinear interpolation (order 1)
im_swir = transform.resize(im20, (nrows, ncols), order=1, preserve_range=True,
mode='constant')
im_swir = np.expand_dims(im_swir, axis=2)
# append down-sampled SWIR1 band to the other 10m bands
im_ms = np.append(im10, im_swir, axis=2)
# create cloud mask using 60m QA band (not as good as Landsat cloud cover)
fn60 = fn[2]
data = gdal.Open(fn60, gdal.GA_ReadOnly)
bands = [data.GetRasterBand(k + 1).ReadAsArray() for k in range(data.RasterCount)]
im60 = np.stack(bands, 2)
im_QA = im60[:,:,0]
cloud_mask = create_cloud_mask(im_QA, satname, cloud_mask_issue)
# resize the cloud mask using nearest neighbour interpolation (order 0)
cloud_mask = transform.resize(cloud_mask,(nrows, ncols), order=0, preserve_range=True,
mode='constant')
# check if -inf or nan values on any band and add to cloud mask
im_nodata = np.zeros(cloud_mask.shape).astype(bool)
for k in range(im_ms.shape[2]):
im_inf = np.isin(im_ms[:,:,k], -np.inf)
im_nan = np.isnan(im_ms[:,:,k])
cloud_mask = np.logical_or(np.logical_or(cloud_mask, im_inf), im_nan)
im_nodata = np.logical_or(np.logical_or(im_nodata, im_inf), im_nan)
# check if there are pixels with 0 intensity in the Green, NIR and SWIR bands and add those
# to the cloud mask as otherwise they will cause errors when calculating the NDWI and MNDWI
im_zeros = np.ones(cloud_mask.shape).astype(bool)
for k in [1,3,4]: # loop through the Green, NIR and SWIR bands
im_zeros = np.logical_and(np.isin(im_ms[:,:,k],0), im_zeros)
# update cloud mask and nodata
cloud_mask = np.logical_or(im_zeros, cloud_mask)
im_nodata = np.logical_or(im_zeros, im_nodata)
# the extra image is the 20m SWIR band
im_extra = im20
return im_ms, georef, cloud_mask, im_extra, im_QA, im_nodata
def create_jpg(im_ms, cloud_mask, date, satname, filepath):
"""
Saves a .jpg file with the RGB image as well as the NIR and SWIR1 grayscale images.
This functions can be modified to obtain different visualisations of the
multispectral images.
KV WRL 2018
Arguments:
-----------
im_ms: np.array
3D array containing the pansharpened/down-sampled bands (B,G,R,NIR,SWIR1)
cloud_mask: np.array
2D cloud mask with True where cloud pixels are
date: str
string containing the date at which the image was acquired
satname: str
name of the satellite mission (e.g., 'L5')
Returns:
-----------
Saves a .jpg image corresponding to the preprocessed satellite image
"""
# rescale image intensity for display purposes
im_RGB = rescale_image_intensity(im_ms[:,:,[2,1,0]], cloud_mask, 99.9)
# im_NIR = rescale_image_intensity(im_ms[:,:,3], cloud_mask, 99.9)
# im_SWIR = rescale_image_intensity(im_ms[:,:,4], cloud_mask, 99.9)
# make figure (just RGB)
fig = plt.figure()
fig.set_size_inches([18,9])
fig.set_tight_layout(True)
ax1 = fig.add_subplot(111)
ax1.axis('off')
ax1.imshow(im_RGB)
ax1.set_title(date + ' ' + satname, fontsize=16)
# if im_RGB.shape[1] > 2*im_RGB.shape[0]:
# ax1 = fig.add_subplot(311)
# ax2 = fig.add_subplot(312)
# ax3 = fig.add_subplot(313)
# else:
# ax1 = fig.add_subplot(131)
# ax2 = fig.add_subplot(132)
# ax3 = fig.add_subplot(133)
# # RGB
# ax1.axis('off')
# ax1.imshow(im_RGB)
# ax1.set_title(date + ' ' + satname, fontsize=16)
# # NIR
# ax2.axis('off')
# ax2.imshow(im_NIR, cmap='seismic')
# ax2.set_title('Near Infrared', fontsize=16)
# # SWIR
# ax3.axis('off')
# ax3.imshow(im_SWIR, cmap='seismic')
# ax3.set_title('Short-wave Infrared', fontsize=16)
# save figure
plt.rcParams['savefig.jpeg_quality'] = 100
fig.savefig(os.path.join(filepath,
date + '_' + satname + '.jpg'), dpi=150)
plt.close()
def save_jpg(metadata, settings, **kwargs):
"""
Saves a .jpg image for all the images contained in metadata.
KV WRL 2018
Arguments:
-----------
metadata: dict
contains all the information about the satellite images that were downloaded
settings: dict with the following keys
'inputs': dict
input parameters (sitename, filepath, polygon, dates, sat_list)
'cloud_thresh': float
value between 0 and 1 indicating the maximum cloud fraction in
the cropped image that is accepted
'cloud_mask_issue': boolean
True if there is an issue with the cloud mask and sand pixels
are erroneously being masked on the images
Returns:
-----------
Stores the images as .jpg in a folder named /preprocessed
"""
sitename = settings['inputs']['sitename']
cloud_thresh = settings['cloud_thresh']
filepath_data = settings['inputs']['filepath']
# create subfolder to store the jpg files
filepath_jpg = os.path.join(filepath_data, sitename, 'jpg_files', 'preprocessed')
if not os.path.exists(filepath_jpg):
os.makedirs(filepath_jpg)
# loop through satellite list
for satname in metadata.keys():
filepath = SDS_tools.get_filepath(settings['inputs'],satname)
filenames = metadata[satname]['filenames']
# loop through images
for i in range(len(filenames)):
# image filename
fn = SDS_tools.get_filenames(filenames[i],filepath, satname)
# read and preprocess image
im_ms, georef, cloud_mask, im_extra, im_QA, im_nodata = preprocess_single(fn, satname, settings['cloud_mask_issue'])
# calculate cloud cover
cloud_cover = np.divide(sum(sum(cloud_mask.astype(int))),
(cloud_mask.shape[0]*cloud_mask.shape[1]))
# skip image if cloud cover is above threshold
if cloud_cover > cloud_thresh or cloud_cover == 1:
continue
# save .jpg with date and satellite in the title
date = filenames[i][:19]
plt.ioff() # turning interactive plotting off
create_jpg(im_ms, cloud_mask, date, satname, filepath_jpg)
# print the location where the images have been saved
print('Satellite images saved as .jpg in ' + os.path.join(filepath_data, sitename,
'jpg_files', 'preprocessed'))
def get_reference_sl(metadata, settings):
"""
Allows the user to manually digitize a reference shoreline that is used seed
the shoreline detection algorithm. The reference shoreline helps to detect
the outliers, making the shoreline detection more robust.
KV WRL 2018
Arguments:
-----------
metadata: dict
contains all the information about the satellite images that were downloaded
settings: dict with the following keys
'inputs': dict
input parameters (sitename, filepath, polygon, dates, sat_list)
'cloud_thresh': float
value between 0 and 1 indicating the maximum cloud fraction in
the cropped image that is accepted
'cloud_mask_issue': boolean
True if there is an issue with the cloud mask and sand pixels
are erroneously being masked on the images
'output_epsg': int
output spatial reference system as EPSG code
Returns:
-----------
reference_shoreline: np.array
coordinates of the reference shoreline that was manually digitized.
This is also saved as a .pkl and .geojson file.
"""
sitename = settings['inputs']['sitename']
filepath_data = settings['inputs']['filepath']
pts_coords = []
# check if reference shoreline already exists in the corresponding folder
filepath = os.path.join(filepath_data, sitename)
filename = sitename + '_reference_shoreline.pkl'
# if it exist, load it and return it
if filename in os.listdir(filepath):
print('Reference shoreline already exists and was loaded')
with open(os.path.join(filepath, sitename + '_reference_shoreline.pkl'), 'rb') as f:
refsl = pickle.load(f)
return refsl
# otherwise get the user to manually digitise a shoreline on S2, L8 or L5 images (no L7 because of scan line error)
else:
# first try to use S2 images (10m res for manually digitizing the reference shoreline)
if 'S2' in metadata.keys():
satname = 'S2'
filepath = SDS_tools.get_filepath(settings['inputs'],satname)
filenames = metadata[satname]['filenames']
# if no S2 images, try L8 (15m res in the RGB with pansharpening)
elif not 'S2' in metadata.keys() and 'L8' in metadata.keys():
satname = 'L8'
filepath = SDS_tools.get_filepath(settings['inputs'],satname)
filenames = metadata[satname]['filenames']
# if no S2 images and no L8, use L5 images (L7 images have black diagonal bands making it
# hard to manually digitize a shoreline)
elif not 'S2' in metadata.keys() and not 'L8' in metadata.keys() and 'L5' in metadata.keys():
satname = 'L5'
filepath = SDS_tools.get_filepath(settings['inputs'],satname)
filenames = metadata[satname]['filenames']
else:
raise Exception('You cannot digitize the shoreline on L7 images (because of gaps in the images), add another L8, S2 or L5 to your dataset.')
# create figure
fig, ax = plt.subplots(1,1, figsize=[18,9], tight_layout=True)
mng = plt.get_current_fig_manager()
mng.window.showMaximized()
# loop trhough the images
for i in range(len(filenames)):
# read image
fn = SDS_tools.get_filenames(filenames[i],filepath, satname)
im_ms, georef, cloud_mask, im_extra, im_QA, im_nodata = preprocess_single(fn, satname, settings['cloud_mask_issue'])
# calculate cloud cover
cloud_cover = np.divide(sum(sum(cloud_mask.astype(int))),
(cloud_mask.shape[0]*cloud_mask.shape[1]))
# skip image if cloud cover is above threshold
if cloud_cover > settings['cloud_thresh']:
continue
# rescale image intensity for display purposes
im_RGB = rescale_image_intensity(im_ms[:,:,[2,1,0]], cloud_mask, 99.9)
# plot the image RGB on a figure
ax.axis('off')
ax.imshow(im_RGB)
# decide if the image if good enough for digitizing the shoreline
ax.set_title('Press <right arrow> if image is clear enough to digitize the shoreline.\n' +
'If the image is cloudy press <left arrow> to get another image', fontsize=14)
# set a key event to accept/reject the detections (see https://stackoverflow.com/a/15033071)
# this variable needs to be immuatable so we can access it after the keypress event
skip_image = False
key_event = {}
def press(event):
# store what key was pressed in the dictionary
key_event['pressed'] = event.key
# let the user press a key, right arrow to keep the image, left arrow to skip it
# to break the loop the user can press 'escape'
while True:
btn_keep = plt.text(1.1, 0.9, 'keep ⇨', size=12, ha="right", va="top",
transform=ax.transAxes,
bbox=dict(boxstyle="square", ec='k',fc='w'))
btn_skip = plt.text(-0.1, 0.9, '⇦ skip', size=12, ha="left", va="top",
transform=ax.transAxes,
bbox=dict(boxstyle="square", ec='k',fc='w'))
btn_esc = plt.text(0.5, 0, '<esc> to quit', size=12, ha="center", va="top",
transform=ax.transAxes,
bbox=dict(boxstyle="square", ec='k',fc='w'))
plt.draw()
fig.canvas.mpl_connect('key_press_event', press)
plt.waitforbuttonpress()
# after button is pressed, remove the buttons
btn_skip.remove()
btn_keep.remove()
btn_esc.remove()
# keep/skip image according to the pressed key, 'escape' to break the loop
if key_event.get('pressed') == 'right':
skip_image = False
break
elif key_event.get('pressed') == 'left':
skip_image = True
break
elif key_event.get('pressed') == 'escape':
plt.close()
raise StopIteration('User cancelled checking shoreline detection')
else:
plt.waitforbuttonpress()
if skip_image:
ax.clear()
continue
else:
# create two new buttons
add_button = plt.text(0, 0.9, 'add', size=16, ha="left", va="top",
transform=plt.gca().transAxes,
bbox=dict(boxstyle="square", ec='k',fc='w'))
end_button = plt.text(1, 0.9, 'end', size=16, ha="right", va="top",
transform=plt.gca().transAxes,
bbox=dict(boxstyle="square", ec='k',fc='w'))
# add multiple reference shorelines (until user clicks on <end> button)
pts_sl = np.expand_dims(np.array([np.nan, np.nan]),axis=0)
geoms = []
while 1:
add_button.set_visible(False)
end_button.set_visible(False)
# update title (instructions)
ax.set_title('Click points along the shoreline (enough points to capture the beach curvature).\n' +
'Start at one end of the beach.\n' + 'When finished digitizing, click <ENTER>',
fontsize=14)
plt.draw()
# let user click on the shoreline
pts = ginput(n=50000, timeout=1e9, show_clicks=True)
pts_pix = np.array(pts)
# convert pixel coordinates to world coordinates
pts_world = SDS_tools.convert_pix2world(pts_pix[:,[1,0]], georef)
# interpolate between points clicked by the user (1m resolution)
pts_world_interp = np.expand_dims(np.array([np.nan, np.nan]),axis=0)
for k in range(len(pts_world)-1):
pt_dist = np.linalg.norm(pts_world[k,:]-pts_world[k+1,:])
xvals = np.arange(0,pt_dist)
yvals = np.zeros(len(xvals))
pt_coords = np.zeros((len(xvals),2))
pt_coords[:,0] = xvals
pt_coords[:,1] = yvals
phi = 0
deltax = pts_world[k+1,0] - pts_world[k,0]
deltay = pts_world[k+1,1] - pts_world[k,1]
phi = np.pi/2 - np.math.atan2(deltax, deltay)
tf = transform.EuclideanTransform(rotation=phi, translation=pts_world[k,:])
pts_world_interp = np.append(pts_world_interp,tf(pt_coords), axis=0)
pts_world_interp = np.delete(pts_world_interp,0,axis=0)
# save as geometry (to create .geojson file later)
geoms.append(geometry.LineString(pts_world_interp))
# convert to pixel coordinates and plot
pts_pix_interp = SDS_tools.convert_world2pix(pts_world_interp, georef)
pts_sl = np.append(pts_sl, pts_world_interp, axis=0)
ax.plot(pts_pix_interp[:,0], pts_pix_interp[:,1], 'r--')
ax.plot(pts_pix_interp[0,0], pts_pix_interp[0,1],'ko')
ax.plot(pts_pix_interp[-1,0], pts_pix_interp[-1,1],'ko')
# update title and buttons
add_button.set_visible(True)
end_button.set_visible(True)
ax.set_title('click on <add> to digitize another shoreline or on <end> to finish and save the shoreline(s)',
fontsize=14)
plt.draw()
# let the user click again (<add> another shoreline or <end>)
pt_input = ginput(n=1, timeout=1e9, show_clicks=False)
pt_input = np.array(pt_input)
# if user clicks on <end>, save the points and break the loop
if pt_input[0][0] > im_ms.shape[1]/2:
add_button.set_visible(False)
end_button.set_visible(False)
plt.title('Reference shoreline saved as ' + sitename + '_reference_shoreline.pkl and ' + sitename + '_reference_shoreline.geojson')
plt.draw()
ginput(n=1, timeout=3, show_clicks=False)
plt.close()
break
pts_sl = np.delete(pts_sl,0,axis=0)
# convert world image coordinates to user-defined coordinate system
image_epsg = metadata[satname]['epsg'][i]
pts_coords = SDS_tools.convert_epsg(pts_sl, image_epsg, settings['output_epsg'])
# save the reference shoreline as .pkl
filepath = os.path.join(filepath_data, sitename)
with open(os.path.join(filepath, sitename + '_reference_shoreline.pkl'), 'wb') as f:
pickle.dump(pts_coords, f)
# also store as .geojson in case user wants to drag-and-drop on GIS for verification
for k,line in enumerate(geoms):
gdf = gpd.GeoDataFrame(geometry=gpd.GeoSeries(line))
gdf.index = [k]
gdf.loc[k,'name'] = 'reference shoreline ' + str(k+1)
# store into geodataframe
if k == 0:
gdf_all = gdf
else:
gdf_all = gdf_all.append(gdf)
gdf_all.crs = {'init':'epsg:'+str(image_epsg)}
# convert from image_epsg to user-defined coordinate system
gdf_all = gdf_all.to_crs({'init': 'epsg:'+str(settings['output_epsg'])})
# save as geojson
gdf_all.to_file(os.path.join(filepath, sitename + '_reference_shoreline.geojson'),
driver='GeoJSON', encoding='utf-8')
print('Reference shoreline has been saved in ' + filepath)
break
# check if a shoreline was digitised
if len(pts_coords) == 0:
raise Exception('No cloud free images are available to digitise the reference shoreline,'+
'download more images and try again')
return pts_coords