fix test predictions depending on model output specification
This commit is contained in:
		
							
								
								
									
										14
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										14
									
								
								main.py
									
									
									
									
									
								
							@@ -183,9 +183,17 @@ def main_test():
 | 
			
		||||
    domain_val, flow_val, client_val, server_val = load_or_generate_h5data(args.test_h5data, args.test_data,
 | 
			
		||||
                                                                           args.domain_length, args.window)
 | 
			
		||||
    clf = load_model(args.clf_model, custom_objects=models.get_metrics())
 | 
			
		||||
    c_pred, s_pred = clf.predict([domain_val, flow_val],
 | 
			
		||||
                                 batch_size=args.batch_size,
 | 
			
		||||
                                 verbose=1)
 | 
			
		||||
    pred = clf.predict([domain_val, flow_val],
 | 
			
		||||
                       batch_size=args.batch_size,
 | 
			
		||||
                       verbose=1)
 | 
			
		||||
    if args.model_output == "both":
 | 
			
		||||
        c_pred, s_pred = pred
 | 
			
		||||
    elif args.model_output == "client":
 | 
			
		||||
        c_pred = pred
 | 
			
		||||
        s_pred = np.array()
 | 
			
		||||
    else:
 | 
			
		||||
        c_pred = np.array()
 | 
			
		||||
        s_pred = pred
 | 
			
		||||
    dataset.save_predictions(args.future_prediction, c_pred, s_pred)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										26
									
								
								run.sh
									
									
									
									
									
								
							
							
						
						
									
										26
									
								
								run.sh
									
									
									
									
									
								
							@@ -1,49 +1,51 @@
 | 
			
		||||
python3 main.py --mode train \
 | 
			
		||||
#!/usr/bin/env bash
 | 
			
		||||
 | 
			
		||||
python main.py --mode train \
 | 
			
		||||
                --train /tmp/rk/currentData.csv \
 | 
			
		||||
                --model /tmp/rk/results/simple_both \
 | 
			
		||||
                --epochs 25 \
 | 
			
		||||
                --hidden_char_dims 64 \
 | 
			
		||||
                --hidden_char_dims 128 \
 | 
			
		||||
                --domain_embd 32 \
 | 
			
		||||
                --batch 256 \
 | 
			
		||||
                --balanced_weights \
 | 
			
		||||
                --model_output both
 | 
			
		||||
 | 
			
		||||
python3 main.py --mode test --batch 512 --model /tmp/rk/results/simple_both --test /tmp/rk/futureData.csv
 | 
			
		||||
python main.py --mode test --batch 512 --model /tmp/rk/results/simple_both --test /tmp/rk/futureData.csv --model_output both
 | 
			
		||||
 | 
			
		||||
python3 main.py --mode train \
 | 
			
		||||
python main.py --mode train \
 | 
			
		||||
                --train /tmp/rk/currentData.csv \
 | 
			
		||||
                --model /tmp/rk/results/simple_client \
 | 
			
		||||
                --epochs 25 \
 | 
			
		||||
                --hidden_char_dims 64 \
 | 
			
		||||
                --hidden_char_dims 128 \
 | 
			
		||||
                --domain_embd 32 \
 | 
			
		||||
                --batch 256 \
 | 
			
		||||
                --balanced_weights \
 | 
			
		||||
                --model_output client
 | 
			
		||||
 | 
			
		||||
python3 main.py --mode test --batch 512 --model /tmp/rk/results/simple_client --test /tmp/rk/futureData.csv
 | 
			
		||||
python main.py --mode test --batch 512 --model /tmp/rk/results/simple_client --test /tmp/rk/futureData.csv --model_output client
 | 
			
		||||
 | 
			
		||||
python3 main.py --mode train \
 | 
			
		||||
