Source code for h3.models.simple_models

import os
import pickle
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn import metrics

from h3.constants import RANDOM_STATE
from h3.utils.directories import get_pickle_dir


[docs]def logistic_reg(x_train, y_train, x_test, y_test): model = LogisticRegression() model.fit(x_train, y_train) predictions = model.predict(x_test) model.score(x_test, y_test) importance = model.coef_[0] confusion_matrix = metrics.confusion_matrix(y_test, predictions) sns.heatmap( confusion_matrix / np.sum(confusion_matrix), annot=True, fmt=".2%", linewidths=.5, square=True, cmap="Blues_r" ) plt.ylabel(r"Actual label") plt.xlabel(r"Predicted label") plt.show()
[docs]def main(): # TODO: Rename this file to shorter name filename = "df_points_posthurr_flood_risk_storm_surge_soil_properties.pkl" filepath = os.path.join(get_pickle_dir(), filename) # TODO: change this path x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.25, random_state=RANDOM_STATE) logistic_reg(x_train, y_train, x_test, y_test)
if __name__ == "__main__": main()