visualize per user stats

This commit is contained in:
René Knaebel 2017-09-04 13:37:26 +02:00
parent 3f6779fa3d
commit 1da31cc97c
1 changed files with 29 additions and 6 deletions

35
main.py
View File

@ -241,7 +241,7 @@ def main_visualization():
logger.warning(f"could not generate training curves: {e}")
client_pred, server_pred = dataset.load_predictions(args.future_prediction)
client_pred, server_pred = client_pred.value, server_pred.value
client_pred, server_pred = client_pred.value.flatten(), server_pred.value.flatten()
logger.info("plot pr curve")
visualize.plot_clf()
visualize.plot_precision_recall(client_val, client_pred)
@ -259,7 +259,7 @@ def main_visualization():
df_val = pd.DataFrame(data={"names": name_val, "client_val": client_val})
user_vals = df_val.groupby(df_val.names).max().client_val.as_matrix().astype(float)
df_pred = pd.DataFrame(data={"names": name_val, "client_val": client_val})
df_pred = pd.DataFrame(data={"names": name_val, "client_val": client_pred})
user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float)
visualize.plot_clf()
@ -276,10 +276,10 @@ def main_visualization():
# visualize.plot_confusion_matrix(server_val.argmax(1), server_pred.argmax(1),
# "{}/server_cov.png".format(args.model_path),
# normalize=False, title="Server Confusion Matrix")
logger.info("visualize embedding")
domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length)
domain_embedding = np.load(args.model_path + "/domain_embds.npy")
visualize.plot_embedding(domain_embedding, labels, path="{}/embd.png".format(args.model_path))
# logger.info("visualize embedding")
# domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length)
# domain_embedding = np.load(args.model_path + "/domain_embds.npy")
# visualize.plot_embedding(domain_embedding, labels, path="{}/embd.png".format(args.model_path))
def main_visualize_all():
@ -303,6 +303,29 @@ def main_visualize_all():
visualize.plot_legend()
visualize.plot_save("all_client_roc.png")
df_val = pd.DataFrame(data={"names": name_val, "client_val": client_val})
user_vals = df_val.groupby(df_val.names).max().client_val.as_matrix().astype(float)
logger.info("plot user pr curves")
visualize.plot_clf()
for model_args in get_model_args(args):
client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"])
df_pred = pd.DataFrame(data={"names": name_val, "client_val": client_pred.value.flatten()})
user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float)
visualize.plot_precision_recall(user_vals, user_preds, model_args["model_path"])
visualize.plot_legend()
visualize.plot_save("all_user_client_prc.png")
logger.info("plot user roc curves")
visualize.plot_clf()
for model_args in get_model_args(args):
client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"])
df_pred = pd.DataFrame(data={"names": name_val, "client_val": client_pred.value.flatten()})
user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float)
visualize.plot_roc_curve(user_vals, user_preds, model_args["model_path"])
visualize.plot_legend()
visualize.plot_save("all_user_client_roc.png")
def main_data():
char_dict = dataset.get_character_dict()