add some print lines for better following the process structure

This commit is contained in:
René Knaebel 2017-11-10 11:38:57 +01:00
parent 27f4d086eb
commit c19d649bc4
1 changed files with 6 additions and 13 deletions

19
main.py
View File

@ -16,7 +16,6 @@ import models
# create logger
import visualize
from arguments import get_model_args
from server import test_server_only, train_server_only
from utils import exists_or_make_path, get_custom_class_weights, get_custom_sample_weights, load_model
logger = logging.getLogger('cisco_logger')
@ -648,7 +647,8 @@ def main_beta():
return res, df_server
res = dataset.load_predictions(path)
logger.info(f"load results from {args.model_path}")
res = dataset.load_predictions(args.model_path)
model_keys = sorted(filter(lambda x: x.startswith("clf"), res.keys()), key=lambda x: int(x[4:-3]))
client_preds = []
@ -659,11 +659,13 @@ def main_beta():
server_domain_preds = []
server_domain_avg_preds = []
for model_name in model_keys:
logger.info(f"load model {model_name}")
df, df_server = load_df(res[model_name])
client_preds.append(df.client_pred.as_matrix())
if "server_val" in df.columns:
server_preds.append(df.server_pred.as_matrix())
if df_server is not None:
logger.info(f" group servers")
server_flow_preds.append(df_server.server_pred.as_matrix())
df_domain = df_server.groupby(df_server.domain).max()
server_domain_preds.append(df_domain.server_pred.as_matrix())
@ -672,6 +674,7 @@ def main_beta():
curves[model_prefix][model_name] = confusion_matrix(df.client_val.as_matrix(),
df.client_pred.as_matrix().round())
logger.info(f" group users")
df_user = df.groupby(df.names).max()
client_user_preds.append(df_user.client_pred.as_matrix())
if "server_val" in df.columns:
@ -840,25 +843,15 @@ def main():
main_retrain()
if "hyperband" == args.mode:
main_hyperband(args.data, args.domain_length, args.window, args.model_type, args.hyperband_results,
arg.hyper_max_iter)
args.hyper_max_iter)
if "test" == args.mode:
main_test()
if "fancy" == args.mode:
main_visualization()
if "all_fancy" == args.mode:
main_visualize_all()
if "beta" == args.mode:
main_beta()
if "all_beta" == args.mode:
plot_overall_result()
if "server" == args.mode:
train_server_only()
if "server_test" == args.mode:
test_server_only()
if "embedding" == args.mode:
main_visualize_all_embds()
if "stats" == args.mode:
main_stats()
if __name__ == "__main__":