reorder curve storing
This commit is contained in:
parent
d58dbcb101
commit
349bc92a61
38
main.py
38
main.py
@ -610,17 +610,12 @@ def main_beta():
|
|||||||
args.domain_length,
|
args.domain_length,
|
||||||
args.window)
|
args.window)
|
||||||
path, model_prefix = os.path.split(os.path.normpath(args.model_path))
|
path, model_prefix = os.path.split(os.path.normpath(args.model_path))
|
||||||
print(path, model_prefix)
|
curves = {
|
||||||
try:
|
model_prefix: {"all": {}}
|
||||||
curves = joblib.load(f"{path}/curves.joblib")
|
}
|
||||||
logger.info(f"load file {path}/curves.joblib successfully")
|
|
||||||
except Exception:
|
|
||||||
curves = {}
|
|
||||||
logger.info(f"currently {len(curves)} models in file: {curves.keys()}")
|
|
||||||
curves[model_prefix] = {"all": {}}
|
|
||||||
|
|
||||||
domains = domain_val.value.reshape(-1, 40)
|
# domains = domain_val.value.reshape(-1, 40)
|
||||||
domains = np.apply_along_axis(lambda d: dataset.decode_domain(d), 1, domains)
|
# domains = np.apply_along_axis(lambda d: dataset.decode_domain(d), 1, domains)
|
||||||
|
|
||||||
def load_df(res):
|
def load_df(res):
|
||||||
df_server = None
|
df_server = None
|
||||||
@ -634,12 +629,12 @@ def main_beta():
|
|||||||
data["server_pred"] = server.flatten()
|
data["server_pred"] = server.flatten()
|
||||||
data["server_val"] = val.flatten()
|
data["server_val"] = val.flatten()
|
||||||
|
|
||||||
if res["server_pred"].flatten().shape == server_val.value.flatten().shape:
|
# if res["server_pred"].flatten().shape == server_val.value.flatten().shape:
|
||||||
df_server = pd.DataFrame(data={
|
# df_server = pd.DataFrame(data={
|
||||||
"server_pred": res["server_pred"].flatten(),
|
# "server_pred": res["server_pred"].flatten(),
|
||||||
"domain": domains,
|
# "domain": domains,
|
||||||
"server_val": server_val.value.flatten()
|
# "server_val": server_val.value.flatten()
|
||||||
})
|
# })
|
||||||
|
|
||||||
res = pd.DataFrame(data=data)
|
res = pd.DataFrame(data=data)
|
||||||
res["client_val"] = np.logical_or(res.hits_vt == 1.0, res.hits_trusted >= 3)
|
res["client_val"] = np.logical_or(res.hits_vt == 1.0, res.hits_trusted >= 3)
|
||||||
@ -716,8 +711,15 @@ def main_beta():
|
|||||||
df_domain_avg.server_val.as_matrix(),
|
df_domain_avg.server_val.as_matrix(),
|
||||||
server_domain_avg_preds)
|
server_domain_avg_preds)
|
||||||
|
|
||||||
joblib.dump(curves, f"{path}/curves.joblib")
|
joblib.dump(curves, f"{args.model_path}_curves.joblib")
|
||||||
|
try:
|
||||||
|
curves_all: dict = joblib.load(f"{path}/curves.joblib")
|
||||||
|
logger.info(f"load file {path}/curves.joblib successfully")
|
||||||
|
curves_all[model_prefix] = curves[model_prefix]
|
||||||
|
except Exception:
|
||||||
|
curves_all = curves
|
||||||
|
logger.info(f"currently {len(curves_all)} models in file: {curves_all.keys()}")
|
||||||
|
joblib.dump(curves_all, f"{path}/curves.joblib")
|
||||||
|
|
||||||
import matplotlib
|
import matplotlib
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user