visualize per user stats
This commit is contained in:
parent
3f6779fa3d
commit
1da31cc97c
35
main.py
35
main.py
@ -241,7 +241,7 @@ def main_visualization():
|
||||
logger.warning(f"could not generate training curves: {e}")
|
||||
|
||||
client_pred, server_pred = dataset.load_predictions(args.future_prediction)
|
||||
client_pred, server_pred = client_pred.value, server_pred.value
|
||||
client_pred, server_pred = client_pred.value.flatten(), server_pred.value.flatten()
|
||||
logger.info("plot pr curve")
|
||||
visualize.plot_clf()
|
||||
visualize.plot_precision_recall(client_val, client_pred)
|
||||
@ -259,7 +259,7 @@ def main_visualization():
|
||||
|
||||
df_val = pd.DataFrame(data={"names": name_val, "client_val": client_val})
|
||||
user_vals = df_val.groupby(df_val.names).max().client_val.as_matrix().astype(float)
|
||||
df_pred = pd.DataFrame(data={"names": name_val, "client_val": client_val})
|
||||
df_pred = pd.DataFrame(data={"names": name_val, "client_val": client_pred})
|
||||
user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float)
|
||||
|
||||
visualize.plot_clf()
|
||||
@ -276,10 +276,10 @@ def main_visualization():
|
||||
# visualize.plot_confusion_matrix(server_val.argmax(1), server_pred.argmax(1),
|
||||
# "{}/server_cov.png".format(args.model_path),
|
||||
# normalize=False, title="Server Confusion Matrix")
|
||||
logger.info("visualize embedding")
|
||||
domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length)
|
||||
domain_embedding = np.load(args.model_path + "/domain_embds.npy")
|
||||
visualize.plot_embedding(domain_embedding, labels, path="{}/embd.png".format(args.model_path))
|
||||
# logger.info("visualize embedding")
|
||||
# domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length)
|
||||
# domain_embedding = np.load(args.model_path + "/domain_embds.npy")
|
||||
# visualize.plot_embedding(domain_embedding, labels, path="{}/embd.png".format(args.model_path))
|
||||
|
||||
|
||||
def main_visualize_all():
|
||||
@ -303,6 +303,29 @@ def main_visualize_all():
|
||||
visualize.plot_legend()
|
||||
visualize.plot_save("all_client_roc.png")
|
||||
|
||||
df_val = pd.DataFrame(data={"names": name_val, "client_val": client_val})
|
||||
user_vals = df_val.groupby(df_val.names).max().client_val.as_matrix().astype(float)
|
||||
|
||||
logger.info("plot user pr curves")
|
||||
visualize.plot_clf()
|
||||
for model_args in get_model_args(args):
|
||||
client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"])
|
||||
df_pred = pd.DataFrame(data={"names": name_val, "client_val": client_pred.value.flatten()})
|
||||
user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float)
|
||||
visualize.plot_precision_recall(user_vals, user_preds, model_args["model_path"])
|
||||
visualize.plot_legend()
|
||||
visualize.plot_save("all_user_client_prc.png")
|
||||
|
||||
logger.info("plot user roc curves")
|
||||
visualize.plot_clf()
|
||||
for model_args in get_model_args(args):
|
||||
client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"])
|
||||
df_pred = pd.DataFrame(data={"names": name_val, "client_val": client_pred.value.flatten()})
|
||||
user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float)
|
||||
visualize.plot_roc_curve(user_vals, user_preds, model_args["model_path"])
|
||||
visualize.plot_legend()
|
||||
visualize.plot_save("all_user_client_roc.png")
|
||||
|
||||
|
||||
def main_data():
|
||||
char_dict = dataset.get_character_dict()
|
||||
|
Loading…
Reference in New Issue
Block a user