# -*- coding: utf-8 -*- """ Created on Thu Mar 1 11:20:35 2018 @author: z5030440 """ """This script contains the functions needed for satellite derived shoreline (SDS) extraction""" # Initial settings import numpy as np import matplotlib.pyplot as plt import pdb import ee # other modules from osgeo import gdal, ogr, osr import tempfile from urllib.request import urlretrieve import zipfile # image processing modules import skimage.filters as filters import skimage.exposure as exposure import skimage.transform as transform import sklearn.decomposition as decomposition import skimage.measure as measure import skimage.morphology as morphology # machine learning modules from sklearn.cluster import KMeans from sklearn.neural_network import MLPClassifier from sklearn.externals import joblib # import own modules from functions.utils import * # Download from ee server function def download_tif(image, polygon, bandsId): """downloads tif image (region and bands) from the ee server and stores it in a temp file""" url = ee.data.makeDownloadUrl(ee.data.getDownloadId({ 'image': image.serialize(), 'region': polygon, 'bands': bandsId, 'filePerBand': 'false', 'name': 'data', })) local_zip, headers = urlretrieve(url) with zipfile.ZipFile(local_zip) as local_zipfile: return local_zipfile.extract('data.tif', tempfile.mkdtemp()) def load_image(image, polygon, bandsId): """ Loads an ee.Image() as a np.array. e.Image() is retrieved from the EE database. The geographic area and bands to select can be specified KV WRL 2018 Arguments: ----------- image: ee.Image() image objec from the EE database polygon: list coordinates of the points creating a polygon. Each point is a list with 2 values bandsId: list bands to select, each band is a dictionnary in the list containing the following keys: crs, crs_transform, data_type and id. NOTE: you have to remove the key dimensions, otherwise the entire image is retrieved. Returns: ----------- image_array : np.ndarray An array containing the image (2D if one band, otherwise 3D) georef : np.ndarray 6 element vector containing the crs_parameters [X_ul_corner Xscale Xshear Y_ul_corner Yshear Yscale] """ local_tif_filename = download_tif(image, polygon, bandsId) dataset = gdal.Open(local_tif_filename, gdal.GA_ReadOnly) georef = np.array(dataset.GetGeoTransform()) bands = [dataset.GetRasterBand(i + 1).ReadAsArray() for i in range(dataset.RasterCount)] return np.stack(bands, 2), georef def create_cloud_mask(im_qa, satname, plot_bool): """ Creates a cloud mask from the image containing the QA band information KV WRL 2018 Arguments: ----------- im_qa: np.ndarray Image containing the QA band satname: string short name for the satellite (L8, L7, S2) plot_bool: boolean True if plot is wanted Returns: ----------- cloud_mask : np.ndarray of booleans A boolean array with True where the cloud are present """ # convert QA bits if satname == 'L8': cloud_values = [2800, 2804, 2808, 2812, 6896, 6900, 6904, 6908] elif satname == 'L7': cloud_values = [752, 756, 760, 764] cloud_mask = np.isin(im_qa, cloud_values) # remove isolated cloud pixels (there are some in the swash and they cause problems) if sum(sum(cloud_mask)) > 0: morphology.remove_small_objects(cloud_mask, min_size=10, connectivity=1, in_place=True) if plot_bool: plt.figure() plt.imshow(cloud_mask, cmap='gray') plt.draw() #cloud_shadow_values = [2976, 2980, 2984, 2988, 3008, 3012, 3016, 3020] #cloud_shadow_mask = np.isin(im_qa, cloud_shadow_values) return cloud_mask def read_eeimage(im, polygon, sat_name, plot_bool): """ Read an ee.Image() object and returns the panchromatic band, multispectral bands (B, G, R, NIR, SWIR) and a cloud mask. All outputs are at 15m resolution (bilinear interpolation for the multispectral bands) KV WRL 2018 Arguments: ----------- im: ee.Image() Image to read from the Google Earth Engine database plot_bool: boolean True if plot is wanted Returns: ----------- im_pan: np.ndarray (2D) The panchromatic band (15m) im_ms: np.ndarray (3D) The multispectral bands interpolated at 15m im_cloud: np.ndarray (2D) The cloud mask at 15m crs_params: list EPSG code and affine transformation parameters """ im_dic = im.getInfo() # save metadata im_meta = im_dic.get('properties') meta = {'timestamp':im_meta['system:time_start'], 'date_acquired':im_meta['DATE_ACQUIRED'], 'geom_rmse_model':im_meta['GEOMETRIC_RMSE_MODEL'], 'gcp_model':im_meta['GROUND_CONTROL_POINTS_MODEL'], 'quality':im_meta['IMAGE_QUALITY_OLI'], 'sun_azimuth':im_meta['SUN_AZIMUTH'], 'sun_elevation':im_meta['SUN_ELEVATION']} im_bands = im_dic.get('bands') # delete dimensions key from dictionnary, otherwise the entire image is extracted for i in range(len(im_bands)): del im_bands[i]['dimensions'] # load panchromatic band pan_band = [im_bands[7]] im_pan, crs_pan = load_image(im, polygon, pan_band) im_pan = im_pan[:,:,0] # load the multispectral bands (B2,B3,B4,B5,B6) = (blue,green,red,nir,swir1) ms_bands = [im_bands[1], im_bands[2], im_bands[3], im_bands[4], im_bands[5]] im_ms_30m, crs_ms = load_image(im, polygon, ms_bands) # create cloud mask qa_band = [im_bands[11]] im_qa, crs_qa = load_image(im, polygon, qa_band) im_qa = im_qa[:,:,0] im_cloud = create_cloud_mask(im_qa, sat_name, plot_bool) im_cloud = transform.resize(im_cloud, (im_pan.shape[0], im_pan.shape[1]), order=0, preserve_range=True, mode='constant').astype('bool_') # resize the image using bilinear interpolation (order 1) im_ms = transform.resize(im_ms_30m,(im_pan.shape[0], im_pan.shape[1]), order=1, preserve_range=True, mode='constant') # check if -inf values (means out of image) and add to cloud mask im_inf = np.isin(im_ms[:,:,0], -np.inf) im_nan = np.isnan(im_ms[:,:,0]) im_cloud = np.logical_or(np.logical_or(im_cloud, im_inf), im_nan) # get the crs parameters for the image at 15m and 30m resolution crs = {'crs_15m':crs_pan, 'crs_30m':crs_ms, 'epsg_code':int(pan_band[0]['crs'][5:])} if plot_bool: # if there are -inf in the image, set them to 0 before plotting if sum(sum(np.isin(im_ms_30m[:,:,0], -np.inf).astype(int))) > 0: idx = np.isin(im_ms_30m[:,:,0], -np.inf) im_ms_30m[idx,0] = 0; im_ms_30m[idx,1] = 0; im_ms_30m[idx,2] = 0; im_ms_30m[idx,3] = 0; im_ms_30m[idx,4] = 0 plt.figure() plt.subplot(221) plt.imshow(im_pan, cmap='gray') plt.title('PANCHROMATIC') plt.subplot(222) plt.imshow(im_ms_30m[:,:,[2,1,0]]) plt.title('RGB') plt.subplot(223) plt.imshow(im_ms_30m[:,:,3], cmap='gray') plt.title('NIR') plt.subplot(224) plt.imshow(im_ms_30m[:,:,4], cmap='gray') plt.title('SWIR') plt.show() return im_pan, im_ms, im_cloud, crs, meta def rescale_image_intensity(im, cloud_mask, prob_high, plot_bool): """ Rescales the intensity of an image (multispectral or single band) by applying a cloud mask and clipping the prob_high upper percentile. This functions allows to stretch the contrast of an image. KV WRL 2018 Arguments: ----------- im: np.ndarray Image to rescale, can be 3D (multispectral) or 2D (single band) cloud_mask: np.ndarray 2D cloud mask with True where cloud pixels are prob_high: float probability of exceedence used to calculate the upper percentile plot_bool: boolean True if plot is wanted Returns: ----------- im_adj: np.ndarray The rescaled image """ prc_low = 0 # lower percentile vec_mask = cloud_mask.reshape(im.shape[0] * im.shape[1]) if plot_bool: plt.figure() if len(im.shape) > 2: vec = im.reshape(im.shape[0] * im.shape[1], im.shape[2]) vec_adj = np.ones((len(vec_mask), im.shape[2])) * np.nan for i in range(im.shape[2]): prc_high = np.percentile(vec[~vec_mask, i], prob_high) vec_rescaled = exposure.rescale_intensity(vec[~vec_mask, i], in_range=(prc_low, prc_high)) vec_adj[~vec_mask,i] = vec_rescaled if plot_bool: plt.subplot(np.floor(im.shape[2]/2) + 1, np.floor(im.shape[2]/2), i+1) plt.hist(vec[~vec_mask, i], bins=200, label='original') plt.hist(vec_rescaled, bins=200, alpha=0.5, label='rescaled') plt.legend() plt.title('Band' + str(i+1)) plt.show() im_adj = vec_adj.reshape(im.shape[0], im.shape[1], im.shape[2]) if plot_bool: plt.figure() ax1 = plt.subplot(121) plt.imshow(im[:,:,[2,1,0]]) plt.axis('off') plt.title('Original') ax2 = plt.subplot(122, sharex=ax1, sharey=ax1) plt.imshow(im_adj[:,:,[2,1,0]]) plt.axis('off') plt.title('Rescaled') plt.show() else: vec = im.reshape(im.shape[0] * im.shape[1]) vec_adj = np.ones(len(vec_mask)) * np.nan prc_high = np.percentile(vec[~vec_mask], prob_high) vec_rescaled = exposure.rescale_intensity(vec[~vec_mask], in_range=(prc_low, prc_high)) vec_adj[~vec_mask] = vec_rescaled if plot_bool: plt.hist(vec[~vec_mask], bins=200, label='original') plt.hist(vec_rescaled, bins=200, alpha=0.5, label='rescaled') plt.legend() plt.title('Single band') plt.show() im_adj = vec_adj.reshape(im.shape[0], im.shape[1]) if plot_bool: plt.figure() ax1 = plt.subplot(121) plt.imshow(im, cmap='gray') plt.axis('off') plt.title('Original') ax2 = plt.subplot(122, sharex=ax1, sharey=ax1) plt.imshow(im_adj, cmap='gray') plt.axis('off') plt.title('Rescaled') plt.show() return im_adj def hist_match(source, template): """ Adjust the pixel values of a grayscale image such that its histogram matches that of a target image Arguments: ----------- source: np.ndarray Image to transform; the histogram is computed over the flattened array template: np.ndarray Template image; can have different dimensions to source Returns: ----------- matched: np.ndarray The transformed output image """ oldshape = source.shape source = source.ravel() template = template.ravel() # get the set of unique pixel values and their corresponding indices and # counts s_values, bin_idx, s_counts = np.unique(source, return_inverse=True, return_counts=True) t_values, t_counts = np.unique(template, return_counts=True) # take the cumsum of the counts and normalize by the number of pixels to # get the empirical cumulative distribution functions for the source and # template images (maps pixel value --> quantile) s_quantiles = np.cumsum(s_counts).astype(np.float64) s_quantiles /= s_quantiles[-1] t_quantiles = np.cumsum(t_counts).astype(np.float64) t_quantiles /= t_quantiles[-1] # interpolate linearly to find the pixel values in the template image # that correspond most closely to the quantiles in the source image interp_t_values = np.interp(s_quantiles, t_quantiles, t_values) return interp_t_values[bin_idx].reshape(oldshape) def pansharpen(im_ms, im_pan, cloud_mask, plot_bool): """ Pansharpens a multispectral image (3D), using the panchromatic band (2D) and a cloud mask KV WRL 2018 Arguments: ----------- im_ms: np.ndarray Multispectral image to pansharpen (3D) im_pan: np.ndarray Panchromatic band (2D) cloud_mask: np.ndarray 2D cloud mask with True where cloud pixels are plot_bool: boolean True if plot is wanted Returns: ----------- im_ms_ps: np.ndarray Pansharpened multisoectral image (3D) """ # reshape image into vector and apply cloud mask vec = im_ms.reshape(im_ms.shape[0] * im_ms.shape[1], im_ms.shape[2]) vec_mask = cloud_mask.reshape(im_ms.shape[0] * im_ms.shape[1]) vec = vec[~vec_mask, :] # apply PCA to RGB bands pca = decomposition.PCA() vec_pcs = pca.fit_transform(vec) # replace 1st PC with pan band (after matching histograms) vec_pan = im_pan.reshape(im_pan.shape[0] * im_pan.shape[1]) vec_pan = vec_pan[~vec_mask] vec_pcs[:,0] = hist_match(vec_pan, vec_pcs[:,0]) vec_ms_ps = pca.inverse_transform(vec_pcs) # reshape vector into image vec_ms_ps_full = np.ones((len(vec_mask), im_ms.shape[2])) * np.nan vec_ms_ps_full[~vec_mask,:] = vec_ms_ps im_ms_ps = vec_ms_ps_full.reshape(im_ms.shape[0], im_ms.shape[1], im_ms.shape[2]) if plot_bool: plt.figure() ax1 = plt.subplot(121) plt.imshow(rescale_image_intensity(im_ms[:,:,[2,1,0]], cloud_mask, 100, False)) plt.axis('off') plt.title('Original') ax2 = plt.subplot(122, sharex=ax1, sharey=ax1) plt.imshow(rescale_image_intensity(im_ms_ps[:,:,[2,1,0]], cloud_mask, 100, False)) plt.axis('off') plt.title('Pansharpened') plt.show() return im_ms_ps def nd_index(im1, im2, cloud_mask, plot_bool): """ Computes normalised difference index on 2 images (2D), given a cloud mask (2D) KV WRL 2018 Arguments: ----------- im1, im2: np.ndarray Images (2D) with which to calculate the ND index cloud_mask: np.ndarray 2D cloud mask with True where cloud pixels are plot_bool: boolean True if plot is wanted Returns: ----------- im_nd: np.ndarray Image (2D) containing the ND index """ vec_mask = cloud_mask.reshape(im1.shape[0] * im1.shape[1]) vec_nd = np.ones(len(vec_mask)) * np.nan vec1 = im1.reshape(im1.shape[0] * im1.shape[1]) vec2 = im2.reshape(im2.shape[0] * im2.shape[1]) temp = np.divide(vec1[~vec_mask] - vec2[~vec_mask], vec1[~vec_mask] + vec2[~vec_mask]) vec_nd[~vec_mask] = temp im_nd = vec_nd.reshape(im1.shape[0], im1.shape[1]) if plot_bool: plt.figure() plt.imshow(im_nd, cmap='seismic') plt.colorbar() plt.title('Normalised index') plt.show() return im_nd def find_wl_contours(im_ndwi, cloud_mask, min_contour_points, plot_bool): """ Computes normalised difference index on 2 images (2D), given a cloud mask (2D) KV WRL 2018 Arguments: ----------- im_ndwi: np.ndarray Image (2D) with the NDWI (water index) cloud_mask: np.ndarray 2D cloud mask with True where cloud pixels are min_contour_points: int minimum number of points in each contour line plot_bool: boolean True if plot is wanted Returns: ----------- contours_wl: list of np.arrays contains the (row,column) coordinates of the contour lines """ # reshape image to vector vec_ndwi = im_ndwi.reshape(im_ndwi.shape[0] * im_ndwi.shape[1]) vec_mask = cloud_mask.reshape(cloud_mask.shape[0] * cloud_mask.shape[1]) vec = vec_ndwi[~vec_mask] # apply otsu's threshold t_otsu = filters.threshold_otsu(vec) # use Marching Squares algorithm to detect contours on ndwi image contours = measure.find_contours(im_ndwi, t_otsu) # filter water lines contours_wl = [] for i, contour in enumerate(contours): # remove contour points that are around clouds (nan values) if np.any(np.isnan(contour)): index_nan = np.where(np.isnan(contour))[0] contour = np.delete(contour, index_nan, axis=0) # remove contours that have only few points (less than min_contour_points) if contour.shape[0] > min_contour_points: contours_wl.append(contour) if plot_bool: # plot otsu's histogram segmentation plt.figure() vals = plt.hist(vec, bins=200) plt.plot([t_otsu, t_otsu],[0, np.max(vals[0])], 'r-', label='Otsu threshold') plt.legend() plt.show() # plot the water line contours on top of water index plt.figure() plt.imshow(im_ndwi, cmap='seismic') plt.colorbar() for i,contour in enumerate(contours_wl): plt.plot(contour[:, 1], contour[:, 0], linewidth=3, color='k') plt.axis('image') plt.title('Detected water lines') plt.show() return contours_wl def convert_pix2world(points, crs_vec): """ Converts pixel coordinates (row,columns) to world projected coordinates performing an affine transformation KV WRL 2018 Arguments: ----------- points: np.ndarray or list of np.ndarray array with 2 columns (rows first and columns second) crs_vec: np.ndarray vector of 6 elements [Xtr, Xscale, Xshear, Ytr, Yshear, Yscale] Returns: ----------- points_converted: np.ndarray or list of np.ndarray converted coordinates, first columns with X and second column with Y """ # make affine transformation matrix aff_mat = np.array([[crs_vec[1], crs_vec[2], crs_vec[0]], [crs_vec[4], crs_vec[5], crs_vec[3]], [0, 0, 1]]) # create affine transformation tform = transform.AffineTransform(aff_mat) if type(points) is list: points_converted = [] # iterate over the list for i, arr in enumerate(points): tmp = arr[:,[1,0]] points_converted.append(tform(tmp)) elif type(points) is np.ndarray: tmp = points[:,[1,0]] points_converted = tform(tmp) else: print('invalid input type') raise return points_converted def convert_epsg(points, epsg_in, epsg_out): """ Converts from one spatial reference to another using the epsg codes KV WRL 2018 Arguments: ----------- points: np.ndarray or list of np.ndarray array with 2 columns (rows first and columns second) epsg_in: int epsg code of the spatial reference in which the input is epsg_out: int epsg code of the spatial reference in which the output will be Returns: ----------- points_converted: np.ndarray or list of np.ndarray converted coordinates """ # define input and output spatial references inSpatialRef = osr.SpatialReference() inSpatialRef.ImportFromEPSG(epsg_in) outSpatialRef = osr.SpatialReference() outSpatialRef.ImportFromEPSG(epsg_out) # create a coordinates transform coordTransform = osr.CoordinateTransformation(inSpatialRef, outSpatialRef) # transform points if type(points) is list: points_converted = [] # iterate over the list for i, arr in enumerate(points): points_converted.append(np.array(coordTransform.TransformPoints(arr))) elif type(points) is np.ndarray: points_converted = np.array(coordTransform.TransformPoints(points)) else: print('invalid input type') raise return points_converted def classify_sand_unsupervised(im_ms_ps, im_pan, cloud_mask, wl_pix, buffer_size, min_beach_size, plot_bool): """ Classifies sand pixels using an unsupervised algorithm (Kmeans) Set buffer size to False if you want to classify the entire image, otherwise buffer size defines the buffer around the shoreline in which pixels are considered for classification. This classification is not robust and is only used to train a supervised algorithm KV WRL 2018 Arguments: ----------- im_ms_ps: np.ndarray Pansharpened RGB + downsampled NIR and SWIR im_pan: Panchromatic band cloud_mask: np.ndarray 2D cloud mask with True where cloud pixels are wl_pix: list of np.ndarray list of arrays containig the pixel coordinates of the water line buffer_size: int or False radius of the disk used to create a buffer around the water line when False, the entire image is considered for kmeans min_beach_size: int minimum number of connected pixels belonging to a single beach plot_bool: boolean True if plot is wanted Returns: ----------- im_sand: np.ndarray 2D binary image containing True where sand pixels are located """ # reshape the 2D images into vectors vec_ms_ps = im_ms_ps.reshape(im_ms_ps.shape[0] * im_ms_ps.shape[1], im_ms_ps.shape[2]) vec_pan = im_pan.reshape(im_pan.shape[0]*im_pan.shape[1]) vec_mask = cloud_mask.reshape(im_ms_ps.shape[0] * im_ms_ps.shape[1]) # add B,G,R,NIR and pan bands to the vector of features vec_features = np.zeros((vec_ms_ps.shape[0], 5)) vec_features[:,[0,1,2,3]] = vec_ms_ps[:,[0,1,2,3]] vec_features[:,4] = vec_pan if buffer_size: # create binary image with ones where the detected water lines is im_buffer = np.zeros((im_ms_ps.shape[0], im_ms_ps.shape[1])) for i, contour in enumerate(wl_pix): indices = [(int(_[0]), int(_[1])) for _ in list(np.round(contour))] for j, idx in enumerate(indices): im_buffer[idx] = 1 # perform a dilation on the binary image se = morphology.disk(buffer_size) im_buffer = morphology.binary_dilation(im_buffer, se) vec_buffer = (im_buffer == 1).reshape(im_ms_ps.shape[0] * im_ms_ps.shape[1]) else: vec_buffer = np.ones((vec_pan.shape[0])) # add cloud mask to buffer vec_buffer= np.logical_and(vec_buffer, ~vec_mask) # perform kmeans (6 clusters) kmeans = KMeans(n_clusters=6, random_state=0).fit(vec_features[vec_buffer,:]) labels = np.ones((len(vec_mask))) * np.nan labels[vec_buffer] = kmeans.labels_ im_labels = labels.reshape(im_ms_ps.shape[0], im_ms_ps.shape[1]) # find the class with maximum reflection in the B,G,R,Pan im_sand = im_labels == np.argmax(np.mean(kmeans.cluster_centers_[:,[0,1,2,4]], axis=1)) im_sand = morphology.remove_small_objects(im_sand, min_size=min_beach_size, connectivity=2) im_sand = morphology.binary_erosion(im_sand, morphology.disk(1)) # im_sand = morphology.binary_dilation(im_sand, morphology.disk(1)) if plot_bool: im = np.copy(rescale_image_intensity(im_ms_ps[:,:,[2,1,0]], cloud_mask, 100, False)) im[im_sand,0] = 0 im[im_sand,1] = 0 im[im_sand,2] = 1 plt.figure() plt.imshow(im) plt.axis('image') plt.title('Sand classification') plt.show() return im_sand def classify_image_NN(im_ms_ps, im_pan, cloud_mask, plot_bool): """ Classifies every pixel in the image in one of 4 classes: - sand --> label = 1 - whitewater (breaking waves and swash) --> label = 2 - water --> label = 3 - other (vegetation, buildings, rocks...) --> label = 0 The classifier is a Neural Network, trained with 7000 pixels for the class SAND and 1500 pixels for each of the other classes. This is because the class of interest for my application is SAND and I wanted to minimize the classification error for that class KV WRL 2018 Arguments: ----------- im_ms_ps: np.ndarray Pansharpened RGB + downsampled NIR and SWIR im_pan: Panchromatic band cloud_mask: np.ndarray 2D cloud mask with True where cloud pixels are plot_bool: boolean True if plot is wanted Returns: ----------- im_labels: np.ndarray 2D binary image containing True where sand pixels are located """ # load classifier clf = joblib.load('functions/NeuralNet_classif.pkl') # calculate features n_features = 10 im_features = np.zeros((im_ms_ps.shape[0], im_ms_ps.shape[1], n_features)) im_features[:,:,[0,1,2,3,4]] = im_ms_ps im_features[:,:,5] = im_pan im_features[:,:,6] = nd_index(im_ms_ps[:,:,3], im_ms_ps[:,:,1], cloud_mask, False) # (NIR-G) im_features[:,:,7] = nd_index(im_ms_ps[:,:,3], im_ms_ps[:,:,2], cloud_mask, False) # ND(NIR-R) im_features[:,:,8] = nd_index(im_ms_ps[:,:,0], im_ms_ps[:,:,2], cloud_mask, False) # ND(B-R) im_features[:,:,9] = nd_index(im_ms_ps[:,:,4], im_ms_ps[:,:,1], cloud_mask, False) # ND(SWIR-G) # remove NaNs and clouds vec_features = im_features.reshape((im_ms_ps.shape[0] * im_ms_ps.shape[1], n_features)) vec_cloud = cloud_mask.reshape(cloud_mask.shape[0]*cloud_mask.shape[1]) vec_nan = np.any(np.isnan(vec_features), axis=1) vec_mask = np.logical_or(vec_cloud, vec_nan) vec_features = vec_features[~vec_mask, :] # predict with NN classifier labels = clf.predict(vec_features) # recompose image vec_classif = np.zeros((cloud_mask.shape[0]*cloud_mask.shape[1])) vec_classif[~vec_mask] = labels im_classif = vec_classif.reshape((im_ms_ps.shape[0], im_ms_ps.shape[1])) # labels im_sand = im_classif == 1 im_sand = morphology.remove_small_objects(im_sand, min_size=20, connectivity=2) im_swash = im_classif == 2 im_water = im_classif == 3 im_labels = np.stack((im_sand,im_swash,im_water), axis=-1) if plot_bool: im_display = rescale_image_intensity(im_ms_ps[:,:,[2,1,0]], cloud_mask, 100, False) im = np.copy(im_display) colours = np.array([[1,128/255,0/255],[204/255,1,1],[0,0,204/255]]) for k in range(0,im_labels.shape[2]): im[im_labels[:,:,k],0] = colours[k,0] im[im_labels[:,:,k],1] = colours[k,1] im[im_labels[:,:,k],2] = colours[k,2] plt.figure() ax1 = plt.subplot(121) plt.imshow(im_display) plt.axis('off') plt.title('Image') ax2 = plt.subplot(122, sharex=ax1, sharey=ax1) plt.imshow(im) plt.axis('off') plt.title('NN classifier') mng = plt.get_current_fig_manager() mng.window.showMaximized() plt.tight_layout() plt.draw() return im_classif, im_labels