diff --git a/main.py b/main.py index 733b3f4..76da15f 100644 --- a/main.py +++ b/main.py @@ -241,7 +241,7 @@ def main_visualization(): logger.warning(f"could not generate training curves: {e}") client_pred, server_pred = dataset.load_predictions(args.future_prediction) - client_pred, server_pred = client_pred.value, server_pred.value + client_pred, server_pred = client_pred.value.flatten(), server_pred.value.flatten() logger.info("plot pr curve") visualize.plot_clf() visualize.plot_precision_recall(client_val, client_pred) @@ -259,7 +259,7 @@ def main_visualization(): df_val = pd.DataFrame(data={"names": name_val, "client_val": client_val}) user_vals = df_val.groupby(df_val.names).max().client_val.as_matrix().astype(float) - df_pred = pd.DataFrame(data={"names": name_val, "client_val": client_val}) + df_pred = pd.DataFrame(data={"names": name_val, "client_val": client_pred}) user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float) visualize.plot_clf() @@ -276,10 +276,10 @@ def main_visualization(): # visualize.plot_confusion_matrix(server_val.argmax(1), server_pred.argmax(1), # "{}/server_cov.png".format(args.model_path), # normalize=False, title="Server Confusion Matrix") - logger.info("visualize embedding") - domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length) - domain_embedding = np.load(args.model_path + "/domain_embds.npy") - visualize.plot_embedding(domain_embedding, labels, path="{}/embd.png".format(args.model_path)) + # logger.info("visualize embedding") + # domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length) + # domain_embedding = np.load(args.model_path + "/domain_embds.npy") + # visualize.plot_embedding(domain_embedding, labels, path="{}/embd.png".format(args.model_path)) def main_visualize_all(): @@ -303,6 +303,29 @@ def main_visualize_all(): visualize.plot_legend() visualize.plot_save("all_client_roc.png") + df_val = pd.DataFrame(data={"names": name_val, "client_val": client_val}) + user_vals = df_val.groupby(df_val.names).max().client_val.as_matrix().astype(float) + + logger.info("plot user pr curves") + visualize.plot_clf() + for model_args in get_model_args(args): + client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"]) + df_pred = pd.DataFrame(data={"names": name_val, "client_val": client_pred.value.flatten()}) + user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float) + visualize.plot_precision_recall(user_vals, user_preds, model_args["model_path"]) + visualize.plot_legend() + visualize.plot_save("all_user_client_prc.png") + + logger.info("plot user roc curves") + visualize.plot_clf() + for model_args in get_model_args(args): + client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"]) + df_pred = pd.DataFrame(data={"names": name_val, "client_val": client_pred.value.flatten()}) + user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float) + visualize.plot_roc_curve(user_vals, user_preds, model_args["model_path"]) + visualize.plot_legend() + visualize.plot_save("all_user_client_roc.png") + def main_data(): char_dict = dataset.get_character_dict()