From ec5a1101be735c5750d237e7158c612057a222d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20Knaebel?= Date: Sat, 16 Sep 2017 15:25:34 +0200 Subject: [PATCH] remove model selection based on validation loss --- main.py | 27 ++++++++++++++------------- scripts/make_csv_dataset.py | 2 +- visualize.py | 24 +++++++++++++++--------- 3 files changed, 30 insertions(+), 23 deletions(-) diff --git a/main.py b/main.py index 2a2f736..5851314 100644 --- a/main.py +++ b/main.py @@ -135,7 +135,8 @@ def main_train(param=None): logger.info("define callbacks") callbacks = [] callbacks.append(ModelCheckpoint(filepath=args.clf_model, - monitor='val_loss', + monitor='loss', + # monitor='val_loss', verbose=False, save_best_only=True)) callbacks.append(CSVLogger(args.train_log)) @@ -199,7 +200,7 @@ def main_train(param=None): batch_size=args.batch_size, epochs=args.epochs, shuffle=True, - validation_split=0.2, + # validation_split=0.2, class_weight=custom_class_weights) logger.info("fix server model") @@ -223,7 +224,7 @@ def main_train(param=None): epochs=args.epochs, callbacks=callbacks, shuffle=True, - validation_split=0.2, + # validation_split=0.2, class_weight=custom_class_weights) @@ -285,16 +286,16 @@ def main_visualization(): logger.info("plot model") model = load_model(args.clf_model, custom_objects=models.get_metrics()) visualize.plot_model_as(model, os.path.join(args.model_path, "model.png")) - - logger.info("plot training curve") - logs = pd.read_csv(args.train_log) - if "acc" in logs.keys(): - visualize.plot_training_curve(logs, "", "{}/client_train.png".format(args.model_path)) - elif "client_acc" in logs.keys() and "server_acc" in logs.keys(): - visualize.plot_training_curve(logs, "client_", "{}/client_train.png".format(args.model_path)) - visualize.plot_training_curve(logs, "server_", "{}/server_train.png".format(args.model_path)) - else: - logger.warning("Error while plotting training curves") + + # logger.info("plot training curve") + # logs = pd.read_csv(args.train_log) + # if "acc" in logs.keys(): + # visualize.plot_training_curve(logs, "", "{}/client_train.png".format(args.model_path)) + # elif "client_acc" in logs.keys() and "server_acc" in logs.keys(): + # visualize.plot_training_curve(logs, "client_", "{}/client_train.png".format(args.model_path)) + # visualize.plot_training_curve(logs, "server_", "{}/server_train.png".format(args.model_path)) + # else: + # logger.warning("Error while plotting training curves") logger.info("plot pr curve") visualize.plot_clf() diff --git a/scripts/make_csv_dataset.py b/scripts/make_csv_dataset.py index 2af0279..21414f3 100644 --- a/scripts/make_csv_dataset.py +++ b/scripts/make_csv_dataset.py @@ -23,4 +23,4 @@ df.serverLabel = df.serverLabel.astype(np.bool) df.virusTotalHits = df.virusTotalHits.astype(np.int8) df.trustedHits = df.trustedHits.astype(np.int8) -df.to_csv("/tmp/rk/{}.csv".format(fn), encoding="utf-8") +df.to_csv("/tmp/rk/data/{}.csv".format(fn), encoding="utf-8") diff --git a/visualize.py b/visualize.py index bd330e6..243ce4b 100644 --- a/visualize.py +++ b/visualize.py @@ -58,17 +58,19 @@ def plot_precision_recall(y, y_pred, label=""): # ax.step(recall[::-1], decreasing_max_precision, '-r') plt.xlabel('Recall') plt.ylabel('Precision') + plt.ylim([0.0, 1.0]) + plt.xlim([0.0, 1.0]) -# def plot_precision_recall_curves(y, y_pred): -# y = y.flatten() -# y_pred = y_pred.flatten() -# precision, recall, thresholds = precision_recall_curve(y, y_pred) -# -# plt.plot(recall, label="Recall") -# plt.plot(precision, label="Precision") -# plt.xlabel('Threshold') -# plt.ylabel('Score') +def plot_pr_curves(y, y_preds, label=""): + for idx, y in enumerate(y_preds): + y = y.flatten() + y_pred = y_pred.flatten() + precision, recall, thresholds = precision_recall_curve(y, y_pred) + score = fbeta_score(y, y_pred.round(), 1) + plt.plot(recall, precision, '--', label=f"{idx}{label} - {score:5.4}") + plt.xlabel('Recall') + plt.ylabel('Precision') def score_model(y, prediction): @@ -91,6 +93,10 @@ def plot_roc_curve(mask, prediction, label=""): roc_auc = auc(fpr, tpr) plt.xscale('log') plt.plot(fpr, tpr, label=f"{label} - {roc_auc:5.4}") + plt.ylim([0.0, 1.0]) + plt.xlim([0.0, 1.0]) + plt.xlabel('False Positive Rate') + plt.ylabel('True Positive Rate') def plot_confusion_matrix(y_true, y_pred, path,