separating logical sections into dataset, models and main.
continued initial refactoring
This commit is contained in:
68
main.py
Normal file
68
main.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import numpy as np
|
||||
from keras.utils import np_utils
|
||||
|
||||
import dataset
|
||||
import models
|
||||
|
||||
|
||||
def main():
|
||||
# parameter
|
||||
innerCNNFilters = 512
|
||||
innerCNNKernelSize = 2
|
||||
cnnDropout = 0.5
|
||||
cnnHiddenDims = 1024
|
||||
domainFeatures = 512
|
||||
flowFeatures = 3
|
||||
numCiscoFeatures = 30
|
||||
windowSize = 10
|
||||
maxLen = 40
|
||||
embeddingSize = 100
|
||||
kernel_size = 2
|
||||
drop_out = 0.5
|
||||
filters = 2
|
||||
hidden_dims = 100
|
||||
vocabSize = 40
|
||||
threshold = 3
|
||||
minFlowsPerUser = 10
|
||||
numEpochs = 100
|
||||
timesNeg = -1
|
||||
|
||||
char_dict = dataset.get_character_dict()
|
||||
user_flow_df = dataset.get_user_flow_data()
|
||||
|
||||
print("create training dataset")
|
||||
(X_tr, y_tr, hits_tr, names_tr) = dataset.create_dataset_from_flows(
|
||||
user_flow_df, char_dict,
|
||||
maxLen=maxLen, threshold=threshold, windowSize=windowSize)
|
||||
|
||||
pos_idx = np.where(y_tr == 1.0)[0]
|
||||
neg_idx = np.where(y_tr == 0.0)[0]
|
||||
|
||||
use_idx = np.concatenate((pos_idx, neg_idx))
|
||||
|
||||
y_tr = y_tr[use_idx]
|
||||
# hits_tr = hits_tr[use_idx]
|
||||
# names_tr = names_tr[use_idx]
|
||||
for i in range(len(X_tr)):
|
||||
X_tr[i] = X_tr[i][use_idx]
|
||||
|
||||
# TODO: WTF? I don't get it...
|
||||
shared_cnn = models.get_shared_cnn(len(char_dict) + 1, embeddingSize, maxLen,
|
||||
domainFeatures, kernel_size, domainFeatures, 0.5)
|
||||
|
||||
model = models.get_top_cnn(shared_cnn, flowFeatures, maxLen, windowSize, domainFeatures, filters, kernel_size,
|
||||
cnnHiddenDims, cnnDropout)
|
||||
|
||||
model.compile(optimizer='adam',
|
||||
loss='binary_crossentropy',
|
||||
metrics=['accuracy'])
|
||||
|
||||
epochNumber = 0
|
||||
y_tr = np_utils.to_categorical(y_tr, 2)
|
||||
model.fit(x=X_tr, y=y_tr, batch_size=128,
|
||||
epochs=epochNumber + 1, shuffle=True, initial_epoch=epochNumber) # ,
|
||||
# validation_data=(testData,testLabel))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user