You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
437 lines
13 KiB
Plaintext
437 lines
13 KiB
Plaintext
5 years ago
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Train a new classifier for CoastSat\n",
|
||
|
"\n",
|
||
|
"In this notebook the CoastSat classifier is trained using satellite images from new sites. This can improve the accuracy of the shoreline detection if the users are experiencing issues with the default classifier."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"#### Initial settings"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {
|
||
|
"code_folding": [],
|
||
|
"run_control": {
|
||
|
"marked": false
|
||
|
}
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# load modules\n",
|
||
|
"%load_ext autoreload\n",
|
||
|
"%autoreload 2\n",
|
||
|
"import os, sys\n",
|
||
|
"import numpy as np\n",
|
||
|
"import pickle\n",
|
||
|
"import warnings\n",
|
||
|
"warnings.filterwarnings(\"ignore\")\n",
|
||
|
"import matplotlib.pyplot as plt\n",
|
||
|
"\n",
|
||
|
"# sklearn modules\n",
|
||
|
"from sklearn.model_selection import train_test_split\n",
|
||
|
"from sklearn.neural_network import MLPClassifier\n",
|
||
|
"from sklearn.model_selection import cross_val_score\n",
|
||
|
"from sklearn.externals import joblib\n",
|
||
|
"\n",
|
||
|
"# coastsat modules\n",
|
||
|
"sys.path.insert(0, os.pardir)\n",
|
||
|
"from coastsat import SDS_download, SDS_preprocess, SDS_shoreline, SDS_tools, SDS_classify\n",
|
||
|
"\n",
|
||
|
"# plotting params\n",
|
||
|
"plt.rcParams['font.size'] = 14\n",
|
||
|
"plt.rcParams['xtick.labelsize'] = 12\n",
|
||
|
"plt.rcParams['ytick.labelsize'] = 12\n",
|
||
|
"plt.rcParams['axes.titlesize'] = 12\n",
|
||
|
"plt.rcParams['axes.labelsize'] = 12\n",
|
||
|
"\n",
|
||
|
"# filepaths \n",
|
||
|
"filepath_images = os.path.join(os.getcwd(), 'data')\n",
|
||
|
"filepath_train = os.path.join(os.getcwd(), 'training_data')\n",
|
||
|
"filepath_models = os.path.join(os.getcwd(), 'models')\n",
|
||
|
"\n",
|
||
|
"# settings\n",
|
||
|
"settings ={'filepath_train':filepath_train, # folder where the labelled images will be stored\n",
|
||
|
" 'cloud_thresh':0.9, # percentage of cloudy pixels accepted on the image\n",
|
||
|
" 'cloud_mask_issue':True, # set to True if problems with the default cloud mask \n",
|
||
|
" 'inputs':{'filepath':filepath_images}, # folder where the images are stored\n",
|
||
|
" 'labels':{'sand':1,'white-water':2,'water':3,'other land features':4}, # labels for the classifier\n",
|
||
|
" 'colors':{'sand':[1, 0.65, 0],'white-water':[1,0,1],'water':[0.1,0.1,0.7],'other land features':[0.8,0.8,0.1]},\n",
|
||
|
" 'tolerance':0.01, # this is the pixel intensity tolerance, when using flood fill for sandy pixels\n",
|
||
|
" # set to 0 to select one pixel at a time\n",
|
||
|
" }\n",
|
||
|
" \n",
|
||
|
"# read kml files for the training sites\n",
|
||
|
"filepath_sites = os.path.join(os.getcwd(), 'training_sites')\n",
|
||
|
"train_sites = os.listdir(filepath_sites)\n",
|
||
|
"print('Sites for training:\\n%s\\n'%train_sites)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### 1. Download images\n",
|
||
|
"\n",
|
||
|
"For each site on which you want to train the classifier, save a .kml file with the region of interest (5 vertices clockwise, first and last points are the same, can be created from Google myMaps) in the folder *\\training_sites*.\n",
|
||
|
"\n",
|
||
|
"You only need a few images (~10) to train the classifier."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {
|
||
|
"code_folding": []
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# dowload images at the sites\n",
|
||
|
"dates = ['2019-01-01', '2019-07-01']\n",
|
||
|
"sat_list = 'L8'\n",
|
||
|
"for site in train_sites:\n",
|
||
|
" polygon = SDS_tools.polygon_from_kml(os.path.join(filepath_sites,site))\n",
|
||
|
" sitename = site[:site.find('.')] \n",
|
||
|
" inputs = {'polygon':polygon, 'dates':dates, 'sat_list':sat_list,\n",
|
||
|
" 'sitename':sitename, 'filepath':filepath_images}\n",
|
||
|
" print(sitename)\n",
|
||
|
" metadata = SDS_download.retrieve_images(inputs)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### 2. Label images\n",
|
||
|
"\n",
|
||
|
"Label the images into 4 classes: sand, white-water, water and other land features.\n",
|
||
|
"\n",
|
||
|
"The labelled images are saved in the *filepath_train* and can be visualised afterwards for quality control. If yo make a mistake, don't worry, this can be fixed later by deleting the labelled image."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {
|
||
|
"code_folding": [],
|
||
|
"run_control": {
|
||
|
"marked": true
|
||
|
}
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# label the images with an interactive annotator\n",
|
||
|
"%matplotlib qt\n",
|
||
|
"for site in train_sites:\n",
|
||
|
" settings['inputs']['sitename'] = site[:site.find('.')] \n",
|
||
|
" # load metadata\n",
|
||
|
" metadata = SDS_download.get_metadata(settings['inputs'])\n",
|
||
|
" # label images\n",
|
||
|
" SDS_classify.label_images(metadata,settings)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### 3. Train Classifier\n",
|
||
|
"\n",
|
||
|
"A Multilayer Perceptron is trained with *scikit-learn*. To train the classifier, the training data needs to be loaded.\n",
|
||
|
"\n",
|
||
|
"You can use the data that was labelled here and/or the original CoastSat training data."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# load labelled images\n",
|
||
|
"features = SDS_classify.load_labels(train_sites, settings)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# you can also load the original CoastSat training data (and optionally merge it with your labelled data)\n",
|
||
|
"with open(os.path.join(settings['filepath_train'], 'CoastSat_training_set_L8.pkl'), 'rb') as f:\n",
|
||
|
" features_original = pickle.load(f)\n",
|
||
|
"for key in features_original.keys():\n",
|
||
|
" print('%s : %d pixels'%(key,len(features_original[key])))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Run this section to combine the original training data with your labelled data:"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {
|
||
|
"code_folding": []
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# add the white-water data from the original training data\n",
|
||
|
"features['white-water'] = np.append(features['white-water'], features_original['white-water'], axis=0)\n",
|
||
|
"# or merge all the classes\n",
|
||
|
"# for key in features.keys():\n",
|
||
|
"# features[key] = np.append(features[key], features_original[key], axis=0)\n",
|
||
|
"# features = features_original \n",
|
||
|
"for key in features.keys():\n",
|
||
|
" print('%s : %d pixels'%(key,len(features[key])))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"[OPTIONAL] As the classes do not have the same number of pixels, it is good practice to subsample the very large classes (in this case 'water' and 'other land features'):"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# subsample randomly the land and water classes\n",
|
||
|
"# as the most important class is 'sand', the number of samples should be close to the number of sand pixels\n",
|
||
|
"n_samples = 5000\n",
|
||
|
"for key in ['water', 'other land features']:\n",
|
||
|
" features[key] = features[key][np.random.choice(features[key].shape[0], n_samples, replace=False),:]\n",
|
||
|
"# print classes again\n",
|
||
|
"for key in features.keys():\n",
|
||
|
" print('%s : %d pixels'%(key,len(features[key])))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"When the labelled data is ready, format it into X, a matrix of features, and y, a vector of labels:"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {
|
||
|
"code_folding": [],
|
||
|
"run_control": {
|
||
|
"marked": true
|
||
|
}
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# format into X (features) and y (labels) \n",
|
||
|
"classes = ['sand','white-water','water','other land features']\n",
|
||
|
"labels = [1,2,3,0]\n",
|
||
|
"X,y = SDS_classify.format_training_data(features, classes, labels)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Divide the dataset into train and test: train on 70% of the data and evaluate on the other 30%:"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {
|
||
|
"code_folding": [],
|
||
|
"run_control": {
|
||
|
"marked": true
|
||
|
}
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# divide in train and test and evaluate the classifier\n",
|
||
|
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, shuffle=True, random_state=0)\n",
|
||
|
"classifier = MLPClassifier(hidden_layer_sizes=(100,50), solver='adam')\n",
|
||
|
"classifier.fit(X_train,y_train)\n",
|
||
|
"print('Accuracy: %0.4f' % classifier.score(X_test,y_test))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"[OPTIONAL] A more robust evaluation is 10-fold cross-validation (may take a few minutes to run):"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {
|
||
|
"code_folding": [],
|
||
|
"run_control": {
|
||
|
"marked": true
|
||
|
}
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# cross-validation\n",
|
||
|
"scores = cross_val_score(classifier, X, y, cv=10)\n",
|
||
|
"print('Accuracy: %0.4f (+/- %0.4f)' % (scores.mean(), scores.std() * 2))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Plot a confusion matrix:"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {
|
||
|
"code_folding": []
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# plot confusion matrix\n",
|
||
|
"%matplotlib inline\n",
|
||
|
"y_pred = classifier.predict(X_test)\n",
|
||
|
"SDS_classify.plot_confusion_matrix(y_test, y_pred,\n",
|
||
|
" classes=['other land features','sand','white-water','water'],\n",
|
||
|
" normalize=False);"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"When satisfied with the accuracy and confusion matrix, train the model using ALL the training data and save it:"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# train with all the data and save the final classifier\n",
|
||
|
"classifier = MLPClassifier(hidden_layer_sizes=(100,50), solver='adam')\n",
|
||
|
"classifier.fit(X,y)\n",
|
||
|
"joblib.dump(classifier, os.path.join(filepath_models, 'NN_4classes_Landsat_test.pkl'))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### 4. Evaluate the classifier\n",
|
||
|
"\n",
|
||
|
"Load a classifier that you have trained (specify the classifiers filename) and evaluate it on the satellite images.\n",
|
||
|
"\n",
|
||
|
"This section will save the output of the classification for each site in a directory named \\evaluation."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# load and evaluate a classifier\n",
|
||
|
"%matplotlib qt\n",
|
||
|
"classifier = joblib.load(os.path.join(filepath_models, 'NN_4classes_Landsat_test.pkl'))\n",
|
||
|
"settings['output_epsg'] = 3857\n",
|
||
|
"settings['min_beach_area'] = 4500\n",
|
||
|
"settings['buffer_size'] = 200\n",
|
||
|
"settings['min_length_sl'] = 200\n",
|
||
|
"settings['cloud_thresh'] = 0.5\n",
|
||
|
"# visualise the classified images\n",
|
||
|
"for site in train_sites:\n",
|
||
|
" settings['inputs']['sitename'] = site[:site.find('.')] \n",
|
||
|
" # load metadata\n",
|
||
|
" metadata = SDS_download.get_metadata(settings['inputs'])\n",
|
||
|
" # plot the classified images\n",
|
||
|
" SDS_classify.evaluate_classifier(classifier,metadata,settings)"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.7.3"
|
||
|
},
|
||
|
"toc": {
|
||
|
"base_numbering": 1,
|
||
|
"nav_menu": {},
|
||
|
"number_sections": false,
|
||
|
"sideBar": true,
|
||
|
"skip_h1_title": false,
|
||
|
"title_cell": "Table of Contents",
|
||
|
"title_sidebar": "Contents",
|
||
|
"toc_cell": false,
|
||
|
"toc_position": {},
|
||
|
"toc_section_display": true,
|
||
|
"toc_window_display": false
|
||
|
},
|
||
|
"varInspector": {
|
||
|
"cols": {
|
||
|
"lenName": 16,
|
||
|
"lenType": 16,
|
||
|
"lenVar": 40
|
||
|
},
|
||
|
"kernels_config": {
|
||
|
"python": {
|
||
|
"delete_cmd_postfix": "",
|
||
|
"delete_cmd_prefix": "del ",
|
||
|
"library": "var_list.py",
|
||
|
"varRefreshCmd": "print(var_dic_list())"
|
||
|
},
|
||
|
"r": {
|
||
|
"delete_cmd_postfix": ") ",
|
||
|
"delete_cmd_prefix": "rm(",
|
||
|
"library": "var_list.r",
|
||
|
"varRefreshCmd": "cat(var_dic_list()) "
|
||
|
}
|
||
|
},
|
||
|
"types_to_exclude": [
|
||
|
"module",
|
||
|
"function",
|
||
|
"builtin_function_or_method",
|
||
|
"instance",
|
||
|
"_Feature"
|
||
|
],
|
||
|
"window_display": false
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 2
|
||
|
}
|