ma_cisco_malware/main.py

706 lines
30 KiB
Python

import logging
import os
import numpy as np
import pandas as pd
import tensorflow as tf
from keras.callbacks import CSVLogger, EarlyStopping, ModelCheckpoint
from keras.models import Model
from sklearn.metrics import confusion_matrix
import arguments
import dataset
import hyperband
import models
# create logger
import visualize
from arguments import get_model_args
from utils import exists_or_make_path, get_custom_class_weights, load_model
logger = logging.getLogger('logger')
logger.setLevel(logging.DEBUG)
# create console handler and set level to debug
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
# create formatter
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# add formatter to ch
ch.setFormatter(formatter)
# add ch to logger
logger.addHandler(ch)
ch = logging.FileHandler("info.log")
ch.setLevel(logging.DEBUG)
# create formatter
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# add formatter to ch
ch.setFormatter(formatter)
# add ch to logger
logger.addHandler(ch)
args = arguments.parse()
if args.gpu:
config = tf.ConfigProto(log_device_placement=True)
config.gpu_options.per_process_gpu_memory_fraction = 0.5
config.gpu_options.allow_growth = True
session = tf.Session(config=config)
# default parameter
PARAMS = {
"type": args.model_type,
"depth": args.model_depth,
# "batch_size": 64,
"window_size": args.window,
"domain_length": args.domain_length,
"flow_features": 3,
#
'dropout': 0.5, # currently fix
'domain_features': args.domain_embedding,
'embedding': args.embedding,
'flow_features': 3,
'filter_embedding': args.filter_embedding,
'dense_embedding': args.dense_embedding,
'kernel_embedding': args.kernel_embedding,
'filter_main': args.filter_main,
'dense_main': args.dense_main,
'kernel_main': args.kernel_main,
'input_length': 40,
'model_output': args.model_output
}
def create_model(model, output_type):
if output_type == "both":
return Model(inputs=[model.in_domains, model.in_flows], outputs=(model.out_client, model.out_server))
elif output_type == "client":
return Model(inputs=[model.in_domains, model.in_flows], outputs=(model.out_client,))
else:
raise Exception("unknown model output")
def main_paul_best():
pauls_best_params = models.pauls_networks.best_config
main_train(pauls_best_params)
def main_hyperband():
param_dist = {
# static params
"type": [args.model_type],
"depth": [args.model_depth],
"model_output": [args.model_output],
"batch_size": [args.batch_size],
"window_size": [args.window],
"flow_features": [3],
"domain_length": [args.domain_length],
'input_length': [40],
# model params
"embedding_size": [2 ** x for x in range(3, 7)],
"filter_embedding": [2 ** x for x in range(1, 10)],
"kernel_embedding": [1, 3, 5, 7, 9],
"dense_embedding": [2 ** x for x in range(4, 10)],
"dropout": [0.5],
"filter_main": [2 ** x for x in range(1, 10)],
"kernel_main": [1, 3, 5, 7, 9],
"dense_main": [2 ** x for x in range(1, 12)],
}
logger.info("create training dataset")
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(args.data,
args.data,
args.domain_length,
args.window)
server_tr = np.max(server_windows_tr, axis=1)
if args.model_type in ("inter", "staggered"):
server_tr = np.expand_dims(server_windows_tr, 2)
hp = hyperband.Hyperband(param_dist,
[domain_tr, flow_tr],
[client_tr, server_tr],
max_iter=81,
savefile=args.hyperband_results)
results = hp.run()
return results
def main_train(param=None):
logger.info(f"Create model path {args.model_path}")
exists_or_make_path(args.model_path)
logger.info(f"Use command line arguments: {args}")
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(args.data,
args.data,
args.domain_length,
args.window)
logger.info("define callbacks")
callbacks = []
callbacks.append(ModelCheckpoint(filepath=args.clf_model,
monitor='loss',
verbose=False,
save_best_only=True))
callbacks.append(CSVLogger(args.train_log))
logger.info(f"Use early stopping: {args.stop_early}")
if args.stop_early:
callbacks.append(EarlyStopping(monitor='val_loss',
patience=5,
verbose=False))
custom_metrics = models.get_metric_functions()
server_tr = np.max(server_windows_tr, axis=1)
if args.class_weights:
logger.info("class weights: compute custom weights")
custom_class_weights = get_custom_class_weights(client_tr.value, server_tr)
logger.info(custom_class_weights)
else:
logger.info("class weights: set default")
custom_class_weights = None
if not param:
param = PARAMS
logger.info(f"Generator model with params: {param}")
embedding, model, new_model = models.get_models_by_params(param)
model = create_model(model, args.model_output)
new_model = create_model(new_model, args.model_output)
if args.model_type in ("inter", "staggered"):
server_tr = np.expand_dims(server_windows_tr, 2)
model = new_model
features = {"ipt_domains": domain_tr.value, "ipt_flows": flow_tr.value}
if args.model_output == "both":
labels = {"client": client_tr.value, "server": server_tr}
loss_weights = {"client": 1.0, "server": 1.0}
elif args.model_output == "client":
labels = {"client": client_tr.value}
loss_weights = {"client": 1.0}
elif args.model_output == "server":
labels = {"server": server_tr}
loss_weights = {"server": 1.0}
else:
raise ValueError("unknown model output")
logger.info(f"select model: {args.model_type}")
if args.model_type == "staggered":
logger.info("compile and pre-train server model")
logger.info(model.get_config())
model.compile(optimizer='adam',
loss='binary_crossentropy',
loss_weights={"client": 0.0, "server": 1.0},
metrics=['accuracy'] + custom_metrics)
model.summary()
model.fit(features, labels,
batch_size=args.batch_size,
epochs=args.epochs,
class_weight=custom_class_weights)
logger.info("fix server model")
model.get_layer("domain_cnn").trainable = False
model.get_layer("domain_cnn").layer.trainable = False
model.get_layer("dense_server").trainable = False
model.get_layer("server").trainable = False
loss_weights = {"client": 1.0, "server": 0.0}
logger.info("compile and train model")
embedding.summary()
logger.info(model.get_config())
model.compile(optimizer='adam',
loss='binary_crossentropy',
loss_weights=loss_weights,
metrics=['accuracy'] + custom_metrics)
model.summary()
model.fit(features, labels,
batch_size=args.batch_size,
epochs=args.epochs,
callbacks=callbacks,
class_weight=custom_class_weights)
def main_retrain():
source = os.path.join(args.model_source, "clf.h5")
destination = os.path.join(args.model_destination, "clf.h5")
logger.info(f"Use command line arguments: {args}")
exists_or_make_path(args.model_destination)
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(args.data,
args.data,
args.domain_length,
args.window)
logger.info("define callbacks")
callbacks = []
callbacks.append(ModelCheckpoint(filepath=destination,
monitor='loss',
verbose=False,
save_best_only=True))
callbacks.append(CSVLogger(args.train_log))
logger.info(f"Use early stopping: {args.stop_early}")
if args.stop_early:
callbacks.append(EarlyStopping(monitor='val_loss',
patience=5,
verbose=False))
server_tr = np.max(server_windows_tr, axis=1)
if args.class_weights:
logger.info("class weights: compute custom weights")
custom_class_weights = get_custom_class_weights(client_tr.value, server_tr)
logger.info(custom_class_weights)
else:
logger.info("class weights: set default")
custom_class_weights = None
logger.info(f"Load pretrained model")
embedding, model = load_model(source, custom_objects=models.get_custom_objects())
if args.model_type in ("inter", "staggered"):
server_tr = np.expand_dims(server_windows_tr, 2)
features = {"ipt_domains": domain_tr.value, "ipt_flows": flow_tr.value}
if args.model_output == "both":
labels = {"client": client_tr.value, "server": server_tr}
elif args.model_output == "client":
labels = {"client": client_tr.value}
elif args.model_output == "server":
labels = {"server": server_tr}
else:
raise ValueError("unknown model output")
logger.info("re-train model")
embedding.summary()
model.summary()
model.fit(features, labels,
batch_size=args.batch_size,
epochs=args.epochs,
callbacks=callbacks,
class_weight=custom_class_weights,
initial_epoch=args.initial_epoch)
def main_test():
logger.info("start test: load data")
domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.data,
args.data,
args.domain_length,
args.window)
domain_encs, _ = dataset.load_or_generate_domains(args.data, args.domain_length)
for model_args in get_model_args(args):
results = {}
logger.info(f"process model {model_args['model_path']}")
embd_model, clf_model = load_model(model_args["clf_model"], custom_objects=models.get_custom_objects())
pred = clf_model.predict([domain_val, flow_val],
batch_size=args.batch_size,
verbose=1)
if args.model_output == "both":
c_pred, s_pred = pred
results["client_pred"] = c_pred
results["server_pred"] = s_pred
elif args.model_output == "client":
results["client_pred"] = pred
else:
results["server_pred"] = pred
domain_embeddings = embd_model.predict(domain_encs, batch_size=args.batch_size, verbose=1)
results["domain_embds"] = domain_embeddings
dataset.save_predictions(model_args["model_path"], results)
def main_visualization():
def plot_model(clf_model, path):
embd, model = load_model(clf_model, custom_objects=models.get_custom_objects())
visualize.plot_model_as(embd, os.path.join(path, "model_embd.pdf"))
visualize.plot_model_as(model, os.path.join(path, "model_clf.pdf"))
def vis(model_name, model_path, df, df_paul, aggregation, curve):
visualize.plot_clf()
if aggregation == "user":
df = df.groupby(df.names).max()
df_paul = df_paul.groupby(df_paul.names).max()
if curve == "prc":
visualize.plot_precision_recall(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_name)
visualize.plot_precision_recall(df_paul.client_val.as_matrix(), df_paul.client_pred.as_matrix(), "paul")
elif curve == "roc":
visualize.plot_roc_curve(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_name)
visualize.plot_roc_curve(df_paul.client_val.as_matrix(), df_paul.client_pred.as_matrix(), "paul")
visualize.plot_legend()
visualize.plot_save("{}/{}_{}.png".format(model_path, aggregation, curve))
_, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.data,
args.data,
args.domain_length,
args.window)
results = dataset.load_predictions(args.model_path)
df = pd.DataFrame(data={
"names": name_val, "client_pred": results["client_pred"].flatten(),
"hits_vt": hits_vt, "hits_trusted": hits_trusted
})
df["client_val"] = np.logical_or(df.hits_vt == 1.0, df.hits_trusted >= 3)
df_user = df.groupby(df.names).max()
paul = dataset.load_predictions("results/paul/")
df_paul = pd.DataFrame(data={
"names": paul["testNames"].flatten(), "client_pred": paul["testScores"].flatten(),
"hits_vt": paul["testLabel"].flatten(), "hits_trusted": paul["testHits"].flatten()
})
df_paul["client_val"] = np.logical_or(df_paul.hits_vt == 1.0, df_paul.hits_trusted >= 3)
logger.info("plot model")
plot_model(args.clf_model, args.model_path)
# logger.info("plot training curve")
# logs = pd.read_csv(args.train_log)
# if "acc" in logs.keys():
# visualize.plot_training_curve(logs, "", "{}/client_train.png".format(args.model_path))
# elif "client_acc" in logs.keys() and "server_acc" in logs.keys():
# visualize.plot_training_curve(logs, "client_", "{}/client_train.png".format(args.model_path))
# visualize.plot_training_curve(logs, "server_", "{}/server_train.png".format(args.model_path))
# else:
# logger.warning("Error while plotting training curves")
logger.info("plot window prc")
vis(args.model_name, args.model_path, df, df_paul, "window", "prc")
logger.info("plot window roc")
vis(args.model_name, args.model_path, df, df_paul, "window", "roc")
logger.info("plot user prc")
vis(args.model_name, args.model_path, df, df_paul, "user", "prc")
logger.info("plot user roc")
vis(args.model_name, args.model_path, df, df_paul, "user", "roc")
# absolute values
visualize.plot_confusion_matrix(df.client_val.as_matrix(), df.client_pred.as_matrix().round(),
"{}/client_cov.png".format(args.model_path),
normalize=False, title="Client Confusion Matrix")
visualize.plot_confusion_matrix(df_user.client_val.as_matrix(), df_user.client_pred.as_matrix().round(),
"{}/user_cov.png".format(args.model_path),
normalize=False, title="User Confusion Matrix")
# normalized
visualize.plot_confusion_matrix(df.client_val.as_matrix(), df.client_pred.as_matrix().round(),
"{}/client_cov_norm.png".format(args.model_path),
normalize=True, title="Client Confusion Matrix")
visualize.plot_confusion_matrix(df_user.client_val.as_matrix(), df_user.client_pred.as_matrix().round(),
"{}/user_cov_norm.png".format(args.model_path),
normalize=True, title="User Confusion Matrix")
plot_embedding(args.model_path, results["domain_embds"], args.data, args.domain_length)
def plot_embedding(model_path, domain_embedding, data, domain_length):
logger.info("visualize embedding")
domain_encs, labels = dataset.load_or_generate_domains(data, domain_length)
visualize.plot_embedding(domain_embedding, labels, path="{}/embd_svd.png".format(model_path), method="svd")
def main_visualize_all():
_, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.data,
args.data,
args.domain_length,
args.window)
def load_df(path):
res = dataset.load_predictions(path)
res = pd.DataFrame(data={
"names": name_val, "client_pred": res["client_pred"].flatten(),
"hits_vt": hits_vt, "hits_trusted": hits_trusted
})
res["client_val"] = np.logical_or(res.hits_vt == 1.0, res.hits_trusted >= 3)
return res
paul = dataset.load_predictions("results/paul/")
df_paul = pd.DataFrame(data={
"names": paul["testNames"].flatten(), "client_pred": paul["testScores"].flatten(),
"hits_vt": paul["testLabel"].flatten(), "hits_trusted": paul["testHits"].flatten()
})
df_paul["client_val"] = np.logical_or(df_paul.hits_vt == 1.0, df_paul.hits_trusted >= 3)
df_paul_user = df_paul.groupby(df_paul.names).max()
logger.info("plot pr curves")
visualize.plot_clf()
for model_args in get_model_args(args):
df = load_df(model_args["model_path"])
visualize.plot_precision_recall(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_args["model_name"])
visualize.plot_precision_recall(df_paul.client_val.as_matrix(), df_paul.client_pred.as_matrix(), "paul")
visualize.plot_legend()
visualize.plot_save(f"{args.output_prefix}_window_client_prc.png")
logger.info("plot roc curves")
visualize.plot_clf()
for model_args in get_model_args(args):
df = load_df(model_args["model_path"])
visualize.plot_roc_curve(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_args["model_name"])
visualize.plot_roc_curve(df_paul.client_val.as_matrix(), df_paul.client_pred.as_matrix(), "paul")
visualize.plot_legend()
visualize.plot_save(f"{args.output_prefix}_window_client_roc.png")
logger.info("plot user pr curves")
visualize.plot_clf()
for model_args in get_model_args(args):
df = load_df(model_args["model_path"])
df = df.groupby(df.names).max()
visualize.plot_precision_recall(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_args["model_name"])
visualize.plot_precision_recall(df_paul_user.client_val.as_matrix(), df_paul_user.client_pred.as_matrix(), "paul")
visualize.plot_legend()
visualize.plot_save(f"{args.output_prefix}_user_client_prc.png")
logger.info("plot user roc curves")
visualize.plot_clf()
for model_args in get_model_args(args):
df = load_df(model_args["model_path"])
df = df.groupby(df.names).max()
visualize.plot_roc_curve(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_args["model_name"])
visualize.plot_roc_curve(df_paul_user.client_val.as_matrix(), df_paul_user.client_pred.as_matrix(), "paul")
visualize.plot_legend()
visualize.plot_save(f"{args.output_prefix}_user_client_roc.png")
import joblib
def main_beta():
_, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.data,
args.data,
args.domain_length,
args.window)
path, model_prefix = os.path.split(os.path.normpath(args.output_prefix))
try:
results = joblib.load(f"{path}/curves.joblib")
except Exception:
results = {}
results[model_prefix] = {"all": {}}
def load_df(path):
res = dataset.load_predictions(path)
res = pd.DataFrame(data={
"names": name_val, "client_pred": res["client_pred"].flatten(),
"hits_vt": hits_vt, "hits_trusted": hits_trusted
})
res["client_val"] = np.logical_or(res.hits_vt == 1.0, res.hits_trusted >= 3)
return res
paul = dataset.load_predictions("results/paul/")
df_paul = pd.DataFrame(data={
"names": paul["testNames"].flatten(), "client_pred": paul["testScores"].flatten(),
"hits_vt": paul["testLabel"].flatten(), "hits_trusted": paul["testHits"].flatten()
})
df_paul["client_val"] = np.logical_or(df_paul.hits_vt == 1.0, df_paul.hits_trusted >= 3)
df_paul_user = df_paul.groupby(df_paul.names).max()
logger.info("plot pr curves")
visualize.plot_clf()
predictions = []
for model_args in get_model_args(args):
df = load_df(model_args["model_path"])
predictions.append(df.client_pred.as_matrix())
results[model_prefix][model_args["model_name"]] = confusion_matrix(df.client_val.as_matrix(),
df.client_pred.as_matrix().round())
results[model_prefix]["all"]["window_prc"] = visualize.calc_pr_mean(df.client_val.as_matrix(), predictions)
visualize.plot_pr_mean(df.client_val.as_matrix(), predictions, "mean")
visualize.plot_pr_mean(df_paul.client_val.as_matrix(), [df_paul.client_pred.as_matrix()], "paul")
visualize.plot_legend()
visualize.plot_save(f"{args.output_prefix}_window_client_prc_all.png")
logger.info("plot roc curves")
visualize.plot_clf()
predictions = []
for model_args in get_model_args(args):
df = load_df(model_args["model_path"])
predictions.append(df.client_pred.as_matrix())
results[model_prefix][model_args["model_name"]] = confusion_matrix(df.client_val.as_matrix(),
df.client_pred.as_matrix().round())
results[model_prefix]["all"]["window_roc"] = visualize.calc_roc_mean(df.client_val.as_matrix(), predictions)
visualize.plot_roc_mean(df.client_val.as_matrix(), predictions, "mean")
visualize.plot_roc_mean(df_paul.client_val.as_matrix(), [df_paul.client_pred.as_matrix()], "paul")
visualize.plot_legend()
visualize.plot_save(f"{args.output_prefix}_window_client_roc_all.png")
logger.info("plot user pr curves")
visualize.plot_clf()
predictions = []
for model_args in get_model_args(args):
df = load_df(model_args["model_path"])
df = df.groupby(df.names).max()
predictions.append(df.client_pred.as_matrix())
results[model_prefix][model_args["model_name"]] = confusion_matrix(df.client_val.as_matrix(),
df.client_pred.as_matrix().round())
results[model_prefix]["all"]["user_prc"] = visualize.calc_pr_mean(df.client_val.as_matrix(), predictions)
visualize.plot_pr_mean(df.client_val.as_matrix(), predictions, "mean")
visualize.plot_pr_mean(df_paul_user.client_val.as_matrix(), [df_paul_user.client_pred.as_matrix()], "paul")
visualize.plot_legend()
visualize.plot_save(f"{args.output_prefix}_user_client_prc_all.png")
logger.info("plot user roc curves")
visualize.plot_clf()
predictions = []
for model_args in get_model_args(args):
df = load_df(model_args["model_path"])
df = df.groupby(df.names).max()
predictions.append(df.client_pred.as_matrix())
results[model_prefix]["all"]["user_roc"] = visualize.calc_roc_mean(df.client_val.as_matrix(), predictions)
visualize.plot_roc_mean(df.client_val.as_matrix(), predictions, "mean")
visualize.plot_roc_mean(df_paul_user.client_val.as_matrix(), [df_paul_user.client_pred.as_matrix()], "paul")
visualize.plot_legend()
visualize.plot_save(f"{args.output_prefix}_user_client_roc_all.png")
joblib.dump(results, f"{path}/curves.joblib")
plot_overall_result()
def plot_overall_result():
path, model_prefix = os.path.split(os.path.normpath(args.output_prefix))
try:
results = joblib.load(f"{path}/curves.joblib")
except Exception:
results = {}
import matplotlib.pyplot as plt
x = np.linspace(0, 1, 10000)
for vis in ["window_prc", "window_roc", "user_prc", "user_roc"]:
logger.info(f"plot {vis}")
visualize.plot_clf()
for model_key in results.keys():
ys_mean, ys_std, score = results[model_key]["all"][vis]
plt.plot(x, ys_mean, label=f"{model_key} - {score:5.4}")
plt.fill_between(x, ys_mean - ys_std, ys_mean + ys_std, alpha=0.2)
if vis.endswith("prc"):
plt.xlabel('Recall')
plt.ylabel('Precision')
else:
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.ylim([0.0, 1.0])
plt.xlim([0.0, 1.0])
visualize.plot_legend()
visualize.plot_save(f"{path}/{vis}_all.png")
for cat, models in results.items():
visualize.plot_clf()
visualize.plot_error_bars(models)
visualize.plot_legend()
visualize.plot_save(f"{path}/error_bars_{cat}.png")
def train_server_only():
logger.info(f"Create model path {args.model_path}")
exists_or_make_path(args.model_path)
logger.info(f"Use command line arguments: {args}")
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(args.data,
args.data,
args.domain_length,
args.window)
domain_tr = domain_tr.value.reshape(-1, 40)
flow_tr = flow_tr.value.reshape(-1, 3)
server_tr = server_windows_tr.value.reshape(-1)
logger.info("define callbacks")
callbacks = []
callbacks.append(ModelCheckpoint(filepath=args.clf_model,
monitor='loss',
verbose=False,
save_best_only=True))
callbacks.append(CSVLogger(args.train_log))
logger.info(f"Use early stopping: {args.stop_early}")
if args.stop_early:
callbacks.append(EarlyStopping(monitor='val_loss',
patience=5,
verbose=False))
custom_metrics = models.get_metric_functions()
model = models.get_server_model_by_params(params=PARAMS)
features = {"ipt_domains": domain_tr, "ipt_flows": flow_tr}
if args.model_output == "both":
labels = {"client": client_tr, "server": server_tr}
elif args.model_output == "client":
labels = {"client": client_tr}
elif args.model_output == "server":
labels = {"server": server_tr}
else:
raise ValueError("unknown model output")
logger.info("compile and train model")
logger.info(model.get_config())
model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy'] + custom_metrics)
model.summary()
model.fit(features, labels,
batch_size=args.batch_size,
epochs=args.epochs,
validation_split=0.2,
callbacks=callbacks)
def test_server_only():
logger.info("start test: load data")
domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.data,
args.data,
args.domain_length,
args.window)
domain_val = domain_val.value.reshape(-1, 40)
flow_val = flow_val.value.reshape(-1, 3)
domain_encs, _ = dataset.load_or_generate_domains(args.data, args.domain_length)
for model_args in get_model_args(args):
results = {}
logger.info(f"process model {model_args['model_path']}")
embd_model, clf_model = load_model(model_args["clf_model"], custom_objects=models.get_custom_objects())
pred = clf_model.predict([domain_val, flow_val],
batch_size=args.batch_size,
verbose=1)
results["server_pred"] = pred
domain_embeddings = embd_model.predict(domain_encs, batch_size=args.batch_size, verbose=1)
results["domain_embds"] = domain_embeddings
dataset.save_predictions(model_args["model_path"], results)
def main():
if "train" == args.mode:
main_train()
if "retrain" == args.mode:
main_retrain()
if "hyperband" == args.mode:
main_hyperband()
if "test" == args.mode:
main_test()
if "fancy" == args.mode:
main_visualization()
if "all_fancy" == args.mode:
main_visualize_all()
if "beta" == args.mode:
main_beta()
if "all_beta" == args.mode:
plot_overall_result()
if "server" == args.mode:
train_server_only()
if "server_test" == args.mode:
test_server_only()
if __name__ == "__main__":
main()