diff --git a/main.py b/main.py index bf6830c..f561b16 100644 --- a/main.py +++ b/main.py @@ -16,7 +16,6 @@ import models # create logger import visualize from arguments import get_model_args -from server import test_server_only, train_server_only from utils import exists_or_make_path, get_custom_class_weights, get_custom_sample_weights, load_model logger = logging.getLogger('cisco_logger') @@ -648,7 +647,8 @@ def main_beta(): return res, df_server - res = dataset.load_predictions(path) + logger.info(f"load results from {args.model_path}") + res = dataset.load_predictions(args.model_path) model_keys = sorted(filter(lambda x: x.startswith("clf"), res.keys()), key=lambda x: int(x[4:-3])) client_preds = [] @@ -659,11 +659,13 @@ def main_beta(): server_domain_preds = [] server_domain_avg_preds = [] for model_name in model_keys: + logger.info(f"load model {model_name}") df, df_server = load_df(res[model_name]) client_preds.append(df.client_pred.as_matrix()) if "server_val" in df.columns: server_preds.append(df.server_pred.as_matrix()) if df_server is not None: + logger.info(f" group servers") server_flow_preds.append(df_server.server_pred.as_matrix()) df_domain = df_server.groupby(df_server.domain).max() server_domain_preds.append(df_domain.server_pred.as_matrix()) @@ -672,6 +674,7 @@ def main_beta(): curves[model_prefix][model_name] = confusion_matrix(df.client_val.as_matrix(), df.client_pred.as_matrix().round()) + logger.info(f" group users") df_user = df.groupby(df.names).max() client_user_preds.append(df_user.client_pred.as_matrix()) if "server_val" in df.columns: @@ -840,25 +843,15 @@ def main(): main_retrain() if "hyperband" == args.mode: main_hyperband(args.data, args.domain_length, args.window, args.model_type, args.hyperband_results, - arg.hyper_max_iter) + args.hyper_max_iter) if "test" == args.mode: main_test() - if "fancy" == args.mode: - main_visualization() - if "all_fancy" == args.mode: - main_visualize_all() if "beta" == args.mode: main_beta() if "all_beta" == args.mode: plot_overall_result() - if "server" == args.mode: - train_server_only() - if "server_test" == args.mode: - test_server_only() if "embedding" == args.mode: main_visualize_all_embds() - if "stats" == args.mode: - main_stats() if __name__ == "__main__":