network predicts 2 by 2 classes, refactored threshold to main
This commit is contained in:
28
main.py
28
main.py
@@ -37,23 +37,24 @@ def main():
|
||||
user_flow_df = dataset.get_user_flow_data()
|
||||
|
||||
print("create training dataset")
|
||||
(X_tr, y_tr, hits_tr, names_tr, server_tr, trusted_hits_tr) = dataset.create_dataset_from_flows(
|
||||
(X_tr, hits_tr, names_tr, server_tr, trusted_hits_tr) = dataset.create_dataset_from_flows(
|
||||
user_flow_df, char_dict,
|
||||
maxLen=maxLen, threshold=threshold, windowSize=windowSize)
|
||||
|
||||
pos_idx = np.where(y_tr == 1.0)[0]
|
||||
neg_idx = np.where(y_tr == 0.0)[0]
|
||||
maxLen=maxLen, windowSize=windowSize)
|
||||
# make client labels discrete with 4 different values
|
||||
# TODO: use trusted_hits_tr for client classification too
|
||||
client_labels = np.apply_along_axis(lambda x: dataset.discretize_label(x, 3), 0, np.atleast_2d(hits_tr))
|
||||
# select only 1.0 and 0.0 from training data
|
||||
pos_idx = np.where(client_labels == 1.0)[0]
|
||||
neg_idx = np.where(client_labels == 0.0)[0]
|
||||
idx = np.concatenate((pos_idx, neg_idx))
|
||||
# select labels for prediction
|
||||
client_labels = client_labels[idx]
|
||||
server_labels = server_tr[idx]
|
||||
|
||||
y_tr = y_tr[idx]
|
||||
hits_tr = hits_tr[idx]
|
||||
names_tr = names_tr[idx]
|
||||
server_tr = server_tr[idx]
|
||||
trusted_hits_tr = trusted_hits_tr[idx]
|
||||
# TODO: remove when features are flattened
|
||||
for i in range(len(X_tr)):
|
||||
X_tr[i] = X_tr[i][idx]
|
||||
|
||||
# TODO: WTF? I don't get it...
|
||||
shared_cnn = models.get_shared_cnn(len(char_dict) + 1, embeddingSize, maxLen,
|
||||
domainFeatures, kernel_size, domainFeatures, 0.5)
|
||||
|
||||
@@ -65,8 +66,9 @@ def main():
|
||||
metrics=['accuracy'])
|
||||
|
||||
epochNumber = 0
|
||||
y_tr = np_utils.to_categorical(y_tr, 2)
|
||||
model.fit(x=X_tr, y=y_tr, batch_size=128,
|
||||
client_labels = np_utils.to_categorical(client_labels, 2)
|
||||
server_labels = np_utils.to_categorical(server_labels, 2)
|
||||
model.fit(X_tr, [client_labels, server_labels], batch_size=128,
|
||||
epochs=epochNumber + 1, shuffle=True, initial_epoch=epochNumber) # ,
|
||||
# validation_data=(testData,testLabel))
|
||||
|
||||
|
Reference in New Issue
Block a user