forked from kilianv/CoastSat_WRL
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
96 lines
3.1 KiB
Python
96 lines
3.1 KiB
Python
# -*- coding: utf-8 -*-
|
|
|
|
#==========================================================#
|
|
# Train Neural Network
|
|
#==========================================================#
|
|
|
|
# Initial settings
|
|
import os
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
import ee
|
|
import pdb
|
|
import time
|
|
import pandas as pd
|
|
# other modules
|
|
from osgeo import gdal, ogr, osr
|
|
import pickle
|
|
import matplotlib.cm as cm
|
|
from pylab import ginput
|
|
|
|
# image processing modules
|
|
import skimage.filters as filters
|
|
import skimage.exposure as exposure
|
|
import skimage.transform as transform
|
|
import sklearn.decomposition as decomposition
|
|
import skimage.measure as measure
|
|
import skimage.morphology as morphology
|
|
from scipy import ndimage
|
|
import random
|
|
|
|
# machine learning modules
|
|
from sklearn.model_selection import train_test_split
|
|
from sklearn.neural_network import MLPClassifier
|
|
from sklearn.preprocessing import StandardScaler, Normalizer
|
|
from sklearn.externals import joblib
|
|
|
|
# import own modules
|
|
import functions.utils as utils
|
|
import functions.sds as sds
|
|
|
|
# some settings
|
|
np.seterr(all='ignore') # raise/ignore divisions by 0 and nans
|
|
plt.rcParams['axes.grid'] = True
|
|
plt.rcParams['figure.max_open_warning'] = 100
|
|
ee.Initialize()
|
|
#%%
|
|
|
|
# load training data
|
|
sitename = 'NARRA'
|
|
with open(os.path.join(os.getcwd(), 'sand_classification', sitename + '_sand.pkl'), 'rb') as f:
|
|
train_sand = pickle.load(f)
|
|
with open(os.path.join(os.getcwd(), 'sand_classification', sitename + '_swash.pkl'), 'rb') as f:
|
|
train_swash = pickle.load(f)
|
|
with open(os.path.join(os.getcwd(), 'sand_classification', sitename + '_water.pkl'), 'rb') as f:
|
|
train_water = pickle.load(f)
|
|
with open(os.path.join(os.getcwd(), 'sand_classification', sitename + '_other.pkl'), 'rb') as f:
|
|
train_other = pickle.load(f)
|
|
|
|
|
|
train_water = train_water[np.random.choice(train_water.shape[0], 1500),:]
|
|
train_other = train_other[np.random.choice(train_other.shape[0], 1500),:]
|
|
|
|
n_features = train_sand.shape[1]
|
|
n_sand = len(train_sand)
|
|
n_swash = len(train_swash)
|
|
n_water = len(train_water)
|
|
n_other = len(train_other)
|
|
|
|
training_data = np.zeros((n_sand+n_swash+n_water+n_other, n_features+1))
|
|
training_data[:n_sand,:n_features] = train_sand
|
|
training_data[n_sand:n_sand+n_swash,:n_features] = train_swash
|
|
training_data[n_sand+n_swash:n_sand+n_swash+n_water,:n_features] = train_water
|
|
training_data[n_sand+n_swash+n_water:n_sand+n_swash+n_water+n_other,:n_features] = train_other
|
|
training_data[:n_sand,n_features] = 1*np.ones((n_sand))
|
|
training_data[n_sand:n_sand+n_swash,n_features] = 2*np.ones((n_swash))
|
|
training_data[n_sand+n_swash:n_sand+n_swash+n_water,n_features] = 3*np.ones((n_water))
|
|
|
|
X = training_data[:,0:n_features]
|
|
y = training_data[:,n_features]
|
|
|
|
# divide in train and test
|
|
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=0)
|
|
|
|
# standardize data
|
|
#scaler = StandardScaler()
|
|
#scaler.fit(X_train)
|
|
#X_train = scaler.transform(X_train)
|
|
#X_test = scaler.transform(X_test)
|
|
|
|
# run NN on train dat and evaluate on test data
|
|
clf = MLPClassifier()
|
|
clf.fit(X,y)
|
|
clf.score(X_test,y_test)
|
|
|
|
# save NN model
|
|
joblib.dump(clf, os.path.join(os.getcwd(), 'sand_classification', 'NN1.pkl')) |