python main.py --mode train \
 | 
			
		||||
                --train /tmp/rk/currentData.csv \
 | 
			
		||||
                --model /tmp/rk/results/simple_new_both \
 | 
			
		||||
                --epochs 25 \
 | 
			
		||||
                --hidden_char_dims 64 \
 | 
			
		||||
                --hidden_char_dims 128 \
 | 
			
		||||
                --domain_embd 32 \
 | 
			
		||||
                --batch 256 \
 | 
			
		||||
                --balanced_weights \
 | 
			
		||||
                --model_output both \
 | 
			
		||||
                --new_model
 | 
			
		||||
 | 
			
		||||
python3 main.py --mode test --batch 512 --model /tmp/rk/results/simple_new_both --test /tmp/rk/futureData.csv
 | 
			
		||||
python main.py --mode test --batch 512 --model /tmp/rk/results/simple_new_both --test /tmp/rk/futureData.csv --model_output both
 | 
			
		||||
 | 
			
		||||
python3 main.py --mode train \
 | 
			
		||||
python main.py --mode train \
 | 
			
		||||
                --train /tmp/rk/currentData.csv \
 | 
			
		||||
                --model /tmp/rk/results/simple_new_client \
 | 
			
		||||
                --epochs 25 \
 | 
			
		||||
                --hidden_char_dims 64 \
 | 
			
		||||
                --hidden_char_dims 128 \
 | 
			
		||||
                --domain_embd 32 \
 | 
			
		||||
                --batch 256 \
 | 
			
		||||
                --balanced_weights \
 | 
			
		||||
                --model_output client \
 | 
			
		||||
                --new_model
 | 
			
		||||
 | 
			
		||||
python3 main.py --mode test --batch 512 --model /tmp/rk/results/simple_new_client --test /tmp/rk/futureData.csv
 | 
			
		||||
python main.py --mode test --batch 512 --model /tmp/rk/results/simple_new_client --test /tmp/rk/futureData.csv --model_output client
 | 
			
		||||
@@ -1,17 +1,26 @@
 | 
			
		||||
#!/usr/bin/python2
 | 
			
		||||
 | 
			
		||||
import sys
 | 
			
		||||
 | 
			
		||||
import joblib
 | 
			
		||||
import numpy as np
 | 
			
		||||
import pandas as pd
 | 
			
		||||
 | 
			
		||||
df = joblib.load("/mnt/projekte/pmlcluster/cisco/trainData/multipleTaskLearning/currentData.joblib")
 | 
			
		||||
fn = sys.argv[1]
 | 
			
		||||
 | 
			
		||||
df = joblib.load("/mnt/projekte/pmlcluster/cisco/trainData/multipleTaskLearning/{}.joblib".format(fn))
 | 
			
		||||
df = pd.concat(df["data"])
 | 
			
		||||
df.reset_index(inplace=True)
 | 
			
		||||
df.dropna(axis=0, how="any", inplace=True)
 | 
			
		||||
df[["duration", "bytes_down", "bytes_up"]] = df[["duration", "bytes_down", "bytes_up"]].astype(np.int)
 | 
			
		||||
df[["domain", "server_ip"]] = df[["domain", "server_ip"]].astype(str)
 | 
			
		||||
 | 
			
		||||
df.serverLabel = pd.to_numeric(df.serverLabel, errors='coerce')
 | 
			
		||||
df.duration = pd.to_numeric(df.duration, errors='coerce')
 | 
			
		||||
df.bytes_down = pd.to_numeric(df.bytes_down, errors='coerce')
 | 
			
		||||
df.bytes_up = pd.to_numeric(df.bytes_up, errors='coerce')
 | 
			
		||||
 | 
			
		||||
df.http_method = df.http_method.astype("category")
 | 
			
		||||
df.serverLabel = df.serverLabel.astype(np.bool)
 | 
			
		||||
df.virusTotalHits = df.virusTotalHits.astype(np.int8)
 | 
			
		||||
df.trustedHits = df.trustedHits.astype(np.int8)
 | 
			
		||||
 | 
			
		||||
df.to_csv("/tmp/rk/full_future_dataset.csv.gz", compression="gzip")
 | 
			
		||||
df.to_csv("/tmp/rk/{}.csv".format(fn), encoding="utf-8")
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user