add extended test mode for embeddings
This commit is contained in:
parent
79fc441fe1
commit
18b60e1754
@ -216,7 +216,6 @@ def load_or_generate_h5data(h5data, train_data, domain_length, window_size):
|
|||||||
logger.info(f"check for h5data {h5data}")
|
logger.info(f"check for h5data {h5data}")
|
||||||
try:
|
try:
|
||||||
open(h5data, "r")
|
open(h5data, "r")
|
||||||
raise FileNotFoundError
|
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
logger.info("h5 data not found - load csv file")
|
logger.info("h5 data not found - load csv file")
|
||||||
user_flow_df = get_user_flow_data(train_data)
|
user_flow_df = get_user_flow_data(train_data)
|
||||||
|
34
main.py
34
main.py
@ -198,6 +198,25 @@ def main_test():
|
|||||||
verbose=1)
|
verbose=1)
|
||||||
np.save(args.future_prediction, y_pred)
|
np.save(args.future_prediction, y_pred)
|
||||||
|
|
||||||
|
char_dict = dataset.get_character_dict()
|
||||||
|
user_flow_df = dataset.get_user_flow_data(args.test_data)
|
||||||
|
domains = user_flow_df.domain.unique()
|
||||||
|
|
||||||
|
def get_domain_features_reduced(d):
|
||||||
|
return dataset.get_domain_features(d[0], char_dict, args.domain_length)
|
||||||
|
|
||||||
|
domain_features = []
|
||||||
|
for ds in domains:
|
||||||
|
domain_features.append(np.apply_along_axis(get_domain_features_reduced, 2, np.atleast_3d(ds)))
|
||||||
|
|
||||||
|
model = load_model(args.embedding_model)
|
||||||
|
domain_features = np.stack(domain_features).reshape((-1, 40))
|
||||||
|
pred = model.predict(domains, batch_size=args.batch_size, verbose=1)
|
||||||
|
|
||||||
|
np.save("/tmp/rk/domains.npy", domains)
|
||||||
|
np.save("/tmp/rk/domain_features.npy", domain_features)
|
||||||
|
np.save("/tmp/rk/domain_embd.npy", pred)
|
||||||
|
|
||||||
|
|
||||||
def main_visualization():
|
def main_visualization():
|
||||||
domain_val, flow_val, client_val, server_val = load_or_generate_h5data(args.test_h5data, args.test_data,
|
domain_val, flow_val, client_val, server_val = load_or_generate_h5data(args.test_h5data, args.test_data,
|
||||||
@ -205,10 +224,13 @@ def main_visualization():
|
|||||||
logger.info("plot model")
|
logger.info("plot model")
|
||||||
model = load_model(args.clf_model, custom_objects=models.get_metrics())
|
model = load_model(args.clf_model, custom_objects=models.get_metrics())
|
||||||
visualize.plot_model(model, os.path.join(args.model_path, "model.png"))
|
visualize.plot_model(model, os.path.join(args.model_path, "model.png"))
|
||||||
logger.info("plot training curve")
|
try:
|
||||||
logs = pd.read_csv(args.train_log)
|
logger.info("plot training curve")
|
||||||
visualize.plot_training_curve(logs, "client", "{}/client_train.png".format(args.model_path))
|
logs = pd.read_csv(args.train_log)
|
||||||
visualize.plot_training_curve(logs, "server", "{}/server_train.png".format(args.model_path))
|
visualize.plot_training_curve(logs, "client", "{}/client_train.png".format(args.model_path))
|
||||||
|
visualize.plot_training_curve(logs, "server", "{}/server_train.png".format(args.model_path))
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"could not generate training curves: {e}")
|
||||||
|
|
||||||
client_pred, server_pred = np.load(args.future_prediction)
|
client_pred, server_pred = np.load(args.future_prediction)
|
||||||
logger.info("plot pr curve")
|
logger.info("plot pr curve")
|
||||||
@ -230,8 +252,8 @@ def main_visualization():
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
model = load_model(args.embedding_model)
|
model = load_model(args.embedding_model)
|
||||||
domains = np.reshape(domain_val, (12800, 40))
|
domains = np.reshape(domain_val, (domain_val.shape[0] * domain_val.shape[1], 40))
|
||||||
domain_embedding = model.predict(domains)
|
domain_embedding = model.predict(domains, batch_size=args.batch_size, verbose=1)
|
||||||
|
|
||||||
pca = PCA(n_components=2)
|
pca = PCA(n_components=2)
|
||||||
domain_reduced = pca.fit_transform(domain_embedding)
|
domain_reduced = pca.fit_transform(domain_embedding)
|
||||||
|
Loading…
Reference in New Issue
Block a user