fix test predictions depending on model output specification

This commit is contained in:
René Knaebel 2017-08-03 07:51:58 +02:00
parent 8ac195ba6f
commit 787f43b328
3 changed files with 38 additions and 19 deletions

14
main.py
View File

@ -183,9 +183,17 @@ def main_test():
domain_val, flow_val, client_val, server_val = load_or_generate_h5data(args.test_h5data, args.test_data, domain_val, flow_val, client_val, server_val = load_or_generate_h5data(args.test_h5data, args.test_data,
args.domain_length, args.window) args.domain_length, args.window)
clf = load_model(args.clf_model, custom_objects=models.get_metrics()) clf = load_model(args.clf_model, custom_objects=models.get_metrics())
c_pred, s_pred = clf.predict([domain_val, flow_val], pred = clf.predict([domain_val, flow_val],
batch_size=args.batch_size, batch_size=args.batch_size,
verbose=1) verbose=1)
if args.model_output == "both":
c_pred, s_pred = pred
elif args.model_output == "client":
c_pred = pred
s_pred = np.array()
else:
c_pred = np.array()
s_pred = pred
dataset.save_predictions(args.future_prediction, c_pred, s_pred) dataset.save_predictions(args.future_prediction, c_pred, s_pred)

26
run.sh
View File

@ -1,49 +1,51 @@
python3 main.py --mode train \ #!/usr/bin/env bash
python main.py --mode train \
--train /tmp/rk/currentData.csv \ --train /tmp/rk/currentData.csv \
--model /tmp/rk/results/simple_both \ --model /tmp/rk/results/simple_both \
--epochs 25 \ --epochs 25 \
--hidden_char_dims 64 \ --hidden_char_dims 128 \
--domain_embd 32 \ --domain_embd 32 \
--batch 256 \ --batch 256 \
--balanced_weights \ --balanced_weights \
--model_output both --model_output both
python3 main.py --mode test --batch 512 --model /tmp/rk/results/simple_both --test /tmp/rk/futureData.csv python main.py --mode test --batch 512 --model /tmp/rk/results/simple_both --test /tmp/rk/futureData.csv --model_output both
python3 main.py --mode train \ python main.py --mode train \
--train /tmp/rk/currentData.csv \ --train /tmp/rk/currentData.csv \
--model /tmp/rk/results/simple_client \ --model /tmp/rk/results/simple_client \
--epochs 25 \ --epochs 25 \
--hidden_char_dims 64 \ --hidden_char_dims 128 \
--domain_embd 32 \ --domain_embd 32 \
--batch 256 \ --batch 256 \
--balanced_weights \ --balanced_weights \
--model_output client --model_output client
python3 main.py --mode test --batch 512 --model /tmp/rk/results/simple_client --test /tmp/rk/futureData.csv python main.py --mode test --batch 512 --model /tmp/rk/results/simple_client --test /tmp/rk/futureData.csv --model_output client
python3 main.py --mode train \ python main.py --mode train \
--train /tmp/rk/currentData.csv \ --train /tmp/rk/currentData.csv \
--model /tmp/rk/results/simple_new_both \ --model /tmp/rk/results/simple_new_both \
--epochs 25 \ --epochs 25 \
--hidden_char_dims 64 \ --hidden_char_dims 128 \
--domain_embd 32 \ --domain_embd 32 \
--batch 256 \ --batch 256 \
--balanced_weights \ --balanced_weights \
--model_output both \ --model_output both \
--new_model --new_model
python3 main.py --mode test --batch 512 --model /tmp/rk/results/simple_new_both --test /tmp/rk/futureData.csv python main.py --mode test --batch 512 --model /tmp/rk/results/simple_new_both --test /tmp/rk/futureData.csv --model_output both
python3 main.py --mode train \ python main.py --mode train \
--train /tmp/rk/currentData.csv \ --train /tmp/rk/currentData.csv \
--model /tmp/rk/results/simple_new_client \ --model /tmp/rk/results/simple_new_client \
--epochs 25 \ --epochs 25 \
--hidden_char_dims 64 \ --hidden_char_dims 128 \
--domain_embd 32 \ --domain_embd 32 \
--batch 256 \ --batch 256 \
--balanced_weights \ --balanced_weights \
--model_output client \ --model_output client \
--new_model --new_model
python3 main.py --mode test --batch 512 --model /tmp/rk/results/simple_new_client --test /tmp/rk/futureData.csv python main.py --mode test --batch 512 --model /tmp/rk/results/simple_new_client --test /tmp/rk/futureData.csv --model_output client

View File

@ -1,17 +1,26 @@
#!/usr/bin/python2 #!/usr/bin/python2
import sys
import joblib import joblib
import numpy as np import numpy as np
import pandas as pd import pandas as pd
df = joblib.load("/mnt/projekte/pmlcluster/cisco/trainData/multipleTaskLearning/currentData.joblib") fn = sys.argv[1]
df = joblib.load("/mnt/projekte/pmlcluster/cisco/trainData/multipleTaskLearning/{}.joblib".format(fn))
df = pd.concat(df["data"]) df = pd.concat(df["data"])
df.reset_index(inplace=True) df.reset_index(inplace=True)
df.dropna(axis=0, how="any", inplace=True) df.dropna(axis=0, how="any", inplace=True)
df[["duration", "bytes_down", "bytes_up"]] = df[["duration", "bytes_down", "bytes_up"]].astype(np.int)
df[["domain", "server_ip"]] = df[["domain", "server_ip"]].astype(str) df.serverLabel = pd.to_numeric(df.serverLabel, errors='coerce')
df.duration = pd.to_numeric(df.duration, errors='coerce')
df.bytes_down = pd.to_numeric(df.bytes_down, errors='coerce')
df.bytes_up = pd.to_numeric(df.bytes_up, errors='coerce')
df.http_method = df.http_method.astype("category")
df.serverLabel = df.serverLabel.astype(np.bool) df.serverLabel = df.serverLabel.astype(np.bool)
df.virusTotalHits = df.virusTotalHits.astype(np.int8) df.virusTotalHits = df.virusTotalHits.astype(np.int8)
df.trustedHits = df.trustedHits.astype(np.int8) df.trustedHits = df.trustedHits.astype(np.int8)
df.to_csv("/tmp/rk/full_future_dataset.csv.gz", compression="gzip") df.to_csv("/tmp/rk/{}.csv".format(fn), encoding="utf-8")