add extended test mode for embeddings

This commit is contained in:
René Knaebel 2017-07-17 19:30:56 +02:00
parent 79fc441fe1
commit 18b60e1754
2 changed files with 28 additions and 7 deletions

View File

@ -216,7 +216,6 @@ def load_or_generate_h5data(h5data, train_data, domain_length, window_size):
logger.info(f"check for h5data {h5data}")
try:
open(h5data, "r")
raise FileNotFoundError
except FileNotFoundError:
logger.info("h5 data not found - load csv file")
user_flow_df = get_user_flow_data(train_data)

34
main.py
View File

@ -198,6 +198,25 @@ def main_test():
verbose=1)
np.save(args.future_prediction, y_pred)
char_dict = dataset.get_character_dict()
user_flow_df = dataset.get_user_flow_data(args.test_data)
domains = user_flow_df.domain.unique()
def get_domain_features_reduced(d):
return dataset.get_domain_features(d[0], char_dict, args.domain_length)
domain_features = []
for ds in domains:
domain_features.append(np.apply_along_axis(get_domain_features_reduced, 2, np.atleast_3d(ds)))
model = load_model(args.embedding_model)
domain_features = np.stack(domain_features).reshape((-1, 40))
pred = model.predict(domains, batch_size=args.batch_size, verbose=1)
np.save("/tmp/rk/domains.npy", domains)
np.save("/tmp/rk/domain_features.npy", domain_features)
np.save("/tmp/rk/domain_embd.npy", pred)
def main_visualization():
domain_val, flow_val, client_val, server_val = load_or_generate_h5data(args.test_h5data, args.test_data,
@ -205,10 +224,13 @@ def main_visualization():
logger.info("plot model")
model = load_model(args.clf_model, custom_objects=models.get_metrics())
visualize.plot_model(model, os.path.join(args.model_path, "model.png"))
logger.info("plot training curve")
logs = pd.read_csv(args.train_log)
visualize.plot_training_curve(logs, "client", "{}/client_train.png".format(args.model_path))
visualize.plot_training_curve(logs, "server", "{}/server_train.png".format(args.model_path))
try:
logger.info("plot training curve")
logs = pd.read_csv(args.train_log)
visualize.plot_training_curve(logs, "client", "{}/client_train.png".format(args.model_path))
visualize.plot_training_curve(logs, "server", "{}/server_train.png".format(args.model_path))
except Exception as e:
logger.warning(f"could not generate training curves: {e}")
client_pred, server_pred = np.load(args.future_prediction)
logger.info("plot pr curve")
@ -230,8 +252,8 @@ def main_visualization():
import matplotlib.pyplot as plt
model = load_model(args.embedding_model)
domains = np.reshape(domain_val, (12800, 40))
domain_embedding = model.predict(domains)
domains = np.reshape(domain_val, (domain_val.shape[0] * domain_val.shape[1], 40))
domain_embedding = model.predict(domains, batch_size=args.batch_size, verbose=1)
pca = PCA(n_components=2)
domain_reduced = pca.fit_transform(domain_embedding)