diff --git a/models/__init__.py b/models/__init__.py index 7bb5918..cf4ca0e 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -73,6 +73,19 @@ def get_models_by_params(params: dict): l1 = [0.001 * K.sum(K.abs(x - y)) for (x, y) in zip(conv_server, conv_client)] model.add_loss(l1) + dense_server = model.get_layer("dense_server").trainable_weights + dense_client = model.get_layer("dense_client").trainable_weights + l2 = [0.001 * K.sum(K.abs(x - y)) for (x, y) in zip(dense_server, dense_client)] + model.add_loss(l2) + elif network_type == "sluice": + model = networks.get_sluice_model(0.25, flow_features, window_size, domain_length, + filter_main, kernel_main, dense_dim, domain_cnn) + model = create_model(model, model_output) + conv_server = model.get_layer("conv_server").trainable_weights + conv_client = model.get_layer("conv_client").trainable_weights + l1 = [0.001 * K.sum(K.abs(x - y)) for (x, y) in zip(conv_server, conv_client)] + model.add_loss(l1) + dense_server = model.get_layer("dense_server").trainable_weights dense_client = model.get_layer("dense_client").trainable_weights l2 = [0.001 * K.sum(K.abs(x - y)) for (x, y) in zip(dense_server, dense_client)] diff --git a/models/networks.py b/models/networks.py index c02633d..517399d 100644 --- a/models/networks.py +++ b/models/networks.py @@ -1,7 +1,9 @@ from collections import namedtuple import keras +import keras.backend as K from keras.engine import Input, Model as KerasModel +from keras.engine.topology import Layer from keras.layers import Conv1D, Dense, Dropout, Embedding, GlobalAveragePooling1D, GlobalMaxPooling1D, TimeDistributed import dataset @@ -132,3 +134,97 @@ def get_long_model(dropout, flow_features, window_size, domain_length, cnn_dims, out_client = Dense(1, activation='sigmoid', name="client")(y) return Model(ipt_domains, ipt_flows, out_client, out_server) + + +class CrossStitch2(Layer): + def __init__(self, **kwargs): + super(CrossStitch2, self).__init__(**kwargs) + + def build(self, input_shape): + # Create a trainable weight variable for this layer. + self.s = self.add_weight(name='cross-stitch-s', + shape=(1,), + initializer='uniform', + trainable=True) + self.d = self.add_weight(name='cross-stitch-d', + shape=(1,), + initializer='uniform', + trainable=True) + super(CrossStitch2, self).build(input_shape) + + def call(self, xs): + x1, x2 = xs + out = x1 * self.s + x2 * self.d + print("==>", x1, x2, out) + return out + + def compute_output_shape(self, input_shape): + return input_shape[0] + + +class CrossStitchMix2(Layer): + def __init__(self, **kwargs): + super(CrossStitchMix2, self).__init__(**kwargs) + + def build(self, input_shape): + # Create a trainable weight variable for this layer. + self.s = self.add_weight(name='cross-stitch-s', + shape=(1,), + initializer='uniform', + trainable=True) + self.d = self.add_weight(name='cross-stitch-d', + shape=(1,), + initializer='uniform', + trainable=True) + super(CrossStitchMix2, self).build(input_shape) + + def call(self, xs): + x1, x2 = xs + out = (x1 * self.s, x2 * self.d) + out = K.concatenate(out, axis=-1) + return out + + def compute_output_shape(self, input_shape): + return (input_shape[0][0], input_shape[0][1] + input_shape[1][1]) + + +def get_sluice_model(dropout, flow_features, window_size, domain_length, cnn_dims, kernel_size, + dense_dim, cnn) -> Model: + ipt_domains = Input(shape=(window_size, domain_length), name="ipt_domains") + ipt_flows = Input(shape=(window_size, flow_features), name="ipt_flows") + encoded = TimeDistributed(cnn, name="domain_cnn")(ipt_domains) + merged = keras.layers.concatenate([encoded, ipt_flows], -1) + y1 = Conv1D(cnn_dims, + kernel_size, + activation='relu', name="conv_server")(merged) + y1 = GlobalMaxPooling1D()(y1) + + y2 = Conv1D(cnn_dims, + kernel_size, + activation='relu', name="conv_client")(merged) + y2 = GlobalMaxPooling1D()(y2) + + c11 = CrossStitch2()([y1, y2]) + c12 = CrossStitch2()([y1, y2]) + + y1 = Dropout(dropout)(c11) + y1 = Dense(dense_dim, + activation="relu", + name="dense_server")(y1) + + y2 = Dropout(dropout)(c12) + y2 = Dense(dense_dim, + activation='relu', + name="dense_client")(y2) + + c21 = CrossStitch2()([y1, y2]) + c22 = CrossStitch2()([y1, y2]) + + beta1 = CrossStitchMix2()([c11, c21]) + beta2 = CrossStitchMix2()([c12, c22]) + + out_server = Dense(1, activation="sigmoid", name="server")(beta1) + + out_client = Dense(1, activation='sigmoid', name="client")(beta2) + + return Model(ipt_domains, ipt_flows, out_client, out_server)