add visualization for training curves, pr, roc
This commit is contained in:
parent
d0418b9efa
commit
b35f23e518
3
.gitignore
vendored
3
.gitignore
vendored
@ -101,3 +101,6 @@ ENV/
|
|||||||
*.csv
|
*.csv
|
||||||
*.csv.gz
|
*.csv.gz
|
||||||
*.csv.tar.*
|
*.csv.tar.*
|
||||||
|
*.h5
|
||||||
|
*.npy
|
||||||
|
*.png
|
||||||
|
10
Makefile
10
Makefile
@ -1,5 +1,11 @@
|
|||||||
|
run:
|
||||||
|
python3 main.py --modes train --batch 128 --model results/test --train data/rk_mini.csv.gz --epochs 10
|
||||||
|
|
||||||
test:
|
test:
|
||||||
python3 main.py --modes train --epochs 1 --batch 128 --train data/rk_mini.csv.gz
|
python3 main.py --modes test --batch 128 --model results/test --test data/rk_mini.csv.gz
|
||||||
|
|
||||||
|
fancy:
|
||||||
|
python3 main.py --modes fancy --batch 128 --model results/test --test data/rk_mini.csv.gz
|
||||||
|
|
||||||
hyper:
|
hyper:
|
||||||
python3 main.py --modes hyperband --epochs 1 --batch 64 --train data/rk_data.csv.gz
|
python3 main.py --modes hyperband --batch 64 --train data/rk_data.csv.gz
|
||||||
|
24
arguments.py
24
arguments.py
@ -7,20 +7,14 @@ parser.add_argument("--modes", action="store", dest="modes", nargs="+",
|
|||||||
default=[])
|
default=[])
|
||||||
|
|
||||||
parser.add_argument("--train", action="store", dest="train_data",
|
parser.add_argument("--train", action="store", dest="train_data",
|
||||||
default="data/full_dataset.csv.tar.bz2")
|
default="data/full_dataset.csv.tar.gz")
|
||||||
|
|
||||||
parser.add_argument("--test", action="store", dest="test_data",
|
parser.add_argument("--test", action="store", dest="test_data",
|
||||||
default="data/full_future_dataset.csv.tar.bz2")
|
default="data/full_future_dataset.csv.tar.gz")
|
||||||
|
|
||||||
# parser.add_argument("--h5data", action="store", dest="h5data",
|
parser.add_argument("--model", action="store", dest="model_path",
|
||||||
# default="")
|
default="results/model_x")
|
||||||
#
|
|
||||||
parser.add_argument("--models", action="store", dest="model_path",
|
|
||||||
default="models/models_x")
|
|
||||||
|
|
||||||
# parser.add_argument("--pred", action="store", dest="pred",
|
|
||||||
# default="")
|
|
||||||
#
|
|
||||||
parser.add_argument("--type", action="store", dest="model_type",
|
parser.add_argument("--type", action="store", dest="model_type",
|
||||||
default="paul")
|
default="paul")
|
||||||
|
|
||||||
@ -66,13 +60,17 @@ parser.add_argument("--domain_embd", action="store", dest="domain_embedding",
|
|||||||
#
|
#
|
||||||
# parser.add_argument("--tmp", action="store_true", dest="tmp")
|
# parser.add_argument("--tmp", action="store_true", dest="tmp")
|
||||||
#
|
#
|
||||||
# parser.add_argument("--test", action="store_true", dest="test")
|
parser.add_argument("--stop_early", action="store_true", dest="stop_early")
|
||||||
|
parser.add_argument("--gpu", action="store_true", dest="gpu")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def parse():
|
def parse():
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.embedding_model = os.path.join(args.model_path, "embd.h5")
|
args.embedding_model = os.path.join(args.model_path, "embd.h5")
|
||||||
args.clf_model = os.path.join(args.model_path, "clf.h5")
|
args.clf_model = os.path.join(args.model_path, "clf.h5")
|
||||||
args.train_log = os.path.join(args.model_path, "train.log")
|
args.train_log = os.path.join(args.model_path, "train.log.csv")
|
||||||
args.h5data = args.train_data + ".h5"
|
args.train_h5data = args.train_data + ".h5"
|
||||||
|
args.test_h5data = args.test_data + ".h5"
|
||||||
|
args.future_prediction = os.path.join(args.model_path, "future_predict.npy")
|
||||||
return args
|
return args
|
||||||
|
18
dataset.py
18
dataset.py
@ -199,3 +199,21 @@ def get_flow_per_user(df):
|
|||||||
users = df['user_hash'].unique().tolist()
|
users = df['user_hash'].unique().tolist()
|
||||||
for user in users:
|
for user in users:
|
||||||
yield df.loc[df.user_hash == user]
|
yield df.loc[df.user_hash == user]
|
||||||
|
|
||||||
|
|
||||||
|
def load_or_generate_h5data(h5data, train_data, domain_length, window_size):
|
||||||
|
char_dict = get_character_dict()
|
||||||
|
logger.info(f"check for h5data {h5data}")
|
||||||
|
try:
|
||||||
|
open(h5data, "r")
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.info("h5 data not found - load csv file")
|
||||||
|
user_flow_df = get_user_flow_data(train_data)
|
||||||
|
logger.info("create training dataset")
|
||||||
|
domain_tr, flow_tr, client_tr, server_tr = create_dataset_from_flows(user_flow_df, char_dict,
|
||||||
|
max_len=domain_length,
|
||||||
|
window_size=window_size)
|
||||||
|
logger.info("store training dataset as h5 file")
|
||||||
|
store_h5dataset(h5data, domain_tr, flow_tr, client_tr, server_tr)
|
||||||
|
logger.info("load h5 dataset")
|
||||||
|
return load_h5dataset(h5data)
|
||||||
|
132
main.py
132
main.py
@ -3,6 +3,8 @@ import logging
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import tensorflow as tf
|
||||||
from keras.callbacks import ModelCheckpoint, CSVLogger, EarlyStopping
|
from keras.callbacks import ModelCheckpoint, CSVLogger, EarlyStopping
|
||||||
from keras.models import load_model
|
from keras.models import load_model
|
||||||
|
|
||||||
@ -10,8 +12,10 @@ import arguments
|
|||||||
import dataset
|
import dataset
|
||||||
import hyperband
|
import hyperband
|
||||||
import models
|
import models
|
||||||
|
|
||||||
# create logger
|
# create logger
|
||||||
|
import visualize
|
||||||
|
from dataset import load_or_generate_h5data
|
||||||
|
|
||||||
logger = logging.getLogger('logger')
|
logger = logging.getLogger('logger')
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
@ -40,13 +44,15 @@ ch.setFormatter(formatter)
|
|||||||
# add ch to logger
|
# add ch to logger
|
||||||
logger.addHandler(ch)
|
logger.addHandler(ch)
|
||||||
|
|
||||||
|
print = logger.info
|
||||||
|
|
||||||
args = arguments.parse()
|
args = arguments.parse()
|
||||||
|
|
||||||
|
if args.gpu:
|
||||||
# config = tf.ConfigProto(log_device_placement=True)
|
config = tf.ConfigProto(log_device_placement=True)
|
||||||
# config.gpu_options.per_process_gpu_memory_fraction = 0.5
|
config.gpu_options.per_process_gpu_memory_fraction = 0.5
|
||||||
# config.gpu_options.allow_growth = True
|
config.gpu_options.allow_growth = True
|
||||||
# session = tf.Session(config=config)
|
session = tf.Session(config=config)
|
||||||
|
|
||||||
|
|
||||||
def exists_or_make_path(p):
|
def exists_or_make_path(p):
|
||||||
@ -56,32 +62,13 @@ def exists_or_make_path(p):
|
|||||||
|
|
||||||
def main_paul_best():
|
def main_paul_best():
|
||||||
char_dict = dataset.get_character_dict()
|
char_dict = dataset.get_character_dict()
|
||||||
domain_tr, flow_tr, client_tr, server_tr = load_or_generate_h5data(args.h5data, args.train_data,
|
pauls_best_params = models.pauls_networks.best_config
|
||||||
args.domain_length, args.window)
|
pauls_best_params["vocab_size"] = len(char_dict) + 1
|
||||||
|
main_train(pauls_best_params)
|
||||||
param = models.pauls_networks.best_config
|
|
||||||
param["vocab_size"] = len(char_dict) + 1
|
|
||||||
|
|
||||||
embedding, model = models.get_models_by_params(param)
|
|
||||||
|
|
||||||
model.compile(optimizer='adam',
|
|
||||||
loss='categorical_crossentropy',
|
|
||||||
metrics=['accuracy'])
|
|
||||||
|
|
||||||
model.fit([domain_tr, flow_tr],
|
|
||||||
[client_tr, server_tr],
|
|
||||||
batch_size=args.batch_size,
|
|
||||||
epochs=args.epochs,
|
|
||||||
shuffle=True,
|
|
||||||
validation_split=0.2)
|
|
||||||
|
|
||||||
embedding.save(args.embedding_model)
|
|
||||||
model.save(args.clf_model)
|
|
||||||
|
|
||||||
|
|
||||||
def main_hyperband():
|
def main_hyperband():
|
||||||
char_dict = dataset.get_character_dict()
|
char_dict = dataset.get_character_dict()
|
||||||
user_flow_df = dataset.get_user_flow_data(args.train_data)
|
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
# static params
|
# static params
|
||||||
@ -105,7 +92,7 @@ def main_hyperband():
|
|||||||
}
|
}
|
||||||
|
|
||||||
logger.info("create training dataset")
|
logger.info("create training dataset")
|
||||||
domain_tr, flow_tr, client_tr, server_tr = load_or_generate_h5data(args.h5data, args.train_data,
|
domain_tr, flow_tr, client_tr, server_tr = load_or_generate_h5data(args.train_h5data, args.train_data,
|
||||||
args.domain_length, args.window)
|
args.domain_length, args.window)
|
||||||
hp = hyperband.Hyperband(params,
|
hp = hyperband.Hyperband(params,
|
||||||
[domain_tr, flow_tr],
|
[domain_tr, flow_tr],
|
||||||
@ -114,33 +101,15 @@ def main_hyperband():
|
|||||||
json.dump(results, open("hyperband.json"))
|
json.dump(results, open("hyperband.json"))
|
||||||
|
|
||||||
|
|
||||||
def load_or_generate_h5data(h5data, train_data, domain_length, window_size):
|
def main_train(param=None):
|
||||||
char_dict = dataset.get_character_dict()
|
|
||||||
logger.info(f"check for h5data {h5data}")
|
|
||||||
try:
|
|
||||||
open(h5data, "r")
|
|
||||||
except FileNotFoundError:
|
|
||||||
logger.info("h5 data not found - load csv file")
|
|
||||||
user_flow_df = dataset.get_user_flow_data(train_data)
|
|
||||||
logger.info("create training dataset")
|
|
||||||
domain_tr, flow_tr, client_tr, server_tr = dataset.create_dataset_from_flows(user_flow_df, char_dict,
|
|
||||||
max_len=domain_length,
|
|
||||||
window_size=window_size)
|
|
||||||
logger.info("store training dataset as h5 file")
|
|
||||||
dataset.store_h5dataset(args.h5data, domain_tr, flow_tr, client_tr, server_tr)
|
|
||||||
logger.info("load h5 dataset")
|
|
||||||
return dataset.load_h5dataset(h5data)
|
|
||||||
|
|
||||||
|
|
||||||
def main_train():
|
|
||||||
exists_or_make_path(args.model_path)
|
exists_or_make_path(args.model_path)
|
||||||
|
|
||||||
char_dict = dataset.get_character_dict()
|
char_dict = dataset.get_character_dict()
|
||||||
domain_tr, flow_tr, client_tr, server_tr = load_or_generate_h5data(args.h5data, args.train_data,
|
domain_tr, flow_tr, client_tr, server_tr = load_or_generate_h5data(args.train_h5data, args.train_data,
|
||||||
args.domain_length, args.window)
|
args.domain_length, args.window)
|
||||||
|
|
||||||
# parameter
|
# parameter
|
||||||
param = {
|
p = {
|
||||||
"type": "paul",
|
"type": "paul",
|
||||||
"batch_size": 64,
|
"batch_size": 64,
|
||||||
"window_size": args.window,
|
"window_size": args.window,
|
||||||
@ -160,29 +129,34 @@ def main_train():
|
|||||||
'kernels_main': 3,
|
'kernels_main': 3,
|
||||||
'input_length': 40
|
'input_length': 40
|
||||||
}
|
}
|
||||||
|
if not param:
|
||||||
|
param = p
|
||||||
|
|
||||||
embedding, model = models.get_models_by_params(param)
|
embedding, model = models.get_models_by_params(param)
|
||||||
embedding.summary()
|
embedding.summary()
|
||||||
model.summary()
|
model.summary()
|
||||||
logger.info("define callbacks")
|
logger.info("define callbacks")
|
||||||
cp = ModelCheckpoint(filepath=args.clf_model,
|
callbacks = []
|
||||||
|
callbacks.append(ModelCheckpoint(filepath=args.clf_model,
|
||||||
monitor='val_loss',
|
monitor='val_loss',
|
||||||
verbose=False,
|
verbose=False,
|
||||||
save_best_only=True)
|
save_best_only=True))
|
||||||
csv = CSVLogger(args.train_log)
|
callbacks.append(CSVLogger(args.train_log))
|
||||||
early = EarlyStopping(monitor='val_loss',
|
if args.stop_early:
|
||||||
|
callbacks.append(EarlyStopping(monitor='val_loss',
|
||||||
patience=5,
|
patience=5,
|
||||||
verbose=False)
|
verbose=False))
|
||||||
logger.info("compile model")
|
logger.info("compile model")
|
||||||
|
custom_metrics = models.get_metric_functions()
|
||||||
model.compile(optimizer='adam',
|
model.compile(optimizer='adam',
|
||||||
loss='categorical_crossentropy',
|
loss='categorical_crossentropy',
|
||||||
metrics=['accuracy'])
|
metrics=['accuracy'] + custom_metrics)
|
||||||
logger.info("start training")
|
logger.info("start training")
|
||||||
model.fit([domain_tr, flow_tr],
|
model.fit([domain_tr, flow_tr],
|
||||||
[client_tr, server_tr],
|
[client_tr, server_tr],
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
epochs=args.epochs,
|
epochs=args.epochs,
|
||||||
callbacks=[cp, csv, early],
|
callbacks=callbacks,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
validation_split=0.2)
|
validation_split=0.2)
|
||||||
logger.info("save embedding")
|
logger.info("save embedding")
|
||||||
@ -190,42 +164,46 @@ def main_train():
|
|||||||
|
|
||||||
|
|
||||||
def main_test():
|
def main_test():
|
||||||
domain_val, flow_val, client_val, server_val = load_or_generate_h5data(args.h5data, args.train_data,
|
domain_val, flow_val, client_val, server_val = load_or_generate_h5data(args.test_h5data, args.test_data,
|
||||||
args.domain_length, args.window)
|
args.domain_length, args.window)
|
||||||
clf = load_model(args.clf_model)
|
clf = load_model(args.clf_model, custom_objects=models.get_metrics())
|
||||||
loss, _, _, client_acc, server_acc = clf.evaluate([domain_val, flow_val],
|
stats = clf.evaluate([domain_val, flow_val],
|
||||||
[client_val, server_val],
|
[client_val, server_val],
|
||||||
batch_size=args.batch_size)
|
batch_size=args.batch_size)
|
||||||
logger.info(f"loss: {loss}\nclient acc: {client_acc}\nserver acc: {server_acc}")
|
# logger.info(f"loss: {loss}\nclient acc: {client_acc}\nserver acc: {server_acc}")
|
||||||
|
logger.info(stats)
|
||||||
y_pred = clf.predict([domain_val, flow_val],
|
y_pred = clf.predict([domain_val, flow_val],
|
||||||
batch_size=args.batch_size)
|
batch_size=args.batch_size)
|
||||||
np.save(os.path.join(args.model_path, "future_predict.npy"), y_pred)
|
np.save(args.future_prediction, y_pred)
|
||||||
|
|
||||||
|
|
||||||
def main_visualization():
|
def main_visualization():
|
||||||
mask = dataset.load_mask_eval(args.data, args.test_image)
|
_, _, client_val, server_val = load_or_generate_h5data(args.test_h5data, args.test_data,
|
||||||
y_pred_path = args.model_path + "pred.npy"
|
args.domain_length, args.window)
|
||||||
logger.info("plot model")
|
logger.info("plot model")
|
||||||
model = load_model(args.model_path + "model.h5",
|
model = load_model(args.clf_model, custom_objects=models.get_metrics())
|
||||||
custom_objects=evaluation.get_metrics())
|
|
||||||
visualize.plot_model(model, args.model_path + "model.png")
|
visualize.plot_model(model, args.model_path + "model.png")
|
||||||
logger.info("plot training curve")
|
logger.info("plot training curve")
|
||||||
logs = pd.read_csv(args.model_path + "train.log")
|
logs = pd.read_csv(args.train_log)
|
||||||
visualize.plot_training_curve(logs, "{}/train.png".format(args.model_path))
|
visualize.plot_training_curve(logs, "client", "{}/client_train.png".format(args.model_path))
|
||||||
pred = np.load(y_pred_path)
|
visualize.plot_training_curve(logs, "server", "{}/server_train.png".format(args.model_path))
|
||||||
|
|
||||||
|
client_pred, server_pred = np.load(args.future_prediction)
|
||||||
logger.info("plot pr curve")
|
logger.info("plot pr curve")
|
||||||
visualize.plot_precision_recall(mask, pred, "{}/prc.png".format(args.model_path))
|
visualize.plot_precision_recall(client_val.value, client_pred, "{}/client_prc.png".format(args.model_path))
|
||||||
visualize.plot_precision_recall_curves(mask, pred, "{}/prc2.png".format(args.model_path))
|
visualize.plot_precision_recall(server_val.value, server_pred, "{}/server_prc.png".format(args.model_path))
|
||||||
|
visualize.plot_precision_recall_curves(client_val.value, client_pred, "{}/client_prc2.png".format(args.model_path))
|
||||||
|
visualize.plot_precision_recall_curves(server_val.value, server_pred, "{}/server_prc2.png".format(args.model_path))
|
||||||
logger.info("plot roc curve")
|
logger.info("plot roc curve")
|
||||||
visualize.plot_roc_curve(mask, pred, "{}/roc.png".format(args.model_path))
|
visualize.plot_roc_curve(client_val.value, client_pred, "{}/client_roc.png".format(args.model_path))
|
||||||
logger.info("store prediction image")
|
visualize.plot_roc_curve(server_val.value, server_pred, "{}/server_roc.png".format(args.model_path))
|
||||||
visualize.save_image_as(pred, "{}/pred.png".format(args.model_path))
|
|
||||||
|
|
||||||
|
|
||||||
def main_score():
|
def main_score():
|
||||||
mask = dataset.load_mask_eval(args.data, args.test_image)
|
# mask = dataset.load_mask_eval(args.data, args.test_image)
|
||||||
pred = np.load(args.pred)
|
# pred = np.load(args.pred)
|
||||||
visualize.score_model(mask, pred)
|
# visualize.score_model(mask, pred)
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
@ -1,3 +1,6 @@
|
|||||||
|
import keras.backend as K
|
||||||
|
|
||||||
|
import dataset
|
||||||
from . import pauls_networks
|
from . import pauls_networks
|
||||||
from . import renes_networks
|
from . import renes_networks
|
||||||
|
|
||||||
@ -6,7 +9,7 @@ def get_models_by_params(params: dict):
|
|||||||
# decomposing param section
|
# decomposing param section
|
||||||
# mainly embedding model
|
# mainly embedding model
|
||||||
network_type = params.get("type")
|
network_type = params.get("type")
|
||||||
vocab_size = params.get("vocab_size")
|
vocab_size = len(dataset.get_character_dict()) + 1
|
||||||
embedding_size = params.get("embedding_size")
|
embedding_size = params.get("embedding_size")
|
||||||
input_length = params.get("input_length")
|
input_length = params.get("input_length")
|
||||||
filter_embedding = params.get("filter_embedding")
|
filter_embedding = params.get("filter_embedding")
|
||||||
@ -30,3 +33,51 @@ def get_models_by_params(params: dict):
|
|||||||
filter_main, kernel_main, dense_dim, embedding_model)
|
filter_main, kernel_main, dense_dim, embedding_model)
|
||||||
|
|
||||||
return embedding_model, predict_model
|
return embedding_model, predict_model
|
||||||
|
|
||||||
|
|
||||||
|
def get_metrics():
|
||||||
|
return dict([
|
||||||
|
("precision", precision),
|
||||||
|
("recall", recall),
|
||||||
|
("f1_score", f1_score),
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
def get_metric_functions():
|
||||||
|
return [precision, recall, f1_score]
|
||||||
|
|
||||||
|
|
||||||
|
def precision(y_true, y_pred):
|
||||||
|
# Count positive samples.
|
||||||
|
true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
|
||||||
|
predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
|
||||||
|
return true_positives / (predicted_positives + K.epsilon())
|
||||||
|
|
||||||
|
|
||||||
|
def recall(y_true, y_pred):
|
||||||
|
# Count positive samples.
|
||||||
|
true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
|
||||||
|
possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
|
||||||
|
return true_positives / (possible_positives + K.epsilon())
|
||||||
|
|
||||||
|
|
||||||
|
def f1_score(y_true, y_pred):
|
||||||
|
return f_score(1)(y_true, y_pred)
|
||||||
|
|
||||||
|
|
||||||
|
def f05_score(y_true, y_pred):
|
||||||
|
return f_score(0.5)(y_true, y_pred)
|
||||||
|
|
||||||
|
|
||||||
|
def f_score(beta):
|
||||||
|
def _f(y_true, y_pred):
|
||||||
|
p = precision(y_true, y_pred)
|
||||||
|
r = recall(y_true, y_pred)
|
||||||
|
|
||||||
|
bb = beta ** 2
|
||||||
|
|
||||||
|
fbeta_score = (1 + bb) * (p * r) / (bb * p + r + K.epsilon())
|
||||||
|
|
||||||
|
return fbeta_score
|
||||||
|
|
||||||
|
return _f
|
||||||
|
146
visualize.py
Normal file
146
visualize.py
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
from keras.utils import plot_model
|
||||||
|
from sklearn.metrics import (
|
||||||
|
auc, classification_report, confusion_matrix, fbeta_score, precision_recall_curve,
|
||||||
|
roc_auc_score, roc_curve
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def scores(y_true, y_pred):
|
||||||
|
for (path, dirnames, fnames) in os.walk("results/"):
|
||||||
|
for f in fnames:
|
||||||
|
if path[-1] == "1" and f.endswith("npy"):
|
||||||
|
y_pred = np.load(os.path.join(path, f)).flatten()
|
||||||
|
print(path)
|
||||||
|
tp = np.sum(np.logical_and(y_pred >= 0.5, y_true == 1))
|
||||||
|
tn = np.sum(np.logical_and(y_pred < 0.5, y_true == 0))
|
||||||
|
fp = np.sum(np.logical_and(y_pred >= 0.5, y_true == 0))
|
||||||
|
fn = np.sum(np.logical_and(y_pred < 0.5, y_true == 1))
|
||||||
|
precision = tp / (tp + fp)
|
||||||
|
recall = tp / (tp + fn)
|
||||||
|
accuracy = (tp + tn) / len(y_true)
|
||||||
|
f1_score = 2 * (precision * recall) / (precision + recall)
|
||||||
|
f05_score = (1 + 0.5 ** 2) * (precision * recall) / (0.5 ** 2 * precision + recall)
|
||||||
|
print(" precision:", precision)
|
||||||
|
print(" recall:", recall)
|
||||||
|
print(" accuracy:", accuracy)
|
||||||
|
print(" f1 score:", f1_score)
|
||||||
|
print(" f0.5 score:", f05_score)
|
||||||
|
|
||||||
|
|
||||||
|
def plot_precision_recall(mask, prediction, path):
|
||||||
|
y = mask.flatten()
|
||||||
|
y_pred = prediction.flatten()
|
||||||
|
precision, recall, thresholds = precision_recall_curve(y, y_pred)
|
||||||
|
decreasing_max_precision = np.maximum.accumulate(precision)[::-1]
|
||||||
|
|
||||||
|
plt.clf()
|
||||||
|
# fig, ax = plt.subplots(1, 1)
|
||||||
|
# ax.hold(True)
|
||||||
|
plt.plot(recall, precision, '--b')
|
||||||
|
# ax.step(recall[::-1], decreasing_max_precision, '-r')
|
||||||
|
plt.xlabel('Recall')
|
||||||
|
plt.ylabel('Precision')
|
||||||
|
|
||||||
|
plt.savefig(path, dpi=600)
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
|
def plot_precision_recall_curves(mask, prediction, path):
|
||||||
|
y = mask.flatten()
|
||||||
|
y_pred = prediction.flatten()
|
||||||
|
precision, recall, thresholds = precision_recall_curve(y, y_pred)
|
||||||
|
|
||||||
|
plt.clf()
|
||||||
|
plt.plot(recall, label="Recall")
|
||||||
|
plt.plot(precision, label="Precision")
|
||||||
|
plt.xlabel('Threshold')
|
||||||
|
plt.ylabel('Score')
|
||||||
|
|
||||||
|
plt.savefig(path, dpi=600)
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
|
def score_model(y, prediction):
|
||||||
|
y = y.flatten()
|
||||||
|
y_pred = prediction.flatten()
|
||||||
|
|
||||||
|
precision, recall, thresholds = precision_recall_curve(y, y_pred)
|
||||||
|
|
||||||
|
print(classification_report(y, y_pred.round()))
|
||||||
|
print("Area under PR curve", auc(recall, precision))
|
||||||
|
print("roc auc score", roc_auc_score(y, y_pred))
|
||||||
|
print("F1 Score", fbeta_score(y, y_pred.round(), 1))
|
||||||
|
print("F0.5 Score", fbeta_score(y, y_pred.round(), 0.5))
|
||||||
|
|
||||||
|
|
||||||
|
def plot_roc_curve(mask, prediction, path):
|
||||||
|
y = mask.flatten()
|
||||||
|
y_pred = prediction.flatten()
|
||||||
|
fpr, tpr, thresholds = roc_curve(y, y_pred)
|
||||||
|
roc_auc = auc(fpr, tpr)
|
||||||
|
plt.clf()
|
||||||
|
plt.plot(fpr, tpr)
|
||||||
|
plt.savefig(path, dpi=600)
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
print("roc_auc", roc_auc)
|
||||||
|
|
||||||
|
|
||||||
|
def plot_confusion_matrix(y_true, y_pred,
|
||||||
|
normalize=False,
|
||||||
|
title='Confusion matrix',
|
||||||
|
cmap="Blues"):
|
||||||
|
"""
|
||||||
|
This function prints and plots the confusion matrix.
|
||||||
|
Normalization can be applied by setting `normalize=True`.
|
||||||
|
"""
|
||||||
|
plt.clf()
|
||||||
|
cm = confusion_matrix(y_true, y_pred)
|
||||||
|
classes = [0, 1]
|
||||||
|
plt.imshow(cm, interpolation='nearest', cmap=cmap)
|
||||||
|
plt.title(title)
|
||||||
|
plt.colorbar()
|
||||||
|
tick_marks = np.arange(len(classes))
|
||||||
|
plt.xticks(tick_marks, classes, rotation=45)
|
||||||
|
plt.yticks(tick_marks, classes)
|
||||||
|
|
||||||
|
if normalize:
|
||||||
|
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
|
||||||
|
print("Normalized confusion matrix")
|
||||||
|
else:
|
||||||
|
print('Confusion matrix, without normalization')
|
||||||
|
|
||||||
|
print(cm)
|
||||||
|
|
||||||
|
thresh = cm.max() / 2.
|
||||||
|
for i, j in ((i, j) for i in range(cm.shape[0]) for j in range(cm.shape[1])):
|
||||||
|
plt.text(j, i, cm[i, j],
|
||||||
|
horizontalalignment="center",
|
||||||
|
color="white" if cm[i, j] > thresh else "black")
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.ylabel('True label')
|
||||||
|
plt.xlabel('Predicted label')
|
||||||
|
|
||||||
|
|
||||||
|
def plot_training_curve(logs, key, path, dpi=600):
|
||||||
|
plt.clf()
|
||||||
|
plt.plot(logs[f"{key}_acc"], label="accuracy")
|
||||||
|
plt.plot(logs[f"{key}_f1_score"], label="f1_score")
|
||||||
|
|
||||||
|
plt.plot(logs[f"val_{key}_acc"], label="accuracy")
|
||||||
|
plt.plot(logs[f"val_{key}_f1_score"], label="val_f1_score")
|
||||||
|
|
||||||
|
plt.xlabel('epoch')
|
||||||
|
plt.ylabel('percentage')
|
||||||
|
plt.legend()
|
||||||
|
plt.savefig(path, dpi=dpi)
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
|
def plot_model_as(model, path):
|
||||||
|
plot_model(model, to_file=path, show_shapes=True, show_layer_names=True)
|
Loading…
Reference in New Issue
Block a user