remove model selection based on validation loss

This commit is contained in:
René Knaebel 2017-09-16 15:25:34 +02:00
parent b0e0cd904e
commit ec5a1101be
3 changed files with 30 additions and 23 deletions

27
main.py
View File

@ -135,7 +135,8 @@ def main_train(param=None):
logger.info("define callbacks") logger.info("define callbacks")
callbacks = [] callbacks = []
callbacks.append(ModelCheckpoint(filepath=args.clf_model, callbacks.append(ModelCheckpoint(filepath=args.clf_model,
monitor='val_loss', monitor='loss',
# monitor='val_loss',
verbose=False, verbose=False,
save_best_only=True)) save_best_only=True))
callbacks.append(CSVLogger(args.train_log)) callbacks.append(CSVLogger(args.train_log))
@ -199,7 +200,7 @@ def main_train(param=None):
batch_size=args.batch_size, batch_size=args.batch_size,
epochs=args.epochs, epochs=args.epochs,
shuffle=True, shuffle=True,
validation_split=0.2, # validation_split=0.2,
class_weight=custom_class_weights) class_weight=custom_class_weights)
logger.info("fix server model") logger.info("fix server model")
@ -223,7 +224,7 @@ def main_train(param=None):
epochs=args.epochs, epochs=args.epochs,
callbacks=callbacks, callbacks=callbacks,
shuffle=True, shuffle=True,
validation_split=0.2, # validation_split=0.2,
class_weight=custom_class_weights) class_weight=custom_class_weights)
@ -285,16 +286,16 @@ def main_visualization():
logger.info("plot model") logger.info("plot model")
model = load_model(args.clf_model, custom_objects=models.get_metrics()) model = load_model(args.clf_model, custom_objects=models.get_metrics())
visualize.plot_model_as(model, os.path.join(args.model_path, "model.png")) visualize.plot_model_as(model, os.path.join(args.model_path, "model.png"))
logger.info("plot training curve") # logger.info("plot training curve")
logs = pd.read_csv(args.train_log) # logs = pd.read_csv(args.train_log)
if "acc" in logs.keys(): # if "acc" in logs.keys():
visualize.plot_training_curve(logs, "", "{}/client_train.png".format(args.model_path)) # visualize.plot_training_curve(logs, "", "{}/client_train.png".format(args.model_path))
elif "client_acc" in logs.keys() and "server_acc" in logs.keys(): # elif "client_acc" in logs.keys() and "server_acc" in logs.keys():
visualize.plot_training_curve(logs, "client_", "{}/client_train.png".format(args.model_path)) # visualize.plot_training_curve(logs, "client_", "{}/client_train.png".format(args.model_path))
visualize.plot_training_curve(logs, "server_", "{}/server_train.png".format(args.model_path)) # visualize.plot_training_curve(logs, "server_", "{}/server_train.png".format(args.model_path))
else: # else:
logger.warning("Error while plotting training curves") # logger.warning("Error while plotting training curves")
logger.info("plot pr curve") logger.info("plot pr curve")
visualize.plot_clf() visualize.plot_clf()

View File

@ -23,4 +23,4 @@ df.serverLabel = df.serverLabel.astype(np.bool)
df.virusTotalHits = df.virusTotalHits.astype(np.int8) df.virusTotalHits = df.virusTotalHits.astype(np.int8)
df.trustedHits = df.trustedHits.astype(np.int8) df.trustedHits = df.trustedHits.astype(np.int8)
df.to_csv("/tmp/rk/{}.csv".format(fn), encoding="utf-8") df.to_csv("/tmp/rk/data/{}.csv".format(fn), encoding="utf-8")

View File

@ -58,17 +58,19 @@ def plot_precision_recall(y, y_pred, label=""):
# ax.step(recall[::-1], decreasing_max_precision, '-r') # ax.step(recall[::-1], decreasing_max_precision, '-r')
plt.xlabel('Recall') plt.xlabel('Recall')
plt.ylabel('Precision') plt.ylabel('Precision')
plt.ylim([0.0, 1.0])
plt.xlim([0.0, 1.0])
# def plot_precision_recall_curves(y, y_pred): def plot_pr_curves(y, y_preds, label=""):
# y = y.flatten() for idx, y in enumerate(y_preds):
# y_pred = y_pred.flatten() y = y.flatten()
# precision, recall, thresholds = precision_recall_curve(y, y_pred) y_pred = y_pred.flatten()
# precision, recall, thresholds = precision_recall_curve(y, y_pred)
# plt.plot(recall, label="Recall") score = fbeta_score(y, y_pred.round(), 1)
# plt.plot(precision, label="Precision") plt.plot(recall, precision, '--', label=f"{idx}{label} - {score:5.4}")
# plt.xlabel('Threshold') plt.xlabel('Recall')
# plt.ylabel('Score') plt.ylabel('Precision')
def score_model(y, prediction): def score_model(y, prediction):
@ -91,6 +93,10 @@ def plot_roc_curve(mask, prediction, label=""):
roc_auc = auc(fpr, tpr) roc_auc = auc(fpr, tpr)
plt.xscale('log') plt.xscale('log')
plt.plot(fpr, tpr, label=f"{label} - {roc_auc:5.4}") plt.plot(fpr, tpr, label=f"{label} - {roc_auc:5.4}")
plt.ylim([0.0, 1.0])
plt.xlim([0.0, 1.0])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
def plot_confusion_matrix(y_true, y_pred, path, def plot_confusion_matrix(y_true, y_pred, path,