from . import pauls_networks from . import renes_networks def get_models_by_params(params: dict): # decomposing param section # mainly embedding model network_type = params.get("type") vocab_size = params.get("vocab_size") embedding_size = params.get("embedding_size") input_length = params.get("input_length") filter_embedding = params.get("filter_embedding") kernel_embedding = params.get("kernel_embedding") hidden_embedding = params.get("hidden_embedding") dropout = params.get("dropout") # mainly prediction model flow_features = params.get("flow_features") domain_features = params.get("domain_features") window_size = params.get("window_size") domain_length = params.get("domain_length") filter_main = params.get("filter_main") kernel_main = params.get("kernels_main") dense_dim = params.get("dense_main") # create models networks = renes_networks if network_type == "rene" else pauls_networks embedding_model = networks.get_embedding(vocab_size, embedding_size, input_length, filter_embedding, kernel_embedding, hidden_embedding, drop_out=dropout) predict_model = networks.get_model(dropout, flow_features, domain_features, window_size, domain_length, filter_main, kernel_main, dense_dim, embedding_model) return embedding_model, predict_model