From a70d1cb03a6bcef6eb4af3be4287016e15bd3ad1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20Knaebel?= Date: Wed, 5 Jul 2017 18:37:29 +0200 Subject: [PATCH] fix: replace X_tr by its elements; choose selected samples for training data too --- dataset.py | 4 ++-- main.py | 8 +++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/dataset.py b/dataset.py index e8abb7e..2f6408a 100644 --- a/dataset.py +++ b/dataset.py @@ -169,8 +169,8 @@ def create_dataset_from_lists(domains, features, vocab, max_len, names.append(np.unique(features[i]['user_hash'])) servers.append(np.max(features[i]['serverLabel'])) trusted_hits.append(np.max(features[i]['trustedHits'])) - X = [domain_features, flow_features] - return X, np.array(hits), np.array(names), np.array(servers), np.array(trusted_hits) + return (domain_features, flow_features, + np.array(hits), np.array(names), np.array(servers), np.array(trusted_hits)) def discretize_label(values, threshold): diff --git a/main.py b/main.py index c5e5862..1fc2aad 100644 --- a/main.py +++ b/main.py @@ -92,7 +92,7 @@ def main(): user_flow_df = dataset.get_user_flow_data() print("create training dataset") - (X_tr, hits_tr, names_tr, server_tr, trusted_hits_tr) = dataset.create_dataset_from_flows( + domain_tr, flow_tr, hits_tr, names_tr, server_tr, trusted_hits_tr = dataset.create_dataset_from_flows( user_flow_df, char_dict, max_len=args.domain_length, window_size=args.window) # make client labels discrete with 4 different values @@ -102,7 +102,9 @@ def main(): pos_idx = np.where(client_labels == 1.0)[0] neg_idx = np.where(client_labels == 0.0)[0] idx = np.concatenate((pos_idx, neg_idx)) - # select labels for prediction + # choose selected sample to train on + domain_tr = domain_tr[idx] + flow_tr = flow_tr[idx] client_labels = client_labels[idx] server_labels = server_tr[idx] @@ -121,7 +123,7 @@ def main(): client_labels = np_utils.to_categorical(client_labels, 2) server_labels = np_utils.to_categorical(server_labels, 2) - model.fit(X_tr, + model.fit([domain_tr, flow_tr], [client_labels, server_labels], batch_size=args.batch_size, epochs=args.epochs,