From a0b49c7dcfc2f20c622c528259ec4990093ef454 Mon Sep 17 00:00:00 2001 From: Kilian Vos Date: Tue, 27 Nov 2018 12:58:00 +1100 Subject: [PATCH] major updates --- .gitignore | 3 +- NARRA.kml | 62 +++ SDS_download.py | 497 +++++++++++++++++--- SDS_preprocess.py | 346 +++++++++----- SDS_shoreline.py | 575 ++++++++++++++---------- SDS_tools.py | 203 ++++++++- gdal_merge.py | 540 ++++++++++++++++++++++ main_test.py | 285 ++++++++++++ shoreline_extraction.ipynb | 2 +- main_spyder.py => test_spyder_simple.py | 63 +-- 10 files changed, 2119 insertions(+), 457 deletions(-) create mode 100644 NARRA.kml create mode 100644 gdal_merge.py create mode 100644 main_test.py rename main_spyder.py => test_spyder_simple.py (51%) diff --git a/.gitignore b/.gitignore index 8b83274..903a401 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,5 @@ *.mp4 *.gif *.jpg -*.pkl \ No newline at end of file +*.pkl +*.xml \ No newline at end of file diff --git a/NARRA.kml b/NARRA.kml new file mode 100644 index 0000000..ab57b18 --- /dev/null +++ b/NARRA.kml @@ -0,0 +1,62 @@ + + + + NARRA + + + + + normal + #poly-000000-1200-77-nodesc-normal + + + highlight + #poly-000000-1200-77-nodesc-highlight + + + + Polygon 1 + #poly-000000-1200-77-nodesc + + + + 1 + + 151.2957545,-33.7012561,0 + 151.297557,-33.7388075,0 + 151.312234,-33.7390216,0 + 151.311204,-33.701399,0 + 151.2957545,-33.7012561,0 + + + + + + + diff --git a/SDS_download.py b/SDS_download.py index 138a865..a209f34 100644 --- a/SDS_download.py +++ b/SDS_download.py @@ -1,25 +1,37 @@ -"""This module contains all the functions needed to download the satellite images from GEE +"""This module contains all the functions needed to download the satellite images from the Google +Earth Engine Server Author: Kilian Vos, Water Research Laboratory, University of New South Wales """ -# Initial settings +# load modules import os import numpy as np import matplotlib.pyplot as plt import pdb + +# earth engine modules import ee from urllib.request import urlretrieve +import zipfile +import copy +import gdal_merge + +# additional modules from datetime import datetime import pytz import pickle -import zipfile +import skimage.morphology as morphology + +# own modules +import SDS_preprocess, SDS_tools + +np.seterr(all='ignore') # raise/ignore divisions by 0 and nans + # initialise connection with GEE server ee.Initialize() -# Functions - def download_tif(image, polygon, bandsId, filepath): """ Downloads a .TIF image from the ee server and stores it in a temp file @@ -49,34 +61,48 @@ def download_tif(image, polygon, bandsId, filepath): return local_zipfile.extract('data.tif', filepath) -def get_images(sitename,polygon,dates,sat): +def get_images(inputs): """ - Downloads all images from Landsat 5, Landsat 7, Landsat 8 and Sentinel-2 covering the given - polygon and acquired during the given dates. The images are organised in subfolders and divided - by satellite mission and pixel resolution. + Downloads all images from Landsat 5, Landsat 7, Landsat 8 and Sentinel-2 covering the area of + interest and acquired between the specified dates. + The downloaded images are in .TIF format and organised in subfolders, divided by satellite + mission and pixel resolution. KV WRL 2018 Arguments: ----------- - sitename: str + inputs: dict + dictionnary that contains the following fields: + 'sitename': str String containig the name of the site - polygon: list + 'polygon': list polygon containing the lon/lat coordinates to be extracted longitudes in the first column and latitudes in the second column - dates: list of str + 'dates': list of str list that contains 2 strings with the initial and final dates in format 'yyyy-mm-dd' e.g. ['1987-01-01', '2018-01-01'] - sat: list of str + 'sat_list': list of str list that contains the names of the satellite missions to include e.g. ['L5', 'L7', 'L8', 'S2'] - + + Returns: + ----------- + metadata: dict + contains all the information about the satellite images that were downloaded + """ + # read inputs dictionnary + sitename = inputs['sitename'] + polygon = inputs['polygon'] + dates = inputs['dates'] + sat_list= inputs['sat_list'] + # format in which the images are downloaded suffix = '.tif' - # initialise metadata dictionnary (stores timestamps and georefencing accuracy of each image) + # initialize metadata dictionnary (stores timestamps and georefencing accuracy of each image) metadata = dict([]) # create directories @@ -89,7 +115,7 @@ def get_images(sitename,polygon,dates,sat): # download L5 images #=============================================================================================# - if 'L5' in sat or 'Landsat5' in sat: + if 'L5' in sat_list or 'Landsat5' in sat_list: satname = 'L5' # create a subfolder to store L5 images @@ -105,19 +131,27 @@ def get_images(sitename,polygon,dates,sat): flt_col = input_col.filterBounds(ee.Geometry.Polygon(polygon)).filterDate(dates[0],dates[1]) # get all images in the filtered collection im_all = flt_col.getInfo().get('features') - # print how many images there are for the user - n_img = flt_col.size().getInfo() - print('Number of ' + satname + ' images covering ' + sitename + ':', n_img) + # remove very cloudy images (>95% cloud) + cloud_cover = [_['properties']['CLOUD_COVER'] for _ in im_all] + if np.any([_ > 95 for _ in cloud_cover]): + idx_delete = np.where([_ > 95 for _ in cloud_cover])[0] + im_all_cloud = [x for k,x in enumerate(im_all) if k not in idx_delete] + else: + im_all_cloud = im_all + n_img = len(im_all_cloud) + # print how many images there are + print('Number of ' + satname + ' images covering ' + sitename + ':', n_img) # loop trough images timestamps = [] acc_georef = [] + filenames = [] all_names = [] im_epsg = [] for i in range(n_img): # find each image in ee database - im = ee.Image(im_all[i].get('id')) + im = ee.Image(im_all_cloud[i].get('id')) # read metadata im_dic = im.getInfo() # get bands @@ -136,32 +170,38 @@ def get_images(sitename,polygon,dates,sat): except: # default value of accuracy (RMSE = 12m) acc_georef.append(12) - print('No geometric rmse model property') # delete dimensions key from dictionnary, otherwise the entire image is extracted for j in range(len(im_bands)): del im_bands[j]['dimensions'] # bands for L5 ms_bands = [im_bands[0], im_bands[1], im_bands[2], im_bands[3], im_bands[4], im_bands[7]] # filenames for the images filename = im_date + '_' + satname + '_' + sitename + suffix - # if two images taken at the same date add 'dup' in the name + # if two images taken at the same date add 'dup' in the name (duplicate) if any(filename in _ for _ in all_names): filename = im_date + '_' + satname + '_' + sitename + '_dup' + suffix all_names.append(filename) + filenames.append(filename) # download .TIF image local_data = download_tif(im, polygon, ms_bands, filepath) # update filename - os.rename(local_data, os.path.join(filepath, filename)) - print(i, end='..') + try: + os.rename(local_data, os.path.join(filepath, filename)) + except: + os.remove(os.path.join(filepath, filename)) + os.rename(local_data, os.path.join(filepath, filename)) + + print(i+1, end='..') - # sort timestamps and georef accuracy (dowloaded images are sorted by date in directory) + # sort timestamps and georef accuracy (downloaded images are sorted by date in directory) timestamps_sorted = sorted(timestamps) idx_sorted = sorted(range(len(timestamps)), key=timestamps.__getitem__) acc_georef_sorted = [acc_georef[j] for j in idx_sorted] + filenames_sorted = [filenames[j] for j in idx_sorted] im_epsg_sorted = [im_epsg[j] for j in idx_sorted] # save into dict metadata[satname] = {'dates':timestamps_sorted, 'acc_georef':acc_georef_sorted, - 'epsg':im_epsg_sorted} - print('Finished with ' + satname) + 'epsg':im_epsg_sorted, 'filenames':filenames_sorted} + print('\nFinished with ' + satname) @@ -169,7 +209,7 @@ def get_images(sitename,polygon,dates,sat): # download L7 images #=============================================================================================# - if 'L7' in sat or 'Landsat7' in sat: + if 'L7' in sat_list or 'Landsat7' in sat_list: satname = 'L7' # create subfolders (one for 30m multispectral bands and one for 15m pan bands) @@ -188,19 +228,27 @@ def get_images(sitename,polygon,dates,sat): flt_col = input_col.filterBounds(ee.Geometry.Polygon(polygon)).filterDate(dates[0],dates[1]) # get all images in the filtered collection im_all = flt_col.getInfo().get('features') - # print how many images there are for the user - n_img = flt_col.size().getInfo() - print('Number of ' + satname + ' images covering ' + sitename + ':', n_img) + # remove very cloudy images (>95% cloud) + cloud_cover = [_['properties']['CLOUD_COVER'] for _ in im_all] + if np.any([_ > 95 for _ in cloud_cover]): + idx_delete = np.where([_ > 95 for _ in cloud_cover])[0] + im_all_cloud = [x for k,x in enumerate(im_all) if k not in idx_delete] + else: + im_all_cloud = im_all + n_img = len(im_all_cloud) + # print how many images there are + print('Number of ' + satname + ' images covering ' + sitename + ':', n_img) # loop trough images timestamps = [] acc_georef = [] + filenames = [] all_names = [] im_epsg = [] for i in range(n_img): # find each image in ee database - im = ee.Image(im_all[i].get('id')) + im = ee.Image(im_all_cloud[i].get('id')) # read metadata im_dic = im.getInfo() # get bands @@ -219,7 +267,6 @@ def get_images(sitename,polygon,dates,sat): except: # default value of accuracy (RMSE = 12m) acc_georef.append(12) - print('No geometric rmse model property') # delete dimensions key from dictionnary, otherwise the entire image is extracted for j in range(len(im_bands)): del im_bands[j]['dimensions'] # bands for L7 @@ -232,31 +279,42 @@ def get_images(sitename,polygon,dates,sat): if any(filename_pan in _ for _ in all_names): filename_pan = im_date + '_' + satname + '_' + sitename + '_pan' + '_dup' + suffix filename_ms = im_date + '_' + satname + '_' + sitename + '_ms' + '_dup' + suffix - all_names.append(filename_pan) + all_names.append(filename_pan) + filenames.append(filename_pan) # download .TIF image local_data_pan = download_tif(im, polygon, pan_band, filepath_pan) local_data_ms = download_tif(im, polygon, ms_bands, filepath_ms) # update filename - os.rename(local_data_pan, os.path.join(filepath_pan, filename_pan)) - os.rename(local_data_ms, os.path.join(filepath_ms, filename_ms)) - print(i, end='..') + try: + os.rename(local_data_pan, os.path.join(filepath_pan, filename_pan)) + except: + os.remove(os.path.join(filepath_pan, filename_pan)) + os.rename(local_data_pan, os.path.join(filepath_pan, filename_pan)) + try: + os.rename(local_data_ms, os.path.join(filepath_ms, filename_ms)) + except: + os.remove(os.path.join(filepath_ms, filename_ms)) + os.rename(local_data_ms, os.path.join(filepath_ms, filename_ms)) + + print(i+1, end='..') # sort timestamps and georef accuracy (dowloaded images are sorted by date in directory) timestamps_sorted = sorted(timestamps) idx_sorted = sorted(range(len(timestamps)), key=timestamps.__getitem__) acc_georef_sorted = [acc_georef[j] for j in idx_sorted] + filenames_sorted = [filenames[j] for j in idx_sorted] im_epsg_sorted = [im_epsg[j] for j in idx_sorted] # save into dict metadata[satname] = {'dates':timestamps_sorted, 'acc_georef':acc_georef_sorted, - 'epsg':im_epsg_sorted} - print('Finished with ' + satname) + 'epsg':im_epsg_sorted, 'filenames':filenames_sorted} + print('\nFinished with ' + satname) #=============================================================================================# # download L8 images #=============================================================================================# - if 'L8' in sat or 'Landsat8' in sat: + if 'L8' in sat_list or 'Landsat8' in sat_list: satname = 'L8' # create subfolders (one for 30m multispectral bands and one for 15m pan bands) @@ -275,19 +333,27 @@ def get_images(sitename,polygon,dates,sat): flt_col = input_col.filterBounds(ee.Geometry.Polygon(polygon)).filterDate(dates[0],dates[1]) # get all images in the filtered collection im_all = flt_col.getInfo().get('features') - # print how many images there are for the user - n_img = flt_col.size().getInfo() - print('Number of ' + satname + ' images covering ' + sitename + ':', n_img) + # remove very cloudy images (>95% cloud) + cloud_cover = [_['properties']['CLOUD_COVER'] for _ in im_all] + if np.any([_ > 95 for _ in cloud_cover]): + idx_delete = np.where([_ > 95 for _ in cloud_cover])[0] + im_all_cloud = [x for k,x in enumerate(im_all) if k not in idx_delete] + else: + im_all_cloud = im_all + n_img = len(im_all_cloud) + # print how many images there are + print('Number of ' + satname + ' images covering ' + sitename + ':', n_img) # loop trough images timestamps = [] acc_georef = [] + filenames = [] all_names = [] im_epsg = [] for i in range(n_img): # find each image in ee database - im = ee.Image(im_all[i].get('id')) + im = ee.Image(im_all_cloud[i].get('id')) # read metadata im_dic = im.getInfo() # get bands @@ -306,7 +372,6 @@ def get_images(sitename,polygon,dates,sat): except: # default value of accuracy (RMSE = 12m) acc_georef.append(12) - print('No geometric rmse model property') # delete dimensions key from dictionnary, otherwise the entire image is extracted for j in range(len(im_bands)): del im_bands[j]['dimensions'] # bands for L8 @@ -319,30 +384,41 @@ def get_images(sitename,polygon,dates,sat): if any(filename_pan in _ for _ in all_names): filename_pan = im_date + '_' + satname + '_' + sitename + '_pan' + '_dup' + suffix filename_ms = im_date + '_' + satname + '_' + sitename + '_ms' + '_dup' + suffix - all_names.append(filename_pan) + all_names.append(filename_pan) + filenames.append(filename_pan) # download .TIF image local_data_pan = download_tif(im, polygon, pan_band, filepath_pan) local_data_ms = download_tif(im, polygon, ms_bands, filepath_ms) # update filename - os.rename(local_data_pan, os.path.join(filepath_pan, filename_pan)) - os.rename(local_data_ms, os.path.join(filepath_ms, filename_ms)) - print(i, end='..') + try: + os.rename(local_data_pan, os.path.join(filepath_pan, filename_pan)) + except: + os.remove(os.path.join(filepath_pan, filename_pan)) + os.rename(local_data_pan, os.path.join(filepath_pan, filename_pan)) + try: + os.rename(local_data_ms, os.path.join(filepath_ms, filename_ms)) + except: + os.remove(os.path.join(filepath_ms, filename_ms)) + os.rename(local_data_ms, os.path.join(filepath_ms, filename_ms)) + + print(i+1, end='..') # sort timestamps and georef accuracy (dowloaded images are sorted by date in directory) timestamps_sorted = sorted(timestamps) idx_sorted = sorted(range(len(timestamps)), key=timestamps.__getitem__) acc_georef_sorted = [acc_georef[j] for j in idx_sorted] + filenames_sorted = [filenames[j] for j in idx_sorted] im_epsg_sorted = [im_epsg[j] for j in idx_sorted] metadata[satname] = {'dates':timestamps_sorted, 'acc_georef':acc_georef_sorted, - 'epsg':im_epsg_sorted} - print('Finished with ' + satname) + 'epsg':im_epsg_sorted, 'filenames':filenames_sorted} + print('\nFinished with ' + satname) #=============================================================================================# # download S2 images #=============================================================================================# - if 'S2' in sat or 'Sentinel2' in sat: + if 'S2' in sat_list or 'Sentinel2' in sat_list: satname = 'S2' # create subfolders for the 10m, 20m and 60m multipectral bands @@ -359,20 +435,60 @@ def get_images(sitename,polygon,dates,sat): # filter by location and dates flt_col = input_col.filterBounds(ee.Geometry.Polygon(polygon)).filterDate(dates[0],dates[1]) # get all images in the filtered collection - im_all = flt_col.getInfo().get('features') + im_all = flt_col.getInfo().get('features') + # remove duplicates in the collection (there are many in S2 collection) + timestamps = [datetime.fromtimestamp(_['properties']['system:time_start']/1000, + tz=pytz.utc) for _ in im_all] + # utm zone projection + utm_zones = np.array([int(_['bands'][0]['crs'][5:]) for _ in im_all]) + utm_zone_selected = np.max(np.unique(utm_zones)) + # find the images that were acquired at the same time but have different utm zones + idx_all = np.arange(0,len(im_all),1) + idx_covered = np.ones(len(im_all)).astype(bool) + idx_delete = [] + i = 0 + while 1: + same_time = np.abs([(timestamps[i]-_).total_seconds() for _ in timestamps]) < 60*60*24 + idx_same_time = np.where(same_time)[0] + same_utm = utm_zones == utm_zone_selected + idx_temp = np.where([same_time[j] == True and same_utm[j] == False for j in idx_all])[0] + idx_keep = idx_same_time[[_ not in idx_temp for _ in idx_same_time ]] + # if more than 2 images with same date and same utm, drop the last ones + if len(idx_keep) > 2: + idx_temp = np.append(idx_temp,idx_keep[-(len(idx_keep)-2):]) + for j in idx_temp: + idx_delete.append(j) + idx_covered[idx_same_time] = False + if np.any(idx_covered): + i = np.where(idx_covered)[0][0] + else: + break + # update the collection by deleting all those images that have same timestamp and different + # utm projection + im_all_updated = [x for k,x in enumerate(im_all) if k not in idx_delete] + + # remove very cloudy images (>95% cloud) + cloud_cover = [_['properties']['CLOUDY_PIXEL_PERCENTAGE'] for _ in im_all_updated] + if np.any([_ > 95 for _ in cloud_cover]): + idx_delete = np.where([_ > 95 for _ in cloud_cover])[0] + im_all_cloud = [x for k,x in enumerate(im_all_updated) if k not in idx_delete] + else: + im_all_cloud = im_all_updated + + n_img = len(im_all_cloud) # print how many images there are - n_img = flt_col.size().getInfo() print('Number of ' + satname + ' images covering ' + sitename + ':', n_img) # loop trough images timestamps = [] acc_georef = [] + filenames = [] all_names = [] im_epsg = [] for i in range(n_img): # find each image in ee database - im = ee.Image(im_all[i].get('id')) + im = ee.Image(im_all_cloud[i].get('id')) # read metadata im_dic = im.getInfo() # get bands @@ -394,39 +510,290 @@ def get_images(sitename,polygon,dates,sat): filename60 = im_date + '_' + satname + '_' + sitename + '_' + '60m' + suffix # if two images taken at the same date skip the second image (they are the same) if any(filename10 in _ for _ in all_names): - continue + filename10 = filename10[:filename10.find('.')] + '_dup' + suffix + filename20 = filename20[:filename20.find('.')] + '_dup' + suffix + filename60 = filename60[:filename60.find('.')] + '_dup' + suffix all_names.append(filename10) + filenames.append(filename10) + # download .TIF image and update filename local_data = download_tif(im, polygon, bands10, os.path.join(filepath, '10m')) - os.rename(local_data, os.path.join(filepath, '10m', filename10)) + try: + os.rename(local_data, os.path.join(filepath, '10m', filename10)) + except: + os.remove(os.path.join(filepath, '10m', filename10)) + os.rename(local_data, os.path.join(filepath, '10m', filename10)) + local_data = download_tif(im, polygon, bands20, os.path.join(filepath, '20m')) - os.rename(local_data, os.path.join(filepath, '20m', filename20)) + try: + os.rename(local_data, os.path.join(filepath, '20m', filename20)) + except: + os.remove(os.path.join(filepath, '20m', filename20)) + os.rename(local_data, os.path.join(filepath, '20m', filename20)) + local_data = download_tif(im, polygon, bands60, os.path.join(filepath, '60m')) - os.rename(local_data, os.path.join(filepath, '60m', filename60)) + try: + os.rename(local_data, os.path.join(filepath, '60m', filename60)) + except: + os.remove(os.path.join(filepath, '60m', filename60)) + os.rename(local_data, os.path.join(filepath, '60m', filename60)) # save timestamp, epsg code and georeferencing accuracy (1 if passed 0 if not passed) timestamps.append(im_timestamp) im_epsg.append(int(im_dic['bands'][0]['crs'][5:])) + # Sentinel-2 products don't provide a georeferencing accuracy (RMSE as in Landsat) + # but they have a flag indicating if the geometric quality control was passed or failed + # if passed a value of 1 is stored if faile a value of -1 is stored in the metadata try: if im_dic['properties']['GEOMETRIC_QUALITY_FLAG'] == 'PASSED': acc_georef.append(1) else: - acc_georef.append(0) + acc_georef.append(-1) except: - acc_georef.append(0) - print(i, end='..') + acc_georef.append(-1) + print(i+1, end='..') # sort timestamps and georef accuracy (dowloaded images are sorted by date in directory) timestamps_sorted = sorted(timestamps) idx_sorted = sorted(range(len(timestamps)), key=timestamps.__getitem__) acc_georef_sorted = [acc_georef[j] for j in idx_sorted] + filenames_sorted = [filenames[j] for j in idx_sorted] im_epsg_sorted = [im_epsg[j] for j in idx_sorted] metadata[satname] = {'dates':timestamps_sorted, 'acc_georef':acc_georef_sorted, - 'epsg':im_epsg_sorted} - print('Finished with ' + satname) + 'epsg':im_epsg_sorted, 'filenames':filenames_sorted} + print('\nFinished with ' + satname) + + # merge overlapping images (only if polygon is at the edge of an image) + if 'S2' in metadata.keys(): + metadata = merge_overlapping_images(metadata,inputs) # save metadata dict filepath = os.path.join(os.getcwd(), 'data', sitename) with open(os.path.join(filepath, sitename + '_metadata' + '.pkl'), 'wb') as f: - pickle.dump(metadata, f) \ No newline at end of file + pickle.dump(metadata, f) + + return metadata + + +def merge_overlapping_images(metadata,inputs): + """ + When the area of interest is located at the boundary between 2 images, there will be overlap + between the 2 images and both will be downloaded from Google Earth Engine. This function + merges the 2 images, so that the area of interest is covered by only 1 image. + + KV WRL 2018 + + Arguments: + ----------- + metadata: dict + contains all the information about the satellite images that were downloaded + inputs: dict + dictionnary that contains the following fields: + 'sitename': str + String containig the name of the site + 'polygon': list + polygon containing the lon/lat coordinates to be extracted + longitudes in the first column and latitudes in the second column + 'dates': list of str + list that contains 2 strings with the initial and final dates in format 'yyyy-mm-dd' + e.g. ['1987-01-01', '2018-01-01'] + 'sat_list': list of str + list that contains the names of the satellite missions to include + e.g. ['L5', 'L7', 'L8', 'S2'] + + Returns: + ----------- + metadata: dict + updated metadata with the information of the merged images + + """ + + # only for Sentinel-2 at this stage (could be implemented for Landsat as well) + sat = 'S2' + filepath = os.path.join(os.getcwd(), 'data', inputs['sitename']) + + # find the images that are overlapping (same date in S2 filenames) + filenames = metadata[sat]['filenames'] + filenames_copy = filenames.copy() + + # loop through all the filenames and find the pairs of overlapping images (same date and time of acquisition) + pairs = [] + for i,fn in enumerate(filenames): + filenames_copy[i] = [] + # find duplicate + boolvec = [fn[:22] == _[:22] for _ in filenames_copy] + if np.any(boolvec): + idx_dup = np.where(boolvec)[0][0] + if len(filenames[i]) > len(filenames[idx_dup]): + pairs.append([idx_dup,i]) + else: + pairs.append([i,idx_dup]) + + msg = 'Merging %d pairs of overlapping images...' % len(pairs) + print(msg) + + # for each pair of images, merge them into one complete image + for i,pair in enumerate(pairs): + print(i+1, end='..') + + fn_im = [] + for index in range(len(pair)): + # read image + fn_im.append([os.path.join(filepath, 'S2', '10m', filenames[pair[index]]), + os.path.join(filepath, 'S2', '20m', filenames[pair[index]].replace('10m','20m')), + os.path.join(filepath, 'S2', '60m', filenames[pair[index]].replace('10m','60m'))]) + im_ms, georef, cloud_mask, im_extra, imQA = SDS_preprocess.preprocess_single(fn_im[index], sat) + + # in Sentinel2 images close to the edge of the image there are some artefacts, + # that are squares with constant pixel intensities. They need to be masked in the + # raster (GEOTIFF). It can be done using the image standard deviation, which + # indicates values close to 0 for the artefacts. + + # First mask the 10m bands + if len(im_ms) > 0: + im_std = SDS_tools.image_std(im_ms[:,:,0],1) + im_binary = np.logical_or(im_std < 1e-6, np.isnan(im_std)) + mask = morphology.dilation(im_binary, morphology.square(3)) + for k in range(im_ms.shape[2]): + im_ms[mask,k] = np.nan + + SDS_tools.mask_raster(fn_im[index][0], mask) + + # Then mask the 20m band + im_std = SDS_tools.image_std(im_extra,1) + im_binary = np.logical_or(im_std < 1e-6, np.isnan(im_std)) + mask = morphology.dilation(im_binary, morphology.square(3)) + im_extra[mask] = np.nan + + SDS_tools.mask_raster(fn_im[index][1], mask) + else: + continue + + # make a figure for quality control +# plt.figure() +# plt.subplot(221) +# plt.imshow(im_ms[:,:,[2,1,0]]) +# plt.title('imRGB') +# plt.subplot(222) +# plt.imshow(im20, cmap='gray') +# plt.title('im20') +# plt.subplot(223) +# plt.imshow(imQA, cmap='gray') +# plt.title('imQA') +# plt.subplot(224) +# plt.title(fn_im[index][0][-30:]) + + # merge masked 10m bands + fn_merged = os.path.join(os.getcwd(), 'merged.tif') + gdal_merge.main(['', '-o', fn_merged, '-n', '0', fn_im[0][0], fn_im[1][0]]) + os.chmod(fn_im[0][0], 0o777) + os.remove(fn_im[0][0]) + os.chmod(fn_im[1][0], 0o777) + os.remove(fn_im[1][0]) + os.rename(fn_merged, fn_im[0][0]) + + # merge masked 20m band (SWIR band) + fn_merged = os.path.join(os.getcwd(), 'merged.tif') + gdal_merge.main(['', '-o', fn_merged, '-n', '0', fn_im[0][1], fn_im[1][1]]) + os.chmod(fn_im[0][1], 0o777) + os.remove(fn_im[0][1]) + os.chmod(fn_im[1][1], 0o777) + os.remove(fn_im[1][1]) + os.rename(fn_merged, fn_im[0][1]) + + # merge QA band (60m band) + fn_merged = os.path.join(os.getcwd(), 'merged.tif') + gdal_merge.main(['', '-o', fn_merged, '-n', 'nan', fn_im[0][2], fn_im[1][2]]) + os.chmod(fn_im[0][2], 0o777) + os.remove(fn_im[0][2]) + os.chmod(fn_im[1][2], 0o777) + os.remove(fn_im[1][2]) + os.rename(fn_merged, fn_im[0][2]) + + # update the metadata dict (delete all the duplicates) + metadata2 = copy.deepcopy(metadata) + filenames_copy = metadata2[sat]['filenames'] + index_list = [] + for i in range(len(filenames_copy)): + if filenames_copy[i].find('dup') == -1: + index_list.append(i) + for key in metadata2[sat].keys(): + metadata2[sat][key] = [metadata2[sat][key][_] for _ in index_list] + + return metadata2 + +def remove_cloudy_images(metadata,inputs,cloud_thresh): + """ + Deletes the .TIF file of images that have a cloud cover percentage that is above the cloud + threshold. + + KV WRL 2018 + + Arguments: + ----------- + metadata: dict + contains all the information about the satellite images that were downloaded + inputs: dict + dictionnary that contains the following fields: + 'sitename': str + String containig the name of the site + 'polygon': list + polygon containing the lon/lat coordinates to be extracted + longitudes in the first column and latitudes in the second column + 'dates': list of str + list that contains 2 strings with the initial and final dates in format 'yyyy-mm-dd' + e.g. ['1987-01-01', '2018-01-01'] + 'sat_list': list of str + list that contains the names of the satellite missions to include + e.g. ['L5', 'L7', 'L8', 'S2'] + cloud_thresh: float + value between 0 and 1 indicating the maximum cloud fraction in the image that is accepted + + Returns: + ----------- + metadata: dict + updated metadata with the information of the merged images + + """ + + # create a deep copy + metadata2 = copy.deepcopy(metadata) + + for satname in metadata.keys(): + + # get the image filenames + filepath = SDS_tools.get_filepath(inputs,satname) + filenames = metadata[satname]['filenames'] + + # loop through images + idx_good = [] + for i in range(len(filenames)): + # image filename + fn = SDS_tools.get_filenames(filenames[i],filepath, satname) + # preprocess image (cloud mask + pansharpening/downsampling) + im_ms, georef, cloud_mask, im_extra, imQA = SDS_preprocess.preprocess_single(fn, satname) + # calculate cloud cover + cloud_cover = np.divide(sum(sum(cloud_mask.astype(int))), + (cloud_mask.shape[0]*cloud_mask.shape[1])) + # skip image if cloud cover is above threshold + if cloud_cover > cloud_thresh or cloud_cover == 1: + # remove image files + if satname == 'L5': + os.chmod(fn, 0o777) + os.remove(fn) + else: + for j in range(len(fn)): + os.chmod(fn[j], 0o777) + os.remove(fn[j]) + else: + idx_good.append(i) + + msg = '\n%d cloudy images were removed for %s.' % (len(filenames)-len(idx_good), satname) + print(msg) + + # update the metadata dict (delete all cloudy images) + for key in metadata2[satname].keys(): + metadata2[satname][key] = [metadata2[satname][key][_] for _ in idx_good] + + return metadata2 \ No newline at end of file diff --git a/SDS_preprocess.py b/SDS_preprocess.py index c98c501..68b6245 100644 --- a/SDS_preprocess.py +++ b/SDS_preprocess.py @@ -1,28 +1,36 @@ -"""This module contains all the functions needed to preprocess the satellite images: creating a -cloud mask and pansharpening/downsampling the images. +"""This module contains all the functions needed to preprocess the satellite images before the +shoreline can be extracted. This includes creating a cloud mask and +pansharpening/downsampling the multispectral bands. Author: Kilian Vos, Water Research Laboratory, University of New South Wales """ -# Initial settings +# load modules import os import numpy as np import matplotlib.pyplot as plt -from osgeo import gdal, ogr, osr +import pdb + +# image processing modules import skimage.transform as transform import skimage.morphology as morphology import sklearn.decomposition as decomposition import skimage.exposure as exposure + +# other modules +from osgeo import gdal, ogr, osr from pylab import ginput import pickle -import pdb +import matplotlib.path as mpltPath + +# own modules import SDS_tools -# Functions +np.seterr(all='ignore') # raise/ignore divisions by 0 and nans def create_cloud_mask(im_qa, satname): """ - Creates a cloud mask from the image containing the QA band information. + Creates a cloud mask using the information contained in the QA band. KV WRL 2018 @@ -31,15 +39,15 @@ def create_cloud_mask(im_qa, satname): im_qa: np.array Image containing the QA band satname: string - short name for the satellite (L8, L7, S2) + short name for the satellite (L5, L7, L8 or S2) Returns: ----------- - cloud_mask : np.ndarray of booleans - A boolean array with True where the cloud are present + cloud_mask : np.array + A boolean array with True if a pixel is cloudy and False otherwise """ - # convert QA bits depending on the satellite mission + # convert QA bits (the bits allocated to cloud cover vary depending on the satellite mission) if satname == 'L8': cloud_values = [2800, 2804, 2808, 2812, 6896, 6900, 6904, 6908] elif satname == 'L7' or satname == 'L5' or satname == 'L4': @@ -50,7 +58,7 @@ def create_cloud_mask(im_qa, satname): # find which pixels have bits corresponding to cloud values cloud_mask = np.isin(im_qa, cloud_values) - # remove isolated cloud pixels (there are some in the swash zone and they can cause problems) + # remove isolated cloud pixels (there are some in the swash zone and they are not clouds) if sum(sum(cloud_mask)) > 0 and sum(sum(~cloud_mask)) > 0: morphology.remove_small_objects(cloud_mask, min_size=10, connectivity=1, in_place=True) @@ -100,8 +108,10 @@ def hist_match(source, template): def pansharpen(im_ms, im_pan, cloud_mask): """ - Pansharpens a multispectral image (3D), using the panchromatic band (2D) and a cloud mask. + Pansharpens a multispectral image, using the panchromatic band and a cloud mask. A PCA is applied to the image, then the 1st PC is replaced with the panchromatic band. + Note that it is essential to match the histrograms of the 1st PC and the panchromatic band + before replacing and inverting the PCA. KV WRL 2018 @@ -117,14 +127,14 @@ def pansharpen(im_ms, im_pan, cloud_mask): Returns: ----------- im_ms_ps: np.ndarray - Pansharpened multisoectral image (3D) + Pansharpened multispectral image (3D) """ # reshape image into vector and apply cloud mask vec = im_ms.reshape(im_ms.shape[0] * im_ms.shape[1], im_ms.shape[2]) vec_mask = cloud_mask.reshape(im_ms.shape[0] * im_ms.shape[1]) vec = vec[~vec_mask, :] - # apply PCA to RGB bands + # apply PCA to multispectral bands pca = decomposition.PCA() vec_pcs = pca.fit_transform(vec) @@ -146,7 +156,7 @@ def rescale_image_intensity(im, cloud_mask, prob_high): """ Rescales the intensity of an image (multispectral or single band) by applying a cloud mask and clipping the prob_high upper percentile. This functions allows - to stretch the contrast of an image for visualisation purposes. + to stretch the contrast of an image, only for visualisation purposes. KV WRL 2018 @@ -201,7 +211,10 @@ def rescale_image_intensity(im, cloud_mask, prob_high): def preprocess_single(fn, satname): """ - Creates a cloud mask using the QA band and performs pansharpening/down-sampling of the image. + Reads the image and outputs the pansharpened/down-sampled multispectral bands, the + georeferencing vector of the image (coordinates of the upper left pixel), the cloud mask and + the QA band. For Landsat 7-8 it also outputs the panchromatic band and for Sentinel-2 it also + outputs the 20m SWIR band. KV WRL 2018 @@ -209,7 +222,8 @@ def preprocess_single(fn, satname): ----------- fn: str or list of str filename of the .TIF file containing the image - for L7, L8 and S2 there is a filename for the bands at different resolutions + for L7, L8 and S2 this is a list of filenames, one filename for each band at different + resolution (30m and 15m for Landsat 7-8, 10m, 20m, 60m for Sentinel-2) satname: str name of the satellite mission (e.g., 'L5') @@ -222,6 +236,11 @@ def preprocess_single(fn, satname): coordinates of the top-left pixel of the image cloud_mask: np.array 2D cloud mask with True where cloud pixels are + im_extra : np.array + 2D array containing the 20m resolution SWIR band for Sentinel-2 and the 15m resolution + panchromatic band for Landsat 7 and Landsat 8. This field is empty for Landsat 5. + imQA: np.array + 2D array containing the QA band, from which the cloud_mask can be computed. """ @@ -267,6 +286,9 @@ def preprocess_single(fn, satname): # calculate cloud cover cloud_cover = sum(sum(cloud_mask.astype(int)))/(cloud_mask.shape[0]*cloud_mask.shape[1]) + # no extra image for Landsat 5 (they are all 30 m bands) + im_extra = [] + imQA = im_qa #=============================================================================================# # L7 images @@ -324,6 +346,9 @@ def preprocess_single(fn, satname): im_ms_ps = np.append(im_ms_ps, im_ms[:,:,[4]], axis=2) im_ms = im_ms_ps.copy() + # the extra image is the 15m panchromatic band + im_extra = im_pan + imQA = im_qa #=============================================================================================# # L8 images @@ -380,6 +405,9 @@ def preprocess_single(fn, satname): im_ms_ps = np.append(im_ms_ps, im_ms[:,:,[3,4]], axis=2) im_ms = im_ms_ps.copy() + # the extra image is the 15m panchromatic band + im_extra = im_pan + imQA = im_qa #=============================================================================================# # S2 images @@ -400,7 +428,7 @@ def preprocess_single(fn, satname): georef = [] # skip the image by giving it a full cloud_mask cloud_mask = np.ones((im10.shape[0],im10.shape[1])).astype('bool') - return im_ms, georef, cloud_mask + return im_ms, georef, cloud_mask, [], [] # size of 10m bands nrows = im10.shape[0] @@ -427,8 +455,8 @@ def preprocess_single(fn, satname): data = gdal.Open(fn60, gdal.GA_ReadOnly) bands = [data.GetRasterBand(k + 1).ReadAsArray() for k in range(data.RasterCount)] im60 = np.stack(bands, 2) - im_qa = im60[:,:,0] - cloud_mask = create_cloud_mask(im_qa, satname) + imQA = im60[:,:,0] + cloud_mask = create_cloud_mask(imQA, satname) # resize the cloud mask using nearest neighbour interpolation (order 0) cloud_mask = transform.resize(cloud_mask,(nrows, ncols), order=0, preserve_range=True, mode='constant') @@ -440,8 +468,10 @@ def preprocess_single(fn, satname): # calculate cloud cover cloud_cover = sum(sum(cloud_mask.astype(int)))/(cloud_mask.shape[0]*cloud_mask.shape[1]) + # the extra image is the 20m SWIR band + im_extra = im20 - return im_ms, georef, cloud_mask + return im_ms, georef, cloud_mask, im_extra, imQA def create_jpg(im_ms, cloud_mask, date, satname, filepath): @@ -476,21 +506,32 @@ def create_jpg(im_ms, cloud_mask, date, satname, filepath): fig = plt.figure() fig.set_size_inches([18,9]) fig.set_tight_layout(True) - # RGB - plt.subplot(131) - plt.axis('off') - plt.imshow(im_RGB) - plt.title(date + ' ' + satname, fontsize=16) - # NIR - plt.subplot(132) - plt.axis('off') - plt.imshow(im_NIR, cmap='seismic') - plt.title('Near Infrared', fontsize=16) - # SWIR - plt.subplot(133) - plt.axis('off') - plt.imshow(im_SWIR, cmap='seismic') - plt.title('Short-wave Infrared', fontsize=16) + ax1 = fig.add_subplot(111) + ax1.axis('off') + ax1.imshow(im_RGB) + ax1.set_title(date + ' ' + satname, fontsize=16) + +# if im_RGB.shape[1] > 2*im_RGB.shape[0]: +# ax1 = fig.add_subplot(311) +# ax2 = fig.add_subplot(312) +# ax3 = fig.add_subplot(313) +# else: +# ax1 = fig.add_subplot(131) +# ax2 = fig.add_subplot(132) +# ax3 = fig.add_subplot(133) +# # RGB +# ax1.axis('off') +# ax1.imshow(im_RGB) +# ax1.set_title(date + ' ' + satname, fontsize=16) +# # NIR +# ax2.axis('off') +# ax2.imshow(im_NIR, cmap='seismic') +# ax2.set_title('Near Infrared', fontsize=16) +# # SWIR +# ax3.axis('off') +# ax3.imshow(im_SWIR, cmap='seismic') +# ax3.set_title('Short-wave Infrared', fontsize=16) + # save figure plt.rcParams['savefig.jpeg_quality'] = 100 fig.savefig(os.path.join(filepath, @@ -498,28 +539,29 @@ def create_jpg(im_ms, cloud_mask, date, satname, filepath): plt.close() -def preprocess_all_images(metadata, settings): +def save_jpg(metadata, settings): """ - Saves a .jpg image for all the file contained in metadata. + Saves a .jpg image for all the images contained in metadata. KV WRL 2018 Arguments: ----------- - sitename: str - name of the site (and corresponding folder) metadata: dict contains all the information about the satellite images that were downloaded - cloud_thresh: float - maximum fraction of cloud cover allowed in the images - + settings: dict + contains the following fields: + 'cloud_thresh': float + value between 0 and 1 indicating the maximum cloud fraction in the image that is accepted + 'sitename': string + name of the site (also name of the folder where the images are stored) + Returns: ----------- - Generates .jpg files for all the satellite images avaialble """ - sitename = settings['sitename'] + sitename = settings['inputs']['sitename'] cloud_thresh = settings['cloud_thresh'] # create subfolder to store the jpg files @@ -531,65 +573,57 @@ def preprocess_all_images(metadata, settings): # loop through satellite list for satname in metadata.keys(): - # access the images - if satname == 'L5': - # access downloaded Landsat 5 images - filepath = os.path.join(os.getcwd(), 'data', sitename, satname, '30m') - filenames = os.listdir(filepath) - elif satname == 'L7': - # access downloaded Landsat 7 images - filepath_pan = os.path.join(os.getcwd(), 'data', sitename, 'L7', 'pan') - filepath_ms = os.path.join(os.getcwd(), 'data', sitename, 'L7', 'ms') - filenames_pan = os.listdir(filepath_pan) - filenames_ms = os.listdir(filepath_ms) - if (not len(filenames_pan) == len(filenames_ms)): - raise 'error: not the same amount of files for pan and ms' - filepath = [filepath_pan, filepath_ms] - filenames = filenames_pan - elif satname == 'L8': - # access downloaded Landsat 7 images - filepath_pan = os.path.join(os.getcwd(), 'data', sitename, 'L8', 'pan') - filepath_ms = os.path.join(os.getcwd(), 'data', sitename, 'L8', 'ms') - filenames_pan = os.listdir(filepath_pan) - filenames_ms = os.listdir(filepath_ms) - if (not len(filenames_pan) == len(filenames_ms)): - raise 'error: not the same amount of files for pan and ms' - filepath = [filepath_pan, filepath_ms] - filenames = filenames_pan - elif satname == 'S2': - # access downloaded Sentinel 2 images - filepath10 = os.path.join(os.getcwd(), 'data', sitename, satname, '10m') - filenames10 = os.listdir(filepath10) - filepath20 = os.path.join(os.getcwd(), 'data', sitename, satname, '20m') - filenames20 = os.listdir(filepath20) - filepath60 = os.path.join(os.getcwd(), 'data', sitename, satname, '60m') - filenames60 = os.listdir(filepath60) - if (not len(filenames10) == len(filenames20)) or (not len(filenames20) == len(filenames60)): - raise 'error: not the same amount of files for 10, 20 and 60 m' - filepath = [filepath10, filepath20, filepath60] - filenames = filenames10 + + filepath = SDS_tools.get_filepath(settings['inputs'],satname) + filenames = metadata[satname]['filenames'] # loop through images for i in range(len(filenames)): # image filename fn = SDS_tools.get_filenames(filenames[i],filepath, satname) - # preprocess image (cloud mask + pansharpening/downsampling) - im_ms, georef, cloud_mask = preprocess_single(fn, satname) + # read and preprocess image + im_ms, georef, cloud_mask, im_extra, imQA = preprocess_single(fn, satname) # calculate cloud cover cloud_cover = np.divide(sum(sum(cloud_mask.astype(int))), (cloud_mask.shape[0]*cloud_mask.shape[1])) # skip image if cloud cover is above threshold - if cloud_cover > cloud_thresh: + if cloud_cover > cloud_thresh or cloud_cover == 1: continue # save .jpg with date and satellite in the title date = filenames[i][:10] create_jpg(im_ms, cloud_mask, date, satname, filepath_jpg) -def get_reference_sl(metadata, settings): +def get_reference_sl_manual(metadata, settings): + """ + Allows the user to manually digitize a reference shoreline that is used seed the shoreline + detection algorithm. The reference shoreline helps to detect the outliers, making the shoreline + detection more robust. + + KV WRL 2018 + + Arguments: + ----------- + metadata: dict + contains all the information about the satellite images that were downloaded + settings: dict + contains the following fields: + 'cloud_thresh': float + value between 0 and 1 indicating the maximum cloud fraction in the image that is accepted + 'sitename': string + name of the site (also name of the folder where the images are stored) + 'output_epsg': int + epsg code of the desired spatial reference system + + Returns: + ----------- + ref_sl: np.array + coordinates of the reference shoreline that was manually digitized + + """ - sitename = settings['sitename'] + sitename = settings['inputs']['sitename'] - # check if reference shoreline already exists + # check if reference shoreline already exists in the corresponding folder filepath = os.path.join(os.getcwd(), 'data', sitename) filename = sitename + '_ref_sl.pkl' if filename in os.listdir(filepath): @@ -599,23 +633,31 @@ def get_reference_sl(metadata, settings): return refsl else: - satname = 'S2' - # access downloaded Sentinel 2 images - filepath10 = os.path.join(os.getcwd(), 'data', sitename, satname, '10m') - filenames10 = os.listdir(filepath10) - filepath20 = os.path.join(os.getcwd(), 'data', sitename, satname, '20m') - filenames20 = os.listdir(filepath20) - filepath60 = os.path.join(os.getcwd(), 'data', sitename, satname, '60m') - filenames60 = os.listdir(filepath60) - if (not len(filenames10) == len(filenames20)) or (not len(filenames20) == len(filenames60)): - raise 'error: not the same amount of files for 10, 20 and 60 m' - for i in range(len(filenames10)): - # image filename - fn = [os.path.join(filepath10, filenames10[i]), - os.path.join(filepath20, filenames20[i]), - os.path.join(filepath60, filenames60[i])] - # preprocess image (cloud mask + pansharpening/downsampling) - im_ms, georef, cloud_mask = preprocess_single(fn, satname) + # first try to use S2 images (10m res for manually digitizing the reference shoreline) + if 'S2' in metadata.keys(): + satname = 'S2' + filepath = SDS_tools.get_filepath(settings['inputs'],satname) + filenames = metadata[satname]['filenames'] + # if no S2 images, try L8 (15m res in the RGB with pansharpening) + elif not 'S2' in metadata.keys() and 'L8' in metadata.keys(): + satname = 'L8' + filepath = SDS_tools.get_filepath(settings['inputs'],satname) + filenames = metadata[satname]['filenames'] + # if no S2 images and no L8, use L5 images (L7 images have black diagonal bands making it + # hard to manually digitize a shoreline) + elif not 'S2' in metadata.keys() and not 'L8' in metadata.keys() and 'L5' in metadata.keys(): + satname = 'L5' + filepath = SDS_tools.get_filepath(settings['inputs'],satname) + filenames = metadata[satname]['filenames'] + else: + print('You cannot digitize the shoreline on L7 images, add another L8, S2 or L5 to your dataset.') + + # loop trhough the images + for i in range(len(filenames)): + + # read image + fn = SDS_tools.get_filenames(filenames[i],filepath, satname) + im_ms, georef, cloud_mask, im_extra, imQA = preprocess_single(fn, satname) # calculate cloud cover cloud_cover = np.divide(sum(sum(cloud_mask.astype(int))), (cloud_mask.shape[0]*cloud_mask.shape[1])) @@ -624,44 +666,110 @@ def get_reference_sl(metadata, settings): continue # rescale image intensity for display purposes im_RGB = rescale_image_intensity(im_ms[:,:,[2,1,0]], cloud_mask, 99.9) - # make figure + # plot the image RGB on a figure fig = plt.figure() fig.set_size_inches([18,9]) fig.set_tight_layout(True) - # RGB plt.axis('off') plt.imshow(im_RGB) - plt.title('click if image is not clear enough to digitize the shoreline.\n' + - 'Otherwise click on and start digitizing the shoreline.\n' + - 'When finished digitizing the shoreline click on the scroll wheel ' + - '(middle click).', fontsize=14) - plt.text(0, 0.9, 'keep', size=16, ha="left", va="top", + # decide if the image if good enough for digitizing the shoreline + plt.title('click if image is clear enough to digitize the shoreline.\n' + + 'If not (too cloudy) click on to get another image', fontsize=14) + keep_button = plt.text(0, 0.9, 'keep', size=16, ha="left", va="top", transform=plt.gca().transAxes, bbox=dict(boxstyle="square", ec='k',fc='w')) - plt.text(1, 0.9, 'skip', size=16, ha="right", va="top", + skip_button = plt.text(1, 0.9, 'skip', size=16, ha="right", va="top", transform=plt.gca().transAxes, bbox=dict(boxstyle="square", ec='k',fc='w')) mng = plt.get_current_fig_manager() mng.window.showMaximized() # let user click on the image once - pt_keep = ginput(n=1, timeout=100, show_clicks=True) - pt_keep = np.array(pt_keep) + pt_input = ginput(n=1, timeout=1000000, show_clicks=True) + pt_input = np.array(pt_input) # if clicks next to , show another image - if pt_keep[0][0] > im_ms.shape[1]/2: + if pt_input[0][0] > im_ms.shape[1]/2: plt.close() continue else: + # remove keep and skip buttons + keep_button.set_visible(False) + skip_button.set_visible(False) + # update title (instructions) + plt.title('Digitize the shoreline on this image by clicking on it.\n' + + 'When finished digitizing the shoreline click on the scroll wheel ' + + '(middle click).', fontsize=14) + plt.draw() # let user click on the shoreline - pts = ginput(n=5000, timeout=100000, show_clicks=True) + pts = ginput(n=50000, timeout=100000, show_clicks=True) pts_pix = np.array(pts) plt.close() # convert image coordinates to world coordinates pts_world = SDS_tools.convert_pix2world(pts_pix[:,[1,0]], georef) image_epsg = metadata[satname]['epsg'][i] pts_coords = SDS_tools.convert_epsg(pts_world, image_epsg, settings['output_epsg']) + + # save the reference shoreline + filepath = os.path.join(os.getcwd(), 'data', sitename) with open(os.path.join(filepath, sitename + '_ref_sl.pkl'), 'wb') as f: pickle.dump(pts_coords, f) print('Reference shoreline has been saved') break - return pts_coords \ No newline at end of file + return pts_coords + +def get_reference_sl_Australia(settings): + """ + Automatically finds a reference shoreline from a high resolution coastline of Australia + (Smartline from Geoscience Australia). It finds the points of the national coastline vector + that are situated inside the area of interest (polygon). + + KV WRL 2018 + + Arguments: + ----------- + settings: dict + contains the following fields: + 'cloud_thresh': float + value between 0 and 1 indicating the maximum cloud fraction in the image that is accepted + 'sitename': string + name of the site (also name of the folder where the images are stored) + 'output_epsg': int + epsg code of the desired spatial reference system + + Returns: + ----------- + ref_sl: np.array + coordinates of the reference shoreline found in the shapefile + + """ + + # load high-resolution shoreline of Australia + filename = os.path.join(os.getcwd(), 'data', 'shoreline_Australia.pkl') + with open(filename, 'rb') as f: + sl = pickle.load(f) + # spatial reference system of this shoreline + sl_epsg = 4283 # GDA94 geographic + + # only select the points that sit inside the area of interest (polygon) + polygon = settings['inputs']['polygon'] + # spatial reference system of the polygon (latitudes and longitudes) + polygon_epsg = 4326 # WGS84 geographic + polygon = SDS_tools.convert_epsg(np.array(polygon[0]), polygon_epsg, sl_epsg)[:,:-1] + + # use matplotlib function Path + path = mpltPath.Path(polygon) + sl_inside = sl[np.where(path.contains_points(sl))] + + # convert to desired output coordinate system + ref_sl = SDS_tools.convert_epsg(sl_inside, sl_epsg, settings['output_epsg'])[:,:-1] + + # make a figure for quality control + plt.figure() + plt.axis('equal') + plt.xlabel('Eastings [m]') + plt.ylabel('Northings [m]') + plt.plot(ref_sl[:,0], ref_sl[:,1], 'r.') + polygon = SDS_tools.convert_epsg(polygon, sl_epsg, settings['output_epsg'])[:,:-1] + plt.plot(polygon[:,0], polygon[:,1], 'k-') + + return ref_sl \ No newline at end of file diff --git a/SDS_shoreline.py b/SDS_shoreline.py index b583aab..884836f 100644 --- a/SDS_shoreline.py +++ b/SDS_shoreline.py @@ -3,23 +3,12 @@ Author: Kilian Vos, Water Research Laboratory, University of New South Wales """ -# Initial settings +# load modules import os import numpy as np import matplotlib.pyplot as plt import pdb -# other modules -from osgeo import gdal, ogr, osr -import scipy.interpolate as interpolate -from datetime import datetime, timedelta -import matplotlib.patches as mpatches -import matplotlib.lines as mlines -import matplotlib.cm as cm -from matplotlib import gridspec -from pylab import ginput -import pickle - # image processing modules import skimage.filters as filters import skimage.exposure as exposure @@ -32,7 +21,20 @@ import skimage.morphology as morphology from sklearn.externals import joblib from shapely.geometry import LineString +# other modules +from osgeo import gdal, ogr, osr +import scipy.interpolate as interpolate +from datetime import datetime, timedelta +import matplotlib.patches as mpatches +import matplotlib.lines as mlines +import matplotlib.cm as cm +from matplotlib import gridspec +from pylab import ginput +import pickle + +# own modules import SDS_tools, SDS_preprocess + np.seterr(all='ignore') # raise/ignore divisions by 0 and nans @@ -70,139 +72,154 @@ def nd_index(im1, im2, cloud_mask): return im_nd -def classify_image_NN(im_ms_ps, im_pan, cloud_mask, min_beach_size): +def calculate_features(im_ms, cloud_mask, im_bool): """ - Classifies every pixel in the image in one of 4 classes: - - sand --> label = 1 - - whitewater (breaking waves and swash) --> label = 2 - - water --> label = 3 - - other (vegetation, buildings, rocks...) --> label = 0 - - The classifier is a Neural Network, trained with 7000 pixels for the class SAND and 1500 - pixels for each of the other classes. This is because the class of interest for my application - is SAND and I wanted to minimize the classification error for that class. + Calculates a range of features on the image that are used for the supervised classification. + The features include spectral normalized-difference indices and standard deviation of the image. KV WRL 2018 Arguments: ----------- - im_ms_ps: np.array - Pansharpened RGB + downsampled NIR and SWIR - im_pan: - Panchromatic band + im_ms: np.array + RGB + downsampled NIR and SWIR cloud_mask: np.array 2D cloud mask with True where cloud pixels are - plot_bool: boolean - True if plot is wanted - + im_bool: np.array + 2D array of boolean indicating where on the image to calculate the features + Returns: ----------- - im_classif: np.array - 2D image containing labels - im_labels: np.array of booleans - 3D image containing a boolean image for each class (im_classif == label) - - """ + features: np.array + matrix containing each feature (columns) calculated for all + the pixels (rows) indicated in im_bool + """ - # load classifier - clf = joblib.load('.\\classifiers\\NN_4classes_withpan.pkl') - - # calculate features - n_features = 10 - im_features = np.zeros((im_ms_ps.shape[0], im_ms_ps.shape[1], n_features)) - im_features[:,:,[0,1,2,3,4]] = im_ms_ps - im_features[:,:,5] = im_pan - im_features[:,:,6] = nd_index(im_ms_ps[:,:,3], im_ms_ps[:,:,1], cloud_mask, False) # (NIR-G) - im_features[:,:,7] = nd_index(im_ms_ps[:,:,3], im_ms_ps[:,:,2], cloud_mask, False) # ND(NIR-R) - im_features[:,:,8] = nd_index(im_ms_ps[:,:,0], im_ms_ps[:,:,2], cloud_mask, False) # ND(B-R) - im_features[:,:,9] = nd_index(im_ms_ps[:,:,4], im_ms_ps[:,:,1], cloud_mask, False) # ND(SWIR-G) - # remove NaNs and clouds - vec_features = im_features.reshape((im_ms_ps.shape[0] * im_ms_ps.shape[1], n_features)) - vec_cloud = cloud_mask.reshape(cloud_mask.shape[0]*cloud_mask.shape[1]) - vec_nan = np.any(np.isnan(vec_features), axis=1) - vec_mask = np.logical_or(vec_cloud, vec_nan) - vec_features = vec_features[~vec_mask, :] - # predict with NN classifier - labels = clf.predict(vec_features) - # recompose image - vec_classif = np.zeros((cloud_mask.shape[0]*cloud_mask.shape[1])) - vec_classif[~vec_mask] = labels - im_classif = vec_classif.reshape((im_ms_ps.shape[0], im_ms_ps.shape[1])) - - # labels - im_sand = im_classif == 1 - # remove small patches of sand - im_sand = morphology.remove_small_objects(im_sand, min_size=min_beach_size, connectivity=2) - im_swash = im_classif == 2 - im_water = im_classif == 3 - im_labels = np.stack((im_sand,im_swash,im_water), axis=-1) - - return im_classif, im_labels - - -def classify_image_NN_nopan(im_ms_ps, cloud_mask, min_beach_size): + # add all the multispectral bands + features = np.expand_dims(im_ms[im_bool,0],axis=1) + for k in range(1,im_ms.shape[2]): + feature = np.expand_dims(im_ms[im_bool,k],axis=1) + features = np.append(features, feature, axis=-1) + # NIR-G + im_NIRG = nd_index(im_ms[:,:,3], im_ms[:,:,1], cloud_mask) + features = np.append(features, np.expand_dims(im_NIRG[im_bool],axis=1), axis=-1) + # SWIR-G + im_SWIRG = nd_index(im_ms[:,:,4], im_ms[:,:,1], cloud_mask) + features = np.append(features, np.expand_dims(im_SWIRG[im_bool],axis=1), axis=-1) + # NIR-R + im_NIRR = nd_index(im_ms[:,:,3], im_ms[:,:,2], cloud_mask) + features = np.append(features, np.expand_dims(im_NIRR[im_bool],axis=1), axis=-1) + # SWIR-NIR + im_SWIRNIR = nd_index(im_ms[:,:,4], im_ms[:,:,3], cloud_mask) + features = np.append(features, np.expand_dims(im_SWIRNIR[im_bool],axis=1), axis=-1) + # B-R + im_BR = nd_index(im_ms[:,:,0], im_ms[:,:,2], cloud_mask) + features = np.append(features, np.expand_dims(im_BR[im_bool],axis=1), axis=-1) + # calculate standard deviation of individual bands + for k in range(im_ms.shape[2]): + im_std = SDS_tools.image_std(im_ms[:,:,k], 1) + features = np.append(features, np.expand_dims(im_std[im_bool],axis=1), axis=-1) + # calculate standard deviation of the spectral indices + im_std = SDS_tools.image_std(im_NIRG, 1) + features = np.append(features, np.expand_dims(im_std[im_bool],axis=1), axis=-1) + im_std = SDS_tools.image_std(im_SWIRG, 1) + features = np.append(features, np.expand_dims(im_std[im_bool],axis=1), axis=-1) + im_std = SDS_tools.image_std(im_NIRR, 1) + features = np.append(features, np.expand_dims(im_std[im_bool],axis=1), axis=-1) + im_std = SDS_tools.image_std(im_SWIRNIR, 1) + features = np.append(features, np.expand_dims(im_std[im_bool],axis=1), axis=-1) + im_std = SDS_tools.image_std(im_BR, 1) + features = np.append(features, np.expand_dims(im_std[im_bool],axis=1), axis=-1) + + return features + +def classify_image_NN(im_ms, im_extra, cloud_mask, min_beach_size, satname): """ - To be used for multispectral images that do not have a panchromatic band (L5 and S2). Classifies every pixel in the image in one of 4 classes: - sand --> label = 1 - whitewater (breaking waves and swash) --> label = 2 - water --> label = 3 - other (vegetation, buildings, rocks...) --> label = 0 - The classifier is a Neural Network, trained with 7000 pixels for the class SAND and 1500 - pixels for each of the other classes. This is because the class of interest for my application - is SAND and I wanted to minimize the classification error for that class. + The classifier is a Neural Network, trained on several sites in New South Wales, Australia. KV WRL 2018 Arguments: ----------- - im_ms_ps: np.array + im_ms: np.array Pansharpened RGB + downsampled NIR and SWIR - im_pan: - Panchromatic band + im_extra: + only used for Landsat 7 and 8 where im_extra is the panchromatic band cloud_mask: np.array 2D cloud mask with True where cloud pixels are + min_beach_size: int + minimum number of pixels that have to be connected in the SAND class Returns: ----------- - im_classif: np.ndarray + im_classif: np.array 2D image containing labels - im_labels: np.ndarray of booleans + im_labels: np.array of booleans 3D image containing a boolean image for each class (im_classif == label) """ - # load classifier - clf = joblib.load('.\\classifiers\\NN_4classes_nopan.pkl') - - # calculate features - n_features = 9 - im_features = np.zeros((im_ms_ps.shape[0], im_ms_ps.shape[1], n_features)) - im_features[:,:,[0,1,2,3,4]] = im_ms_ps - im_features[:,:,5] = nd_index(im_ms_ps[:,:,3], im_ms_ps[:,:,1], cloud_mask) # (NIR-G) - im_features[:,:,6] = nd_index(im_ms_ps[:,:,3], im_ms_ps[:,:,2], cloud_mask) # ND(NIR-R) - im_features[:,:,7] = nd_index(im_ms_ps[:,:,0], im_ms_ps[:,:,2], cloud_mask) # ND(B-R) - im_features[:,:,8] = nd_index(im_ms_ps[:,:,4], im_ms_ps[:,:,1], cloud_mask) # ND(SWIR-G) - # remove NaNs and clouds - vec_features = im_features.reshape((im_ms_ps.shape[0] * im_ms_ps.shape[1], n_features)) + if satname == 'L5': + # load classifier (without panchromatic band) + clf = joblib.load(os.path.join(os.getcwd(), 'classifiers', 'NN_4classes_nopan.pkl')) + # calculate features + n_features = 9 + im_features = np.zeros((im_ms.shape[0], im_ms.shape[1], n_features)) + im_features[:,:,[0,1,2,3,4]] = im_ms + im_features[:,:,5] = nd_index(im_ms[:,:,3], im_ms[:,:,1], cloud_mask) # (NIR-G) + im_features[:,:,6] = nd_index(im_ms[:,:,3], im_ms[:,:,2], cloud_mask) # ND(NIR-R) + im_features[:,:,7] = nd_index(im_ms[:,:,0], im_ms[:,:,2], cloud_mask) # ND(B-R) + im_features[:,:,8] = nd_index(im_ms[:,:,4], im_ms[:,:,1], cloud_mask) # ND(SWIR-G) + vec_features = im_features.reshape((im_ms.shape[0] * im_ms.shape[1], n_features)) + + elif satname in ['L7','L8']: + # load classifier (with panchromatic band) + clf = joblib.load(os.path.join(os.getcwd(), 'classifiers', 'NN_4classes_withpan.pkl')) + # calculate features + n_features = 10 + im_features = np.zeros((im_ms.shape[0], im_ms.shape[1], n_features)) + im_features[:,:,[0,1,2,3,4]] = im_ms + im_features[:,:,5] = im_extra + im_features[:,:,6] = nd_index(im_ms[:,:,3], im_ms[:,:,1], cloud_mask) # (NIR-G) + im_features[:,:,7] = nd_index(im_ms[:,:,3], im_ms[:,:,2], cloud_mask) # ND(NIR-R) + im_features[:,:,8] = nd_index(im_ms[:,:,0], im_ms[:,:,2], cloud_mask) # ND(B-R) + im_features[:,:,9] = nd_index(im_ms[:,:,4], im_ms[:,:,1], cloud_mask) # ND(SWIR-G) + vec_features = im_features.reshape((im_ms.shape[0] * im_ms.shape[1], n_features)) + + elif satname == 'S2': + # load classifier (special classifier for Sentinel-2 images) + clf = joblib.load(os.path.join(os.getcwd(), 'classifiers', 'NN_4classes_S2.pkl')) + # calculate features + vec_features = calculate_features(im_ms, cloud_mask, np.ones(cloud_mask.shape).astype(bool)) + vec_features[np.isnan(vec_features)] = 1e-9 # NaN values are create when std is too close to 0 + + # remove NaNs and cloudy pixels vec_cloud = cloud_mask.reshape(cloud_mask.shape[0]*cloud_mask.shape[1]) vec_nan = np.any(np.isnan(vec_features), axis=1) vec_mask = np.logical_or(vec_cloud, vec_nan) vec_features = vec_features[~vec_mask, :] - # predict with NN classifier + + # classify pixels labels = clf.predict(vec_features) # recompose image - vec_classif = np.zeros((cloud_mask.shape[0]*cloud_mask.shape[1])) + vec_classif = np.nan*np.ones((cloud_mask.shape[0]*cloud_mask.shape[1])) vec_classif[~vec_mask] = labels - im_classif = vec_classif.reshape((im_ms_ps.shape[0], im_ms_ps.shape[1])) + im_classif = vec_classif.reshape((cloud_mask.shape[0], cloud_mask.shape[1])) - # labels + # create a stack of boolean images for each label im_sand = im_classif == 1 - # remove small patches of sand - im_sand = morphology.remove_small_objects(im_sand, min_size=min_beach_size, connectivity=2) im_swash = im_classif == 2 im_water = im_classif == 3 - im_labels = np.stack((im_sand,im_swash,im_water), axis=-1) + # remove small patches of sand or water that could be around the image (usually noise) + im_sand = morphology.remove_small_objects(im_sand, min_size=min_beach_size, connectivity=2) + im_water = morphology.remove_small_objects(im_water, min_size=min_beach_size, connectivity=2) + + im_labels = np.stack((im_sand,im_swash,im_water), axis=-1) return im_classif, im_labels @@ -237,7 +254,7 @@ def find_wl_contours1(im_ndwi, cloud_mask): # use Marching Squares algorithm to detect contours on ndwi image contours = measure.find_contours(im_ndwi, t_otsu) - # remove contours that have nans (due to cloud pixels in the contour) + # remove contours that contain NaNs (due to cloud pixels in the contour) contours_nonans = [] for k in range(len(contours)): if np.any(np.isnan(contours[k])): @@ -251,31 +268,32 @@ def find_wl_contours1(im_ndwi, cloud_mask): return contours -def find_wl_contours2(im_ms_ps, im_labels, cloud_mask, buffer_size): +def find_wl_contours2(im_ms, im_labels, cloud_mask, buffer_size): """ New robust method for extracting shorelines. Incorporates the classification component to - refube the treshold and make it specific to the sand/water interface. + refine the treshold and make it specific to the sand/water interface. KV WRL 2018 Arguments: ----------- - im_ms_ps: np.array - Pansharpened RGB + downsampled NIR and SWIR + im_ms: np.array + RGB + downsampled NIR and SWIR im_labels: np.array 3D image containing a boolean image for each class in the order (sand, swash, water) cloud_mask: np.array 2D cloud mask with True where cloud pixels are buffer_size: int - size of the buffer around the sandy beach + size of the buffer around the sandy beach over which the pixels are considered in the + thresholding algorithm. Returns: ----------- contours_wi: list of np.arrays - contains the (row,column) coordinates of the contour lines extracted with the - NDWI (Normalized Difference Water Index) + contains the (row,column) coordinates of the contour lines extracted from the + NDWI (Normalized Difference Water Index) image contours_mwi: list of np.arrays - contains the (row,column) coordinates of the contour lines extracted with the - MNDWI (Modified Normalized Difference Water Index) + contains the (row,column) coordinates of the contour lines extracted from the + MNDWI (Modified Normalized Difference Water Index) image """ @@ -283,9 +301,9 @@ def find_wl_contours2(im_ms_ps, im_labels, cloud_mask, buffer_size): ncols = cloud_mask.shape[1] # calculate Normalized Difference Modified Water Index (SWIR - G) - im_mwi = nd_index(im_ms_ps[:,:,4], im_ms_ps[:,:,1], cloud_mask) + im_mwi = nd_index(im_ms[:,:,4], im_ms[:,:,1], cloud_mask) # calculate Normalized Difference Modified Water Index (NIR - G) - im_wi = nd_index(im_ms_ps[:,:,3], im_ms_ps[:,:,1], cloud_mask) + im_wi = nd_index(im_ms[:,:,3], im_ms[:,:,1], cloud_mask) # stack indices together im_ind = np.stack((im_wi, im_mwi), axis=-1) vec_ind = im_ind.reshape(nrows*ncols,2) @@ -306,16 +324,14 @@ def find_wl_contours2(im_ms_ps, im_labels, cloud_mask, buffer_size): # make sure both classes have the same number of pixels before thresholding if len(int_water) > 0 and len(int_sand) > 0: if np.argmin([int_sand.shape[0],int_water.shape[0]]) == 1: - if (int_sand.shape[0] - int_water.shape[0])/int_water.shape[0] > 0.5: - int_sand = int_sand[np.random.randint(0,int_sand.shape[0],int_water.shape[0]),:] + int_sand = int_sand[np.random.choice(int_sand.shape[0],int_water.shape[0], replace=False),:] else: - if (int_water.shape[0] - int_sand.shape[0])/int_sand.shape[0] > 0.5: - int_water = int_water[np.random.randint(0,int_water.shape[0],int_sand.shape[0]),:] + int_water = int_water[np.random.choice(int_water.shape[0],int_sand.shape[0], replace=False),:] # threshold the sand/water intensities int_all = np.append(int_water,int_sand, axis=0) t_mwi = filters.threshold_otsu(int_all[:,0]) - t_wi = filters.threshold_otsu(int_all[:,1]) + t_wi = filters.threshold_otsu(int_all[:,1]) # find contour with MS algorithm im_wi_buffer = np.copy(im_wi) @@ -325,7 +341,7 @@ def find_wl_contours2(im_ms_ps, im_labels, cloud_mask, buffer_size): contours_wi = measure.find_contours(im_wi_buffer, t_wi) contours_mwi = measure.find_contours(im_mwi, t_mwi) - # remove contour points that are nans (around clouds) + # remove contour points that are NaNs (around clouds) contours = contours_wi contours_nonans = [] for k in range(len(contours)): @@ -337,7 +353,7 @@ def find_wl_contours2(im_ms_ps, im_labels, cloud_mask, buffer_size): else: contours_nonans.append(contours[k]) contours_wi = contours_nonans - + # repeat for MNDWI contours contours = contours_mwi contours_nonans = [] for k in range(len(contours)): @@ -353,6 +369,33 @@ def find_wl_contours2(im_ms_ps, im_labels, cloud_mask, buffer_size): return contours_wi, contours_mwi def process_shoreline(contours, georef, image_epsg, settings): + """ + Converts the contours from image coordinates to world coordinates. This function also removes + the contours that are too small to be a shoreline (based on the parameter + settings['min_length_sl']) + + KV WRL 2018 + + Arguments: + ----------- + contours: np.array or list of np.array + image contours as detected by the function find_contours + georef: np.array + vector of 6 elements [Xtr, Xscale, Xshear, Ytr, Yshear, Yscale] + image_epsg: int + spatial reference system of the image from which the contours were extracted + settings: dict + contains important parameters for processing the shoreline: + output_epsg: output spatial reference system + min_length_sl: minimum length of shoreline perimeter to be kept (in meters) + reference_sl: [optional] reference shoreline coordinates + max_dist_ref: max distance (in meters) allowed from a reference shoreline + + Returns: ----------- + shoreline: np.array + array of points with the X and Y coordinates of the shoreline + + """ # convert pixel coordinates to world coordinates contours_world = SDS_tools.convert_pix2world(contours, georef) @@ -390,13 +433,47 @@ def process_shoreline(contours, georef, image_epsg, settings): def show_detection(im_ms, cloud_mask, im_labels, shoreline,image_epsg, georef, settings, date, satname): + """ + Shows the detected shoreline to the user for visual quality control. The user can select "keep" + if the shoreline detection is correct or "skip" if it is incorrect. - # subfolder to store the .jpg files - filepath = os.path.join(os.getcwd(), 'data', settings['sitename'], 'jpg_files', 'detection') + KV WRL 2018 + + Arguments: + ----------- + im_ms: np.array + RGB + downsampled NIR and SWIR + cloud_mask: np.array + 2D cloud mask with True where cloud pixels are + im_labels: np.array + 3D image containing a boolean image for each class in the order (sand, swash, water) + shoreline: np.array + array of points with the X and Y coordinates of the shoreline + image_epsg: int + spatial reference system of the image from which the contours were extracted + georef: np.array + vector of 6 elements [Xtr, Xscale, Xshear, Ytr, Yshear, Yscale] + settings: dict + contains important parameters for processing the shoreline + date: string + date at which the image was taken + satname: string + indicates the satname (L5,L7,L8 or S2) + + Returns: ----------- + skip_image: boolean + True if the user wants to skip the image, False otherwise. + + """ + + sitename = settings['inputs']['sitename'] + + # subfolder where the .jpg file is stored if the user accepts the shoreline detection + filepath = os.path.join(os.getcwd(), 'data', sitename, 'jpg_files', 'detection') - # display RGB image im_RGB = SDS_preprocess.rescale_image_intensity(im_ms[:,:,[2,1,0]], cloud_mask, 99.9) - # display classified image + + # compute classified image im_class = np.copy(im_RGB) cmap = cm.get_cmap('tab20c') colorpalette = cmap(np.arange(0,13,1)) @@ -408,60 +485,86 @@ def show_detection(im_ms, cloud_mask, im_labels, shoreline,image_epsg, georef, im_class[im_labels[:,:,k],0] = colours[k,0] im_class[im_labels[:,:,k],1] = colours[k,1] im_class[im_labels[:,:,k],2] = colours[k,2] - # display MNDWI grayscale image + + # compute MNDWI grayscale image im_mwi = nd_index(im_ms[:,:,4], im_ms[:,:,1], cloud_mask) + # transform world coordinates of shoreline into pixel coordinates - sl_pix = SDS_tools.convert_world2pix(SDS_tools.convert_epsg(shoreline, settings['output_epsg'], - image_epsg)[:,[0,1]], georef) - # make figure + # use try/except in case there are no coordinates to be transformed (shoreline = []) + try: + sl_pix = SDS_tools.convert_world2pix(SDS_tools.convert_epsg(shoreline, + settings['output_epsg'], + image_epsg)[:,[0,1]], georef) + except: + # if try fails, just add nan into the shoreline vector so the next parts can still run + sl_pix = np.array([[np.nan, np.nan],[np.nan, np.nan]]) + + # according to the image shape, decide whether it is better to have the images in the subplot + # in different rows or different columns fig = plt.figure() - gs = gridspec.GridSpec(1, 3) - gs.update(bottom=0.05, top=0.95) - ax1 = fig.add_subplot(gs[0,0]) - plt.imshow(im_RGB) - plt.plot(sl_pix[:,0], sl_pix[:,1], 'k--') - plt.axis('off') - ax1.set_anchor('W') + if im_RGB.shape[1] > 2*im_RGB.shape[0]: + # vertical subplots + gs = gridspec.GridSpec(3, 1) + gs.update(bottom=0.03, top=0.97, left=0.03, right=0.97) + ax1 = fig.add_subplot(gs[0,0]) + ax2 = fig.add_subplot(gs[1,0]) + ax3 = fig.add_subplot(gs[2,0]) + else: + # horizontal subplots + gs = gridspec.GridSpec(1, 3) + gs.update(bottom=0.05, top=0.95, left=0.05, right=0.95) + ax1 = fig.add_subplot(gs[0,0]) + ax2 = fig.add_subplot(gs[0,1]) + ax3 = fig.add_subplot(gs[0,2]) + + # create image 1 (RGB) + ax1.imshow(im_RGB) + ax1.plot(sl_pix[:,0], sl_pix[:,1], 'k.', markersize=3) + ax1.axis('off') btn_keep = plt.text(0, 0.9, 'keep', size=16, ha="left", va="top", transform=ax1.transAxes, bbox=dict(boxstyle="square", ec='k',fc='w')) btn_skip = plt.text(1, 0.9, 'skip', size=16, ha="right", va="top", transform=ax1.transAxes, bbox=dict(boxstyle="square", ec='k',fc='w')) - plt.title('Click on if shoreline detection is correct. Click on if false detection') - ax2 = fig.add_subplot(gs[0,1]) - plt.imshow(im_class) - plt.plot(sl_pix[:,0], sl_pix[:,1], 'k--') - plt.axis('off') - ax2.set_anchor('W') + ax1.set_title(sitename + ' ' + date + ' ' + satname, fontweight='bold', fontsize=16) + + # create image 2 (classification) + ax2.imshow(im_class) + ax2.plot(sl_pix[:,0], sl_pix[:,1], 'k.', markersize=3) + ax2.axis('off') orange_patch = mpatches.Patch(color=colours[0,:], label='sand') white_patch = mpatches.Patch(color=colours[1,:], label='whitewater') blue_patch = mpatches.Patch(color=colours[2,:], label='water') black_line = mlines.Line2D([],[],color='k',linestyle='--', label='shoreline') - plt.legend(handles=[orange_patch,white_patch,blue_patch, black_line], bbox_to_anchor=(1, 0.5), fontsize=9) - ax3 = fig.add_subplot(gs[0,2]) - plt.imshow(im_mwi, cmap='bwr') - plt.plot(sl_pix[:,0], sl_pix[:,1], 'k--') - plt.axis('off') - cb = plt.colorbar() - cb.ax.tick_params(labelsize=10) - cb.set_label('MNDWI values') - ax3.set_anchor('W') + ax2.legend(handles=[orange_patch,white_patch,blue_patch, black_line], + bbox_to_anchor=(1, 0.5), fontsize=9) + # create image 3 (MNDWI) + ax3.imshow(im_mwi, cmap='bwr') + ax3.plot(sl_pix[:,0], sl_pix[:,1], 'k.', markersize=3) + ax3.axis('off') + +# additional options +# ax1.set_anchor('W') +# ax2.set_anchor('W') +# cb = plt.colorbar() +# cb.ax.tick_params(labelsize=10) +# cb.set_label('MNDWI values') +# ax3.set_anchor('W') + fig.set_size_inches([12.53, 9.3]) - fig.set_tight_layout(True) mng = plt.get_current_fig_manager() mng.window.showMaximized() - # wait for user's selection ( or ) - pt = ginput(n=1, timeout=100, show_clicks=True) + # wait for user's selection: or + pt = ginput(n=1, timeout=100000, show_clicks=True) pt = np.array(pt) - # if clicks next to , return skip_image = True + # if user clicks around the button, return skip_image = True if pt[0][0] > im_ms.shape[1]/2: skip_image = True plt.close() else: skip_image = False - ax1.set_title(date + ' ' + satname) btn_skip.set_visible(False) btn_keep.set_visible(False) fig.savefig(os.path.join(filepath, date + '_' + satname + '.jpg'), dpi=150) @@ -471,11 +574,41 @@ def show_detection(im_ms, cloud_mask, im_labels, shoreline,image_epsg, georef, def extract_shorelines(metadata, settings): - - sitename = settings['sitename'] + """ + Extracts shorelines from satellite images. + + KV WRL 2018 + + Arguments: + ----------- + metadata: dict + contains all the information about the satellite images that were downloaded + + inputs: dict + contains the following fields: + sitename: str + String containig the name of the site + polygon: list + polygon containing the lon/lat coordinates to be extracted + longitudes in the first column and latitudes in the second column + dates: list of str + list that contains 2 strings with the initial and final dates in format + 'yyyy-mm-dd' e.g. ['1987-01-01', '2018-01-01'] + sat_list: list of str + list that contains the names of the satellite missions to include + e.g. ['L5', 'L7', 'L8', 'S2'] + + Returns: + ----------- + output: dict + contains the extracted shorelines and corresponding dates. + + """ + + sitename = settings['inputs']['sitename'] # initialise output structure - out = dict([]) + output = dict([]) # create a subfolder to store the .jpg images showing the detection filepath_jpg = os.path.join(os.getcwd(), 'data', sitename, 'jpg_files', 'detection') try: @@ -486,58 +619,25 @@ def extract_shorelines(metadata, settings): # loop through satellite list for satname in metadata.keys(): - # access the images - if satname == 'L5': - # access downloaded Landsat 5 images - filepath = os.path.join(os.getcwd(), 'data', sitename, satname, '30m') - filenames = os.listdir(filepath) - elif satname == 'L7': - # access downloaded Landsat 7 images - filepath_pan = os.path.join(os.getcwd(), 'data', sitename, 'L7', 'pan') - filepath_ms = os.path.join(os.getcwd(), 'data', sitename, 'L7', 'ms') - filenames_pan = os.listdir(filepath_pan) - filenames_ms = os.listdir(filepath_ms) - if (not len(filenames_pan) == len(filenames_ms)): - raise 'error: not the same amount of files for pan and ms' - filepath = [filepath_pan, filepath_ms] - filenames = filenames_pan - elif satname == 'L8': - # access downloaded Landsat 7 images - filepath_pan = os.path.join(os.getcwd(), 'data', sitename, 'L8', 'pan') - filepath_ms = os.path.join(os.getcwd(), 'data', sitename, 'L8', 'ms') - filenames_pan = os.listdir(filepath_pan) - filenames_ms = os.listdir(filepath_ms) - if (not len(filenames_pan) == len(filenames_ms)): - raise 'error: not the same amount of files for pan and ms' - filepath = [filepath_pan, filepath_ms] - filenames = filenames_pan - elif satname == 'S2': - # access downloaded Sentinel 2 images - filepath10 = os.path.join(os.getcwd(), 'data', sitename, satname, '10m') - filenames10 = os.listdir(filepath10) - filepath20 = os.path.join(os.getcwd(), 'data', sitename, satname, '20m') - filenames20 = os.listdir(filepath20) - filepath60 = os.path.join(os.getcwd(), 'data', sitename, satname, '60m') - filenames60 = os.listdir(filepath60) - if (not len(filenames10) == len(filenames20)) or (not len(filenames20) == len(filenames60)): - raise 'error: not the same amount of files for 10, 20 and 60 m' - filepath = [filepath10, filepath20, filepath60] - filenames = filenames10 - + # get images + filepath = SDS_tools.get_filepath(settings['inputs'],satname) + filenames = metadata[satname]['filenames'] + # initialise some variables - out_timestamp = [] # datetime at which the image was acquired (UTC time) - out_shoreline = [] # vector of shoreline points - out_filename = [] # filename of the images from which the shorelines where derived - out_cloudcover = [] # cloud cover of the images - out_geoaccuracy = []# georeferencing accuracy of the images - out_idxkeep = [] # index that were kept during the analysis (cloudy images are skipped) + output_timestamp = [] # datetime at which the image was acquired (UTC time) + output_shoreline = [] # vector of shoreline points + output_filename = [] # filename of the images from which the shorelines where derived + output_cloudcover = [] # cloud cover of the images + output_geoaccuracy = []# georeferencing accuracy of the images + output_idxkeep = [] # index that were kept during the analysis (cloudy images are skipped) # loop through the images for i in range(len(filenames)): + # get image filename fn = SDS_tools.get_filenames(filenames[i],filepath, satname) # preprocess image (cloud mask + pansharpening/downsampling) - im_ms, georef, cloud_mask = SDS_preprocess.preprocess_single(fn, satname) + im_ms, georef, cloud_mask, im_extra, imQA = SDS_preprocess.preprocess_single(fn, satname) # get image spatial reference system (epsg code) from metadata dict image_epsg = metadata[satname]['epsg'][i] # calculate cloud cover @@ -546,22 +646,28 @@ def extract_shorelines(metadata, settings): # skip image if cloud cover is above threshold if cloud_cover > settings['cloud_thresh']: continue + # classify image in 4 classes (sand, whitewater, water, other) with NN classifier - im_classif, im_labels = classify_image_NN_nopan(im_ms, cloud_mask, - settings['min_beach_size']) + im_classif, im_labels = classify_image_NN(im_ms, im_extra, cloud_mask, + settings['min_beach_size'], satname) + # extract water line contours # if there aren't any sandy pixels, use find_wl_contours1 (traditional method), # otherwise use find_wl_contours2 (enhanced method with classification) - if sum(sum(im_labels[:,:,0])) == 0 : - # compute MNDWI (SWIR-Green normalized index) grayscale image - im_mndwi = nd_index(im_ms[:,:,4], im_ms[:,:,1], cloud_mask) - # find water contourson MNDWI grayscale image - contours_mwi = find_wl_contours1(im_mndwi, cloud_mask) - else: - # use classification to refine threshold and extract sand/water interface - contours_wi, contours_mwi = find_wl_contours2(im_ms, im_labels, - cloud_mask, settings['buffer_size']) - # extract clean shoreline from water contours + try: # use try/except structure for long runs + if sum(sum(im_labels[:,:,0])) == 0 : + # compute MNDWI (SWIR-Green normalized index) grayscale image + im_mndwi = nd_index(im_ms[:,:,4], im_ms[:,:,1], cloud_mask) + # find water contourson MNDWI grayscale image + contours_mwi = find_wl_contours1(im_mndwi, cloud_mask) + else: + # use classification to refine threshold and extract sand/water interface + contours_wi, contours_mwi = find_wl_contours2(im_ms, im_labels, + cloud_mask, settings['buffer_size']) + except: + continue + + # process water contours into shorelines shoreline = process_shoreline(contours_mwi, georef, image_epsg, settings) if settings['check_detection']: @@ -571,34 +677,35 @@ def extract_shorelines(metadata, settings): if skip_image: continue - # fill and save output structure - out_timestamp.append(metadata[satname]['dates'][i]) - out_shoreline.append(shoreline) - out_filename.append(filenames[i]) - out_cloudcover.append(cloud_cover) - out_geoaccuracy.append(metadata[satname]['acc_georef'][i]) - out_idxkeep.append(i) + # fill and save outputput structure + output_timestamp.append(metadata[satname]['dates'][i]) + output_shoreline.append(shoreline) + output_filename.append(filenames[i]) + output_cloudcover.append(cloud_cover) + output_geoaccuracy.append(metadata[satname]['acc_georef'][i]) + output_idxkeep.append(i) - out[satname] = { - 'timestamp': out_timestamp, - 'shoreline': out_shoreline, - 'filename': out_filename, - 'cloudcover': out_cloudcover, - 'geoaccuracy': out_geoaccuracy, - 'idxkeep': out_idxkeep + output[satname] = { + 'timestamp': output_timestamp, + 'shoreline': output_shoreline, + 'filename': output_filename, + 'cloudcover': output_cloudcover, + 'geoaccuracy': output_geoaccuracy, + 'idxkeep': output_idxkeep } # add some metadata - out['meta'] = { + output['meta'] = { 'timestamp': 'UTC time', 'shoreline': 'coordinate system epsg : ' + str(settings['output_epsg']), 'cloudcover': 'calculated on the cropped image', 'geoaccuracy': 'RMSE error based on GCPs', 'idxkeep': 'indices of the images that were kept to extract a shoreline' } - # save output structure as out.pkl + + # save outputput structure as output.pkl filepath = os.path.join(os.getcwd(), 'data', sitename) - with open(os.path.join(filepath, sitename + '_out.pkl'), 'wb') as f: - pickle.dump(out, f) + with open(os.path.join(filepath, sitename + '_output.pkl'), 'wb') as f: + pickle.dump(output, f) - return out \ No newline at end of file + return output \ No newline at end of file diff --git a/SDS_tools.py b/SDS_tools.py index 4b743e9..9eb00cf 100644 --- a/SDS_tools.py +++ b/SDS_tools.py @@ -3,15 +3,17 @@ Author: Kilian Vos, Water Research Laboratory, University of New South Wales """ -# Initial settings +# load modules import os import numpy as np +import matplotlib.pyplot as plt +import pdb + +# other modules from osgeo import gdal, ogr, osr import skimage.transform as transform import simplekml -import pdb - -# Functions +from scipy.ndimage.filters import uniform_filter def convert_pix2world(points, georef): """ @@ -143,6 +145,21 @@ def convert_epsg(points, epsg_in, epsg_out): return points_converted def coords_from_kml(fn): + """ + Extracts coordinates from a .kml file. + + KV WRL 2018 + + Arguments: + ----------- + fn: str + filepath + filename of the kml file to be read + + Returns: ----------- + polygon: list + coordinates extracted from the .kml file + + """ # read .kml file with open(fn) as kmlFile: @@ -152,6 +169,7 @@ def coords_from_kml(fn): str2 = '' subdoc = doc[doc.find(str1)+len(str1):doc.find(str2)] coordlist = subdoc.split('\n') + # read coordinates polygon = [] for i in range(1,len(coordlist)-1): polygon.append([float(coordlist[i].split(',')[0]), float(coordlist[i].split(',')[1])]) @@ -159,29 +177,196 @@ def coords_from_kml(fn): return [polygon] def save_kml(coords, epsg): + """ + Saves coordinates with specified spatial reference system into a .kml file in WGS84. + + KV WRL 2018 + + Arguments: + ----------- + coords: np.array + coordinates (2 columns) to be converted into a .kml file + + Returns: + ----------- + Saves 'coords.kml' in the current folder. + + """ kml = simplekml.Kml() coords_wgs84 = convert_epsg(coords, epsg, 4326) kml.newlinestring(name='coords', coords=coords_wgs84) kml.save('coords.kml') +def get_filepath(inputs,satname): + """ + Create filepath to the different folders containing the satellite images. + + KV WRL 2018 + + Arguments: + ----------- + inputs: dict + dictionnary that contains the following fields: + 'sitename': str + String containig the name of the site + 'polygon': list + polygon containing the lon/lat coordinates to be extracted + longitudes in the first column and latitudes in the second column + 'dates': list of str + list that contains 2 strings with the initial and final dates in format 'yyyy-mm-dd' + e.g. ['1987-01-01', '2018-01-01'] + 'sat_list': list of str + list that contains the names of the satellite missions to include + e.g. ['L5', 'L7', 'L8', 'S2'] + satname: str + short name of the satellite mission + + Returns: + ----------- + filepath: str or list of str + contains the filepath(s) to the folder(s) containing the satellite images + + """ + + sitename = inputs['sitename'] + # access the images + if satname == 'L5': + # access downloaded Landsat 5 images + filepath = os.path.join(os.getcwd(), 'data', sitename, satname, '30m') + elif satname == 'L7': + # access downloaded Landsat 7 images + filepath_pan = os.path.join(os.getcwd(), 'data', sitename, 'L7', 'pan') + filepath_ms = os.path.join(os.getcwd(), 'data', sitename, 'L7', 'ms') + filenames_pan = os.listdir(filepath_pan) + filenames_ms = os.listdir(filepath_ms) + if (not len(filenames_pan) == len(filenames_ms)): + raise 'error: not the same amount of files for pan and ms' + filepath = [filepath_pan, filepath_ms] + elif satname == 'L8': + # access downloaded Landsat 8 images + filepath_pan = os.path.join(os.getcwd(), 'data', sitename, 'L8', 'pan') + filepath_ms = os.path.join(os.getcwd(), 'data', sitename, 'L8', 'ms') + filenames_pan = os.listdir(filepath_pan) + filenames_ms = os.listdir(filepath_ms) + if (not len(filenames_pan) == len(filenames_ms)): + raise 'error: not the same amount of files for pan and ms' + filepath = [filepath_pan, filepath_ms] + elif satname == 'S2': + # access downloaded Sentinel 2 images + filepath10 = os.path.join(os.getcwd(), 'data', sitename, satname, '10m') + filenames10 = os.listdir(filepath10) + filepath20 = os.path.join(os.getcwd(), 'data', sitename, satname, '20m') + filenames20 = os.listdir(filepath20) + filepath60 = os.path.join(os.getcwd(), 'data', sitename, satname, '60m') + filenames60 = os.listdir(filepath60) + if (not len(filenames10) == len(filenames20)) or (not len(filenames20) == len(filenames60)): + raise 'error: not the same amount of files for 10, 20 and 60 m bands' + filepath = [filepath10, filepath20, filepath60] + + return filepath + def get_filenames(filename, filepath, satname): + """ + Creates filepath + filename for all the bands belonging to the same image. + + KV WRL 2018 + + Arguments: + ----------- + filename: str + name of the downloaded satellite image as found in the metadata + filepath: str or list of str + contains the filepath(s) to the folder(s) containing the satellite images + satname: str + short name of the satellite mission + + Returns: + ----------- + fn: str or list of str + contains the filepath + filenames to access the satellite image + + """ if satname == 'L5': fn = os.path.join(filepath, filename) if satname == 'L7' or satname == 'L8': - idx = filename.find('.tif') - filename_ms = filename[:idx-3] + 'ms.tif' + filename_ms = filename.replace('pan','ms') fn = [os.path.join(filepath[0], filename), os.path.join(filepath[1], filename_ms)] if satname == 'S2': - idx = filename.find('.tif') - filename20 = filename[:idx-3] + '20m.tif' - filename60 = filename[:idx-3] + '60m.tif' + filename20 = filename.replace('10m','20m') + filename60 = filename.replace('10m','60m') fn = [os.path.join(filepath[0], filename), os.path.join(filepath[1], filename20), os.path.join(filepath[2], filename60)] + return fn +def image_std(image, radius): + """ + Calculates the standard deviation of an image, using a moving window of specified radius. + + Arguments: + ----------- + image: np.array + 2D array containing the pixel intensities of a single-band image + radius: int + radius defining the moving window used to calculate the standard deviation. For example, + radius = 1 will produce a 3x3 moving window. + + Returns: + ----------- + win_std: np.array + 2D array containing the standard deviation of the image + """ + + # convert to float + image = image.astype(float) + # first pad the image + image_padded = np.pad(image, radius, 'reflect') + # window size + win_rows, win_cols = radius*2 + 1, radius*2 + 1 + # calculate std + win_mean = uniform_filter(image_padded, (win_rows, win_cols)) + win_sqr_mean = uniform_filter(image_padded**2, (win_rows, win_cols)) + win_var = win_sqr_mean - win_mean**2 + win_std = np.sqrt(win_var) + # remove padding + win_std = win_std[radius:-radius, radius:-radius] + + return win_std + +def mask_raster(fn, mask): + """ + Masks a .tif raster using GDAL. + + Arguments: + ----------- + fn: str + filepath + filename of the .tif raster + mask: np.array + array of boolean where True indicates the pixels that are to be masked + + Returns: + ----------- + overwrites the .tif file directly + + """ + + # open raster + raster = gdal.Open(fn, gdal.GA_Update) + # mask raster + for i in range(raster.RasterCount): + out_band = raster.GetRasterBand(i+1) + out_data = out_band.ReadAsArray() + out_band.SetNoDataValue(0) + no_data_value = out_band.GetNoDataValue() + out_data[mask] = no_data_value + out_band.WriteArray(out_data) + # close dataset and flush cache + raster = None + + \ No newline at end of file diff --git a/gdal_merge.py b/gdal_merge.py new file mode 100644 index 0000000..7dec201 --- /dev/null +++ b/gdal_merge.py @@ -0,0 +1,540 @@ +#!/usr/bin/env python +############################################################################### +# $Id$ +# +# Project: InSAR Peppers +# Purpose: Module to extract data from many rasters into one output. +# Author: Frank Warmerdam, warmerdam@pobox.com +# +############################################################################### +# Copyright (c) 2000, Atlantis Scientific Inc. (www.atlsci.com) +# Copyright (c) 2009-2011, Even Rouault +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Library General Public +# License as published by the Free Software Foundation; either +# version 2 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Library General Public License for more details. +# +# You should have received a copy of the GNU Library General Public +# License along with this library; if not, write to the +# Free Software Foundation, Inc., 59 Temple Place - Suite 330, +# Boston, MA 02111-1307, USA. +############################################################################### +# changes 29Apr2011 +# If the input image is a multi-band one, use all the channels in +# building the stack. +# anssi.pekkarinen@fao.org + +import math +import sys +import time + +from osgeo import gdal + +try: + progress = gdal.TermProgress_nocb +except: + progress = gdal.TermProgress + +__version__ = '$id$'[5:-1] +verbose = 0 +quiet = 0 + + +# ============================================================================= +def raster_copy( s_fh, s_xoff, s_yoff, s_xsize, s_ysize, s_band_n, + t_fh, t_xoff, t_yoff, t_xsize, t_ysize, t_band_n, + nodata=None ): + + if verbose != 0: + print('Copy %d,%d,%d,%d to %d,%d,%d,%d.' + % (s_xoff, s_yoff, s_xsize, s_ysize, + t_xoff, t_yoff, t_xsize, t_ysize )) + + if nodata is not None: + return raster_copy_with_nodata( + s_fh, s_xoff, s_yoff, s_xsize, s_ysize, s_band_n, + t_fh, t_xoff, t_yoff, t_xsize, t_ysize, t_band_n, + nodata ) + + s_band = s_fh.GetRasterBand( s_band_n ) + m_band = None + # Works only in binary mode and doesn't take into account + # intermediate transparency values for compositing. + if s_band.GetMaskFlags() != gdal.GMF_ALL_VALID: + m_band = s_band.GetMaskBand() + elif s_band.GetColorInterpretation() == gdal.GCI_AlphaBand: + m_band = s_band + if m_band is not None: + return raster_copy_with_mask( + s_fh, s_xoff, s_yoff, s_xsize, s_ysize, s_band_n, + t_fh, t_xoff, t_yoff, t_xsize, t_ysize, t_band_n, + m_band ) + + s_band = s_fh.GetRasterBand( s_band_n ) + t_band = t_fh.GetRasterBand( t_band_n ) + + data = s_band.ReadRaster( s_xoff, s_yoff, s_xsize, s_ysize, + t_xsize, t_ysize, t_band.DataType ) + t_band.WriteRaster( t_xoff, t_yoff, t_xsize, t_ysize, + data, t_xsize, t_ysize, t_band.DataType ) + + return 0 + +# ============================================================================= +def raster_copy_with_nodata( s_fh, s_xoff, s_yoff, s_xsize, s_ysize, s_band_n, + t_fh, t_xoff, t_yoff, t_xsize, t_ysize, t_band_n, + nodata ): + try: + import numpy as Numeric + except ImportError: + import Numeric + + s_band = s_fh.GetRasterBand( s_band_n ) + t_band = t_fh.GetRasterBand( t_band_n ) + + data_src = s_band.ReadAsArray( s_xoff, s_yoff, s_xsize, s_ysize, + t_xsize, t_ysize ) + data_dst = t_band.ReadAsArray( t_xoff, t_yoff, t_xsize, t_ysize ) + + nodata_test = Numeric.equal(data_src,nodata) + to_write = Numeric.choose( nodata_test, (data_src, data_dst) ) + + t_band.WriteArray( to_write, t_xoff, t_yoff ) + + return 0 + +# ============================================================================= +def raster_copy_with_mask( s_fh, s_xoff, s_yoff, s_xsize, s_ysize, s_band_n, + t_fh, t_xoff, t_yoff, t_xsize, t_ysize, t_band_n, + m_band ): + try: + import numpy as Numeric + except ImportError: + import Numeric + + s_band = s_fh.GetRasterBand( s_band_n ) + t_band = t_fh.GetRasterBand( t_band_n ) + + data_src = s_band.ReadAsArray( s_xoff, s_yoff, s_xsize, s_ysize, + t_xsize, t_ysize ) + data_mask = m_band.ReadAsArray( s_xoff, s_yoff, s_xsize, s_ysize, + t_xsize, t_ysize ) + data_dst = t_band.ReadAsArray( t_xoff, t_yoff, t_xsize, t_ysize ) + + mask_test = Numeric.equal(data_mask, 0) + to_write = Numeric.choose( mask_test, (data_src, data_dst) ) + + t_band.WriteArray( to_write, t_xoff, t_yoff ) + + return 0 + +# ============================================================================= +def names_to_fileinfos( names ): + """ + Translate a list of GDAL filenames, into file_info objects. + + names -- list of valid GDAL dataset names. + + Returns a list of file_info objects. There may be less file_info objects + than names if some of the names could not be opened as GDAL files. + """ + + file_infos = [] + for name in names: + fi = file_info() + if fi.init_from_name( name ) == 1: + file_infos.append( fi ) + + return file_infos + +# ***************************************************************************** +class file_info: + """A class holding information about a GDAL file.""" + + def init_from_name(self, filename): + """ + Initialize file_info from filename + + filename -- Name of file to read. + + Returns 1 on success or 0 if the file can't be opened. + """ + fh = gdal.Open( filename ) + if fh is None: + return 0 + + self.filename = filename + self.bands = fh.RasterCount + self.xsize = fh.RasterXSize + self.ysize = fh.RasterYSize + self.band_type = fh.GetRasterBand(1).DataType + self.projection = fh.GetProjection() + self.geotransform = fh.GetGeoTransform() + self.ulx = self.geotransform[0] + self.uly = self.geotransform[3] + self.lrx = self.ulx + self.geotransform[1] * self.xsize + self.lry = self.uly + self.geotransform[5] * self.ysize + + ct = fh.GetRasterBand(1).GetRasterColorTable() + if ct is not None: + self.ct = ct.Clone() + else: + self.ct = None + + return 1 + + def report( self ): + print('Filename: '+ self.filename) + print('File Size: %dx%dx%d' + % (self.xsize, self.ysize, self.bands)) + print('Pixel Size: %f x %f' + % (self.geotransform[1],self.geotransform[5])) + print('UL:(%f,%f) LR:(%f,%f)' + % (self.ulx,self.uly,self.lrx,self.lry)) + + def copy_into( self, t_fh, s_band = 1, t_band = 1, nodata_arg=None ): + """ + Copy this files image into target file. + + This method will compute the overlap area of the file_info objects + file, and the target gdal.Dataset object, and copy the image data + for the common window area. It is assumed that the files are in + a compatible projection ... no checking or warping is done. However, + if the destination file is a different resolution, or different + image pixel type, the appropriate resampling and conversions will + be done (using normal GDAL promotion/demotion rules). + + t_fh -- gdal.Dataset object for the file into which some or all + of this file may be copied. + + Returns 1 on success (or if nothing needs to be copied), and zero one + failure. + """ + t_geotransform = t_fh.GetGeoTransform() + t_ulx = t_geotransform[0] + t_uly = t_geotransform[3] + t_lrx = t_geotransform[0] + t_fh.RasterXSize * t_geotransform[1] + t_lry = t_geotransform[3] + t_fh.RasterYSize * t_geotransform[5] + + # figure out intersection region + tgw_ulx = max(t_ulx,self.ulx) + tgw_lrx = min(t_lrx,self.lrx) + if t_geotransform[5] < 0: + tgw_uly = min(t_uly,self.uly) + tgw_lry = max(t_lry,self.lry) + else: + tgw_uly = max(t_uly,self.uly) + tgw_lry = min(t_lry,self.lry) + + # do they even intersect? + if tgw_ulx >= tgw_lrx: + return 1 + if t_geotransform[5] < 0 and tgw_uly <= tgw_lry: + return 1 + if t_geotransform[5] > 0 and tgw_uly >= tgw_lry: + return 1 + + # compute target window in pixel coordinates. + tw_xoff = int((tgw_ulx - t_geotransform[0]) / t_geotransform[1] + 0.1) + tw_yoff = int((tgw_uly - t_geotransform[3]) / t_geotransform[5] + 0.1) + tw_xsize = int((tgw_lrx - t_geotransform[0])/t_geotransform[1] + 0.5) \ + - tw_xoff + tw_ysize = int((tgw_lry - t_geotransform[3])/t_geotransform[5] + 0.5) \ + - tw_yoff + + if tw_xsize < 1 or tw_ysize < 1: + return 1 + + # Compute source window in pixel coordinates. + sw_xoff = int((tgw_ulx - self.geotransform[0]) / self.geotransform[1]) + sw_yoff = int((tgw_uly - self.geotransform[3]) / self.geotransform[5]) + sw_xsize = int((tgw_lrx - self.geotransform[0]) \ + / self.geotransform[1] + 0.5) - sw_xoff + sw_ysize = int((tgw_lry - self.geotransform[3]) \ + / self.geotransform[5] + 0.5) - sw_yoff + + if sw_xsize < 1 or sw_ysize < 1: + return 1 + + # Open the source file, and copy the selected region. + s_fh = gdal.Open( self.filename ) + + return raster_copy( s_fh, sw_xoff, sw_yoff, sw_xsize, sw_ysize, s_band, + t_fh, tw_xoff, tw_yoff, tw_xsize, tw_ysize, t_band, + nodata_arg ) + + +# ============================================================================= +def Usage(): + print('Usage: gdal_merge.py [-o out_filename] [-of out_format] [-co NAME=VALUE]*') + print(' [-ps pixelsize_x pixelsize_y] [-tap] [-separate] [-q] [-v] [-pct]') + print(' [-ul_lr ulx uly lrx lry] [-init "value [value...]"]') + print(' [-n nodata_value] [-a_nodata output_nodata_value]') + print(' [-ot datatype] [-createonly] input_files') + print(' [--help-general]') + print('') + +# ============================================================================= +# +# Program mainline. +# + +def main( argv=None ): + + global verbose, quiet + verbose = 0 + quiet = 0 + names = [] + format = 'GTiff' + out_file = 'out.tif' + + ulx = None + psize_x = None + separate = 0 + copy_pct = 0 + nodata = None + a_nodata = None + create_options = [] + pre_init = [] + band_type = None + createonly = 0 + bTargetAlignedPixels = False + start_time = time.time() + + gdal.AllRegister() + if argv is None: + argv = sys.argv + argv = gdal.GeneralCmdLineProcessor( argv ) + if argv is None: + sys.exit( 0 ) + + # Parse command line arguments. + i = 1 + while i < len(argv): + arg = argv[i] + + if arg == '-o': + i = i + 1 + out_file = argv[i] + + elif arg == '-v': + verbose = 1 + + elif arg == '-q' or arg == '-quiet': + quiet = 1 + + elif arg == '-createonly': + createonly = 1 + + elif arg == '-separate': + separate = 1 + + elif arg == '-seperate': + separate = 1 + + elif arg == '-pct': + copy_pct = 1 + + elif arg == '-ot': + i = i + 1 + band_type = gdal.GetDataTypeByName( argv[i] ) + if band_type == gdal.GDT_Unknown: + print('Unknown GDAL data type: %s' % argv[i]) + sys.exit( 1 ) + + elif arg == '-init': + i = i + 1 + str_pre_init = argv[i].split() + for x in str_pre_init: + pre_init.append(float(x)) + + elif arg == '-n': + i = i + 1 + nodata = float(argv[i]) + + elif arg == '-a_nodata': + i = i + 1 + a_nodata = float(argv[i]) + + elif arg == '-f': + # for backward compatibility. + i = i + 1 + format = argv[i] + + elif arg == '-of': + i = i + 1 + format = argv[i] + + elif arg == '-co': + i = i + 1 + create_options.append( argv[i] ) + + elif arg == '-ps': + psize_x = float(argv[i+1]) + psize_y = -1 * abs(float(argv[i+2])) + i = i + 2 + + elif arg == '-tap': + bTargetAlignedPixels = True + + elif arg == '-ul_lr': + ulx = float(argv[i+1]) + uly = float(argv[i+2]) + lrx = float(argv[i+3]) + lry = float(argv[i+4]) + i = i + 4 + + elif arg[:1] == '-': + print('Unrecognized command option: %s' % arg) + Usage() + sys.exit( 1 ) + + else: + names.append(arg) + + i = i + 1 + + if len(names) == 0: + print('No input files selected.') + Usage() + sys.exit( 1 ) + + Driver = gdal.GetDriverByName(format) + if Driver is None: + print('Format driver %s not found, pick a supported driver.' % format) + sys.exit( 1 ) + + DriverMD = Driver.GetMetadata() + if 'DCAP_CREATE' not in DriverMD: + print('Format driver %s does not support creation and piecewise writing.\nPlease select a format that does, such as GTiff (the default) or HFA (Erdas Imagine).' % format) + sys.exit( 1 ) + + # Collect information on all the source files. + file_infos = names_to_fileinfos( names ) + + if ulx is None: + ulx = file_infos[0].ulx + uly = file_infos[0].uly + lrx = file_infos[0].lrx + lry = file_infos[0].lry + + for fi in file_infos: + ulx = min(ulx, fi.ulx) + uly = max(uly, fi.uly) + lrx = max(lrx, fi.lrx) + lry = min(lry, fi.lry) + + if psize_x is None: + psize_x = file_infos[0].geotransform[1] + psize_y = file_infos[0].geotransform[5] + + if band_type is None: + band_type = file_infos[0].band_type + + # Try opening as an existing file. + gdal.PushErrorHandler( 'CPLQuietErrorHandler' ) + t_fh = gdal.Open( out_file, gdal.GA_Update ) + gdal.PopErrorHandler() + + # Create output file if it does not already exist. + if t_fh is None: + + if bTargetAlignedPixels: + ulx = math.floor(ulx / psize_x) * psize_x + lrx = math.ceil(lrx / psize_x) * psize_x + lry = math.floor(lry / -psize_y) * -psize_y + uly = math.ceil(uly / -psize_y) * -psize_y + + geotransform = [ulx, psize_x, 0, uly, 0, psize_y] + + xsize = int((lrx - ulx) / geotransform[1] + 0.5) + ysize = int((lry - uly) / geotransform[5] + 0.5) + + + if separate != 0: + bands=0 + + for fi in file_infos: + bands=bands + fi.bands + else: + bands = file_infos[0].bands + + + t_fh = Driver.Create( out_file, xsize, ysize, bands, + band_type, create_options ) + if t_fh is None: + print('Creation failed, terminating gdal_merge.') + sys.exit( 1 ) + + t_fh.SetGeoTransform( geotransform ) + t_fh.SetProjection( file_infos[0].projection ) + + if copy_pct: + t_fh.GetRasterBand(1).SetRasterColorTable(file_infos[0].ct) + else: + if separate != 0: + bands=0 + for fi in file_infos: + bands=bands + fi.bands + if t_fh.RasterCount < bands : + print('Existing output file has less bands than the input files. You should delete it before. Terminating gdal_merge.') + sys.exit( 1 ) + else: + bands = min(file_infos[0].bands,t_fh.RasterCount) + + # Do we need to set nodata value ? + if a_nodata is not None: + for i in range(t_fh.RasterCount): + t_fh.GetRasterBand(i+1).SetNoDataValue(a_nodata) + + # Do we need to pre-initialize the whole mosaic file to some value? + if pre_init is not None: + if t_fh.RasterCount <= len(pre_init): + for i in range(t_fh.RasterCount): + t_fh.GetRasterBand(i+1).Fill( pre_init[i] ) + elif len(pre_init) == 1: + for i in range(t_fh.RasterCount): + t_fh.GetRasterBand(i+1).Fill( pre_init[0] ) + + # Copy data from source files into output file. + t_band = 1 + + if quiet == 0 and verbose == 0: + progress( 0.0 ) + fi_processed = 0 + + for fi in file_infos: + if createonly != 0: + continue + + if verbose != 0: + print("") + print("Processing file %5d of %5d, %6.3f%% completed in %d minutes." + % (fi_processed+1,len(file_infos), + fi_processed * 100.0 / len(file_infos), + int(round((time.time() - start_time)/60.0)) )) + fi.report() + + if separate == 0 : + for band in range(1, bands+1): + fi.copy_into( t_fh, band, band, nodata ) + else: + for band in range(1, fi.bands+1): + fi.copy_into( t_fh, band, t_band, nodata ) + t_band = t_band+1 + + fi_processed = fi_processed+1 + if quiet == 0 and verbose == 0: + progress( fi_processed / float(len(file_infos)) ) + + # Force file to be closed. + t_fh = None + +if __name__ == '__main__': + sys.exit(main()) diff --git a/main_test.py b/main_test.py new file mode 100644 index 0000000..06341a4 --- /dev/null +++ b/main_test.py @@ -0,0 +1,285 @@ +#==========================================================# +# Create a classifier for satellite images +#==========================================================# + +# load modules +import os +import pickle +import warnings +import numpy as np +import matplotlib.cm as cm +warnings.filterwarnings("ignore") +import matplotlib.pyplot as plt +from pylab import ginput + +import SDS_download, SDS_preprocess, SDS_shoreline, SDS_tools, SDS_classification + +filepath_sites = os.path.join(os.getcwd(), 'polygons') +sites = os.listdir(filepath_sites) + +for site in sites: + + polygon = SDS_tools.coords_from_kml(os.path.join(filepath_sites,site)) + + # load Sentinel-2 images + inputs = { + 'polygon': polygon, + 'dates': ['2016-10-01', '2016-11-01'], + 'sat_list': ['S2'], + 'sitename': site[:site.find('.')] + } + + satname = inputs['sat_list'][0] + + metadata = SDS_download.get_images(inputs) + metadata = SDS_download.remove_cloudy_images(metadata,inputs,0.2) + filepath = os.path.join(os.getcwd(), 'data', inputs['sitename']) + with open(os.path.join(filepath, inputs['sitename'] + '_metadata_' + satname + '.pkl'), 'wb') as f: + pickle.dump(metadata, f) + #with open(os.path.join(filepath, inputs['sitename'] + '_metadata_' + satname + '.pkl'), 'rb') as f: + # metadata = pickle.load(f) + + # settings needed to run the shoreline extraction + settings = { + + # general parameters: + 'cloud_thresh': 0.1, # threshold on maximum cloud cover + 'output_epsg': 28356, # epsg code of spatial reference system desired for the output + + # shoreline detection parameters: + 'min_beach_size': 20, # minimum number of connected pixels for a beach + 'buffer_size': 7, # radius (in pixels) of disk for buffer around sandy pixels + 'min_length_sl': 200, # minimum length of shoreline perimeter to be kept + 'max_dist_ref': 100, # max distance (in meters) allowed from a reference shoreline + + # quality control: + 'check_detection': True, # if True, shows each shoreline detection and lets the user + # decide which ones are correct and which ones are false due to + # the presence of clouds + # also add the inputs + 'inputs': inputs + } + # preprocess images (cloud masking, pansharpening/down-sampling) + SDS_preprocess.preprocess_all_images(metadata, settings) + + training_data = dict([]) + training_data['sand'] = dict([]) + training_data['swash'] = dict([]) + training_data['water'] = dict([]) + training_data['land'] = dict([]) + + # read images + filepath = SDS_tools.get_filepath(inputs,satname) + filenames = metadata[satname]['filenames'] + + for i in range(len(filenames)): + + fn = SDS_tools.get_filenames(filenames[i],filepath,satname) + im_ms, georef, cloud_mask, im20, imQA = SDS_preprocess.preprocess_single(fn,satname) + + nrow = im_ms.shape[0] + ncol = im_ms.shape[1] + + im_RGB = SDS_preprocess.rescale_image_intensity(im_ms[:,:,[2,1,0]], cloud_mask, 99.9) + plt.figure() + mng = plt.get_current_fig_manager() + mng.window.showMaximized() + plt.imshow(im_RGB) + plt.axis('off') + + # Digitize sandy pixels + plt.title('Digitize SAND pixels', fontweight='bold', fontsize=15) + pt = ginput(n=1000, timeout=100000, show_clicks=True) + + if len(pt) > 0: + pt = np.round(pt).astype(int) + im_sand = np.zeros((nrow,ncol)) + for k in range(len(pt)): + im_sand[pt[k,1],pt[k,0]] = 1 + im_RGB[pt[k,1],pt[k,0],0] = 1 + im_RGB[pt[k,1],pt[k,0],1] = 0 + im_RGB[pt[k,1],pt[k,0],2] = 0 + im_sand = im_sand.astype(bool) + features = SDS_classification.calculate_features(im_ms, cloud_mask, im_sand) + else: + im_sand = np.zeros((nrow,ncol)).astype(bool) + features = [] + training_data['sand'][filenames[i]] = {'pixels':im_sand,'features':features} + + # Digitize swash pixels + plt.title('Digitize SWASH pixels', fontweight='bold', fontsize=15) + plt.draw() + pt = ginput(n=1000, timeout=100000, show_clicks=True) + + if len(pt) > 0: + pt = np.round(pt).astype(int) + im_swash = np.zeros((nrow,ncol)) + for k in range(len(pt)): + im_swash[pt[k,1],pt[k,0]] = 1 + im_RGB[pt[k,1],pt[k,0],0] = 0 + im_RGB[pt[k,1],pt[k,0],1] = 1 + im_RGB[pt[k,1],pt[k,0],2] = 0 + im_swash = im_swash.astype(bool) + features = SDS_classification.calculate_features(im_ms, cloud_mask, im_swash) + else: + im_swash = np.zeros((nrow,ncol)).astype(bool) + features = [] + training_data['swash'][filenames[i]] = {'pixels':im_swash,'features':features} + + # Digitize rectangle containig water pixels + plt.title('Click 2 points to draw a rectange in the WATER', fontweight='bold', fontsize=15) + plt.draw() + pt = ginput(n=2, timeout=100000, show_clicks=True) + if len(pt) > 0: + pt = np.round(pt).astype(int) + idx_row = np.arange(np.min(pt[:,1]),np.max(pt[:,1])+1,1) + idx_col = np.arange(np.min(pt[:,0]),np.max(pt[:,0])+1,1) + xx, yy = np.meshgrid(idx_row,idx_col, indexing='ij') + rows = xx.reshape(xx.shape[0]*xx.shape[1]) + cols = yy.reshape(yy.shape[0]*yy.shape[1]) + im_water = np.zeros((nrow,ncol)).astype(bool) + for k in range(len(rows)): + im_water[rows[k],cols[k]] = 1 + im_RGB[rows[k],cols[k],0] = 0 + im_RGB[rows[k],cols[k],1] = 0 + im_RGB[rows[k],cols[k],2] = 1 + im_water = im_water.astype(bool) + features = SDS_classification.calculate_features(im_ms, cloud_mask, im_water) + else: + im_water = np.zeros((nrow,ncol)).astype(bool) + features = [] + training_data['water'][filenames[i]] = {'pixels':im_water,'features':features} + + # Digitize rectangle containig land pixels + plt.title('Click 2 points to draw a rectange in the LAND', fontweight='bold', fontsize=15) + plt.draw() + pt = ginput(n=2, timeout=100000, show_clicks=True) + plt.close() + if len(pt) > 0: + pt = np.round(pt).astype(int) + idx_row = np.arange(np.min(pt[:,1]),np.max(pt[:,1])+1,1) + idx_col = np.arange(np.min(pt[:,0]),np.max(pt[:,0])+1,1) + xx, yy = np.meshgrid(idx_row,idx_col, indexing='ij') + rows = xx.reshape(xx.shape[0]*xx.shape[1]) + cols = yy.reshape(yy.shape[0]*yy.shape[1]) + im_land = np.zeros((nrow,ncol)).astype(bool) + for k in range(len(rows)): + im_land[rows[k],cols[k]] = 1 + im_RGB[rows[k],cols[k],0] = 1 + im_RGB[rows[k],cols[k],1] = 1 + im_RGB[rows[k],cols[k],2] = 0 + im_land = im_land.astype(bool) + features = SDS_classification.calculate_features(im_ms, cloud_mask, im_land) + else: + im_land = np.zeros((nrow,ncol)).astype(bool) + features = [] + training_data['land'][filenames[i]] = {'pixels':im_land,'features':features} + + plt.figure() + plt.title('Classified image') + plt.imshow(im_RGB) + + # save training data for each site + filepath = os.path.join(os.getcwd(), 'data', inputs['sitename']) + with open(os.path.join(filepath, inputs['sitename'] + '_training_' + satname + '.pkl'), 'wb') as f: + pickle.dump(training_data, f) +#%% + +## load Landsat 5 images +#inputs = { +# 'polygon': polygon, +# 'dates': ['1987-01-01', '1988-01-01'], +# 'sat_list': ['L5'], +# 'sitename': site[:site.find('.')] +# } +#metadata = SDS_download.get_images(inputs) +# +## load Landsat 7 images +#inputs = { +# 'polygon': polygon, +# 'dates': ['2001-01-01', '2002-01-01'], +# 'sat_list': ['L7'], +# 'sitename': site[:site.find('.')] +# } +#metadata = SDS_download.get_images(inputs) +# +## load Landsat 8 images +#inputs = { +# 'polygon': polygon, +# 'dates': ['2014-01-01', '2015-01-01'], +# 'sat_list': ['L8'], +# 'sitename': site[:site.find('.')] +# } +#metadata = SDS_download.get_images(inputs) + + +#%% clean the Landsat collections + +#import ee +#from datetime import datetime, timedelta +#import pytz +#import copy +#ee.Initialize() +#site = sites[0] +#dates = ['2017-12-01', '2017-12-25'] +#polygon = SDS_tools.coords_from_kml(os.path.join(filepath_sites,site)) +## Landsat collection +#input_col = ee.ImageCollection('LANDSAT/LC08/C01/T1_RT_TOA') +## filter by location and dates +#flt_col = input_col.filterBounds(ee.Geometry.Polygon(polygon)).filterDate(inputs['dates'][0],inputs['dates'][1]) +## get all images in the filtered collection +#im_all = flt_col.getInfo().get('features') +#cloud_cover = [_['properties']['CLOUD_COVER'] for _ in im_all] +#if np.any([_ > 90 for _ in cloud_cover]): +# idx_delete = np.where([_ > 90 for _ in cloud_cover])[0] +# im_all_cloud = [x for k,x in enumerate(im_all) if k not in idx_delete] + + +#%% clean the S2 collection + +#import ee +#from datetime import datetime, timedelta +#import pytz +#import copy +#ee.Initialize() +## Sentinel2 collection +#input_col = ee.ImageCollection('COPERNICUS/S2') +## filter by location and dates +#flt_col = input_col.filterBounds(ee.Geometry.Polygon(polygon)).filterDate(inputs['dates'][0],inputs['dates'][1]) +## get all images in the filtered collection +#im_all = flt_col.getInfo().get('features') +# +## remove duplicates (there are many in S2 collection) +## timestamps +#timestamps = [datetime.fromtimestamp(_['properties']['system:time_start']/1000, tz=pytz.utc) for _ in im_all] +## utm zones +#utm_zones = np.array([int(_['bands'][0]['crs'][5:]) for _ in im_all]) +#utm_zone_selected = np.max(np.unique(utm_zones)) +#idx_all = np.arange(0,len(im_all),1) +#idx_covered = np.ones(len(im_all)).astype(bool) +#idx_delete = [] +#i = 0 +#while 1: +# same_time = np.abs([(timestamps[i]-_).total_seconds() for _ in timestamps]) < 60*60*24 +# idx_same_time = np.where(same_time)[0] +# same_utm = utm_zones == utm_zone_selected +# idx_temp = np.where([same_time[j] == True and same_utm[j] == False for j in idx_all])[0] +# idx_keep = idx_same_time[[_ not in idx_temp for _ in idx_same_time ]] +# if len(idx_keep) > 2: # if more than 2 images with same date and same utm, drop the last one +# idx_temp = np.append(idx_temp,idx_keep[-1]) +# for j in idx_temp: +# idx_delete.append(j) +# idx_covered[idx_same_time] = False +# if np.any(idx_covered): +# i = np.where(idx_covered)[0][0] +# else: +# break +#im_all_updated = [x for k,x in enumerate(im_all) if k not in idx_delete] +# +## remove very cloudy images (>90% cloud) +#cloud_cover = [_['properties']['CLOUDY_PIXEL_PERCENTAGE'] for _ in im_all_updated] +#if np.any([_ > 90 for _ in cloud_cover]): +# idx_delete = np.where([_ > 90 for _ in cloud_cover])[0] +# im_all_cloud = [x for k,x in enumerate(im_all_updated) if k not in idx_delete] + + diff --git a/shoreline_extraction.ipynb b/shoreline_extraction.ipynb index 8d99b85..c95ec70 100644 --- a/shoreline_extraction.ipynb +++ b/shoreline_extraction.ipynb @@ -135,7 +135,7 @@ " metadata = pickle.load(f)\n", " \n", "# [OPTIONAL] saves .jpg files of the preprocessed images (cloud mask and pansharpening/down-sampling) \n", - "#SDS_preprocess.preprocess_all_images(metadata, settings)\n", + "#SDS_preprocess.save_jpg(metadata, settings)\n", "\n", "# [OPTIONAL] to avoid false detections and identify obvious outliers there is the option to\n", "# create a reference shoreline position (manually clicking on a satellite image)\n", diff --git a/main_spyder.py b/test_spyder_simple.py similarity index 51% rename from main_spyder.py rename to test_spyder_simple.py index 61920fc..d65c0da 100644 --- a/main_spyder.py +++ b/test_spyder_simple.py @@ -8,41 +8,44 @@ import pickle import warnings warnings.filterwarnings("ignore") import matplotlib.pyplot as plt -import SDS_download, SDS_preprocess, SDS_shoreline +import SDS_download, SDS_preprocess, SDS_shoreline, SDS_tools # define the area of interest (longitude, latitude) -polygon = [[[151.301454, -33.700754], - [151.311453, -33.702075], - [151.307237, -33.739761], - [151.294220, -33.736329], - [151.301454, -33.700754]]] +polygon = SDS_tools.coords_from_kml('NARRA.kml') # define dates of interest -dates = ['2017-12-01', '2018-01-01'] +dates = ['2015-01-01', '2019-01-01'] # define satellite missions -sat_list = ['L5', 'L7', 'L8', 'S2'] +sat_list = ['S2'] # give a name to the site sitename = 'NARRA' +# put all the inputs into a dictionnary +inputs = { + 'polygon': polygon, + 'dates': dates, + 'sat_list': sat_list, + 'sitename': sitename + } + # download satellite images (also saves metadata.pkl) -#SDS_download.get_images(sitename, polygon, dates, sat_list) +metadata = SDS_download.get_images(inputs) -# load metadata structure (contains information on the downloaded satellite images and is created -# after all images have been successfully downloaded) +# if you have already downloaded the images, just load the metadata file filepath = os.path.join(os.getcwd(), 'data', sitename) with open(os.path.join(filepath, sitename + '_metadata' + '.pkl'), 'rb') as f: - metadata = pickle.load(f) + metadata = pickle.load(f) -# parameters and settings +#%% +# settings needed to run the shoreline extraction settings = { - 'sitename': sitename, # general parameters: - 'cloud_thresh': 0.5, # threshold on maximum cloud cover - 'output_epsg': 28356, # epsg code of the desired output spatial reference system + 'cloud_thresh': 0.2, # threshold on maximum cloud cover + 'output_epsg': 28356, # epsg code of spatial reference system desired for the output # shoreline detection parameters: 'min_beach_size': 20, # minimum number of connected pixels for a beach @@ -51,30 +54,34 @@ settings = { 'max_dist_ref': 100, # max distance (in meters) allowed from a reference shoreline # quality control: - 'check_detection': True # if True, shows each shoreline detection and lets the user + 'check_detection': True, # if True, shows each shoreline detection and lets the user # decide which ones are correct and which ones are false due to - # the presence of clouds + # the presence of clouds + # also add the inputs + 'inputs': inputs } + # preprocess images (cloud masking, pansharpening/down-sampling) -SDS_preprocess.preprocess_all_images(metadata, settings) +#SDS_preprocess.save_jpg(metadata, settings) -# create a reference shoreline (used to identify outliers and false detections) -settings['refsl'] = SDS_preprocess.get_reference_sl(metadata, settings) +# create a reference shoreline (helps to identify outliers and false detections) +settings['refsl'] = SDS_preprocess.get_reference_sl_manual(metadata, settings) +#settings['refsl'] = SDS_preprocess.get_reference_sl_Australia(settings) # extract shorelines from all images (also saves output.pkl) -out = SDS_shoreline.extract_shorelines(metadata, settings) +output = SDS_shoreline.extract_shorelines(metadata, settings) # plot shorelines plt.figure() plt.axis('equal') plt.xlabel('Eastings [m]') plt.ylabel('Northings [m]') -for satname in out.keys(): +for satname in output.keys(): if satname == 'meta': continue - for i in range(len(out[satname]['shoreline'])): - sl = out[satname]['shoreline'][i] - date = out[satname]['timestamp'][i] - plt.plot(sl[:, 0], sl[:, 1], '-', label=date.strftime('%d-%m-%Y')) -plt.legend() + for i in range(len(output[satname]['shoreline'])): + sl = output[satname]['shoreline'][i] + date = output[satname]['timestamp'][i] + plt.plot(sl[:, 0], sl[:, 1], '.', label=date.strftime('%d-%m-%Y')) +plt.legend() \ No newline at end of file