Create new notebook for comparing models

develop
Chris Leaman 6 years ago
parent 3bfb13e9d6
commit d4995266c9

@ -0,0 +1,313 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Run comparison\n",
"Create a comparison between different runs by looking at the different R_high values and storm regimes."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup notebook"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Enable autoreloading of our modules. \n",
"# Most of the code will be located in the /src/ folder, \n",
"# and then called from the notebook.\n",
"%matplotlib inline\n",
"%reload_ext autoreload\n",
"%autoreload"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from IPython.core.debugger import set_trace\n",
"\n",
"import pandas as pd\n",
"import numpy as np\n",
"import os\n",
"import decimal\n",
"import plotly\n",
"import plotly.graph_objs as go\n",
"import plotly.plotly as py\n",
"import plotly.tools as tls\n",
"import plotly.figure_factory as ff\n",
"from plotly import tools\n",
"import plotly.io as pio\n",
"from scipy import stats\n",
"import math\n",
"import matplotlib\n",
"from matplotlib import cm\n",
"import colorlover as cl\n",
"from tqdm import tqdm_notebook\n",
"from ipywidgets import widgets, Output\n",
"from IPython.display import display, clear_output, Image, HTML\n",
"from scipy import stats\n",
"from sklearn.metrics import confusion_matrix\n",
"import matplotlib.pyplot as plt\n",
"from scipy.interpolate import interp1d\n",
"from pandas.api.types import CategoricalDtype\n",
"from scipy.interpolate import UnivariateSpline\n",
"from shapely.geometry import Point, LineString"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Matplot lib default settings\n",
"plt.rcParams[\"figure.figsize\"] = (10,6)\n",
"plt.rcParams['axes.grid']=True\n",
"plt.rcParams['grid.alpha'] = 0.5\n",
"plt.rcParams['grid.color'] = \"grey\"\n",
"plt.rcParams['grid.linestyle'] = \"--\"\n",
"plt.rcParams['axes.grid']=True\n",
"\n",
"# https://stackoverflow.com/a/20709149\n",
"matplotlib.rcParams['text.usetex'] = True\n",
"\n",
"matplotlib.rcParams['text.latex.preamble'] = [\n",
" r'\\usepackage{siunitx}', # i need upright \\micro symbols, but you need...\n",
" r'\\sisetup{detect-all}', # ...this to force siunitx to actually use your fonts\n",
" r'\\usepackage{helvet}', # set the normal font here\n",
" r'\\usepackage{amsmath}',\n",
" r'\\usepackage{sansmath}', # load up the sansmath so that math -> helvet\n",
" r'\\sansmath', # <- tricky! -- gotta actually tell tex to use!\n",
"] "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Import data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def df_from_csv(csv, index_col, data_folder='../data/interim'):\n",
" print('Importing {}'.format(csv))\n",
" return pd.read_csv(os.path.join(data_folder,csv), index_col=index_col)\n",
"\n",
"df_waves = df_from_csv('waves.csv', index_col=[0, 1])\n",
"df_tides = df_from_csv('tides.csv', index_col=[0, 1])\n",
"df_profiles = df_from_csv('profiles.csv', index_col=[0, 1, 2])\n",
"df_sites = df_from_csv('sites.csv', index_col=[0])\n",
"df_profile_features_crest_toes = df_from_csv('profile_features_crest_toes.csv', index_col=[0])\n",
"\n",
"# Note that the forecasted data sets should be in the same order for impacts and twls\n",
"impacts = {\n",
" 'forecasted': {\n",
" 'foreshore_slope_sto06': df_from_csv('impacts_forecasted_foreshore_slope_sto06.csv', index_col=[0]),\n",
" 'mean_slope_sto06': df_from_csv('impacts_forecasted_mean_slope_sto06.csv', index_col=[0]),\n",
" 'mean_slope_nie91': df_from_csv('impacts_forecasted_mean_slope_nie91.csv', index_col=[0]),\n",
" 'mean_slope_hol86': df_from_csv('impacts_forecasted_mean_slope_hol86.csv', index_col=[0]),\n",
" },\n",
" 'observed': df_from_csv('impacts_observed.csv', index_col=[0])\n",
" }\n",
"\n",
"\n",
"twls = {\n",
" 'forecasted': {\n",
" 'foreshore_slope_sto06': df_from_csv('twl_foreshore_slope_sto06.csv', index_col=[0, 1]),\n",
" 'mean_slope_sto06':df_from_csv('twl_mean_slope_sto06.csv', index_col=[0, 1]),\n",
" 'mean_slope_nie91':df_from_csv('twl_mean_slope_nie91.csv', index_col=[0, 1]),\n",
" 'mean_slope_hol86':df_from_csv('twl_mean_slope_hol86.csv', index_col=[0, 1]),\n",
" }\n",
"}\n",
"print('Done!')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Get prediction accuracy\n",
"Use [scikit-learn](https://scikit-learn.org/stable/modules/model_evaluation.html#classification-metrics) model evaluation metrics"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import pprint\n",
"pp = pprint.PrettyPrinter(indent=2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sklearn.metrics\n",
"\n",
"# Encode the storm regimes values as categorical intgers so we can compare them\n",
"cat_type = CategoricalDtype(\n",
" categories=[\"swash\", \"collision\", \"overwash\", \"inundation\"], ordered=True)\n",
"correct_regime = impacts['observed'].storm_regime.astype(\n",
" cat_type).cat.codes.values\n",
"\n",
"# Define our forecast model names\n",
"models = [model for model in impacts['forecasted']]\n",
"\n",
"# Define the metric we want to calculate for each forecast model\n",
"metrics = [\n",
" 'accuracy_score', 'balanced_accuracy_score', 'confusion_matrix',\n",
" 'classification_report', 'f1_score', 'fbeta_score', 'precision_score', 'recall_score'\n",
"]\n",
"\n",
"# Store results in a nested dictionary by metric\n",
"performance = {metric: {} for metric in metrics}\n",
"\n",
"for model, metric in itertools.product(models, metrics):\n",
"\n",
" # Get predicted storm regims\n",
" df_pred = impacts['forecasted'][model]\n",
" predicted_regime = df_pred.storm_regime.astype(cat_type).cat.codes.values\n",
"\n",
" if metric == 'accuracy_score':\n",
" m = sklearn.metrics.accuracy_score(correct_regime, predicted_regime)\n",
"\n",
" if metric == 'balanced_accuracy_score':\n",
" m = sklearn.metrics.balanced_accuracy_score(correct_regime,\n",
" predicted_regime)\n",
"\n",
" if metric == 'confusion_matrix':\n",
" m = sklearn.metrics.confusion_matrix(\n",
" correct_regime, predicted_regime, labels=[0, 1, 2, 3])\n",
" \n",
" if metric == 'f1_score':\n",
" m = sklearn.metrics.f1_score(correct_regime, predicted_regime, average='weighted')\n",
" \n",
" if metric == 'fbeta_score':\n",
" m = sklearn.metrics.fbeta_score(correct_regime, predicted_regime, average='weighted', beta=1)\n",
" \n",
" if metric == 'precision_score':\n",
" m = sklearn.metrics.precision_score(correct_regime, predicted_regime, average='weighted')\n",
" \n",
" if metric == 'recall_score':\n",
" m = sklearn.metrics.recall_score(correct_regime, predicted_regime, average='weighted')\n",
"# m=1\n",
" \n",
" if metric == 'classification_report':\n",
"# m = sklearn.metrics.classification_report(\n",
"# correct_regime,\n",
"# predicted_regime,\n",
"# labels=[0, 1, 2, 3],\n",
"# target_names=['swash', 'collision', 'overwash', 'inundation'])\n",
"# print(m)\n",
" continue\n",
"\n",
" # Store metric in results dictionary\n",
" performance[metric][model] = m\n",
"\n",
"pp.pprint(performance)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"predicted_regime"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Scatter plot matirx\n",
" - Use [Altair](https://altair-viz.github.io/getting_started/installation.html) for interactivity?\n",
" - Or maybe [Holoviews](https://towardsdatascience.com/pyviz-simplifying-the-data-visualisation-process-in-python-1b6d2cb728f1)?"
]
}
],
"metadata": {
"hide_input": false,
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.6"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
},
"varInspector": {
"cols": {
"lenName": 16,
"lenType": 16,
"lenVar": 40
},
"kernels_config": {
"python": {
"delete_cmd_postfix": "",
"delete_cmd_prefix": "del ",
"library": "var_list.py",
"varRefreshCmd": "print(var_dic_list())"
},
"r": {
"delete_cmd_postfix": ") ",
"delete_cmd_prefix": "rm(",
"library": "var_list.r",
"varRefreshCmd": "cat(var_dic_list()) "
}
},
"types_to_exclude": [
"module",
"function",
"builtin_function_or_method",
"instance",
"_Feature"
],
"window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading…
Cancel
Save