You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
985 lines
39 KiB
Plaintext
985 lines
39 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Evaluate prediction metrics \n",
|
|
"- This notebook is used to check out each of our storm impact prediction models performed in comparison to our observed storm impacts."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Setup notebook\n",
|
|
"Import our required packages and set default plotting options."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Enable autoreloading of our modules. \n",
|
|
"# Most of the code will be located in the /src/ folder, \n",
|
|
"# and then called from the notebook.\n",
|
|
"%matplotlib inline\n",
|
|
"%reload_ext autoreload\n",
|
|
"%autoreload"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from IPython.core.debugger import set_trace\n",
|
|
"\n",
|
|
"import pandas as pd\n",
|
|
"import numpy as np\n",
|
|
"import os\n",
|
|
"import decimal\n",
|
|
"import plotly\n",
|
|
"import plotly.graph_objs as go\n",
|
|
"import plotly.plotly as py\n",
|
|
"import plotly.tools as tls\n",
|
|
"import plotly.figure_factory as ff\n",
|
|
"from plotly import tools\n",
|
|
"import plotly.io as pio\n",
|
|
"from scipy import stats\n",
|
|
"import math\n",
|
|
"import matplotlib\n",
|
|
"from matplotlib import cm\n",
|
|
"import colorlover as cl\n",
|
|
"from tqdm import tqdm_notebook\n",
|
|
"from ipywidgets import widgets, Output\n",
|
|
"from IPython.display import display, clear_output, Image, HTML\n",
|
|
"from scipy import stats\n",
|
|
"from sklearn.metrics import confusion_matrix, matthews_corrcoef\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"from matplotlib.ticker import MultipleLocator\n",
|
|
"from matplotlib.lines import Line2D\n",
|
|
"from cycler import cycler\n",
|
|
"from scipy.interpolate import interp1d\n",
|
|
"from pandas.api.types import CategoricalDtype"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Matplot lib default settings\n",
|
|
"plt.rcParams[\"figure.figsize\"] = (10,6)\n",
|
|
"plt.rcParams['axes.grid']=True\n",
|
|
"plt.rcParams['grid.alpha'] = 0.5\n",
|
|
"plt.rcParams['grid.color'] = \"grey\"\n",
|
|
"plt.rcParams['grid.linestyle'] = \"--\"\n",
|
|
"plt.rcParams['axes.grid']=True\n",
|
|
"\n",
|
|
"# https://stackoverflow.com/a/20709149\n",
|
|
"matplotlib.rcParams['text.usetex'] = True\n",
|
|
"\n",
|
|
"matplotlib.rcParams['text.latex.preamble'] = [\n",
|
|
" r'\\usepackage{siunitx}', # i need upright \\micro symbols, but you need...\n",
|
|
" r'\\sisetup{detect-all}', # ...this to force siunitx to actually use your fonts\n",
|
|
" r'\\usepackage{helvet}', # set the normal font here\n",
|
|
" r'\\usepackage{amsmath}',\n",
|
|
" r'\\usepackage{sansmath}', # load up the sansmath so that math -> helvet\n",
|
|
" r'\\sansmath', # <- tricky! -- gotta actually tell tex to use!\n",
|
|
"] "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Import data\n",
|
|
"Import our data from the `./data/interim/` folder and load it into pandas dataframes. "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def df_from_csv(csv, index_col, data_folder='../data/interim'):\n",
|
|
" print('Importing {}'.format(csv))\n",
|
|
" return pd.read_csv(os.path.join(data_folder,csv), index_col=index_col)\n",
|
|
"\n",
|
|
"df_waves = df_from_csv('waves.csv', index_col=[0, 1])\n",
|
|
"df_tides = df_from_csv('tides.csv', index_col=[0, 1])\n",
|
|
"df_profiles = df_from_csv('profiles.csv', index_col=[0, 1, 2])\n",
|
|
"df_sites = df_from_csv('sites.csv', index_col=[0])\n",
|
|
"df_sites_waves = df_from_csv('sites_waves.csv', index_col=[0])\n",
|
|
"df_profile_features_crest_toes = df_from_csv('profile_features_crest_toes.csv', index_col=[0,1])\n",
|
|
"\n",
|
|
"# Note that the forecasted data sets should be in the same order for impacts and twls\n",
|
|
"impacts = {\n",
|
|
" 'forecasted': {\n",
|
|
" 'postintertidal_slope_hol86': df_from_csv('impacts_forecasted_postintertidal_slope_hol86.csv', index_col=[0]),\n",
|
|
" 'postintertidal_slope_nie91': df_from_csv('impacts_forecasted_postintertidal_slope_nie91.csv', index_col=[0]),\n",
|
|
" 'postintertidal_slope_pow18': df_from_csv('impacts_forecasted_postintertidal_slope_pow18.csv', index_col=[0]),\n",
|
|
" 'postintertidal_slope_sto06': df_from_csv('impacts_forecasted_postintertidal_slope_sto06.csv', index_col=[0]),\n",
|
|
" 'postmean_slope_hol86': df_from_csv('impacts_forecasted_postmean_slope_hol86.csv', index_col=[0]),\n",
|
|
" 'postmean_slope_nie91': df_from_csv('impacts_forecasted_postmean_slope_nie91.csv', index_col=[0]),\n",
|
|
" 'postmean_slope_pow18': df_from_csv('impacts_forecasted_postmean_slope_pow18.csv', index_col=[0]),\n",
|
|
" 'postmean_slope_sto06': df_from_csv('impacts_forecasted_postmean_slope_sto06.csv', index_col=[0]),\n",
|
|
" 'preintertidal_slope_hol86': df_from_csv('impacts_forecasted_preintertidal_slope_hol86.csv', index_col=[0]),\n",
|
|
" 'preintertidal_slope_nie91': df_from_csv('impacts_forecasted_preintertidal_slope_nie91.csv', index_col=[0]),\n",
|
|
" 'preintertidal_slope_pow18': df_from_csv('impacts_forecasted_preintertidal_slope_pow18.csv', index_col=[0]),\n",
|
|
" 'preintertidal_slope_sto06': df_from_csv('impacts_forecasted_preintertidal_slope_sto06.csv', index_col=[0]),\n",
|
|
" 'premean_slope_hol86': df_from_csv('impacts_forecasted_premean_slope_hol86.csv', index_col=[0]),\n",
|
|
" 'premean_slope_nie91': df_from_csv('impacts_forecasted_premean_slope_nie91.csv', index_col=[0]),\n",
|
|
" 'premean_slope_pow18': df_from_csv('impacts_forecasted_premean_slope_pow18.csv', index_col=[0]),\n",
|
|
" 'premean_slope_sto06': df_from_csv('impacts_forecasted_premean_slope_sto06.csv', index_col=[0]),\n",
|
|
" },\n",
|
|
" 'observed': df_from_csv('impacts_observed.csv', index_col=[0])\n",
|
|
" }\n",
|
|
"\n",
|
|
"\n",
|
|
"twls = {\n",
|
|
" 'forecasted': {\n",
|
|
" 'postintertidal_slope_hol86': df_from_csv('twl_postintertidal_slope_hol86.csv', index_col=[0,1]),\n",
|
|
" 'postintertidal_slope_nie91': df_from_csv('twl_postintertidal_slope_nie91.csv', index_col=[0,1]),\n",
|
|
" 'postintertidal_slope_pow18': df_from_csv('twl_postintertidal_slope_pow18.csv', index_col=[0,1]),\n",
|
|
" 'postintertidal_slope_sto06': df_from_csv('twl_postintertidal_slope_sto06.csv', index_col=[0,1]),\n",
|
|
" 'postmean_slope_hol86': df_from_csv('twl_postmean_slope_hol86.csv', index_col=[0,1]),\n",
|
|
" 'postmean_slope_nie91': df_from_csv('twl_postmean_slope_nie91.csv', index_col=[0,1]),\n",
|
|
" 'postmean_slope_pow18': df_from_csv('twl_postmean_slope_pow18.csv', index_col=[0,1]),\n",
|
|
" 'postmean_slope_sto06': df_from_csv('twl_postmean_slope_sto06.csv', index_col=[0,1]),\n",
|
|
" 'preintertidal_slope_hol86': df_from_csv('twl_preintertidal_slope_hol86.csv', index_col=[0,1]),\n",
|
|
" 'preintertidal_slope_nie91': df_from_csv('twl_preintertidal_slope_nie91.csv', index_col=[0,1]),\n",
|
|
" 'preintertidal_slope_pow18': df_from_csv('twl_preintertidal_slope_pow18.csv', index_col=[0,1]),\n",
|
|
" 'preintertidal_slope_sto06': df_from_csv('twl_preintertidal_slope_sto06.csv', index_col=[0,1]),\n",
|
|
" 'premean_slope_hol86': df_from_csv('twl_premean_slope_hol86.csv', index_col=[0,1]),\n",
|
|
" 'premean_slope_nie91': df_from_csv('twl_premean_slope_nie91.csv', index_col=[0,1]),\n",
|
|
" 'premean_slope_pow18': df_from_csv('twl_premean_slope_pow18.csv', index_col=[0,1]),\n",
|
|
" 'premean_slope_sto06': df_from_csv('twl_premean_slope_sto06.csv', index_col=[0,1]),\n",
|
|
" }\n",
|
|
"}\n",
|
|
"print('Done!')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Generate longshore plots for each beach"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"code_folding": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"beaches = list(\n",
|
|
" set([\n",
|
|
" x[:-4] for x in df_profiles.index.get_level_values('site_id').unique()\n",
|
|
" ]))\n",
|
|
"\n",
|
|
"for beach in beaches:\n",
|
|
" \n",
|
|
" df_obs_impacts = impacts['observed'].loc[impacts['observed'].index.str.\n",
|
|
" contains(beach)]\n",
|
|
"\n",
|
|
" # Get index for each site on the beach\n",
|
|
" n = [x for x in range(len(df_obs_impacts))][::-1]\n",
|
|
" n_sites = [x for x in df_obs_impacts.index][::-1]\n",
|
|
"\n",
|
|
" # Convert storm regimes to categorical datatype\n",
|
|
" cat_type = CategoricalDtype(\n",
|
|
" categories=['swash', 'collision', 'overwash', 'inundation'],\n",
|
|
" ordered=True)\n",
|
|
" df_obs_impacts.storm_regime = df_obs_impacts.storm_regime.astype(cat_type)\n",
|
|
"\n",
|
|
" # Create figure\n",
|
|
" \n",
|
|
" # Determine the height of the figure, based on the number of sites.\n",
|
|
" fig_height = max(6, 0.18 * len(n_sites))\n",
|
|
" f, (ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8, ax9) = plt.subplots(\n",
|
|
" 1,\n",
|
|
" 9,\n",
|
|
" sharey=True,\n",
|
|
" figsize=(18, fig_height),\n",
|
|
" gridspec_kw={'width_ratios': [4, 4, 2, 2, 2, 2, 2, 2,2]})\n",
|
|
"\n",
|
|
" # ax1: Impacts\n",
|
|
"\n",
|
|
" # Define colors for storm regime\n",
|
|
" cmap = {'swash': '#1a9850', 'collision': '#fee08b', 'overwash': '#d73027'}\n",
|
|
"\n",
|
|
" # Common marker style\n",
|
|
" marker_style = {\n",
|
|
" 's': 60,\n",
|
|
" 'linewidths': 0.7,\n",
|
|
" 'alpha': 1,\n",
|
|
" 'edgecolors': 'k',\n",
|
|
" 'marker': 'o',\n",
|
|
" }\n",
|
|
"\n",
|
|
" # Plot observed impacts\n",
|
|
" colors = [cmap.get(x) for x in df_obs_impacts.storm_regime]\n",
|
|
" colors = ['#aaaaaa' if c is None else c for c in colors]\n",
|
|
" ax1.scatter([0 for x in n], n, color=colors, **marker_style)\n",
|
|
"\n",
|
|
" # Plot model impacts\n",
|
|
" for i, model in enumerate(impacts['forecasted']):\n",
|
|
"\n",
|
|
" # Only get model results for this beach\n",
|
|
" df_model = impacts['forecasted'][model].loc[\n",
|
|
" impacts['forecasted'][model].index.str.contains(beach)]\n",
|
|
"\n",
|
|
" # Recast storm regimes as categorical data\n",
|
|
" df_model.storm_regime = df_model.storm_regime.astype(cat_type)\n",
|
|
"\n",
|
|
" # Assign colors\n",
|
|
" colors = [cmap.get(x) for x in df_model.storm_regime]\n",
|
|
" colors = ['#aaaaaa' if c is None else c for c in colors]\n",
|
|
"\n",
|
|
" # Only plot markers which are different to the observed storm regime. \n",
|
|
" # This makes it easier to find where model predictions differ\n",
|
|
" y_coords = []\n",
|
|
" for obs_impact, for_impact in zip(df_model.storm_regime,\n",
|
|
" df_obs_impacts.storm_regime):\n",
|
|
" if obs_impact == for_impact:\n",
|
|
" y_coords.append(None)\n",
|
|
" else:\n",
|
|
" y_coords.append(i + 1)\n",
|
|
"\n",
|
|
" ax1.scatter(y_coords, n, color=colors, **marker_style)\n",
|
|
"\n",
|
|
" # Add model names to each impact on x axis\n",
|
|
" ax1.set_xticks(range(len(impacts['forecasted']) + 1))\n",
|
|
" ax1.set_xticklabels(['observed'] +\n",
|
|
" [x.replace('_', '\\_') for x in impacts['forecasted']])\n",
|
|
" ax1.xaxis.set_tick_params(rotation=90)\n",
|
|
"\n",
|
|
" # Add title\n",
|
|
" ax1.set_title('Storm regime')\n",
|
|
"\n",
|
|
" # Create custom legend\n",
|
|
" legend_elements = [\n",
|
|
" Line2D([0], [0],\n",
|
|
" marker='o',\n",
|
|
" color='w',\n",
|
|
" label='Swash',\n",
|
|
" markerfacecolor='#1a9850',\n",
|
|
" markersize=8,\n",
|
|
" markeredgewidth=1.0,\n",
|
|
" markeredgecolor='k'),\n",
|
|
" Line2D([0], [0],\n",
|
|
" marker='o',\n",
|
|
" color='w',\n",
|
|
" label='Collision',\n",
|
|
" markerfacecolor='#fee08b',\n",
|
|
" markersize=8,\n",
|
|
" markeredgewidth=1.0,\n",
|
|
" markeredgecolor='k'),\n",
|
|
" Line2D([0], [0],\n",
|
|
" marker='o',\n",
|
|
" color='w',\n",
|
|
" label='Overwash',\n",
|
|
" markerfacecolor='#d73027',\n",
|
|
" markersize=8,\n",
|
|
" markeredgewidth=1.0,\n",
|
|
" markeredgecolor='k'),\n",
|
|
" ]\n",
|
|
" ax1.legend(\n",
|
|
" handles=legend_elements, loc='lower center', bbox_to_anchor=(0.5, 1.1))\n",
|
|
"\n",
|
|
" # Replace axis ticks with names of site ids\n",
|
|
" ytick_labels = ax1.get_yticks().tolist()\n",
|
|
" yticks = [\n",
|
|
" n_sites[int(y)] if all([y >= 0, y < len(n_sites)]) else ''\n",
|
|
" for y in ytick_labels\n",
|
|
" ]\n",
|
|
" yticks = [x.replace('_', '\\_') for x in yticks]\n",
|
|
" ax1.set_yticklabels(yticks)\n",
|
|
"\n",
|
|
" # ax2: elevations\n",
|
|
"\n",
|
|
" # Dune elevations\n",
|
|
" df_feats = df_profile_features_crest_toes.xs(['prestorm'],\n",
|
|
" level=['profile_type'])\n",
|
|
" df_feats = df_feats.loc[df_feats.index.str.contains(beach)]\n",
|
|
"\n",
|
|
" ax2.plot(df_feats.dune_crest_z, n, color='#fdae61')\n",
|
|
" ax2.plot(df_feats.dune_toe_z, n, color='#fdae61')\n",
|
|
" ax2.fill_betweenx(\n",
|
|
" n,\n",
|
|
" df_feats.dune_toe_z,\n",
|
|
" df_feats.dune_crest_z,\n",
|
|
" alpha=0.2,\n",
|
|
" color='#fdae61',\n",
|
|
" label='$D_{low}$ to $D_{high}$')\n",
|
|
"\n",
|
|
" model_colors = [\n",
|
|
" '#1f78b4',\n",
|
|
" '#33a02c',\n",
|
|
" '#e31a1c',\n",
|
|
" '#6a3d9a',\n",
|
|
" '#a6cee3',\n",
|
|
" '#b2df8a',\n",
|
|
" '#fb9a99',\n",
|
|
" '#cab2d6',\n",
|
|
" '#ffff99',\n",
|
|
" ]\n",
|
|
"\n",
|
|
" # Define colors to cycle through for our R_high\n",
|
|
" ax2.set_prop_cycle(cycler('color', model_colors))\n",
|
|
"\n",
|
|
" # For TWL elevations, Rhigh-Dlow and R2 axis, only plot a few models\n",
|
|
" models_to_plot = [\n",
|
|
" 'premean_slope_hol86',\n",
|
|
" 'premean_slope_sto06',\n",
|
|
" 'preintertidal_slope_hol86',\n",
|
|
" 'preintertidal_slope_sto06',\n",
|
|
" ]\n",
|
|
" models_linewidth = 0.8\n",
|
|
"\n",
|
|
" # Plot R_high values\n",
|
|
" for model in models_to_plot:\n",
|
|
"\n",
|
|
" # Only get model results for this beach\n",
|
|
" df_model = impacts['forecasted'][model].loc[\n",
|
|
" impacts['forecasted'][model].index.str.contains(beach)]\n",
|
|
"\n",
|
|
" # Recast storm regimes as categorical data\n",
|
|
" ax2.plot(\n",
|
|
" df_model.R_high,\n",
|
|
" n,\n",
|
|
" label=model.replace('_', '\\_'),\n",
|
|
" linewidth=models_linewidth)\n",
|
|
"\n",
|
|
" # Set title, legend and labels\n",
|
|
" ax2.set_title('TWL \\& dune\\nelevations')\n",
|
|
" ax2.legend(loc='lower center', bbox_to_anchor=(0.5, 1.1))\n",
|
|
" ax2.set_xlabel('Elevation (m AHD)')\n",
|
|
"# ax2.set_xlim([0, max(df_feats.dune_crest_z)])\n",
|
|
"\n",
|
|
" # ax3: Plot R_high - D_low\n",
|
|
"\n",
|
|
" # Define colors to cycle through for our R_high\n",
|
|
" ax3.set_prop_cycle(cycler('color', model_colors))\n",
|
|
"\n",
|
|
" # Plot R_high values\n",
|
|
" for model in models_to_plot:\n",
|
|
"\n",
|
|
" df_model = impacts['forecasted'][model].loc[\n",
|
|
" impacts['forecasted'][model].index.str.contains(beach)]\n",
|
|
" # R_high - D_low\n",
|
|
" ax3.plot(\n",
|
|
" df_model.R_high - df_feats.dune_toe_z,\n",
|
|
" n,\n",
|
|
" label=model.replace('_', '\\_'),\n",
|
|
" linewidth=models_linewidth)\n",
|
|
"\n",
|
|
" ax3.axvline(x=0, color='black', linestyle=':')\n",
|
|
" ax3.set_title('$R_{high}$ - $D_{low}$')\n",
|
|
" ax3.set_xlabel('Height (m)')\n",
|
|
"# ax3.set_xlim([-2, 2])\n",
|
|
"\n",
|
|
" # Define colors to cycle through for our R2\n",
|
|
" ax4.set_prop_cycle(cycler('color', model_colors))\n",
|
|
"\n",
|
|
" # R_high - D_low\n",
|
|
" for model in models_to_plot:\n",
|
|
" df_R2 = impacts['forecasted'][model].merge(\n",
|
|
" twls['forecasted'][model], on=['site_id', 'datetime'], how='left')\n",
|
|
" df_R2 = df_R2.loc[df_R2.index.str.contains(beach)]\n",
|
|
" ax4.plot(\n",
|
|
" df_R2.R2,\n",
|
|
" n,\n",
|
|
" label=model.replace('_', '\\_'),\n",
|
|
" linewidth=models_linewidth)\n",
|
|
"\n",
|
|
" ax4.set_title(r'$R_{2\\%}$')\n",
|
|
" ax4.set_xlabel('Height (m)')\n",
|
|
"# ax4.set_xlim([0, 10])\n",
|
|
"\n",
|
|
" # Beach slope\n",
|
|
" slope_colors = [\n",
|
|
" '#bebada',\n",
|
|
" '#bc80bd',\n",
|
|
" '#ffed6f',\n",
|
|
" '#fdb462',\n",
|
|
" ]\n",
|
|
" ax5.set_prop_cycle(cycler('color', slope_colors))\n",
|
|
" slope_models = {\n",
|
|
" 'prestorm mean': 'premean_slope_sto06',\n",
|
|
" 'poststorm mean': 'postmean_slope_sto06',\n",
|
|
" 'prestorm intertidal': 'preintertidal_slope_sto06',\n",
|
|
" 'poststorm intertidal': 'postintertidal_slope_sto06',\n",
|
|
" }\n",
|
|
"\n",
|
|
" for label in slope_models:\n",
|
|
" model = slope_models[label]\n",
|
|
" df_beta = impacts['forecasted'][model].merge(\n",
|
|
" twls['forecasted'][model], on=['site_id', 'datetime'], how='left')\n",
|
|
" df_beta = df_beta.loc[df_beta.index.str.contains(beach)]\n",
|
|
" ax5.plot(df_beta.beta, n, label=label, linewidth=models_linewidth)\n",
|
|
"\n",
|
|
" ax5.set_title(r'$\\beta$')\n",
|
|
" ax5.set_xlabel('Beach slope')\n",
|
|
" ax5.legend(loc='lower center', bbox_to_anchor=(0.5, 1.1))\n",
|
|
" # ax5.set_xlim([0, 0.15])\n",
|
|
"\n",
|
|
" # Need to chose a model to extract environmental parameters at maximum R_high time\n",
|
|
" model = 'premean_slope_sto06'\n",
|
|
" df_beach = impacts['forecasted'][model].merge(\n",
|
|
" twls['forecasted'][model], on=['site_id', 'datetime'], how='left')\n",
|
|
" df_beach = df_beach.loc[df_beach.index.str.contains(beach)]\n",
|
|
"\n",
|
|
" # Wave height, wave period\n",
|
|
" ax6.plot(df_beach.Hs0, n, color='#999999')\n",
|
|
" ax6.set_title('$H_{s0}$')\n",
|
|
" ax6.set_xlabel('Sig. wave height (m)')\n",
|
|
" ax6.set_xlim([2, 6])\n",
|
|
"\n",
|
|
" ax7.plot(df_beach.Tp, n, color='#999999')\n",
|
|
" ax7.set_title('$T_{p}$')\n",
|
|
" ax7.set_xlabel('Peak wave period (s)')\n",
|
|
" ax7.set_xlim([8, 14])\n",
|
|
"\n",
|
|
" ax8.plot(df_beach.tide, n, color='#999999')\n",
|
|
" ax8.set_title('Tide \\& surge')\n",
|
|
" ax8.set_xlabel('Elevation (m AHD)')\n",
|
|
" ax8.set_xlim([1, 3])\n",
|
|
"\n",
|
|
" \n",
|
|
" # TODO Cumulative wave energy\n",
|
|
" # df_sites_waves\n",
|
|
" \n",
|
|
" plt.tight_layout()\n",
|
|
" f.subplots_adjust(top=0.88)\n",
|
|
" f.suptitle(beach.replace('_', '\\_'))\n",
|
|
"\n",
|
|
" # Set minor axis ticks on each plot\n",
|
|
" ax1.yaxis.set_minor_locator(MultipleLocator(1))\n",
|
|
" ax1.yaxis.grid(True, which='minor', linestyle='--', alpha=0.1)\n",
|
|
" ax2.yaxis.grid(True, which='minor', linestyle='--', alpha=0.1)\n",
|
|
" ax3.yaxis.grid(True, which='minor', linestyle='--', alpha=0.1)\n",
|
|
" ax4.yaxis.grid(True, which='minor', linestyle='--', alpha=0.1)\n",
|
|
" ax5.yaxis.grid(True, which='minor', linestyle='--', alpha=0.1)\n",
|
|
" ax6.yaxis.grid(True, which='minor', linestyle='--', alpha=0.1)\n",
|
|
" ax7.yaxis.grid(True, which='minor', linestyle='--', alpha=0.1)\n",
|
|
" ax8.yaxis.grid(True, which='minor', linestyle='--', alpha=0.1)\n",
|
|
"\n",
|
|
" # # Print to figure\n",
|
|
" plt.savefig('07_{}.png'.format(beach), dpi=600, bbox_inches='tight')\n",
|
|
"\n",
|
|
" plt.show()\n",
|
|
" plt.close()\n",
|
|
" print('Done: {}'.format(beach))\n",
|
|
" \n",
|
|
" break\n",
|
|
"print('Done!')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Generate classification reports for each model\n",
|
|
"Use sklearn metrics to generate classification reports for each forecasting model."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import sklearn.metrics\n",
|
|
"\n",
|
|
"# Get observed impacts\n",
|
|
"df_obs = impacts['observed']\n",
|
|
"\n",
|
|
"# Convert storm regimes to categorical datatype\n",
|
|
"cat_type = CategoricalDtype(\n",
|
|
" categories=['swash', 'collision', 'overwash', 'inundation'], ordered=True)\n",
|
|
"df_obs.storm_regime = df_obs.storm_regime.astype(cat_type)\n",
|
|
"\n",
|
|
"for model in impacts['forecasted']:\n",
|
|
" df_for = impacts['forecasted'][model]\n",
|
|
" df_for.storm_regime = df_for.storm_regime.astype(cat_type)\n",
|
|
"\n",
|
|
" m = sklearn.metrics.classification_report(\n",
|
|
" df_obs.storm_regime.astype(cat_type).cat.codes.values,\n",
|
|
" df_for.storm_regime.astype(cat_type).cat.codes.values,\n",
|
|
" labels=[0, 1, 2, 3],\n",
|
|
" target_names=['swash', 'collision', 'overwash', 'inundation'])\n",
|
|
" print(model)\n",
|
|
" print(m)\n",
|
|
" print()\n",
|
|
" "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Check matthews coefficient\n",
|
|
"# https://scikit-learn.org/stable/modules/generated/sklearn.metrics.matthews_corrcoef.html\n",
|
|
"\n",
|
|
"for model in impacts['forecasted']:\n",
|
|
" df_for = impacts['forecasted'][model]\n",
|
|
" df_for.storm_regime = df_for.storm_regime.astype(cat_type)\n",
|
|
"\n",
|
|
" m = matthews_corrcoef(\n",
|
|
" df_obs.storm_regime.astype(cat_type).cat.codes.values,\n",
|
|
" df_for.storm_regime.astype(cat_type).cat.codes.values)\n",
|
|
" print('{}: {:.2f}'.format(model,m))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Check accuracy\n",
|
|
"# https://scikit-learn.org/stable/modules/generated/sklearn.metrics.matthews_corrcoef.html\n",
|
|
"\n",
|
|
"for model in impacts['forecasted']:\n",
|
|
" df_for = impacts['forecasted'][model]\n",
|
|
" df_for.storm_regime = df_for.storm_regime.astype(cat_type)\n",
|
|
"\n",
|
|
" m = sklearn.metrics.accuracy_score(\n",
|
|
" df_obs.storm_regime.astype(cat_type).cat.codes.values,\n",
|
|
" df_for.storm_regime.astype(cat_type).cat.codes.values)\n",
|
|
" print('{}: {:.2f}'.format(model,m))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from sklearn.metrics import confusion_matrix\n",
|
|
"# Check confusion matrix\n",
|
|
"for model in impacts['forecasted']:\n",
|
|
" df_for = impacts['forecasted'][model]\n",
|
|
" df_for.storm_regime = df_for.storm_regime.astype(cat_type)\n",
|
|
"\n",
|
|
" m = sklearn.metrics.confusion_matrix(\n",
|
|
" df_obs.storm_regime.astype(cat_type).cat.codes.values,\n",
|
|
" df_for.storm_regime.astype(cat_type).cat.codes.values,\n",
|
|
" labels=[0,1,2,3])\n",
|
|
" print('{}\\n{}'.format(model,m))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Create confusion matrix figure\n",
|
|
"From https://github.com/wcipriano/pretty-print-confusion-matrix/blob/master/confusion_matrix_pretty_print.py"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# -*- coding: utf-8 -*-\n",
|
|
"\"\"\"\n",
|
|
"plot a pretty confusion matrix with seaborn\n",
|
|
"Created on Mon Jun 25 14:17:37 2018\n",
|
|
"@author: Wagner Cipriano - wagnerbhbr - gmail - CEFETMG / MMC\n",
|
|
"REFerences:\n",
|
|
" https://www.mathworks.com/help/nnet/ref/plotconfusion.html\n",
|
|
" https://stackoverflow.com/questions/28200786/how-to-plot-scikit-learn-classification-report\n",
|
|
" https://stackoverflow.com/questions/5821125/how-to-plot-confusion-matrix-with-string-axis-rather-than-integer-in-python\n",
|
|
" https://www.programcreek.com/python/example/96197/seaborn.heatmap\n",
|
|
" https://stackoverflow.com/questions/19233771/sklearn-plot-confusion-matrix-with-labels/31720054\n",
|
|
" http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html#sphx-glr-auto-examples-model-selection-plot-confusion-matrix-py\n",
|
|
"\"\"\"\n",
|
|
"\n",
|
|
"#imports\n",
|
|
"from pandas import DataFrame\n",
|
|
"import numpy as np\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"import matplotlib.font_manager as fm\n",
|
|
"from matplotlib.collections import QuadMesh\n",
|
|
"import seaborn as sn\n",
|
|
"\n",
|
|
"\n",
|
|
"def get_new_fig(fn, figsize=[9,9]):\n",
|
|
" \"\"\" Init graphics \"\"\"\n",
|
|
" fig1 = plt.figure(fn, figsize)\n",
|
|
" ax1 = fig1.gca() #Get Current Axis\n",
|
|
" ax1.cla() # clear existing plot\n",
|
|
" return fig1, ax1\n",
|
|
"#\n",
|
|
"\n",
|
|
"def configcell_text_and_colors(array_df, lin, col, oText, facecolors, posi, fz, fmt, show_null_values=0):\n",
|
|
" \"\"\"\n",
|
|
" config cell text and colors\n",
|
|
" and return text elements to add and to dell\n",
|
|
" @TODO: use fmt\n",
|
|
" \"\"\"\n",
|
|
" text_add = []; text_del = [];\n",
|
|
" cell_val = array_df[lin][col]\n",
|
|
" tot_all = array_df[-1][-1]\n",
|
|
" per = (float(cell_val) / tot_all) * 100\n",
|
|
" curr_column = array_df[:,col]\n",
|
|
" ccl = len(curr_column)\n",
|
|
"\n",
|
|
" #last line and/or last column\n",
|
|
" if(col == (ccl - 1)) or (lin == (ccl - 1)):\n",
|
|
" #tots and percents\n",
|
|
" if(cell_val != 0):\n",
|
|
" if(col == ccl - 1) and (lin == ccl - 1):\n",
|
|
" tot_rig = 0\n",
|
|
" for i in range(array_df.shape[0] - 1):\n",
|
|
" tot_rig += array_df[i][i]\n",
|
|
" per_ok = (float(tot_rig) / cell_val) * 100\n",
|
|
" elif(col == ccl - 1):\n",
|
|
" tot_rig = array_df[lin][lin]\n",
|
|
" per_ok = (float(tot_rig) / cell_val) * 100\n",
|
|
" elif(lin == ccl - 1):\n",
|
|
" tot_rig = array_df[col][col]\n",
|
|
" per_ok = (float(tot_rig) / cell_val) * 100\n",
|
|
" per_err = 100 - per_ok\n",
|
|
" else:\n",
|
|
" per_ok = per_err = 0\n",
|
|
"\n",
|
|
" per_ok_s = ['%.1f%%'%(per_ok), '100%'] [per_ok == 100]\n",
|
|
"\n",
|
|
" #text to DEL\n",
|
|
" text_del.append(oText)\n",
|
|
"\n",
|
|
" #text to ADD\n",
|
|
" font_prop = fm.FontProperties(weight='bold', size=fz)\n",
|
|
" text_kwargs = dict(color='w', ha=\"center\", va=\"center\", gid='sum', fontproperties=font_prop)\n",
|
|
" lis_txt = ['%d'%(cell_val), per_ok_s, '%.1f%%'%(per_err)]\n",
|
|
" lis_kwa = [text_kwargs]\n",
|
|
" dic = text_kwargs.copy(); dic['color'] = 'g'; lis_kwa.append(dic);\n",
|
|
" dic = text_kwargs.copy(); dic['color'] = 'r'; lis_kwa.append(dic);\n",
|
|
" lis_pos = [(oText._x, oText._y-0.3), (oText._x, oText._y), (oText._x, oText._y+0.3)]\n",
|
|
" for i in range(len(lis_txt)):\n",
|
|
" newText = dict(x=lis_pos[i][0], y=lis_pos[i][1], text=lis_txt[i], kw=lis_kwa[i])\n",
|
|
" #print 'lin: %s, col: %s, newText: %s' %(lin, col, newText)\n",
|
|
" text_add.append(newText)\n",
|
|
" #print '\\n'\n",
|
|
"\n",
|
|
" #set background color for sum cells (last line and last column)\n",
|
|
" carr = [0.27, 0.30, 0.27, 1.0]\n",
|
|
" if(col == ccl - 1) and (lin == ccl - 1):\n",
|
|
" carr = [0.17, 0.20, 0.17, 1.0]\n",
|
|
" facecolors[posi] = carr\n",
|
|
"\n",
|
|
" else:\n",
|
|
" if(per > 0):\n",
|
|
" txt = '%s\\n%.1f%%' %(cell_val, per)\n",
|
|
" else:\n",
|
|
" if(show_null_values == 0):\n",
|
|
" txt = ''\n",
|
|
" elif(show_null_values == 1):\n",
|
|
" txt = '0'\n",
|
|
" else:\n",
|
|
" txt = '0\\n0.0%'\n",
|
|
" oText.set_text(txt)\n",
|
|
"\n",
|
|
" #main diagonal\n",
|
|
" if(col == lin):\n",
|
|
" #set color of the textin the diagonal to white\n",
|
|
" oText.set_color('w')\n",
|
|
" # set background color in the diagonal to blue\n",
|
|
" facecolors[posi] = [0.35, 0.8, 0.55, 1.0]\n",
|
|
" else:\n",
|
|
" oText.set_color('r')\n",
|
|
"\n",
|
|
" return text_add, text_del\n",
|
|
"#\n",
|
|
"\n",
|
|
"def insert_totals(df_cm):\n",
|
|
" \"\"\" insert total column and line (the last ones) \"\"\"\n",
|
|
" sum_col = []\n",
|
|
" for c in df_cm.columns:\n",
|
|
" sum_col.append( df_cm[c].sum() )\n",
|
|
" sum_lin = []\n",
|
|
" for item_line in df_cm.iterrows():\n",
|
|
" sum_lin.append( item_line[1].sum() )\n",
|
|
" df_cm['sum_lin'] = sum_lin\n",
|
|
" sum_col.append(np.sum(sum_lin))\n",
|
|
" df_cm.loc['sum_col'] = sum_col\n",
|
|
" #print ('\\ndf_cm:\\n', df_cm, '\\n\\b\\n')\n",
|
|
"#\n",
|
|
"\n",
|
|
"def pretty_plot_confusion_matrix(df_cm, annot=True, cmap=\"Oranges\", fmt='.2f', fz=11,\n",
|
|
" lw=0.5, cbar=False, figsize=[8,8], show_null_values=0, pred_val_axis='y'):\n",
|
|
" \"\"\"\n",
|
|
" print conf matrix with default layout (like matlab)\n",
|
|
" params:\n",
|
|
" df_cm dataframe (pandas) without totals\n",
|
|
" annot print text in each cell\n",
|
|
" cmap Oranges,Oranges_r,YlGnBu,Blues,RdBu, ... see:\n",
|
|
" fz fontsize\n",
|
|
" lw linewidth\n",
|
|
" pred_val_axis where to show the prediction values (x or y axis)\n",
|
|
" 'col' or 'x': show predicted values in columns (x axis) instead lines\n",
|
|
" 'lin' or 'y': show predicted values in lines (y axis)\n",
|
|
" \"\"\"\n",
|
|
" if(pred_val_axis in ('col', 'x')):\n",
|
|
" xlbl = 'Predicted'\n",
|
|
" ylbl = 'Actual'\n",
|
|
" else:\n",
|
|
" xlbl = 'Actual'\n",
|
|
" ylbl = 'Predicted'\n",
|
|
" df_cm = df_cm.T\n",
|
|
"\n",
|
|
" # create \"Total\" column\n",
|
|
" insert_totals(df_cm)\n",
|
|
"\n",
|
|
" #this is for print allways in the same window\n",
|
|
" fig, ax1 = get_new_fig('Conf matrix default', figsize)\n",
|
|
"\n",
|
|
" #thanks for seaborn\n",
|
|
" ax = sn.heatmap(df_cm, annot=annot, annot_kws={\"size\": fz}, linewidths=lw, ax=ax1,\n",
|
|
" cbar=cbar, cmap=cmap, linecolor='w', fmt=fmt)\n",
|
|
"\n",
|
|
" #set ticklabels rotation\n",
|
|
" ax.set_xticklabels(ax.get_xticklabels(), rotation = 45, fontsize = 10)\n",
|
|
" ax.set_yticklabels(ax.get_yticklabels(), rotation = 25, fontsize = 10)\n",
|
|
"\n",
|
|
" # Turn off all the ticks\n",
|
|
" for t in ax.xaxis.get_major_ticks():\n",
|
|
" t.tick1On = False\n",
|
|
" t.tick2On = False\n",
|
|
" for t in ax.yaxis.get_major_ticks():\n",
|
|
" t.tick1On = False\n",
|
|
" t.tick2On = False\n",
|
|
"\n",
|
|
" #face colors list\n",
|
|
" quadmesh = ax.findobj(QuadMesh)[0]\n",
|
|
" facecolors = quadmesh.get_facecolors()\n",
|
|
"\n",
|
|
" #iter in text elements\n",
|
|
" array_df = np.array( df_cm.to_records(index=False).tolist() )\n",
|
|
" text_add = []; text_del = [];\n",
|
|
" posi = -1 #from left to right, bottom to top.\n",
|
|
" for t in ax.collections[0].axes.texts: #ax.texts:\n",
|
|
" pos = np.array( t.get_position()) - [0.5,0.5]\n",
|
|
" lin = int(pos[1]); col = int(pos[0]);\n",
|
|
" posi += 1\n",
|
|
" #print ('>>> pos: %s, posi: %s, val: %s, txt: %s' %(pos, posi, array_df[lin][col], t.get_text()))\n",
|
|
"\n",
|
|
" #set text\n",
|
|
" txt_res = configcell_text_and_colors(array_df, lin, col, t, facecolors, posi, fz, fmt, show_null_values)\n",
|
|
"\n",
|
|
" text_add.extend(txt_res[0])\n",
|
|
" text_del.extend(txt_res[1])\n",
|
|
"\n",
|
|
" #remove the old ones\n",
|
|
" for item in text_del:\n",
|
|
" item.remove()\n",
|
|
" #append the new ones\n",
|
|
" for item in text_add:\n",
|
|
" ax.text(item['x'], item['y'], item['text'], **item['kw'])\n",
|
|
"\n",
|
|
" #titles and legends\n",
|
|
" ax.set_title('Confusion matrix')\n",
|
|
" ax.set_xlabel(xlbl)\n",
|
|
" ax.set_ylabel(ylbl)\n",
|
|
" plt.tight_layout() #set layout slim\n",
|
|
" plt.show()\n",
|
|
" return fig\n",
|
|
"#\n",
|
|
"\n",
|
|
"def plot_confusion_matrix_from_data(y_test, predictions, columns=None, annot=True, cmap=\"Oranges\",\n",
|
|
" fmt='.2f', fz=11, lw=0.5, cbar=False, figsize=[8,8], show_null_values=0, pred_val_axis='lin'):\n",
|
|
" \"\"\"\n",
|
|
" plot confusion matrix function with y_test (actual values) and predictions (predic),\n",
|
|
" whitout a confusion matrix yet\n",
|
|
" \"\"\"\n",
|
|
" from sklearn.metrics import confusion_matrix\n",
|
|
" from pandas import DataFrame\n",
|
|
"\n",
|
|
" #data\n",
|
|
" if(not columns):\n",
|
|
" #labels axis integer:\n",
|
|
" ##columns = range(1, len(np.unique(y_test))+1)\n",
|
|
" #labels axis string:\n",
|
|
" from string import ascii_uppercase\n",
|
|
" columns = ['class %s' %(i) for i in list(ascii_uppercase)[0:len(np.unique(y_test))]]\n",
|
|
"\n",
|
|
" confm = confusion_matrix(y_test, predictions)\n",
|
|
" cmap = 'Oranges';\n",
|
|
" fz = 11;\n",
|
|
" figsize=[9,9];\n",
|
|
" show_null_values = 2\n",
|
|
" df_cm = DataFrame(confm, index=columns, columns=columns)\n",
|
|
" pretty_plot_confusion_matrix(df_cm, fz=fz, cmap=cmap, figsize=figsize, show_null_values=show_null_values, pred_val_axis=pred_val_axis)\n",
|
|
"#\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"#\n",
|
|
"#TEST functions\n",
|
|
"#\n",
|
|
"def _test_cm():\n",
|
|
" #test function with confusion matrix done\n",
|
|
" array = np.array( [[13, 0, 1, 0, 2, 0],\n",
|
|
" [ 0, 50, 2, 0, 10, 0],\n",
|
|
" [ 0, 13, 16, 0, 0, 3],\n",
|
|
" [ 0, 0, 0, 13, 1, 0],\n",
|
|
" [ 0, 40, 0, 1, 15, 0],\n",
|
|
" [ 0, 0, 0, 0, 0, 20]])\n",
|
|
" #get pandas dataframe\n",
|
|
" df_cm = DataFrame(array, index=range(1,7), columns=range(1,7))\n",
|
|
" #colormap: see this and choose your more dear\n",
|
|
" cmap = 'PuRd'\n",
|
|
" pretty_plot_confusion_matrix(df_cm, cmap=cmap)\n",
|
|
"#\n",
|
|
"\n",
|
|
"def _test_data_class():\n",
|
|
" \"\"\" test function with y_test (actual values) and predictions (predic) \"\"\"\n",
|
|
" #data\n",
|
|
" y_test = np.array([1,2,3,4,5, 1,2,3,4,5, 1,2,3,4,5, 1,2,3,4,5, 1,2,3,4,5, 1,2,3,4,5, 1,2,3,4,5, 1,2,3,4,5, 1,2,3,4,5, 1,2,3,4,5, 1,2,3,4,5, 1,2,3,4,5, 1,2,3,4,5, 1,2,3,4,5, 1,2,3,4,5, 1,2,3,4,5, 1,2,3,4,5, 1,2,3,4,5, 1,2,3,4,5, 1,2,3,4,5, 1,2,3,4,5, 1,2,3,4,5])\n",
|
|
" predic = np.array([1,2,4,3,5, 1,2,4,3,5, 1,2,3,4,4, 1,4,3,4,5, 1,2,4,4,5, 1,2,4,4,5, 1,2,4,4,5, 1,2,4,4,5, 1,2,3,3,5, 1,2,3,3,5, 1,2,3,4,4, 1,2,3,4,1, 1,2,3,4,1, 1,2,3,4,1, 1,2,4,4,5, 1,2,4,4,5, 1,2,4,4,5, 1,2,4,4,5, 1,2,3,4,5, 1,2,3,4,5, 1,2,3,4,5, 1,2,3,4,5])\n",
|
|
" \"\"\"\n",
|
|
" Examples to validate output (confusion matrix plot)\n",
|
|
" actual: 5 and prediction 1 >> 3\n",
|
|
" actual: 2 and prediction 4 >> 1\n",
|
|
" actual: 3 and prediction 4 >> 10\n",
|
|
" \"\"\"\n",
|
|
" columns = []\n",
|
|
" annot = True;\n",
|
|
" cmap = 'Oranges';\n",
|
|
" fmt = '.2f'\n",
|
|
" lw = 0.5\n",
|
|
" cbar = False\n",
|
|
" show_null_values = 2\n",
|
|
" pred_val_axis = 'y'\n",
|
|
" #size::\n",
|
|
" fz = 12;\n",
|
|
" figsize = [9,9];\n",
|
|
" if(len(y_test) > 10):\n",
|
|
" fz=9; figsize=[14,14];\n",
|
|
" plot_confusion_matrix_from_data(y_test, predic, columns,\n",
|
|
" annot, cmap, fmt, fz, lw, cbar, figsize, show_null_values, pred_val_axis)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# plot_confusion_matrix_from_data(y_test, predictions, columns=None, annot=True, cmap=\"Oranges\",\n",
|
|
"# fmt='.2f', fz=11, lw=0.5, cbar=False, figsize=[8,8], show_null_values=0, pred_val_axis='lin'):\n",
|
|
"\n",
|
|
"matplotlib.rcParams['text.usetex'] = False\n",
|
|
"\n",
|
|
"forecast_model = 'postintertidal_slope_sto06'\n",
|
|
"\n",
|
|
"df_for = impacts['forecasted'][forecast_model]\n",
|
|
"df_for.storm_regime = df_for.storm_regime.astype(cat_type)\n",
|
|
"observed_regimes = df_obs.storm_regime.astype(cat_type).cat.codes.values\n",
|
|
"forecasted_regimes = df_for.storm_regime.astype(cat_type).cat.codes.values\n",
|
|
"\n",
|
|
"\n",
|
|
"confm = confusion_matrix(observed_regimes, forecasted_regimes,labels=[0,1,2,3])\n",
|
|
"labels=['swash','collision','overwash','inundation']\n",
|
|
"df_cm = DataFrame(confm, index=labels, columns=labels)\n",
|
|
"\n",
|
|
"fig = pretty_plot_confusion_matrix(df_cm, annot=True, cmap=\"Oranges\", fmt='.1f', fz=13,\n",
|
|
" lw=0.1, cbar=False, figsize=[8,5], show_null_values=1, pred_val_axis='y')\n",
|
|
"\n",
|
|
"fig.savefig('11_confusion_matrix',dpi=600)"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"hide_input": false,
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.6.6"
|
|
},
|
|
"toc": {
|
|
"base_numbering": 1,
|
|
"nav_menu": {},
|
|
"number_sections": true,
|
|
"sideBar": true,
|
|
"skip_h1_title": false,
|
|
"title_cell": "Table of Contents",
|
|
"title_sidebar": "Contents",
|
|
"toc_cell": false,
|
|
"toc_position": {
|
|
"height": "calc(100% - 180px)",
|
|
"left": "10px",
|
|
"top": "150px",
|
|
"width": "286.391px"
|
|
},
|
|
"toc_section_display": true,
|
|
"toc_window_display": true
|
|
},
|
|
"varInspector": {
|
|
"cols": {
|
|
"lenName": 16,
|
|
"lenType": 16,
|
|
"lenVar": 40
|
|
},
|
|
"kernels_config": {
|
|
"python": {
|
|
"delete_cmd_postfix": "",
|
|
"delete_cmd_prefix": "del ",
|
|
"library": "var_list.py",
|
|
"varRefreshCmd": "print(var_dic_list())"
|
|
},
|
|
"r": {
|
|
"delete_cmd_postfix": ") ",
|
|
"delete_cmd_prefix": "rm(",
|
|
"library": "var_list.r",
|
|
"varRefreshCmd": "cat(var_dic_list()) "
|
|
}
|
|
},
|
|
"types_to_exclude": [
|
|
"module",
|
|
"function",
|
|
"builtin_function_or_method",
|
|
"instance",
|
|
"_Feature"
|
|
],
|
|
"window_display": false
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|