From 68254d6629f639149900b633f3575cd7ebf1462e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20Knaebel?= Date: Mon, 2 Oct 2017 07:34:04 +0200 Subject: [PATCH] add load function for hyper band results --- models/renes_networks.py | 2 +- utils.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/models/renes_networks.py b/models/renes_networks.py index 1f5085e..94f3caf 100644 --- a/models/renes_networks.py +++ b/models/renes_networks.py @@ -39,7 +39,7 @@ def get_embedding(embedding_size, input_length, filter_size, kernel_size, hidden def get_model(cnnDropout, flow_features, domain_features, window_size, domain_length, cnn_dims, kernel_size, dense_dim, cnn, model_output="both"): ipt_domains = Input(shape=(window_size, domain_length), name="ipt_domains") - encoded = TimeDistributed(cnn)(ipt_domains) + encoded = TimeDistributed(cnn, name="domain_cnn")(ipt_domains) ipt_flows = Input(shape=(window_size, flow_features), name="ipt_flows") merged = keras.layers.concatenate([encoded, ipt_flows], -1) # CNN processing a small slides of flow windows diff --git a/utils.py b/utils.py index 636e3a7..ccab16b 100644 --- a/utils.py +++ b/utils.py @@ -1,5 +1,7 @@ import os +from operator import itemgetter +import joblib import numpy as np from sklearn.utils import class_weight @@ -20,3 +22,8 @@ def get_custom_class_weights(client, server): def get_custom_sample_weights(client, server): return class_weight.compute_sample_weight("balanced", np.vstack((client, server)).T) + + +def load_ordered_hyperband_results(path): + results = joblib.load(path) + return sorted(results, itemgetter("loss"))