From d4995266c9f768ad16140d5f55a8da4dea31d21d Mon Sep 17 00:00:00 2001 From: Chris Leaman Date: Tue, 12 Feb 2019 16:22:05 +1100 Subject: [PATCH] Create new notebook for comparing models --- notebooks/09_run_comparison.ipynb | 313 ++++++++++++++++++++++++++++++ 1 file changed, 313 insertions(+) create mode 100644 notebooks/09_run_comparison.ipynb diff --git a/notebooks/09_run_comparison.ipynb b/notebooks/09_run_comparison.ipynb new file mode 100644 index 0000000..2072cbe --- /dev/null +++ b/notebooks/09_run_comparison.ipynb @@ -0,0 +1,313 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Run comparison\n", + "Create a comparison between different runs by looking at the different R_high values and storm regimes." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup notebook" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Enable autoreloading of our modules. \n", + "# Most of the code will be located in the /src/ folder, \n", + "# and then called from the notebook.\n", + "%matplotlib inline\n", + "%reload_ext autoreload\n", + "%autoreload" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.core.debugger import set_trace\n", + "\n", + "import pandas as pd\n", + "import numpy as np\n", + "import os\n", + "import decimal\n", + "import plotly\n", + "import plotly.graph_objs as go\n", + "import plotly.plotly as py\n", + "import plotly.tools as tls\n", + "import plotly.figure_factory as ff\n", + "from plotly import tools\n", + "import plotly.io as pio\n", + "from scipy import stats\n", + "import math\n", + "import matplotlib\n", + "from matplotlib import cm\n", + "import colorlover as cl\n", + "from tqdm import tqdm_notebook\n", + "from ipywidgets import widgets, Output\n", + "from IPython.display import display, clear_output, Image, HTML\n", + "from scipy import stats\n", + "from sklearn.metrics import confusion_matrix\n", + "import matplotlib.pyplot as plt\n", + "from scipy.interpolate import interp1d\n", + "from pandas.api.types import CategoricalDtype\n", + "from scipy.interpolate import UnivariateSpline\n", + "from shapely.geometry import Point, LineString" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Matplot lib default settings\n", + "plt.rcParams[\"figure.figsize\"] = (10,6)\n", + "plt.rcParams['axes.grid']=True\n", + "plt.rcParams['grid.alpha'] = 0.5\n", + "plt.rcParams['grid.color'] = \"grey\"\n", + "plt.rcParams['grid.linestyle'] = \"--\"\n", + "plt.rcParams['axes.grid']=True\n", + "\n", + "# https://stackoverflow.com/a/20709149\n", + "matplotlib.rcParams['text.usetex'] = True\n", + "\n", + "matplotlib.rcParams['text.latex.preamble'] = [\n", + " r'\\usepackage{siunitx}', # i need upright \\micro symbols, but you need...\n", + " r'\\sisetup{detect-all}', # ...this to force siunitx to actually use your fonts\n", + " r'\\usepackage{helvet}', # set the normal font here\n", + " r'\\usepackage{amsmath}',\n", + " r'\\usepackage{sansmath}', # load up the sansmath so that math -> helvet\n", + " r'\\sansmath', # <- tricky! -- gotta actually tell tex to use!\n", + "] " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Import data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def df_from_csv(csv, index_col, data_folder='../data/interim'):\n", + " print('Importing {}'.format(csv))\n", + " return pd.read_csv(os.path.join(data_folder,csv), index_col=index_col)\n", + "\n", + "df_waves = df_from_csv('waves.csv', index_col=[0, 1])\n", + "df_tides = df_from_csv('tides.csv', index_col=[0, 1])\n", + "df_profiles = df_from_csv('profiles.csv', index_col=[0, 1, 2])\n", + "df_sites = df_from_csv('sites.csv', index_col=[0])\n", + "df_profile_features_crest_toes = df_from_csv('profile_features_crest_toes.csv', index_col=[0])\n", + "\n", + "# Note that the forecasted data sets should be in the same order for impacts and twls\n", + "impacts = {\n", + " 'forecasted': {\n", + " 'foreshore_slope_sto06': df_from_csv('impacts_forecasted_foreshore_slope_sto06.csv', index_col=[0]),\n", + " 'mean_slope_sto06': df_from_csv('impacts_forecasted_mean_slope_sto06.csv', index_col=[0]),\n", + " 'mean_slope_nie91': df_from_csv('impacts_forecasted_mean_slope_nie91.csv', index_col=[0]),\n", + " 'mean_slope_hol86': df_from_csv('impacts_forecasted_mean_slope_hol86.csv', index_col=[0]),\n", + " },\n", + " 'observed': df_from_csv('impacts_observed.csv', index_col=[0])\n", + " }\n", + "\n", + "\n", + "twls = {\n", + " 'forecasted': {\n", + " 'foreshore_slope_sto06': df_from_csv('twl_foreshore_slope_sto06.csv', index_col=[0, 1]),\n", + " 'mean_slope_sto06':df_from_csv('twl_mean_slope_sto06.csv', index_col=[0, 1]),\n", + " 'mean_slope_nie91':df_from_csv('twl_mean_slope_nie91.csv', index_col=[0, 1]),\n", + " 'mean_slope_hol86':df_from_csv('twl_mean_slope_hol86.csv', index_col=[0, 1]),\n", + " }\n", + "}\n", + "print('Done!')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Get prediction accuracy\n", + "Use [scikit-learn](https://scikit-learn.org/stable/modules/model_evaluation.html#classification-metrics) model evaluation metrics" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pprint\n", + "pp = pprint.PrettyPrinter(indent=2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sklearn.metrics\n", + "\n", + "# Encode the storm regimes values as categorical intgers so we can compare them\n", + "cat_type = CategoricalDtype(\n", + " categories=[\"swash\", \"collision\", \"overwash\", \"inundation\"], ordered=True)\n", + "correct_regime = impacts['observed'].storm_regime.astype(\n", + " cat_type).cat.codes.values\n", + "\n", + "# Define our forecast model names\n", + "models = [model for model in impacts['forecasted']]\n", + "\n", + "# Define the metric we want to calculate for each forecast model\n", + "metrics = [\n", + " 'accuracy_score', 'balanced_accuracy_score', 'confusion_matrix',\n", + " 'classification_report', 'f1_score', 'fbeta_score', 'precision_score', 'recall_score'\n", + "]\n", + "\n", + "# Store results in a nested dictionary by metric\n", + "performance = {metric: {} for metric in metrics}\n", + "\n", + "for model, metric in itertools.product(models, metrics):\n", + "\n", + " # Get predicted storm regims\n", + " df_pred = impacts['forecasted'][model]\n", + " predicted_regime = df_pred.storm_regime.astype(cat_type).cat.codes.values\n", + "\n", + " if metric == 'accuracy_score':\n", + " m = sklearn.metrics.accuracy_score(correct_regime, predicted_regime)\n", + "\n", + " if metric == 'balanced_accuracy_score':\n", + " m = sklearn.metrics.balanced_accuracy_score(correct_regime,\n", + " predicted_regime)\n", + "\n", + " if metric == 'confusion_matrix':\n", + " m = sklearn.metrics.confusion_matrix(\n", + " correct_regime, predicted_regime, labels=[0, 1, 2, 3])\n", + " \n", + " if metric == 'f1_score':\n", + " m = sklearn.metrics.f1_score(correct_regime, predicted_regime, average='weighted')\n", + " \n", + " if metric == 'fbeta_score':\n", + " m = sklearn.metrics.fbeta_score(correct_regime, predicted_regime, average='weighted', beta=1)\n", + " \n", + " if metric == 'precision_score':\n", + " m = sklearn.metrics.precision_score(correct_regime, predicted_regime, average='weighted')\n", + " \n", + " if metric == 'recall_score':\n", + " m = sklearn.metrics.recall_score(correct_regime, predicted_regime, average='weighted')\n", + "# m=1\n", + " \n", + " if metric == 'classification_report':\n", + "# m = sklearn.metrics.classification_report(\n", + "# correct_regime,\n", + "# predicted_regime,\n", + "# labels=[0, 1, 2, 3],\n", + "# target_names=['swash', 'collision', 'overwash', 'inundation'])\n", + "# print(m)\n", + " continue\n", + "\n", + " # Store metric in results dictionary\n", + " performance[metric][model] = m\n", + "\n", + "pp.pprint(performance)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "predicted_regime" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Scatter plot matirx\n", + " - Use [Altair](https://altair-viz.github.io/getting_started/installation.html) for interactivity?\n", + " - Or maybe [Holoviews](https://towardsdatascience.com/pyviz-simplifying-the-data-visualisation-process-in-python-1b6d2cb728f1)?" + ] + } + ], + "metadata": { + "hide_input": false, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.6" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": false + }, + "varInspector": { + "cols": { + "lenName": 16, + "lenType": 16, + "lenVar": 40 + }, + "kernels_config": { + "python": { + "delete_cmd_postfix": "", + "delete_cmd_prefix": "del ", + "library": "var_list.py", + "varRefreshCmd": "print(var_dic_list())" + }, + "r": { + "delete_cmd_postfix": ") ", + "delete_cmd_prefix": "rm(", + "library": "var_list.r", + "varRefreshCmd": "cat(var_dic_list()) " + } + }, + "types_to_exclude": [ + "module", + "function", + "builtin_function_or_method", + "instance", + "_Feature" + ], + "window_display": false + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}