You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nsw-2016-storm-impact/notebooks/07_evaluate_model_performan...

1006 lines
40 KiB
Plaintext

6 years ago
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Evaluate prediction metrics \n",
"- This notebook is used to check out each of our storm impact prediction models performed in comparison to our observed storm impacts."
]
},
{
"cell_type": "markdown",
6 years ago
"metadata": {},
6 years ago
"source": [
"## Setup notebook\n",
"Import our required packages and set default plotting options."
]
},
{
"cell_type": "code",
"execution_count": null,
6 years ago
"metadata": {},
6 years ago
"outputs": [],
"source": [
"# Enable autoreloading of our modules. \n",
"# Most of the code will be located in the /src/ folder, \n",
"# and then called from the notebook.\n",
"%matplotlib inline\n",
"%reload_ext autoreload\n",
"%autoreload"
]
},
{
"cell_type": "code",
"execution_count": null,
6 years ago
"metadata": {},
6 years ago
"outputs": [],
"source": [
"from IPython.core.debugger import set_trace\n",
"\n",
"import pandas as pd\n",
"import numpy as np\n",
"import os\n",
"import decimal\n",
"import plotly\n",
"import plotly.graph_objs as go\n",
"import plotly.plotly as py\n",
"import plotly.tools as tls\n",
"import plotly.figure_factory as ff\n",
"from plotly import tools\n",
"import plotly.io as pio\n",
"from scipy import stats\n",
"import math\n",
"import matplotlib\n",
"from matplotlib import cm\n",
"import colorlover as cl\n",
"from tqdm import tqdm_notebook\n",
"from ipywidgets import widgets, Output\n",
"from IPython.display import display, clear_output, Image, HTML\n",
"from scipy import stats\n",
"from sklearn.metrics import confusion_matrix, matthews_corrcoef\n",
6 years ago
"import matplotlib.pyplot as plt\n",
6 years ago
"from matplotlib.ticker import MultipleLocator\n",
6 years ago
"from matplotlib.lines import Line2D\n",
"from cycler import cycler\n",
"from scipy.interpolate import interp1d\n",
"from pandas.api.types import CategoricalDtype"
]
},
{
"cell_type": "code",
"execution_count": null,
6 years ago
"metadata": {},
6 years ago
"outputs": [],
"source": [
"# Matplot lib default settings\n",
"plt.rcParams[\"figure.figsize\"] = (10,6)\n",
"plt.rcParams['axes.grid']=True\n",
"plt.rcParams['grid.alpha'] = 0.5\n",
"plt.rcParams['grid.color'] = \"grey\"\n",
"plt.rcParams['grid.linestyle'] = \"--\"\n",
"plt.rcParams['axes.grid']=True\n",
"\n",
"# https://stackoverflow.com/a/20709149\n",
"matplotlib.rcParams['text.usetex'] = True\n",
"\n",
"matplotlib.rcParams['text.latex.preamble'] = [\n",
" r'\\usepackage{siunitx}', # i need upright \\micro symbols, but you need...\n",
" r'\\sisetup{detect-all}', # ...this to force siunitx to actually use your fonts\n",
" r'\\usepackage{helvet}', # set the normal font here\n",
" r'\\usepackage{amsmath}',\n",
" r'\\usepackage{sansmath}', # load up the sansmath so that math -> helvet\n",
" r'\\sansmath', # <- tricky! -- gotta actually tell tex to use!\n",
"] "
]
},
6 years ago
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Matplot lib default settings\n",
"plt.rcParams[\"figure.figsize\"] = (10, 6)\n",
"plt.rcParams['axes.grid'] = True\n",
"plt.rcParams['grid.alpha'] = 0.3\n",
"plt.rcParams['grid.color'] = \"grey\"\n",
"plt.rcParams['grid.linestyle'] = \"--\"\n",
"plt.rcParams['grid.linewidth'] = 0.5\n",
"plt.rcParams['axes.grid'] = True\n",
"\n",
"# # https://stackoverflow.com/a/20709149\n",
"matplotlib.rcParams['text.usetex'] = True\n",
"matplotlib.rcParams['font.family'] = 'sans-serif'\n",
"\n",
"# # matplotlib.rcParams['text.latex.preamble'] = [\n",
"# # r'\\usepackage{siunitx}', # i need upright \\micro symbols, but you need...\n",
"# # r'\\sisetup{detect-all}', # ...this to force siunitx to actually use your fonts\n",
"# # r'\\usepackage{helvet}', # set the normal font here\n",
"# # r'\\usepackage{amsmath}',\n",
"# # r'\\usepackage{sansmath}', # load up the sansmath so that math -> helvet\n",
"# # r'\\sansmath', # <- tricky! -- gotta actually tell tex to use!\n",
"# # ]\n",
"\n",
"matplotlib.rcParams['text.latex.preamble'] = [\n",
" r'\\usepackage{siunitx}', # i need upright \\micro symbols, but you need...\n",
" r'\\sisetup{detect-all}', # ...this to force siunitx to actually use your fonts\n",
" r'\\usepackage[default]{sourcesanspro}',\n",
" r'\\usepackage{amsmath}',\n",
" r'\\usepackage{sansmath}', # load up the sansmath so that math -> helvet\n",
" r'\\sansmath', # <- tricky! -- gotta actually tell tex to use!\n",
"]\n",
"\n",
"# import matplotlib as mpl\n",
"# mpl.use(\"pgf\")\n",
"# pgf_with_custom_preamble = {\n",
"# \"font.family\":\"sans-serif\", # use serif/main font for text elements\n",
"# \"text.usetex\":True, # use inline math for ticks\n",
"# \"pgf.rcfonts\":False, # don't setup fonts from rc parameters\n",
"# \"pgf.preamble\": [\n",
"# r'\\usepackage{siunitx}', # i need upright \\micro symbols, but you need...\n",
"# r'\\sisetup{detect-all}', # ...this to force siunitx to actually use your fonts\n",
"# r'\\usepackage[default]{sourcesanspro}',\n",
"# r'\\usepackage{amsmath}',\n",
"# r'\\usepackage[mathrm=sym]{unicode-math}',\n",
"# r'\\setmathfont{Fira Math}',\n",
"# ]\n",
"# }\n",
"# mpl.rcParams.update(pgf_with_custom_preamble)"
]
},
6 years ago
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Import data\n",
"Import our data from the `./data/interim/` folder and load it into pandas dataframes. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def df_from_csv(csv, index_col, data_folder='../data/interim'):\n",
" print('Importing {}'.format(csv))\n",
" return pd.read_csv(os.path.join(data_folder,csv), index_col=index_col)\n",
"\n",
"df_waves = df_from_csv('waves.csv', index_col=[0, 1])\n",
"df_tides = df_from_csv('tides.csv', index_col=[0, 1])\n",
"df_profiles = df_from_csv('profiles.csv', index_col=[0, 1, 2])\n",
"df_sites = df_from_csv('sites.csv', index_col=[0])\n",
"df_sites_waves = df_from_csv('sites_waves.csv', index_col=[0])\n",
6 years ago
"df_profile_features_crest_toes = df_from_csv('profile_features_crest_toes.csv', index_col=[0,1])\n",
"\n",
"# Note that the forecasted data sets should be in the same order for impacts and twls\n",
"impacts = {\n",
" 'forecasted': {\n",
6 years ago
" 'postintertidal_slope_hol86': df_from_csv('impacts_forecasted_postintertidal_slope_hol86.csv', index_col=[0]),\n",
" 'postintertidal_slope_nie91': df_from_csv('impacts_forecasted_postintertidal_slope_nie91.csv', index_col=[0]),\n",
" 'postintertidal_slope_pow18': df_from_csv('impacts_forecasted_postintertidal_slope_pow18.csv', index_col=[0]),\n",
" 'postintertidal_slope_sto06': df_from_csv('impacts_forecasted_postintertidal_slope_sto06.csv', index_col=[0]),\n",
" 'postmean_slope_hol86': df_from_csv('impacts_forecasted_postmean_slope_hol86.csv', index_col=[0]),\n",
" 'postmean_slope_nie91': df_from_csv('impacts_forecasted_postmean_slope_nie91.csv', index_col=[0]),\n",
" 'postmean_slope_pow18': df_from_csv('impacts_forecasted_postmean_slope_pow18.csv', index_col=[0]),\n",
" 'postmean_slope_sto06': df_from_csv('impacts_forecasted_postmean_slope_sto06.csv', index_col=[0]),\n",
" 'preintertidal_slope_hol86': df_from_csv('impacts_forecasted_preintertidal_slope_hol86.csv', index_col=[0]),\n",
" 'preintertidal_slope_nie91': df_from_csv('impacts_forecasted_preintertidal_slope_nie91.csv', index_col=[0]),\n",
" 'preintertidal_slope_pow18': df_from_csv('impacts_forecasted_preintertidal_slope_pow18.csv', index_col=[0]),\n",
" 'preintertidal_slope_sto06': df_from_csv('impacts_forecasted_preintertidal_slope_sto06.csv', index_col=[0]),\n",
" 'premean_slope_hol86': df_from_csv('impacts_forecasted_premean_slope_hol86.csv', index_col=[0]),\n",
" 'premean_slope_nie91': df_from_csv('impacts_forecasted_premean_slope_nie91.csv', index_col=[0]),\n",
" 'premean_slope_pow18': df_from_csv('impacts_forecasted_premean_slope_pow18.csv', index_col=[0]),\n",
" 'premean_slope_sto06': df_from_csv('impacts_forecasted_premean_slope_sto06.csv', index_col=[0]),\n",
6 years ago
" },\n",
" 'observed': df_from_csv('impacts_observed.csv', index_col=[0])\n",
" }\n",
"\n",
"\n",
"twls = {\n",
" 'forecasted': {\n",
6 years ago
" 'postintertidal_slope_hol86': df_from_csv('twl_postintertidal_slope_hol86.csv', index_col=[0,1]),\n",
" 'postintertidal_slope_nie91': df_from_csv('twl_postintertidal_slope_nie91.csv', index_col=[0,1]),\n",
" 'postintertidal_slope_pow18': df_from_csv('twl_postintertidal_slope_pow18.csv', index_col=[0,1]),\n",
" 'postintertidal_slope_sto06': df_from_csv('twl_postintertidal_slope_sto06.csv', index_col=[0,1]),\n",
" 'postmean_slope_hol86': df_from_csv('twl_postmean_slope_hol86.csv', index_col=[0,1]),\n",
" 'postmean_slope_nie91': df_from_csv('twl_postmean_slope_nie91.csv', index_col=[0,1]),\n",
" 'postmean_slope_pow18': df_from_csv('twl_postmean_slope_pow18.csv', index_col=[0,1]),\n",
" 'postmean_slope_sto06': df_from_csv('twl_postmean_slope_sto06.csv', index_col=[0,1]),\n",
" 'preintertidal_slope_hol86': df_from_csv('twl_preintertidal_slope_hol86.csv', index_col=[0,1]),\n",
" 'preintertidal_slope_nie91': df_from_csv('twl_preintertidal_slope_nie91.csv', index_col=[0,1]),\n",
" 'preintertidal_slope_pow18': df_from_csv('twl_preintertidal_slope_pow18.csv', index_col=[0,1]),\n",
" 'preintertidal_slope_sto06': df_from_csv('twl_preintertidal_slope_sto06.csv', index_col=[0,1]),\n",
" 'premean_slope_hol86': df_from_csv('twl_premean_slope_hol86.csv', index_col=[0,1]),\n",
" 'premean_slope_nie91': df_from_csv('twl_premean_slope_nie91.csv', index_col=[0,1]),\n",
" 'premean_slope_pow18': df_from_csv('twl_premean_slope_pow18.csv', index_col=[0,1]),\n",
" 'premean_slope_sto06': df_from_csv('twl_premean_slope_sto06.csv', index_col=[0,1]),\n",
6 years ago
" }\n",
"}\n",
"print('Done!')"
]
},
{
"cell_type": "markdown",
6 years ago
"metadata": {},
6 years ago
"source": [
"## Generate longshore plots for each beach"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
6 years ago
"code_folding": []
6 years ago
},
"outputs": [],
"source": [
6 years ago
"beaches = list(\n",
" set([\n",
" x[:-4] for x in df_profiles.index.get_level_values('site_id').unique()\n",
" ]))\n",
"\n",
"for beach in beaches:\n",
" \n",
" df_obs_impacts = impacts['observed'].loc[impacts['observed'].index.str.\n",
" contains(beach)]\n",
"\n",
" # Get index for each site on the beach\n",
" n = [x for x in range(len(df_obs_impacts))][::-1]\n",
" n_sites = [x for x in df_obs_impacts.index][::-1]\n",
"\n",
" # Convert storm regimes to categorical datatype\n",
" cat_type = CategoricalDtype(\n",
" categories=['swash', 'collision', 'overwash', 'inundation'],\n",
" ordered=True)\n",
" df_obs_impacts.storm_regime = df_obs_impacts.storm_regime.astype(cat_type)\n",
"\n",
" # Create figure\n",
" \n",
" # Determine the height of the figure, based on the number of sites.\n",
" fig_height = max(6, 0.18 * len(n_sites))\n",
" f, (ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8, ax9) = plt.subplots(\n",
6 years ago
" 1,\n",
" 9,\n",
6 years ago
" sharey=True,\n",
" figsize=(18, fig_height),\n",
" gridspec_kw={'width_ratios': [4, 4, 2, 2, 2, 2, 2, 2,2]})\n",
6 years ago
"\n",
" # ax1: Impacts\n",
"\n",
" # Define colors for storm regime\n",
" cmap = {'swash': '#1a9850', 'collision': '#fee08b', 'overwash': '#d73027'}\n",
"\n",
" # Common marker style\n",
" marker_style = {\n",
" 's': 60,\n",
" 'linewidths': 0.7,\n",
" 'alpha': 1,\n",
" 'edgecolors': 'k',\n",
" 'marker': 'o',\n",
" }\n",
6 years ago
"\n",
6 years ago
" # Plot observed impacts\n",
" colors = [cmap.get(x) for x in df_obs_impacts.storm_regime]\n",
6 years ago
" colors = ['#aaaaaa' if c is None else c for c in colors]\n",
6 years ago
" ax1.scatter([0 for x in n], n, color=colors, **marker_style)\n",
"\n",
" # Plot model impacts\n",
" for i, model in enumerate(impacts['forecasted']):\n",
"\n",
" # Only get model results for this beach\n",
" df_model = impacts['forecasted'][model].loc[\n",
" impacts['forecasted'][model].index.str.contains(beach)]\n",
"\n",
" # Recast storm regimes as categorical data\n",
" df_model.storm_regime = df_model.storm_regime.astype(cat_type)\n",
"\n",
" # Assign colors\n",
" colors = [cmap.get(x) for x in df_model.storm_regime]\n",
" colors = ['#aaaaaa' if c is None else c for c in colors]\n",
"\n",
" # Only plot markers which are different to the observed storm regime. \n",
" # This makes it easier to find where model predictions differ\n",
" y_coords = []\n",
" for obs_impact, for_impact in zip(df_model.storm_regime,\n",
" df_obs_impacts.storm_regime):\n",
" if obs_impact == for_impact:\n",
" y_coords.append(None)\n",
" else:\n",
" y_coords.append(i + 1)\n",
"\n",
" ax1.scatter(y_coords, n, color=colors, **marker_style)\n",
"\n",
" # Add model names to each impact on x axis\n",
" ax1.set_xticks(range(len(impacts['forecasted']) + 1))\n",
" ax1.set_xticklabels(['observed'] +\n",
" [x.replace('_', '\\_') for x in impacts['forecasted']])\n",
" ax1.xaxis.set_tick_params(rotation=90)\n",
"\n",
" # Add title\n",
" ax1.set_title('Storm regime')\n",
"\n",
" # Create custom legend\n",
" legend_elements = [\n",
" Line2D([0], [0],\n",
" marker='o',\n",
" color='w',\n",
" label='Swash',\n",
" markerfacecolor='#1a9850',\n",
" markersize=8,\n",
" markeredgewidth=1.0,\n",
" markeredgecolor='k'),\n",
" Line2D([0], [0],\n",
" marker='o',\n",
" color='w',\n",
" label='Collision',\n",
" markerfacecolor='#fee08b',\n",
" markersize=8,\n",
" markeredgewidth=1.0,\n",
" markeredgecolor='k'),\n",
" Line2D([0], [0],\n",
" marker='o',\n",
" color='w',\n",
" label='Overwash',\n",
" markerfacecolor='#d73027',\n",
" markersize=8,\n",
" markeredgewidth=1.0,\n",
" markeredgecolor='k'),\n",
" ]\n",
" ax1.legend(\n",
" handles=legend_elements, loc='lower center', bbox_to_anchor=(0.5, 1.1))\n",
"\n",
" # Replace axis ticks with names of site ids\n",
" ytick_labels = ax1.get_yticks().tolist()\n",
" yticks = [\n",
" n_sites[int(y)] if all([y >= 0, y < len(n_sites)]) else ''\n",
" for y in ytick_labels\n",
" ]\n",
" yticks = [x.replace('_', '\\_') for x in yticks]\n",
" ax1.set_yticklabels(yticks)\n",
"\n",
" # ax2: elevations\n",
"\n",
" # Dune elevations\n",
" df_feats = df_profile_features_crest_toes.xs(['prestorm'],\n",
" level=['profile_type'])\n",
" df_feats = df_feats.loc[df_feats.index.str.contains(beach)]\n",
"\n",
" ax2.plot(df_feats.dune_crest_z, n, color='#fdae61')\n",
" ax2.plot(df_feats.dune_toe_z, n, color='#fdae61')\n",
" ax2.fill_betweenx(\n",
" n,\n",
" df_feats.dune_toe_z,\n",
" df_feats.dune_crest_z,\n",
" alpha=0.2,\n",
" color='#fdae61',\n",
" label='$D_{low}$ to $D_{high}$')\n",
"\n",
" model_colors = [\n",
6 years ago
" '#1f78b4',\n",
" '#33a02c',\n",
" '#e31a1c',\n",
" '#6a3d9a',\n",
" '#a6cee3',\n",
" '#b2df8a',\n",
" '#fb9a99',\n",
" '#cab2d6',\n",
" '#ffff99',\n",
" ]\n",
"\n",
6 years ago
" # Define colors to cycle through for our R_high\n",
" ax2.set_prop_cycle(cycler('color', model_colors))\n",
6 years ago
"\n",
6 years ago
" # For TWL elevations, Rhigh-Dlow and R2 axis, only plot a few models\n",
" models_to_plot = [\n",
" 'premean_slope_hol86',\n",
" 'premean_slope_sto06',\n",
" 'preintertidal_slope_hol86',\n",
" 'preintertidal_slope_sto06',\n",
" ]\n",
" models_linewidth = 0.8\n",
6 years ago
"\n",
6 years ago
" # Plot R_high values\n",
" for model in models_to_plot:\n",
6 years ago
"\n",
6 years ago
" # Only get model results for this beach\n",
" df_model = impacts['forecasted'][model].loc[\n",
" impacts['forecasted'][model].index.str.contains(beach)]\n",
6 years ago
"\n",
6 years ago
" # Recast storm regimes as categorical data\n",
" ax2.plot(\n",
" df_model.R_high,\n",
" n,\n",
" label=model.replace('_', '\\_'),\n",
" linewidth=models_linewidth)\n",
6 years ago
"\n",
6 years ago
" # Set title, legend and labels\n",
" ax2.set_title('TWL \\& dune\\nelevations')\n",
" ax2.legend(loc='lower center', bbox_to_anchor=(0.5, 1.1))\n",
" ax2.set_xlabel('Elevation (m AHD)')\n",
"# ax2.set_xlim([0, max(df_feats.dune_crest_z)])\n",
6 years ago
"\n",
6 years ago
" # ax3: Plot R_high - D_low\n",
6 years ago
"\n",
6 years ago
" # Define colors to cycle through for our R_high\n",
" ax3.set_prop_cycle(cycler('color', model_colors))\n",
6 years ago
"\n",
6 years ago
" # Plot R_high values\n",
" for model in models_to_plot:\n",
6 years ago
"\n",
6 years ago
" df_model = impacts['forecasted'][model].loc[\n",
" impacts['forecasted'][model].index.str.contains(beach)]\n",
" # R_high - D_low\n",
" ax3.plot(\n",
" df_model.R_high - df_feats.dune_toe_z,\n",
" n,\n",
" label=model.replace('_', '\\_'),\n",
" linewidth=models_linewidth)\n",
6 years ago
"\n",
6 years ago
" ax3.axvline(x=0, color='black', linestyle=':')\n",
" ax3.set_title('$R_{high}$ - $D_{low}$')\n",
" ax3.set_xlabel('Height (m)')\n",
"# ax3.set_xlim([-2, 2])\n",
6 years ago
"\n",
6 years ago
" # Define colors to cycle through for our R2\n",
" ax4.set_prop_cycle(cycler('color', model_colors))\n",
6 years ago
"\n",
6 years ago
" # R_high - D_low\n",
" for model in models_to_plot:\n",
" df_R2 = impacts['forecasted'][model].merge(\n",
" twls['forecasted'][model], on=['site_id', 'datetime'], how='left')\n",
" df_R2 = df_R2.loc[df_R2.index.str.contains(beach)]\n",
" ax4.plot(\n",
" df_R2.R2,\n",
" n,\n",
" label=model.replace('_', '\\_'),\n",
" linewidth=models_linewidth)\n",
"\n",
" ax4.set_title(r'$R_{2\\%}$')\n",
" ax4.set_xlabel('Height (m)')\n",
"# ax4.set_xlim([0, 10])\n",
"\n",
" # Beach slope\n",
" slope_colors = [\n",
" '#bebada',\n",
" '#bc80bd',\n",
" '#ffed6f',\n",
" '#fdb462',\n",
" ]\n",
" ax5.set_prop_cycle(cycler('color', slope_colors))\n",
" slope_models = {\n",
" 'prestorm mean': 'premean_slope_sto06',\n",
" 'poststorm mean': 'postmean_slope_sto06',\n",
" 'prestorm intertidal': 'preintertidal_slope_sto06',\n",
" 'poststorm intertidal': 'postintertidal_slope_sto06',\n",
" }\n",
6 years ago
"\n",
6 years ago
" for label in slope_models:\n",
" model = slope_models[label]\n",
" df_beta = impacts['forecasted'][model].merge(\n",
" twls['forecasted'][model], on=['site_id', 'datetime'], how='left')\n",
" df_beta = df_beta.loc[df_beta.index.str.contains(beach)]\n",
" ax5.plot(df_beta.beta, n, label=label, linewidth=models_linewidth)\n",
"\n",
" ax5.set_title(r'$\\beta$')\n",
" ax5.set_xlabel('Beach slope')\n",
" ax5.legend(loc='lower center', bbox_to_anchor=(0.5, 1.1))\n",
" # ax5.set_xlim([0, 0.15])\n",
"\n",
" # Need to chose a model to extract environmental parameters at maximum R_high time\n",
" model = 'premean_slope_sto06'\n",
" df_beach = impacts['forecasted'][model].merge(\n",
" twls['forecasted'][model], on=['site_id', 'datetime'], how='left')\n",
" df_beach = df_beach.loc[df_beach.index.str.contains(beach)]\n",
"\n",
" # Wave height, wave period\n",
" ax6.plot(df_beach.Hs0, n, color='#999999')\n",
" ax6.set_title('$H_{s0}$')\n",
" ax6.set_xlabel('Sig. wave height (m)')\n",
" ax6.set_xlim([2, 6])\n",
6 years ago
"\n",
" ax7.plot(df_beach.Tp, n, color='#999999')\n",
" ax7.set_title('$T_{p}$')\n",
" ax7.set_xlabel('Peak wave period (s)')\n",
" ax7.set_xlim([8, 14])\n",
"\n",
" ax8.plot(df_beach.tide, n, color='#999999')\n",
" ax8.set_title('Tide \\& surge')\n",
" ax8.set_xlabel('Elevation (m AHD)')\n",
" ax8.set_xlim([1, 3])\n",
6 years ago
"\n",
" \n",
" # TODO Cumulative wave energy\n",
" # df_sites_waves\n",
" \n",
6 years ago
" plt.tight_layout()\n",
" f.subplots_adjust(top=0.88)\n",
" f.suptitle(beach.replace('_', '\\_'))\n",
"\n",
" # Set minor axis ticks on each plot\n",
" ax1.yaxis.set_minor_locator(MultipleLocator(1))\n",
" ax1.yaxis.grid(True, which='minor', linestyle='--', alpha=0.1)\n",
" ax2.yaxis.grid(True, which='minor', linestyle='--', alpha=0.1)\n",
" ax3.yaxis.grid(True, which='minor', linestyle='--', alpha=0.1)\n",
" ax4.yaxis.grid(True, which='minor', linestyle='--', alpha=0.1)\n",
" ax5.yaxis.grid(True, which='minor', linestyle='--', alpha=0.1)\n",
" ax6.yaxis.grid(True, which='minor', linestyle='--', alpha=0.1)\n",
" ax7.yaxis.grid(True, which='minor', linestyle='--', alpha=0.1)\n",
" ax8.yaxis.grid(True, which='minor', linestyle='--', alpha=0.1)\n",
"\n",
" # # Print to figure\n",
" plt.savefig('07_{}.png'.format(beach), dpi=600, bbox_inches='tight')\n",
"\n",
" plt.show()\n",
6 years ago
" plt.close()\n",
" print('Done: {}'.format(beach))\n",
" \n",
" break\n",
6 years ago
"print('Done!')"
6 years ago
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Generate classification reports for each model\n",
"Use sklearn metrics to generate classification reports for each forecasting model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sklearn.metrics\n",
"\n",
"# Get observed impacts\n",
"df_obs = impacts['observed']\n",
"\n",
"# Convert storm regimes to categorical datatype\n",
"cat_type = CategoricalDtype(\n",
" categories=['swash', 'collision', 'overwash', 'inundation'], ordered=True)\n",
6 years ago
"df_obs.storm_regime = df_obs.storm_regime.astype(cat_type)\n",
6 years ago
"\n",
"for model in impacts['forecasted']:\n",
" df_for = impacts['forecasted'][model]\n",
" df_for.storm_regime = df_for.storm_regime.astype(cat_type)\n",
"\n",
" m = sklearn.metrics.classification_report(\n",
" df_obs.storm_regime.astype(cat_type).cat.codes.values,\n",
" df_for.storm_regime.astype(cat_type).cat.codes.values,\n",
" labels=[0, 1, 2, 3],\n",
" target_names=['swash', 'collision', 'overwash', 'inundation'])\n",
" print(model)\n",
" print(m)\n",
" print()\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Check matthews coefficient\n",
"# https://scikit-learn.org/stable/modules/generated/sklearn.metrics.matthews_corrcoef.html\n",
"\n",
"for model in impacts['forecasted']:\n",
" df_for = impacts['forecasted'][model]\n",
" df_for.storm_regime = df_for.storm_regime.astype(cat_type)\n",
"\n",
" m = matthews_corrcoef(\n",
" df_obs.storm_regime.astype(cat_type).cat.codes.values,\n",
" df_for.storm_regime.astype(cat_type).cat.codes.values)\n",
" print('{}: {:.2f}'.format(model,m))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Check accuracy\n",
"# https://scikit-learn.org/stable/modules/generated/sklearn.metrics.matthews_corrcoef.html\n",
"\n",
"for model in impacts['forecasted']:\n",
" df_for = impacts['forecasted'][model]\n",
" df_for.storm_regime = df_for.storm_regime.astype(cat_type)\n",
"\n",
" m = sklearn.metrics.accuracy_score(\n",
" df_obs.storm_regime.astype(cat_type).cat.codes.values,\n",
" df_for.storm_regime.astype(cat_type).cat.codes.values)\n",
" print('{}: {:.2f}'.format(model,m))"
]
6 years ago
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.metrics import confusion_matrix\n",
"# Check confusion matrix\n",
"for model in impacts['forecasted']:\n",
" df_for = impacts['forecasted'][model]\n",
" df_for.storm_regime = df_for.storm_regime.astype(cat_type)\n",
"\n",
" m = sklearn.metrics.confusion_matrix(\n",
" df_obs.storm_regime.astype(cat_type).cat.codes.values,\n",
" df_for.storm_regime.astype(cat_type).cat.codes.values,\n",
" labels=[0,1,2,3])\n",
" print('{}\\n{}'.format(model,m))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create confusion matrix figure\n",
"From https://github.com/wcipriano/pretty-print-confusion-matrix/blob/master/confusion_matrix_pretty_print.py"
]
},
{
"cell_type": "code",
"execution_count": null,
6 years ago
"metadata": {
6 years ago
"code_folding": []
6 years ago
},
6 years ago
"outputs": [],
"source": [
"# -*- coding: utf-8 -*-\n",
"\"\"\"\n",
"plot a pretty confusion matrix with seaborn\n",
"Created on Mon Jun 25 14:17:37 2018\n",
"@author: Wagner Cipriano - wagnerbhbr - gmail - CEFETMG / MMC\n",
"REFerences:\n",
" https://www.mathworks.com/help/nnet/ref/plotconfusion.html\n",
" https://stackoverflow.com/questions/28200786/how-to-plot-scikit-learn-classification-report\n",
" https://stackoverflow.com/questions/5821125/how-to-plot-confusion-matrix-with-string-axis-rather-than-integer-in-python\n",
" https://www.programcreek.com/python/example/96197/seaborn.heatmap\n",
" https://stackoverflow.com/questions/19233771/sklearn-plot-confusion-matrix-with-labels/31720054\n",
" http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html#sphx-glr-auto-examples-model-selection-plot-confusion-matrix-py\n",
"\"\"\"\n",
"\n",
"#imports\n",
"from pandas import DataFrame\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.font_manager as fm\n",
"from matplotlib.collections import QuadMesh\n",
"import seaborn as sn\n",
"\n",
"\n",
"def get_new_fig(fn, figsize=[9,9]):\n",
" \"\"\" Init graphics \"\"\"\n",
" fig1 = plt.figure(fn, figsize)\n",
" ax1 = fig1.gca() #Get Current Axis\n",
" ax1.cla() # clear existing plot\n",
" return fig1, ax1\n",
"#\n",
"\n",
6 years ago
"def configcell_text_and_colors(array_df, lin, col, oText, facecolors, posi, fz, fmt, show_null_values=0, show_pcts=True):\n",
6 years ago
" \"\"\"\n",
" config cell text and colors\n",
" and return text elements to add and to dell\n",
" @TODO: use fmt\n",
" \"\"\"\n",
" text_add = []; text_del = [];\n",
" cell_val = array_df[lin][col]\n",
" tot_all = array_df[-1][-1]\n",
" per = (float(cell_val) / tot_all) * 100\n",
" curr_column = array_df[:,col]\n",
" ccl = len(curr_column)\n",
"\n",
" #last line and/or last column\n",
" if(col == (ccl - 1)) or (lin == (ccl - 1)):\n",
" #tots and percents\n",
" if(cell_val != 0):\n",
" if(col == ccl - 1) and (lin == ccl - 1):\n",
" tot_rig = 0\n",
" for i in range(array_df.shape[0] - 1):\n",
" tot_rig += array_df[i][i]\n",
" per_ok = (float(tot_rig) / cell_val) * 100\n",
" elif(col == ccl - 1):\n",
" tot_rig = array_df[lin][lin]\n",
" per_ok = (float(tot_rig) / cell_val) * 100\n",
" elif(lin == ccl - 1):\n",
" tot_rig = array_df[col][col]\n",
" per_ok = (float(tot_rig) / cell_val) * 100\n",
" per_err = 100 - per_ok\n",
" else:\n",
" per_ok = per_err = 0\n",
"\n",
6 years ago
" per_ok_s = ['(\\\\small{{{:.1f}\\%}})'.format(per_ok), '100%'] [per_ok == 100]\n",
" \n",
6 years ago
" #text to DEL\n",
" text_del.append(oText)\n",
"\n",
" #text to ADD\n",
" font_prop = fm.FontProperties(weight='bold', size=fz)\n",
6 years ago
" text_kwargs = dict(color='k', ha=\"center\", va=\"center\", gid='sum', fontproperties=font_prop)\n",
" lis_txt = ['%d'%(cell_val), per_ok_s, '(\\\\small{{{:.1f}\\%}})'.format(per_err)]\n",
6 years ago
" lis_kwa = [text_kwargs]\n",
" dic = text_kwargs.copy(); dic['color'] = 'g'; lis_kwa.append(dic);\n",
" dic = text_kwargs.copy(); dic['color'] = 'r'; lis_kwa.append(dic);\n",
6 years ago
" \n",
" if show_pcts:\n",
" lis_pos = [(oText._x, oText._y-0.3), (oText._x, oText._y), (oText._x, oText._y+0.3)]\n",
" else:\n",
" lis_pos = [(oText._x, oText._y)]\n",
" \n",
" for i in range(len(lis_pos)):\n",
6 years ago
" newText = dict(x=lis_pos[i][0], y=lis_pos[i][1], text=lis_txt[i], kw=lis_kwa[i])\n",
" text_add.append(newText)\n",
6 years ago
" \n",
6 years ago
"\n",
" #set background color for sum cells (last line and last column)\n",
6 years ago
" carr = [0.9, 0.9, 0.9, 1.0]\n",
6 years ago
" if(col == ccl - 1) and (lin == ccl - 1):\n",
6 years ago
" carr = [0.9, 0.9, 0.9, 1.0]\n",
6 years ago
" facecolors[posi] = carr\n",
"\n",
" else:\n",
" if(per > 0):\n",
6 years ago
"# txt = '%s\\n%.1f\\%%' %(cell_val, per)\n",
" \n",
" if show_pcts:\n",
" txt = '{}\\n\\\\small{{({:.1f}\\%)}}'.format(cell_val,per)\n",
" else:\n",
" txt = '{}'.format(cell_val)\n",
6 years ago
" else:\n",
" if(show_null_values == 0):\n",
" txt = ''\n",
" elif(show_null_values == 1):\n",
" txt = '0'\n",
" else:\n",
6 years ago
" if show_pcts:\n",
" txt = '0\\n0.0\\%'\n",
" else:\n",
" txt = '0'\n",
6 years ago
" oText.set_text(txt)\n",
"\n",
" #main diagonal\n",
" if(col == lin):\n",
" #set color of the textin the diagonal to white\n",
6 years ago
"# oText.set_color('w')\n",
" oText.set_color('k')\n",
" # set background color in the diagonal to green\n",
6 years ago
" facecolors[posi] = [0.35, 0.8, 0.55, 1.0]\n",
" else:\n",
" oText.set_color('r')\n",
"\n",
" return text_add, text_del\n",
"#\n",
"\n",
"def insert_totals(df_cm):\n",
" \"\"\" insert total column and line (the last ones) \"\"\"\n",
" sum_col = []\n",
" for c in df_cm.columns:\n",
" sum_col.append( df_cm[c].sum() )\n",
" sum_lin = []\n",
" for item_line in df_cm.iterrows():\n",
" sum_lin.append( item_line[1].sum() )\n",
6 years ago
" df_cm[r'$\\sum$ Row'] = sum_lin\n",
6 years ago
" sum_col.append(np.sum(sum_lin))\n",
6 years ago
" df_cm.loc[r'$\\sum$ Col'] = sum_col\n",
"\n",
6 years ago
"\n",
"def pretty_plot_confusion_matrix(df_cm, annot=True, cmap=\"Oranges\", fmt='.2f', fz=11,\n",
6 years ago
" lw=0.5, cbar=False, figsize=[8,8], show_null_values=0, pred_val_axis='y',title='Confusion matrix',show_pcts=True):\n",
6 years ago
" \"\"\"\n",
" print conf matrix with default layout (like matlab)\n",
" params:\n",
" df_cm dataframe (pandas) without totals\n",
" annot print text in each cell\n",
" cmap Oranges,Oranges_r,YlGnBu,Blues,RdBu, ... see:\n",
" fz fontsize\n",
" lw linewidth\n",
" pred_val_axis where to show the prediction values (x or y axis)\n",
" 'col' or 'x': show predicted values in columns (x axis) instead lines\n",
" 'lin' or 'y': show predicted values in lines (y axis)\n",
" \"\"\"\n",
" if(pred_val_axis in ('col', 'x')):\n",
6 years ago
" xlbl = 'Hindcast'\n",
" ylbl = 'Observed'\n",
6 years ago
" else:\n",
6 years ago
" xlbl = 'Observed'\n",
" ylbl = 'Hindcast'\n",
6 years ago
" df_cm = df_cm.T\n",
"\n",
" # create \"Total\" column\n",
" insert_totals(df_cm)\n",
"\n",
" #this is for print allways in the same window\n",
" fig, ax1 = get_new_fig('Conf matrix default', figsize)\n",
"\n",
" #thanks for seaborn\n",
" ax = sn.heatmap(df_cm, annot=annot, annot_kws={\"size\": fz}, linewidths=lw, ax=ax1,\n",
" cbar=cbar, cmap=cmap, linecolor='w', fmt=fmt)\n",
"\n",
" #set ticklabels rotation\n",
6 years ago
"# ax.set_xticklabels(ax.get_xticklabels(),fontsize = tick_label_fontsize)\n",
" ax.set_yticklabels(ax.get_yticklabels(), horizontalalignment='right',verticalalignment='center')\n",
6 years ago
"\n",
" # Turn off all the ticks\n",
" for t in ax.xaxis.get_major_ticks():\n",
" t.tick1On = False\n",
" t.tick2On = False\n",
" for t in ax.yaxis.get_major_ticks():\n",
" t.tick1On = False\n",
" t.tick2On = False\n",
"\n",
" #face colors list\n",
" quadmesh = ax.findobj(QuadMesh)[0]\n",
" facecolors = quadmesh.get_facecolors()\n",
"\n",
" #iter in text elements\n",
" array_df = np.array( df_cm.to_records(index=False).tolist() )\n",
" text_add = []; text_del = [];\n",
" posi = -1 #from left to right, bottom to top.\n",
" for t in ax.collections[0].axes.texts: #ax.texts:\n",
" pos = np.array( t.get_position()) - [0.5,0.5]\n",
" lin = int(pos[1]); col = int(pos[0]);\n",
" posi += 1\n",
" #print ('>>> pos: %s, posi: %s, val: %s, txt: %s' %(pos, posi, array_df[lin][col], t.get_text()))\n",
"\n",
" #set text\n",
6 years ago
" txt_res = configcell_text_and_colors(array_df, lin, col, t, facecolors, posi, fz, fmt, show_null_values, show_pcts)\n",
6 years ago
"\n",
" text_add.extend(txt_res[0])\n",
" text_del.extend(txt_res[1])\n",
"\n",
" #remove the old ones\n",
" for item in text_del:\n",
" item.remove()\n",
" #append the new ones\n",
" for item in text_add:\n",
" ax.text(item['x'], item['y'], item['text'], **item['kw'])\n",
6 years ago
" \n",
6 years ago
" #titles and legends\n",
6 years ago
" ax.set_title(title)\n",
" ax.set_ylabel(r'{} storm regimes'.format(ylbl),fontsize=12)\n",
" ax.set_xlabel(r'{} storm regimes'.format(xlbl),fontsize=12)\n",
6 years ago
"\n",
6 years ago
"# ax.xaxis.set_label_coords(0.5, -0.1)\n",
"# ax.yaxis.set_label_coords(-0.05, 0.5)\n",
" \n",
"# plt.tight_layout() #set layout slim\n",
" \n",
" return fig\n",
"#\n"
6 years ago
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
6 years ago
"## Plot for Coasts & Ports\n",
6 years ago
"\n",
6 years ago
"matplotlib.rcParams['text.usetex'] = True\n",
6 years ago
"\n",
6 years ago
"forecast_model = 'premean_slope_sto06'\n",
6 years ago
"\n",
"df_for = impacts['forecasted'][forecast_model]\n",
"df_for.storm_regime = df_for.storm_regime.astype(cat_type)\n",
"observed_regimes = df_obs.storm_regime.astype(cat_type).cat.codes.values\n",
"forecasted_regimes = df_for.storm_regime.astype(cat_type).cat.codes.values\n",
"\n",
"\n",
"confm = confusion_matrix(observed_regimes, forecasted_regimes,labels=[0,1,2,3])\n",
6 years ago
"labels=['Swash','Collision','Overwash','Inundation']\n",
6 years ago
"df_cm = DataFrame(confm, index=labels, columns=labels)\n",
"\n",
6 years ago
"fig = pretty_plot_confusion_matrix(df_cm, annot=True, cmap=\"OrRd\", fmt='.1f', fz=8,\n",
" lw=1, cbar=False, figsize=[3,1.5], show_null_values=1, pred_val_axis='y', \n",
" title = r'Storm impact regime ' + r'confusion matrix',show_pcts=False)\n",
"\n",
"ax = fig.axes[0]\n",
"ax.set_ylabel(ax.get_ylabel(),fontsize=8)\n",
"ax.set_xlabel(ax.get_xlabel(),fontsize=8)\n",
"ax.set_title(ax.get_title(),fontsize=8)\n",
"ax.set_yticklabels(ax.get_yticklabels(), fontsize=8)\n",
"ax.set_xticklabels(ax.get_xticklabels(), fontsize=8)\n",
6 years ago
"\n",
6 years ago
"fig.savefig('07_c&p_confusion_matrix',dpi=600,bbox_inches = \"tight\")\n",
"plt.show()"
6 years ago
]
6 years ago
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"pd.options.display.max_columns = 500\n",
"print(\"Couldn't get forecast for: {} sites\".format(df_for.storm_regime.isna().sum()))\n",
"print(\"Couldn't get observations for: {} sites\".format(df_obs.storm_regime.isna().sum()))\n",
"\n",
"# df_obs[df_obs.storm_regime.isna()]"
]
6 years ago
}
],
"metadata": {
"hide_input": false,
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
6 years ago
"version": "3.6.7"
6 years ago
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
6 years ago
"toc_position": {
"height": "calc(100% - 180px)",
"left": "10px",
"top": "150px",
"width": "286.391px"
},
6 years ago
"toc_section_display": true,
6 years ago
"toc_window_display": true
6 years ago
},
"varInspector": {
"cols": {
"lenName": 16,
"lenType": 16,
"lenVar": 40
},
"kernels_config": {
"python": {
"delete_cmd_postfix": "",
"delete_cmd_prefix": "del ",
"library": "var_list.py",
"varRefreshCmd": "print(var_dic_list())"
},
"r": {
"delete_cmd_postfix": ") ",
"delete_cmd_prefix": "rm(",
"library": "var_list.r",
"varRefreshCmd": "cat(var_dic_list()) "
}
},
"types_to_exclude": [
"module",
"function",
"builtin_function_or_method",
"instance",
"_Feature"
],
"window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 2
}