Browse Source

added some argparse arguments to main

René Knaebel 1 year ago
parent
commit
5f8a760a0c
2 changed files with 63 additions and 8 deletions
  1. 61
    6
      main.py
  2. 2
    2
      models.py

+ 61
- 6
main.py View File

@@ -1,9 +1,63 @@
1
+import argparse
2
+
1 3
 import numpy as np
2 4
 from keras.utils import np_utils
3 5
 
4 6
 import dataset
5 7
 import models
6 8
 
9
+parser = argparse.ArgumentParser()
10
+
11
+parser.add_argument("--modes", action="store", dest="modes", nargs="+")
12
+
13
+# parser.add_argument("--data", action="store", dest="data",
14
+#                     default="data/")
15
+#
16
+# parser.add_argument("--h5data", action="store", dest="h5data",
17
+#                     default="")
18
+#
19
+# parser.add_argument("--model", action="store", dest="model",
20
+#                     default="model_x")
21
+#
22
+# parser.add_argument("--pred", action="store", dest="pred",
23
+#                     default="")
24
+#
25
+# parser.add_argument("--type", action="store", dest="model_type",
26
+#                     default="simple_conv")
27
+#
28
+parser.add_argument("--batch", action="store", dest="batch_size",
29
+                    default=64, type=int)
30
+
31
+parser.add_argument("--epochs", action="store", dest="epochs",
32
+                    default=10, type=int)
33
+
34
+# parser.add_argument("--samples", action="store", dest="samples",
35
+#                     default=100000, type=int)
36
+#
37
+# parser.add_argument("--samples_val", action="store", dest="samples_val",
38
+#                     default=10000, type=int)
39
+#
40
+# parser.add_argument("--area", action="store", dest="area_size",
41
+#                     default=25, type=int)
42
+#
43
+# parser.add_argument("--queue", action="store", dest="queue_size",
44
+#                     default=50, type=int)
45
+#
46
+# parser.add_argument("--p", action="store", dest="p_train",
47
+#                     default=0.5, type=float)
48
+#
49
+# parser.add_argument("--p_val", action="store", dest="p_val",
50
+#                     default=0.01, type=float)
51
+#
52
+# parser.add_argument("--gpu", action="store", dest="gpu",
53
+#                     default=0, type=int)
54
+#
55
+# parser.add_argument("--tmp", action="store_true", dest="tmp")
56
+#
57
+# parser.add_argument("--test", action="store", dest="test_image",
58
+#                     default=6, choices=range(7), type=int)
59
+
60
+args = parser.parse_args()
7 61
 
8 62
 # config = tf.ConfigProto(log_device_placement=True)
9 63
 # config.gpu_options.per_process_gpu_memory_fraction = 0.5
@@ -31,7 +85,6 @@ def main():
31 85
     threshold = 3
32 86
     minFlowsPerUser = 10
33 87
     numEpochs = 100
34
-    timesNeg = -1
35 88
 
36 89
     char_dict = dataset.get_character_dict()
37 90
     user_flow_df = dataset.get_user_flow_data()
@@ -39,7 +92,7 @@ def main():
39 92
     print("create training dataset")
40 93
     (X_tr, hits_tr, names_tr, server_tr, trusted_hits_tr) = dataset.create_dataset_from_flows(
41 94
         user_flow_df, char_dict,
42
-        maxLen=maxLen, windowSize=windowSize)
95
+        max_len=maxLen, window_size=windowSize)
43 96
     # make client labels discrete with 4 different values
44 97
     # TODO: use trusted_hits_tr for client classification too
45 98
     client_labels = np.apply_along_axis(lambda x: dataset.discretize_label(x, 3), 0, np.atleast_2d(hits_tr))
@@ -65,12 +118,14 @@ def main():
65 118
                   loss='binary_crossentropy',
66 119
                   metrics=['accuracy'])
67 120
 
68
-    epochNumber = 0
69 121
     client_labels = np_utils.to_categorical(client_labels, 2)
70 122
     server_labels = np_utils.to_categorical(server_labels, 2)
71
-    model.fit(X_tr, [client_labels, server_labels], batch_size=128,
72
-              epochs=epochNumber + 1, shuffle=True, initial_epoch=epochNumber)  # ,
73
-    # validation_data=(testData,testLabel))
123
+    model.fit(X_tr,
124
+              [client_labels, server_labels],
125
+              batch_size=args.batch_size,
126
+              epochs=args.epochs,
127
+              shuffle=True)
128
+    # TODO: for validation we use future data -> validation_data=(testData,testLabel))
74 129
 
75 130
 
76 131
 if __name__ == "__main__":

+ 2
- 2
models.py View File

@@ -45,8 +45,8 @@ def get_top_cnn(cnn, numFeatures, maxLen, windowSize, domainFeatures, filters, k
45 45
     maxPool = GlobalMaxPooling1D()(cnn)
46 46
     cnnDropout = Dropout(cnnDropout)(maxPool)
47 47
     cnnDense = Dense(cnnHiddenDims, activation='relu')(cnnDropout)
48
-    cnnOutput1 = Dense(2, activation='softmax')(cnnDense)
49
-    cnnOutput2 = Dense(2, activation='softmax')(cnnDense)
48
+    cnnOutput1 = Dense(2, activation='softmax', name="client")(cnnDense)
49
+    cnnOutput2 = Dense(2, activation='softmax', name="server")(cnnDense)
50 50
 
51 51
     # We define a trainable model linking the
52 52
     # tweet inputs to the predictions

Loading…
Cancel
Save