remove model selection based on validation loss
This commit is contained in:
parent
b0e0cd904e
commit
ec5a1101be
27
main.py
27
main.py
@ -135,7 +135,8 @@ def main_train(param=None):
|
||||
logger.info("define callbacks")
|
||||
callbacks = []
|
||||
callbacks.append(ModelCheckpoint(filepath=args.clf_model,
|
||||
monitor='val_loss',
|
||||
monitor='loss',
|
||||
# monitor='val_loss',
|
||||
verbose=False,
|
||||
save_best_only=True))
|
||||
callbacks.append(CSVLogger(args.train_log))
|
||||
@ -199,7 +200,7 @@ def main_train(param=None):
|
||||
batch_size=args.batch_size,
|
||||
epochs=args.epochs,
|
||||
shuffle=True,
|
||||
validation_split=0.2,
|
||||
# validation_split=0.2,
|
||||
class_weight=custom_class_weights)
|
||||
|
||||
logger.info("fix server model")
|
||||
@ -223,7 +224,7 @@ def main_train(param=None):
|
||||
epochs=args.epochs,
|
||||
callbacks=callbacks,
|
||||
shuffle=True,
|
||||
validation_split=0.2,
|
||||
# validation_split=0.2,
|
||||
class_weight=custom_class_weights)
|
||||
|
||||
|
||||
@ -285,16 +286,16 @@ def main_visualization():
|
||||
logger.info("plot model")
|
||||
model = load_model(args.clf_model, custom_objects=models.get_metrics())
|
||||
visualize.plot_model_as(model, os.path.join(args.model_path, "model.png"))
|
||||
|
||||
logger.info("plot training curve")
|
||||
logs = pd.read_csv(args.train_log)
|
||||
if "acc" in logs.keys():
|
||||
visualize.plot_training_curve(logs, "", "{}/client_train.png".format(args.model_path))
|
||||
elif "client_acc" in logs.keys() and "server_acc" in logs.keys():
|
||||
visualize.plot_training_curve(logs, "client_", "{}/client_train.png".format(args.model_path))
|
||||
visualize.plot_training_curve(logs, "server_", "{}/server_train.png".format(args.model_path))
|
||||
else:
|
||||
logger.warning("Error while plotting training curves")
|
||||
|
||||
# logger.info("plot training curve")
|
||||
# logs = pd.read_csv(args.train_log)
|
||||
# if "acc" in logs.keys():
|
||||
# visualize.plot_training_curve(logs, "", "{}/client_train.png".format(args.model_path))
|
||||
# elif "client_acc" in logs.keys() and "server_acc" in logs.keys():
|
||||
# visualize.plot_training_curve(logs, "client_", "{}/client_train.png".format(args.model_path))
|
||||
# visualize.plot_training_curve(logs, "server_", "{}/server_train.png".format(args.model_path))
|
||||
# else:
|
||||
# logger.warning("Error while plotting training curves")
|
||||
|
||||
logger.info("plot pr curve")
|
||||
visualize.plot_clf()
|
||||
|
@ -23,4 +23,4 @@ df.serverLabel = df.serverLabel.astype(np.bool)
|
||||
df.virusTotalHits = df.virusTotalHits.astype(np.int8)
|
||||
df.trustedHits = df.trustedHits.astype(np.int8)
|
||||
|
||||
df.to_csv("/tmp/rk/{}.csv".format(fn), encoding="utf-8")
|
||||
df.to_csv("/tmp/rk/data/{}.csv".format(fn), encoding="utf-8")
|
||||
|
24
visualize.py
24
visualize.py
@ -58,17 +58,19 @@ def plot_precision_recall(y, y_pred, label=""):
|
||||
# ax.step(recall[::-1], decreasing_max_precision, '-r')
|
||||
plt.xlabel('Recall')
|
||||
plt.ylabel('Precision')
|
||||
plt.ylim([0.0, 1.0])
|
||||
plt.xlim([0.0, 1.0])
|
||||
|
||||
|
||||
# def plot_precision_recall_curves(y, y_pred):
|
||||
# y = y.flatten()
|
||||
# y_pred = y_pred.flatten()
|
||||
# precision, recall, thresholds = precision_recall_curve(y, y_pred)
|
||||
#
|
||||
# plt.plot(recall, label="Recall")
|
||||
# plt.plot(precision, label="Precision")
|
||||
# plt.xlabel('Threshold')
|
||||
# plt.ylabel('Score')
|
||||
def plot_pr_curves(y, y_preds, label=""):
|
||||
for idx, y in enumerate(y_preds):
|
||||
y = y.flatten()
|
||||
y_pred = y_pred.flatten()
|
||||
precision, recall, thresholds = precision_recall_curve(y, y_pred)
|
||||
score = fbeta_score(y, y_pred.round(), 1)
|
||||
plt.plot(recall, precision, '--', label=f"{idx}{label} - {score:5.4}")
|
||||
plt.xlabel('Recall')
|
||||
plt.ylabel('Precision')
|
||||
|
||||
|
||||
def score_model(y, prediction):
|
||||
@ -91,6 +93,10 @@ def plot_roc_curve(mask, prediction, label=""):
|
||||
roc_auc = auc(fpr, tpr)
|
||||
plt.xscale('log')
|
||||
plt.plot(fpr, tpr, label=f"{label} - {roc_auc:5.4}")
|
||||
plt.ylim([0.0, 1.0])
|
||||
plt.xlim([0.0, 1.0])
|
||||
plt.xlabel('False Positive Rate')
|
||||
plt.ylabel('True Positive Rate')
|
||||
|
||||
|
||||
def plot_confusion_matrix(y_true, y_pred, path,
|
||||
|
Loading…
Reference in New Issue
Block a user