load names with data for per-user evaluation
This commit is contained in:
parent
0db8427457
commit
3f6779fa3d
64
dataset.py
64
dataset.py
@ -106,42 +106,42 @@ def create_dataset_from_flows(user_flow_df, char_dict, max_len, window_size=10):
|
||||
results.append(pool.apply_async(get_user_chunks, (user_flow, window_size)))
|
||||
windows = [window for res in results for window in res.get()]
|
||||
logger.info("create training dataset")
|
||||
domain_tr, flow_tr, hits_tr, _, server_tr, trusted_hits_tr = create_dataset_from_lists(chunks=windows,
|
||||
vocab=char_dict,
|
||||
max_len=max_len)
|
||||
domain, flow, hits, names, server, trusted_hits = create_dataset_from_lists(chunks=windows,
|
||||
vocab=char_dict,
|
||||
max_len=max_len)
|
||||
# make client labels discrete with 4 different values
|
||||
hits_tr = np.apply_along_axis(lambda x: discretize_label(x, 3), 0, np.atleast_2d(hits_tr))
|
||||
hits = np.apply_along_axis(lambda x: discretize_label(x, 3), 0, np.atleast_2d(hits))
|
||||
# select only 1.0 and 0.0 from training data
|
||||
pos_idx = np.where(np.logical_or(hits_tr == 1.0, trusted_hits_tr >= 1.0))[0]
|
||||
neg_idx = np.where(hits_tr == 0.0)[0]
|
||||
pos_idx = np.where(np.logical_or(hits == 1.0, trusted_hits >= 1.0))[0]
|
||||
neg_idx = np.where(hits == 0.0)[0]
|
||||
idx = np.concatenate((pos_idx, neg_idx))
|
||||
# choose selected sample to train on
|
||||
domain_tr = domain_tr[idx]
|
||||
flow_tr = flow_tr[idx]
|
||||
domain = domain[idx]
|
||||
flow = flow[idx]
|
||||
client_tr = np.zeros_like(idx, float)
|
||||
client_tr[:pos_idx.shape[-1]] = 1.0
|
||||
server_tr = server_tr[idx]
|
||||
server = server[idx]
|
||||
names = names[idx]
|
||||
|
||||
# client_tr = np_utils.to_categorical(client_tr, 2)
|
||||
|
||||
return domain_tr, flow_tr, client_tr, server_tr
|
||||
return domain, flow, names, client_tr, server
|
||||
|
||||
|
||||
def store_h5dataset(path, domain_tr, flow_tr, client_tr, server_tr):
|
||||
def store_h5dataset(path, domain, flow, name, client, server):
|
||||
f = h5py.File(path, "w")
|
||||
domain_tr = domain_tr.astype(np.int8)
|
||||
f.create_dataset("domain", data=domain_tr)
|
||||
f.create_dataset("flow", data=flow_tr)
|
||||
server_tr = server_tr.astype(np.bool)
|
||||
client_tr = client_tr.astype(np.bool)
|
||||
f.create_dataset("client", data=client_tr)
|
||||
f.create_dataset("server", data=server_tr)
|
||||
domain = domain.astype(np.int8)
|
||||
f.create_dataset("domain", data=domain)
|
||||
f.create_dataset("flow", data=flow)
|
||||
f.create_dataset("name", data=name)
|
||||
server = server.astype(np.bool)
|
||||
client = client.astype(np.bool)
|
||||
f.create_dataset("client", data=client)
|
||||
f.create_dataset("server", data=server)
|
||||
f.close()
|
||||
|
||||
|
||||
def load_h5dataset(path):
|
||||
data = h5py.File(path, "r")
|
||||
return data["domain"], data["flow"], data["client"], data["server"]
|
||||
return data["domain"], data["flow"], data["name"], data["client"], data["server"]
|
||||
|
||||
|
||||
def create_dataset_from_lists(chunks, vocab, max_len):
|
||||
@ -166,7 +166,9 @@ def create_dataset_from_lists(chunks, vocab, max_len):
|
||||
logger.info(" select hits")
|
||||
hits = np.max(np.stack(map(lambda f: f.virusTotalHits, chunks)), axis=1)
|
||||
logger.info(" select names")
|
||||
names = np.unique(np.stack(map(lambda f: f.user_hash, chunks)))
|
||||
names = np.stack(map(lambda f: f.user_hash, chunks))
|
||||
assert (names[:, :1].repeat(10, axis=1) == names).all()
|
||||
names = names[:, 0]
|
||||
logger.info(" select servers")
|
||||
servers = np.stack(map(lambda f: f.serverLabel, chunks))
|
||||
logger.info(" select trusted hits")
|
||||
@ -222,15 +224,29 @@ def load_or_generate_h5data(h5data, train_data, domain_length, window_size):
|
||||
logger.info("h5 data not found - load csv file")
|
||||
user_flow_df = get_user_flow_data(train_data)
|
||||
logger.info("create training dataset")
|
||||
domain_tr, flow_tr, client_tr, server_tr = create_dataset_from_flows(user_flow_df, char_dict,
|
||||
domain, flow, names, client, server = create_dataset_from_flows(user_flow_df, char_dict,
|
||||
max_len=domain_length,
|
||||
window_size=window_size)
|
||||
logger.info("store training dataset as h5 file")
|
||||
store_h5dataset(h5data, domain_tr, flow_tr, client_tr, server_tr)
|
||||
store_h5dataset(h5data, domain, flow, names, client, server)
|
||||
logger.info("load h5 dataset")
|
||||
return load_h5dataset(h5data)
|
||||
|
||||
|
||||
def generate_names(train_data, window_size):
|
||||
user_flow_df = get_user_flow_data(train_data)
|
||||
with Pool() as pool:
|
||||
results = []
|
||||
for user_flow in tqdm(get_flow_per_user(user_flow_df),
|
||||
total=len(user_flow_df['user_hash'].unique().tolist())):
|
||||
results.append(pool.apply_async(get_user_chunks, (user_flow, window_size)))
|
||||
windows = [window for res in results for window in res.get()]
|
||||
names = np.stack(map(lambda f: f.user_hash, windows))
|
||||
names = names[:, 0]
|
||||
|
||||
return names
|
||||
|
||||
|
||||
def load_or_generate_domains(train_data, domain_length):
|
||||
fn = f"{train_data}_domains.gz"
|
||||
char_dict = get_character_dict()
|
||||
|
52
main.py
52
main.py
@ -106,8 +106,8 @@ def main_hyperband():
|
||||
}
|
||||
|
||||
logger.info("create training dataset")
|
||||
domain_tr, flow_tr, client_tr, server_tr = load_or_generate_h5data(args.train_h5data, args.train_data,
|
||||
args.domain_length, args.window)
|
||||
domain_tr, flow_tr, name_tr, client_tr, server_tr = load_or_generate_h5data(args.train_h5data, args.train_data,
|
||||
args.domain_length, args.window)
|
||||
hp = hyperband.Hyperband(params,
|
||||
[domain_tr, flow_tr],
|
||||
[client_tr, server_tr])
|
||||
@ -120,8 +120,10 @@ def main_train(param=None):
|
||||
exists_or_make_path(args.model_path)
|
||||
logger.info(f"Use command line arguments: {args}")
|
||||
|
||||
domain_tr, flow_tr, client_tr, server_windows_tr = load_or_generate_h5data(args.train_h5data, args.train_data,
|
||||
args.domain_length, args.window)
|
||||
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = load_or_generate_h5data(args.train_h5data,
|
||||
args.train_data,
|
||||
args.domain_length,
|
||||
args.window)
|
||||
|
||||
if not param:
|
||||
param = PARAMS
|
||||
@ -169,6 +171,8 @@ def main_train(param=None):
|
||||
labels = [client_tr]
|
||||
elif args.model_output == "server":
|
||||
labels = [server_tr]
|
||||
else:
|
||||
raise ValueError("unknown model output")
|
||||
|
||||
model.fit([domain_tr, flow_tr],
|
||||
labels,
|
||||
@ -184,8 +188,10 @@ def main_train(param=None):
|
||||
|
||||
def main_test():
|
||||
logger.info("start test: load data")
|
||||
domain_val, flow_val, client_val, server_val = load_or_generate_h5data(args.test_h5data, args.test_data,
|
||||
args.domain_length, args.window)
|
||||
domain_val, flow_val, name_val, client_val, server_val = load_or_generate_h5data(args.test_h5data,
|
||||
args.test_data,
|
||||
args.domain_length,
|
||||
args.window)
|
||||
domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length)
|
||||
|
||||
for model_args in get_model_args(args):
|
||||
@ -212,8 +218,10 @@ def main_test():
|
||||
|
||||
|
||||
def main_visualization():
|
||||
domain_val, flow_val, client_val, server_val = load_or_generate_h5data(args.test_h5data, args.test_data,
|
||||
args.domain_length, args.window)
|
||||
domain_val, flow_val, name_val, client_val, server_val = load_or_generate_h5data(args.test_h5data,
|
||||
args.test_data,
|
||||
args.domain_length,
|
||||
args.window)
|
||||
# client_val, server_val = client_val.value, server_val.value
|
||||
client_val = client_val.value
|
||||
|
||||
@ -246,6 +254,22 @@ def main_visualization():
|
||||
visualize.plot_roc_curve(client_val, client_pred)
|
||||
visualize.plot_save("{}/client_roc.png".format(args.model_path))
|
||||
# visualize.plot_roc_curve(server_val, server_pred, "{}/server_roc.png".format(args.model_path))
|
||||
|
||||
print(f"names {name_val.shape} vals {client_val.shape} preds {client_pred.shape}")
|
||||
|
||||
df_val = pd.DataFrame(data={"names": name_val, "client_val": client_val})
|
||||
user_vals = df_val.groupby(df_val.names).max().client_val.as_matrix().astype(float)
|
||||
df_pred = pd.DataFrame(data={"names": name_val, "client_val": client_val})
|
||||
user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float)
|
||||
|
||||
visualize.plot_clf()
|
||||
visualize.plot_precision_recall(user_vals, user_preds)
|
||||
visualize.plot_save("{}/user_client_prc.png".format(args.model_path))
|
||||
|
||||
visualize.plot_clf()
|
||||
visualize.plot_roc_curve(user_vals, user_preds)
|
||||
visualize.plot_save("{}/user_client_roc.png".format(args.model_path))
|
||||
|
||||
visualize.plot_confusion_matrix(client_val, client_pred.flatten().round(),
|
||||
"{}/client_cov.png".format(args.model_path),
|
||||
normalize=False, title="Client Confusion Matrix")
|
||||
@ -259,8 +283,10 @@ def main_visualization():
|
||||
|
||||
|
||||
def main_visualize_all():
|
||||
domain_val, flow_val, client_val, server_val = load_or_generate_h5data(args.test_h5data, args.test_data,
|
||||
args.domain_length, args.window)
|
||||
domain_val, flow_val, name_val, client_val, server_val = load_or_generate_h5data(args.test_h5data,
|
||||
args.test_data,
|
||||
args.domain_length,
|
||||
args.window)
|
||||
logger.info("plot pr curves")
|
||||
visualize.plot_clf()
|
||||
for model_args in get_model_args(args):
|
||||
@ -282,9 +308,9 @@ def main_data():
|
||||
char_dict = dataset.get_character_dict()
|
||||
user_flow_df = dataset.get_user_flow_data(args.train_data)
|
||||
logger.info("create training dataset")
|
||||
domain_tr, flow_tr, client_tr, server_tr, _ = dataset.create_dataset_from_flows(user_flow_df, char_dict,
|
||||
max_len=args.domain_length,
|
||||
window_size=args.window)
|
||||
domain_tr, flow_tr, name_tr, client_tr, server_tr = dataset.create_dataset_from_flows(user_flow_df, char_dict,
|
||||
max_len=args.domain_length,
|
||||
window_size=args.window)
|
||||
print(f"domain shape {domain_tr.shape}")
|
||||
print(f"flow shape {flow_tr.shape}")
|
||||
print(f"client shape {client_tr.shape}")
|
||||
|
25
visualize.py
25
visualize.py
@ -9,7 +9,7 @@ from sklearn.metrics import (
|
||||
)
|
||||
|
||||
|
||||
def scores(y_true, y_pred):
|
||||
def scores(y_true):
|
||||
for (path, dirnames, fnames) in os.walk("results/"):
|
||||
for f in fnames:
|
||||
if path[-1] == "1" and f.endswith("npy"):
|
||||
@ -48,7 +48,7 @@ def plot_precision_recall(y, y_pred, label=""):
|
||||
y = y.flatten()
|
||||
y_pred = y_pred.flatten()
|
||||
precision, recall, thresholds = precision_recall_curve(y, y_pred)
|
||||
decreasing_max_precision = np.maximum.accumulate(precision)[::-1]
|
||||
# decreasing_max_precision = np.maximum.accumulate(precision)[::-1]
|
||||
|
||||
# fig, ax = plt.subplots(1, 1)
|
||||
# ax.hold(True)
|
||||
@ -58,15 +58,15 @@ def plot_precision_recall(y, y_pred, label=""):
|
||||
plt.ylabel('Precision')
|
||||
|
||||
|
||||
def plot_precision_recall_curves(y, y_pred):
|
||||
y = y.flatten()
|
||||
y_pred = y_pred.flatten()
|
||||
precision, recall, thresholds = precision_recall_curve(y, y_pred)
|
||||
|
||||
plt.plot(recall, label="Recall")
|
||||
plt.plot(precision, label="Precision")
|
||||
plt.xlabel('Threshold')
|
||||
plt.ylabel('Score')
|
||||
# def plot_precision_recall_curves(y, y_pred):
|
||||
# y = y.flatten()
|
||||
# y_pred = y_pred.flatten()
|
||||
# precision, recall, thresholds = precision_recall_curve(y, y_pred)
|
||||
#
|
||||
# plt.plot(recall, label="Recall")
|
||||
# plt.plot(precision, label="Precision")
|
||||
# plt.xlabel('Threshold')
|
||||
# plt.ylabel('Score')
|
||||
|
||||
|
||||
def score_model(y, prediction):
|
||||
@ -87,8 +87,7 @@ def plot_roc_curve(mask, prediction, label=""):
|
||||
y_pred = prediction.flatten()
|
||||
fpr, tpr, thresholds = roc_curve(y, y_pred)
|
||||
roc_auc = auc(fpr, tpr)
|
||||
plt.plot(fpr, tpr, label=label)
|
||||
print("roc_auc", roc_auc)
|
||||
plt.plot(fpr, tpr, label=f"{label} - {roc_auc}")
|
||||
|
||||
|
||||
def plot_confusion_matrix(y_true, y_pred, path,
|
||||
|
Loading…
x
Reference in New Issue
Block a user