added pauls extensions for new predictions
This commit is contained in:
15
main.py
15
main.py
@@ -37,20 +37,21 @@ def main():
|
||||
user_flow_df = dataset.get_user_flow_data()
|
||||
|
||||
print("create training dataset")
|
||||
(X_tr, y_tr, hits_tr, names_tr) = dataset.create_dataset_from_flows(
|
||||
(X_tr, y_tr, hits_tr, names_tr, server_tr, trusted_hits_tr) = dataset.create_dataset_from_flows(
|
||||
user_flow_df, char_dict,
|
||||
maxLen=maxLen, threshold=threshold, windowSize=windowSize)
|
||||
|
||||
pos_idx = np.where(y_tr == 1.0)[0]
|
||||
neg_idx = np.where(y_tr == 0.0)[0]
|
||||
idx = np.concatenate((pos_idx, neg_idx))
|
||||
|
||||
use_idx = np.concatenate((pos_idx, neg_idx))
|
||||
|
||||
y_tr = y_tr[use_idx]
|
||||
# hits_tr = hits_tr[use_idx]
|
||||
# names_tr = names_tr[use_idx]
|
||||
y_tr = y_tr[idx]
|
||||
hits_tr = hits_tr[idx]
|
||||
names_tr = names_tr[idx]
|
||||
server_tr = server_tr[idx]
|
||||
trusted_hits_tr = trusted_hits_tr[idx]
|
||||
for i in range(len(X_tr)):
|
||||
X_tr[i] = X_tr[i][use_idx]
|
||||
X_tr[i] = X_tr[i][idx]
|
||||
|
||||
# TODO: WTF? I don't get it...
|
||||
shared_cnn = models.get_shared_cnn(len(char_dict) + 1, embeddingSize, maxLen,
|
||||
|
Reference in New Issue
Block a user