Create new notebook for comparing models
parent
3bfb13e9d6
commit
d4995266c9
@ -0,0 +1,313 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Run comparison\n",
|
||||
"Create a comparison between different runs by looking at the different R_high values and storm regimes."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Setup notebook"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Enable autoreloading of our modules. \n",
|
||||
"# Most of the code will be located in the /src/ folder, \n",
|
||||
"# and then called from the notebook.\n",
|
||||
"%matplotlib inline\n",
|
||||
"%reload_ext autoreload\n",
|
||||
"%autoreload"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from IPython.core.debugger import set_trace\n",
|
||||
"\n",
|
||||
"import pandas as pd\n",
|
||||
"import numpy as np\n",
|
||||
"import os\n",
|
||||
"import decimal\n",
|
||||
"import plotly\n",
|
||||
"import plotly.graph_objs as go\n",
|
||||
"import plotly.plotly as py\n",
|
||||
"import plotly.tools as tls\n",
|
||||
"import plotly.figure_factory as ff\n",
|
||||
"from plotly import tools\n",
|
||||
"import plotly.io as pio\n",
|
||||
"from scipy import stats\n",
|
||||
"import math\n",
|
||||
"import matplotlib\n",
|
||||
"from matplotlib import cm\n",
|
||||
"import colorlover as cl\n",
|
||||
"from tqdm import tqdm_notebook\n",
|
||||
"from ipywidgets import widgets, Output\n",
|
||||
"from IPython.display import display, clear_output, Image, HTML\n",
|
||||
"from scipy import stats\n",
|
||||
"from sklearn.metrics import confusion_matrix\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"from scipy.interpolate import interp1d\n",
|
||||
"from pandas.api.types import CategoricalDtype\n",
|
||||
"from scipy.interpolate import UnivariateSpline\n",
|
||||
"from shapely.geometry import Point, LineString"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Matplot lib default settings\n",
|
||||
"plt.rcParams[\"figure.figsize\"] = (10,6)\n",
|
||||
"plt.rcParams['axes.grid']=True\n",
|
||||
"plt.rcParams['grid.alpha'] = 0.5\n",
|
||||
"plt.rcParams['grid.color'] = \"grey\"\n",
|
||||
"plt.rcParams['grid.linestyle'] = \"--\"\n",
|
||||
"plt.rcParams['axes.grid']=True\n",
|
||||
"\n",
|
||||
"# https://stackoverflow.com/a/20709149\n",
|
||||
"matplotlib.rcParams['text.usetex'] = True\n",
|
||||
"\n",
|
||||
"matplotlib.rcParams['text.latex.preamble'] = [\n",
|
||||
" r'\\usepackage{siunitx}', # i need upright \\micro symbols, but you need...\n",
|
||||
" r'\\sisetup{detect-all}', # ...this to force siunitx to actually use your fonts\n",
|
||||
" r'\\usepackage{helvet}', # set the normal font here\n",
|
||||
" r'\\usepackage{amsmath}',\n",
|
||||
" r'\\usepackage{sansmath}', # load up the sansmath so that math -> helvet\n",
|
||||
" r'\\sansmath', # <- tricky! -- gotta actually tell tex to use!\n",
|
||||
"] "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Import data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def df_from_csv(csv, index_col, data_folder='../data/interim'):\n",
|
||||
" print('Importing {}'.format(csv))\n",
|
||||
" return pd.read_csv(os.path.join(data_folder,csv), index_col=index_col)\n",
|
||||
"\n",
|
||||
"df_waves = df_from_csv('waves.csv', index_col=[0, 1])\n",
|
||||
"df_tides = df_from_csv('tides.csv', index_col=[0, 1])\n",
|
||||
"df_profiles = df_from_csv('profiles.csv', index_col=[0, 1, 2])\n",
|
||||
"df_sites = df_from_csv('sites.csv', index_col=[0])\n",
|
||||
"df_profile_features_crest_toes = df_from_csv('profile_features_crest_toes.csv', index_col=[0])\n",
|
||||
"\n",
|
||||
"# Note that the forecasted data sets should be in the same order for impacts and twls\n",
|
||||
"impacts = {\n",
|
||||
" 'forecasted': {\n",
|
||||
" 'foreshore_slope_sto06': df_from_csv('impacts_forecasted_foreshore_slope_sto06.csv', index_col=[0]),\n",
|
||||
" 'mean_slope_sto06': df_from_csv('impacts_forecasted_mean_slope_sto06.csv', index_col=[0]),\n",
|
||||
" 'mean_slope_nie91': df_from_csv('impacts_forecasted_mean_slope_nie91.csv', index_col=[0]),\n",
|
||||
" 'mean_slope_hol86': df_from_csv('impacts_forecasted_mean_slope_hol86.csv', index_col=[0]),\n",
|
||||
" },\n",
|
||||
" 'observed': df_from_csv('impacts_observed.csv', index_col=[0])\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"twls = {\n",
|
||||
" 'forecasted': {\n",
|
||||
" 'foreshore_slope_sto06': df_from_csv('twl_foreshore_slope_sto06.csv', index_col=[0, 1]),\n",
|
||||
" 'mean_slope_sto06':df_from_csv('twl_mean_slope_sto06.csv', index_col=[0, 1]),\n",
|
||||
" 'mean_slope_nie91':df_from_csv('twl_mean_slope_nie91.csv', index_col=[0, 1]),\n",
|
||||
" 'mean_slope_hol86':df_from_csv('twl_mean_slope_hol86.csv', index_col=[0, 1]),\n",
|
||||
" }\n",
|
||||
"}\n",
|
||||
"print('Done!')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Get prediction accuracy\n",
|
||||
"Use [scikit-learn](https://scikit-learn.org/stable/modules/model_evaluation.html#classification-metrics) model evaluation metrics"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pprint\n",
|
||||
"pp = pprint.PrettyPrinter(indent=2)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sklearn.metrics\n",
|
||||
"\n",
|
||||
"# Encode the storm regimes values as categorical intgers so we can compare them\n",
|
||||
"cat_type = CategoricalDtype(\n",
|
||||
" categories=[\"swash\", \"collision\", \"overwash\", \"inundation\"], ordered=True)\n",
|
||||
"correct_regime = impacts['observed'].storm_regime.astype(\n",
|
||||
" cat_type).cat.codes.values\n",
|
||||
"\n",
|
||||
"# Define our forecast model names\n",
|
||||
"models = [model for model in impacts['forecasted']]\n",
|
||||
"\n",
|
||||
"# Define the metric we want to calculate for each forecast model\n",
|
||||
"metrics = [\n",
|
||||
" 'accuracy_score', 'balanced_accuracy_score', 'confusion_matrix',\n",
|
||||
" 'classification_report', 'f1_score', 'fbeta_score', 'precision_score', 'recall_score'\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"# Store results in a nested dictionary by metric\n",
|
||||
"performance = {metric: {} for metric in metrics}\n",
|
||||
"\n",
|
||||
"for model, metric in itertools.product(models, metrics):\n",
|
||||
"\n",
|
||||
" # Get predicted storm regims\n",
|
||||
" df_pred = impacts['forecasted'][model]\n",
|
||||
" predicted_regime = df_pred.storm_regime.astype(cat_type).cat.codes.values\n",
|
||||
"\n",
|
||||
" if metric == 'accuracy_score':\n",
|
||||
" m = sklearn.metrics.accuracy_score(correct_regime, predicted_regime)\n",
|
||||
"\n",
|
||||
" if metric == 'balanced_accuracy_score':\n",
|
||||
" m = sklearn.metrics.balanced_accuracy_score(correct_regime,\n",
|
||||
" predicted_regime)\n",
|
||||
"\n",
|
||||
" if metric == 'confusion_matrix':\n",
|
||||
" m = sklearn.metrics.confusion_matrix(\n",
|
||||
" correct_regime, predicted_regime, labels=[0, 1, 2, 3])\n",
|
||||
" \n",
|
||||
" if metric == 'f1_score':\n",
|
||||
" m = sklearn.metrics.f1_score(correct_regime, predicted_regime, average='weighted')\n",
|
||||
" \n",
|
||||
" if metric == 'fbeta_score':\n",
|
||||
" m = sklearn.metrics.fbeta_score(correct_regime, predicted_regime, average='weighted', beta=1)\n",
|
||||
" \n",
|
||||
" if metric == 'precision_score':\n",
|
||||
" m = sklearn.metrics.precision_score(correct_regime, predicted_regime, average='weighted')\n",
|
||||
" \n",
|
||||
" if metric == 'recall_score':\n",
|
||||
" m = sklearn.metrics.recall_score(correct_regime, predicted_regime, average='weighted')\n",
|
||||
"# m=1\n",
|
||||
" \n",
|
||||
" if metric == 'classification_report':\n",
|
||||
"# m = sklearn.metrics.classification_report(\n",
|
||||
"# correct_regime,\n",
|
||||
"# predicted_regime,\n",
|
||||
"# labels=[0, 1, 2, 3],\n",
|
||||
"# target_names=['swash', 'collision', 'overwash', 'inundation'])\n",
|
||||
"# print(m)\n",
|
||||
" continue\n",
|
||||
"\n",
|
||||
" # Store metric in results dictionary\n",
|
||||
" performance[metric][model] = m\n",
|
||||
"\n",
|
||||
"pp.pprint(performance)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"predicted_regime"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Scatter plot matirx\n",
|
||||
" - Use [Altair](https://altair-viz.github.io/getting_started/installation.html) for interactivity?\n",
|
||||
" - Or maybe [Holoviews](https://towardsdatascience.com/pyviz-simplifying-the-data-visualisation-process-in-python-1b6d2cb728f1)?"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"hide_input": false,
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.6.6"
|
||||
},
|
||||
"toc": {
|
||||
"base_numbering": 1,
|
||||
"nav_menu": {},
|
||||
"number_sections": true,
|
||||
"sideBar": true,
|
||||
"skip_h1_title": false,
|
||||
"title_cell": "Table of Contents",
|
||||
"title_sidebar": "Contents",
|
||||
"toc_cell": false,
|
||||
"toc_position": {},
|
||||
"toc_section_display": true,
|
||||
"toc_window_display": false
|
||||
},
|
||||
"varInspector": {
|
||||
"cols": {
|
||||
"lenName": 16,
|
||||
"lenType": 16,
|
||||
"lenVar": 40
|
||||
},
|
||||
"kernels_config": {
|
||||
"python": {
|
||||
"delete_cmd_postfix": "",
|
||||
"delete_cmd_prefix": "del ",
|
||||
"library": "var_list.py",
|
||||
"varRefreshCmd": "print(var_dic_list())"
|
||||
},
|
||||
"r": {
|
||||
"delete_cmd_postfix": ") ",
|
||||
"delete_cmd_prefix": "rm(",
|
||||
"library": "var_list.r",
|
||||
"varRefreshCmd": "cat(var_dic_list()) "
|
||||
}
|
||||
},
|
||||
"types_to_exclude": [
|
||||
"module",
|
||||
"function",
|
||||
"builtin_function_or_method",
|
||||
"instance",
|
||||
"_Feature"
|
||||
],
|
||||
"window_display": false
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
Loading…
Reference in New Issue