Browse Source

network predicts 2 by 2 classes, refactored threshold to main

master
René Knaebel 1 year ago
parent
commit
c972963a19
4 changed files with 24 additions and 22 deletions
  1. 3
    3
      dataset.py
  2. 15
    13
      main.py
  3. 3
    3
      models.py
  4. 3
    3
      scripts/make_csv_dataset.py

+ 3
- 3
dataset.py View File

@@ -86,7 +86,7 @@ def get_cisco_features(curDataLine, urlSIPDict):
86 86
         return np.zeros([numCiscoFeatures, ]).ravel()
87 87
 
88 88
 
89
-def create_dataset_from_flows(user_flow_df, char_dict, maxLen, threshold=3, windowSize=10, use_cisco_features=False):
89
+def create_dataset_from_flows(user_flow_df, char_dict, maxLen, windowSize=10, use_cisco_features=False):
90 90
     domainLists = []
91 91
     dfLists = []
92 92
     print("get chunks from user data frames")
@@ -102,12 +102,12 @@ def create_dataset_from_flows(user_flow_df, char_dict, maxLen, threshold=3, wind
102 102
     print("create training dataset")
103 103
     return create_dataset_from_lists(
104 104
         domains=domainLists, dfs=dfLists, vocab=char_dict,
105
-        maxLen=maxLen, threshold=threshold,
105
+        maxLen=maxLen,
106 106
         use_cisco_features=use_cisco_features, urlSIPDIct=dict(),
107 107
         window_size=windowSize)
108 108
 
109 109
 
110
-def create_dataset_from_lists(domains, dfs, vocab, maxLen, threshold=3,
110
+def create_dataset_from_lists(domains, dfs, vocab, maxLen,
111 111
                               use_cisco_features=False, urlSIPDIct=dict(),
112 112
                               window_size=10):
113 113
     # TODO: check for hits vs vth consistency

+ 15
- 13
main.py View File

@@ -37,23 +37,24 @@ def main():
37 37
     user_flow_df = dataset.get_user_flow_data()
38 38
 
39 39
     print("create training dataset")
40
-    (X_tr, y_tr, hits_tr, names_tr, server_tr, trusted_hits_tr) = dataset.create_dataset_from_flows(
40
+    (X_tr, hits_tr, names_tr, server_tr, trusted_hits_tr) = dataset.create_dataset_from_flows(
41 41
         user_flow_df, char_dict,
42
-        maxLen=maxLen, threshold=threshold, windowSize=windowSize)
43
-
44
-    pos_idx = np.where(y_tr == 1.0)[0]
45
-    neg_idx = np.where(y_tr == 0.0)[0]
42
+        maxLen=maxLen, windowSize=windowSize)
43
+    # make client labels discrete with 4 different values
44
+    # TODO: use trusted_hits_tr for client classification too
45
+    client_labels = np.apply_along_axis(lambda x: dataset.discretize_label(x, 3), 0, np.atleast_2d(hits_tr))
46
+    # select only 1.0 and 0.0 from training data
47
+    pos_idx = np.where(client_labels == 1.0)[0]
48
+    neg_idx = np.where(client_labels == 0.0)[0]
46 49
     idx = np.concatenate((pos_idx, neg_idx))
50
+    # select labels for prediction
51
+    client_labels = client_labels[idx]
52
+    server_labels = server_tr[idx]
47 53
 
48
-    y_tr = y_tr[idx]
49
-    hits_tr = hits_tr[idx]
50
-    names_tr = names_tr[idx]
51
-    server_tr = server_tr[idx]
52
-    trusted_hits_tr = trusted_hits_tr[idx]
54
+    # TODO: remove when features are flattened
53 55
     for i in range(len(X_tr)):
54 56
         X_tr[i] = X_tr[i][idx]
55 57
 
56
-    # TODO: WTF? I don't get it...
57 58
     shared_cnn = models.get_shared_cnn(len(char_dict) + 1, embeddingSize, maxLen,
58 59
                                        domainFeatures, kernel_size, domainFeatures, 0.5)
59 60
 
@@ -65,8 +66,9 @@ def main():
65 66
                   metrics=['accuracy'])
66 67
 
67 68
     epochNumber = 0
68
-    y_tr = np_utils.to_categorical(y_tr, 2)
69
-    model.fit(x=X_tr, y=y_tr, batch_size=128,
69
+    client_labels = np_utils.to_categorical(client_labels, 2)
70
+    server_labels = np_utils.to_categorical(server_labels, 2)
71
+    model.fit(X_tr, [client_labels, server_labels], batch_size=128,
70 72
               epochs=epochNumber + 1, shuffle=True, initial_epoch=epochNumber)  # ,
71 73
     # validation_data=(testData,testLabel))
72 74
 

+ 3
- 3
models.py View File

@@ -45,9 +45,9 @@ def get_top_cnn(cnn, numFeatures, maxLen, windowSize, domainFeatures, filters, k
45 45
     maxPool = GlobalMaxPooling1D()(cnn)
46 46
     cnnDropout = Dropout(cnnDropout)(maxPool)
47 47
     cnnDense = Dense(cnnHiddenDims, activation='relu')(cnnDropout)
48
-    cnnOutput = Dense(2, activation='softmax')(cnnDense)
48
+    cnnOutput1 = Dense(2, activation='softmax')(cnnDense)
49
+    cnnOutput2 = Dense(2, activation='softmax')(cnnDense)
49 50
 
50 51
     # We define a trainable model linking the
51 52
     # tweet inputs to the predictions
52
-    model = Model(inputs=inputList, outputs=cnnOutput)
53
-    return model
53
+    return Model(inputs=inputList, outputs=(cnnOutput1, cnnOutput2))

+ 3
- 3
scripts/make_csv_dataset.py View File

@@ -3,8 +3,8 @@
3 3
 import joblib
4 4
 import pandas as pd
5 5
 
6
-datafile = joblib.load("/mnt/projekte/pmlcluster/cisco/trainData/multipleTaskLearning/currentData.joblib")
7
-user_flows = datafile["data"]
8
-df = pd.concat(user_flows)
6
+df = joblib.load("/mnt/projekte/pmlcluster/cisco/trainData/multipleTaskLearning/currentData.joblib")
7
+df = df["data"]
8
+df = pd.concat(df)
9 9
 df.reset_index(inplace=True)
10 10
 df.to_csv("/tmp/rk/full_dataset.csv.gz", compression="gzip")

Loading…
Cancel
Save