diff --git a/main.py b/main.py index 3a9b9ef..bafbd5a 100644 --- a/main.py +++ b/main.py @@ -610,17 +610,12 @@ def main_beta(): args.domain_length, args.window) path, model_prefix = os.path.split(os.path.normpath(args.model_path)) - print(path, model_prefix) - try: - curves = joblib.load(f"{path}/curves.joblib") - logger.info(f"load file {path}/curves.joblib successfully") - except Exception: - curves = {} - logger.info(f"currently {len(curves)} models in file: {curves.keys()}") - curves[model_prefix] = {"all": {}} - - domains = domain_val.value.reshape(-1, 40) - domains = np.apply_along_axis(lambda d: dataset.decode_domain(d), 1, domains) + curves = { + model_prefix: {"all": {}} + } + + # domains = domain_val.value.reshape(-1, 40) + # domains = np.apply_along_axis(lambda d: dataset.decode_domain(d), 1, domains) def load_df(res): df_server = None @@ -634,12 +629,12 @@ def main_beta(): data["server_pred"] = server.flatten() data["server_val"] = val.flatten() - if res["server_pred"].flatten().shape == server_val.value.flatten().shape: - df_server = pd.DataFrame(data={ - "server_pred": res["server_pred"].flatten(), - "domain": domains, - "server_val": server_val.value.flatten() - }) + # if res["server_pred"].flatten().shape == server_val.value.flatten().shape: + # df_server = pd.DataFrame(data={ + # "server_pred": res["server_pred"].flatten(), + # "domain": domains, + # "server_val": server_val.value.flatten() + # }) res = pd.DataFrame(data=data) res["client_val"] = np.logical_or(res.hits_vt == 1.0, res.hits_trusted >= 3) @@ -716,8 +711,15 @@ def main_beta(): df_domain_avg.server_val.as_matrix(), server_domain_avg_preds) - joblib.dump(curves, f"{path}/curves.joblib") - + joblib.dump(curves, f"{args.model_path}_curves.joblib") + try: + curves_all: dict = joblib.load(f"{path}/curves.joblib") + logger.info(f"load file {path}/curves.joblib successfully") + curves_all[model_prefix] = curves[model_prefix] + except Exception: + curves_all = curves + logger.info(f"currently {len(curves_all)} models in file: {curves_all.keys()}") + joblib.dump(curves_all, f"{path}/curves.joblib") import matplotlib