remove model selection based on validation loss

This commit is contained in:
René Knaebel 2017-09-16 15:25:34 +02:00
parent b0e0cd904e
commit ec5a1101be
3 changed files with 30 additions and 23 deletions

27
main.py
View File

@ -135,7 +135,8 @@ def main_train(param=None):
logger.info("define callbacks")
callbacks = []
callbacks.append(ModelCheckpoint(filepath=args.clf_model,
monitor='val_loss',
monitor='loss',
# monitor='val_loss',
verbose=False,
save_best_only=True))
callbacks.append(CSVLogger(args.train_log))
@ -199,7 +200,7 @@ def main_train(param=None):
batch_size=args.batch_size,
epochs=args.epochs,
shuffle=True,
validation_split=0.2,
# validation_split=0.2,
class_weight=custom_class_weights)
logger.info("fix server model")
@ -223,7 +224,7 @@ def main_train(param=None):
epochs=args.epochs,
callbacks=callbacks,
shuffle=True,
validation_split=0.2,
# validation_split=0.2,
class_weight=custom_class_weights)
@ -285,16 +286,16 @@ def main_visualization():
logger.info("plot model")
model = load_model(args.clf_model, custom_objects=models.get_metrics())
visualize.plot_model_as(model, os.path.join(args.model_path, "model.png"))
logger.info("plot training curve")
logs = pd.read_csv(args.train_log)
if "acc" in logs.keys():
visualize.plot_training_curve(logs, "", "{}/client_train.png".format(args.model_path))
elif "client_acc" in logs.keys() and "server_acc" in logs.keys():
visualize.plot_training_curve(logs, "client_", "{}/client_train.png".format(args.model_path))
visualize.plot_training_curve(logs, "server_", "{}/server_train.png".format(args.model_path))
else:
logger.warning("Error while plotting training curves")
# logger.info("plot training curve")
# logs = pd.read_csv(args.train_log)
# if "acc" in logs.keys():
# visualize.plot_training_curve(logs, "", "{}/client_train.png".format(args.model_path))
# elif "client_acc" in logs.keys() and "server_acc" in logs.keys():
# visualize.plot_training_curve(logs, "client_", "{}/client_train.png".format(args.model_path))
# visualize.plot_training_curve(logs, "server_", "{}/server_train.png".format(args.model_path))
# else:
# logger.warning("Error while plotting training curves")
logger.info("plot pr curve")
visualize.plot_clf()

View File

@ -23,4 +23,4 @@ df.serverLabel = df.serverLabel.astype(np.bool)
df.virusTotalHits = df.virusTotalHits.astype(np.int8)
df.trustedHits = df.trustedHits.astype(np.int8)
df.to_csv("/tmp/rk/{}.csv".format(fn), encoding="utf-8")
df.to_csv("/tmp/rk/data/{}.csv".format(fn), encoding="utf-8")

View File

@ -58,17 +58,19 @@ def plot_precision_recall(y, y_pred, label=""):
# ax.step(recall[::-1], decreasing_max_precision, '-r')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.ylim([0.0, 1.0])
plt.xlim([0.0, 1.0])
# def plot_precision_recall_curves(y, y_pred):
# y = y.flatten()
# y_pred = y_pred.flatten()
# precision, recall, thresholds = precision_recall_curve(y, y_pred)
#
# plt.plot(recall, label="Recall")
# plt.plot(precision, label="Precision")
# plt.xlabel('Threshold')
# plt.ylabel('Score')
def plot_pr_curves(y, y_preds, label=""):
for idx, y in enumerate(y_preds):
y = y.flatten()
y_pred = y_pred.flatten()
precision, recall, thresholds = precision_recall_curve(y, y_pred)
score = fbeta_score(y, y_pred.round(), 1)
plt.plot(recall, precision, '--', label=f"{idx}{label} - {score:5.4}")
plt.xlabel('Recall')
plt.ylabel('Precision')
def score_model(y, prediction):
@ -91,6 +93,10 @@ def plot_roc_curve(mask, prediction, label=""):
roc_auc = auc(fpr, tpr)
plt.xscale('log')
plt.plot(fpr, tpr, label=f"{label} - {roc_auc:5.4}")
plt.ylim([0.0, 1.0])
plt.xlim([0.0, 1.0])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
def plot_confusion_matrix(y_true, y_pred, path,