add some print lines for better following the process structure
This commit is contained in:
parent
27f4d086eb
commit
c19d649bc4
19
main.py
19
main.py
@ -16,7 +16,6 @@ import models
|
|||||||
# create logger
|
# create logger
|
||||||
import visualize
|
import visualize
|
||||||
from arguments import get_model_args
|
from arguments import get_model_args
|
||||||
from server import test_server_only, train_server_only
|
|
||||||
from utils import exists_or_make_path, get_custom_class_weights, get_custom_sample_weights, load_model
|
from utils import exists_or_make_path, get_custom_class_weights, get_custom_sample_weights, load_model
|
||||||
|
|
||||||
logger = logging.getLogger('cisco_logger')
|
logger = logging.getLogger('cisco_logger')
|
||||||
@ -648,7 +647,8 @@ def main_beta():
|
|||||||
|
|
||||||
return res, df_server
|
return res, df_server
|
||||||
|
|
||||||
res = dataset.load_predictions(path)
|
logger.info(f"load results from {args.model_path}")
|
||||||
|
res = dataset.load_predictions(args.model_path)
|
||||||
model_keys = sorted(filter(lambda x: x.startswith("clf"), res.keys()), key=lambda x: int(x[4:-3]))
|
model_keys = sorted(filter(lambda x: x.startswith("clf"), res.keys()), key=lambda x: int(x[4:-3]))
|
||||||
|
|
||||||
client_preds = []
|
client_preds = []
|
||||||
@ -659,11 +659,13 @@ def main_beta():
|
|||||||
server_domain_preds = []
|
server_domain_preds = []
|
||||||
server_domain_avg_preds = []
|
server_domain_avg_preds = []
|
||||||
for model_name in model_keys:
|
for model_name in model_keys:
|
||||||
|
logger.info(f"load model {model_name}")
|
||||||
df, df_server = load_df(res[model_name])
|
df, df_server = load_df(res[model_name])
|
||||||
client_preds.append(df.client_pred.as_matrix())
|
client_preds.append(df.client_pred.as_matrix())
|
||||||
if "server_val" in df.columns:
|
if "server_val" in df.columns:
|
||||||
server_preds.append(df.server_pred.as_matrix())
|
server_preds.append(df.server_pred.as_matrix())
|
||||||
if df_server is not None:
|
if df_server is not None:
|
||||||
|
logger.info(f" group servers")
|
||||||
server_flow_preds.append(df_server.server_pred.as_matrix())
|
server_flow_preds.append(df_server.server_pred.as_matrix())
|
||||||
df_domain = df_server.groupby(df_server.domain).max()
|
df_domain = df_server.groupby(df_server.domain).max()
|
||||||
server_domain_preds.append(df_domain.server_pred.as_matrix())
|
server_domain_preds.append(df_domain.server_pred.as_matrix())
|
||||||
@ -672,6 +674,7 @@ def main_beta():
|
|||||||
|
|
||||||
curves[model_prefix][model_name] = confusion_matrix(df.client_val.as_matrix(),
|
curves[model_prefix][model_name] = confusion_matrix(df.client_val.as_matrix(),
|
||||||
df.client_pred.as_matrix().round())
|
df.client_pred.as_matrix().round())
|
||||||
|
logger.info(f" group users")
|
||||||
df_user = df.groupby(df.names).max()
|
df_user = df.groupby(df.names).max()
|
||||||
client_user_preds.append(df_user.client_pred.as_matrix())
|
client_user_preds.append(df_user.client_pred.as_matrix())
|
||||||
if "server_val" in df.columns:
|
if "server_val" in df.columns:
|
||||||
@ -840,25 +843,15 @@ def main():
|
|||||||
main_retrain()
|
main_retrain()
|
||||||
if "hyperband" == args.mode:
|
if "hyperband" == args.mode:
|
||||||
main_hyperband(args.data, args.domain_length, args.window, args.model_type, args.hyperband_results,
|
main_hyperband(args.data, args.domain_length, args.window, args.model_type, args.hyperband_results,
|
||||||
arg.hyper_max_iter)
|
args.hyper_max_iter)
|
||||||
if "test" == args.mode:
|
if "test" == args.mode:
|
||||||
main_test()
|
main_test()
|
||||||
if "fancy" == args.mode:
|
|
||||||
main_visualization()
|
|
||||||
if "all_fancy" == args.mode:
|
|
||||||
main_visualize_all()
|
|
||||||
if "beta" == args.mode:
|
if "beta" == args.mode:
|
||||||
main_beta()
|
main_beta()
|
||||||
if "all_beta" == args.mode:
|
if "all_beta" == args.mode:
|
||||||
plot_overall_result()
|
plot_overall_result()
|
||||||
if "server" == args.mode:
|
|
||||||
train_server_only()
|
|
||||||
if "server_test" == args.mode:
|
|
||||||
test_server_only()
|
|
||||||
if "embedding" == args.mode:
|
if "embedding" == args.mode:
|
||||||
main_visualize_all_embds()
|
main_visualize_all_embds()
|
||||||
if "stats" == args.mode:
|
|
||||||
main_stats()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
Reference in New Issue
Block a user