add initial version of sluice network with alphas, betas, and soft share
This commit is contained in:
parent
349bc92a61
commit
f382d06eb5
@ -73,6 +73,19 @@ def get_models_by_params(params: dict):
|
|||||||
l1 = [0.001 * K.sum(K.abs(x - y)) for (x, y) in zip(conv_server, conv_client)]
|
l1 = [0.001 * K.sum(K.abs(x - y)) for (x, y) in zip(conv_server, conv_client)]
|
||||||
model.add_loss(l1)
|
model.add_loss(l1)
|
||||||
|
|
||||||
|
dense_server = model.get_layer("dense_server").trainable_weights
|
||||||
|
dense_client = model.get_layer("dense_client").trainable_weights
|
||||||
|
l2 = [0.001 * K.sum(K.abs(x - y)) for (x, y) in zip(dense_server, dense_client)]
|
||||||
|
model.add_loss(l2)
|
||||||
|
elif network_type == "sluice":
|
||||||
|
model = networks.get_sluice_model(0.25, flow_features, window_size, domain_length,
|
||||||
|
filter_main, kernel_main, dense_dim, domain_cnn)
|
||||||
|
model = create_model(model, model_output)
|
||||||
|
conv_server = model.get_layer("conv_server").trainable_weights
|
||||||
|
conv_client = model.get_layer("conv_client").trainable_weights
|
||||||
|
l1 = [0.001 * K.sum(K.abs(x - y)) for (x, y) in zip(conv_server, conv_client)]
|
||||||
|
model.add_loss(l1)
|
||||||
|
|
||||||
dense_server = model.get_layer("dense_server").trainable_weights
|
dense_server = model.get_layer("dense_server").trainable_weights
|
||||||
dense_client = model.get_layer("dense_client").trainable_weights
|
dense_client = model.get_layer("dense_client").trainable_weights
|
||||||
l2 = [0.001 * K.sum(K.abs(x - y)) for (x, y) in zip(dense_server, dense_client)]
|
l2 = [0.001 * K.sum(K.abs(x - y)) for (x, y) in zip(dense_server, dense_client)]
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
import keras
|
import keras
|
||||||
|
import keras.backend as K
|
||||||
from keras.engine import Input, Model as KerasModel
|
from keras.engine import Input, Model as KerasModel
|
||||||
|
from keras.engine.topology import Layer
|
||||||
from keras.layers import Conv1D, Dense, Dropout, Embedding, GlobalAveragePooling1D, GlobalMaxPooling1D, TimeDistributed
|
from keras.layers import Conv1D, Dense, Dropout, Embedding, GlobalAveragePooling1D, GlobalMaxPooling1D, TimeDistributed
|
||||||
|
|
||||||
import dataset
|
import dataset
|
||||||
@ -132,3 +134,97 @@ def get_long_model(dropout, flow_features, window_size, domain_length, cnn_dims,
|
|||||||
out_client = Dense(1, activation='sigmoid', name="client")(y)
|
out_client = Dense(1, activation='sigmoid', name="client")(y)
|
||||||
|
|
||||||
return Model(ipt_domains, ipt_flows, out_client, out_server)
|
return Model(ipt_domains, ipt_flows, out_client, out_server)
|
||||||
|
|
||||||
|
|
||||||
|
class CrossStitch2(Layer):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super(CrossStitch2, self).__init__(**kwargs)
|
||||||
|
|
||||||
|
def build(self, input_shape):
|
||||||
|
# Create a trainable weight variable for this layer.
|
||||||
|
self.s = self.add_weight(name='cross-stitch-s',
|
||||||
|
shape=(1,),
|
||||||
|
initializer='uniform',
|
||||||
|
trainable=True)
|
||||||
|
self.d = self.add_weight(name='cross-stitch-d',
|
||||||
|
shape=(1,),
|
||||||
|
initializer='uniform',
|
||||||
|
trainable=True)
|
||||||
|
super(CrossStitch2, self).build(input_shape)
|
||||||
|
|
||||||
|
def call(self, xs):
|
||||||
|
x1, x2 = xs
|
||||||
|
out = x1 * self.s + x2 * self.d
|
||||||
|
print("==>", x1, x2, out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def compute_output_shape(self, input_shape):
|
||||||
|
return input_shape[0]
|
||||||
|
|
||||||
|
|
||||||
|
class CrossStitchMix2(Layer):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super(CrossStitchMix2, self).__init__(**kwargs)
|
||||||
|
|
||||||
|
def build(self, input_shape):
|
||||||
|
# Create a trainable weight variable for this layer.
|
||||||
|
self.s = self.add_weight(name='cross-stitch-s',
|
||||||
|
shape=(1,),
|
||||||
|
initializer='uniform',
|
||||||
|
trainable=True)
|
||||||
|
self.d = self.add_weight(name='cross-stitch-d',
|
||||||
|
shape=(1,),
|
||||||
|
initializer='uniform',
|
||||||
|
trainable=True)
|
||||||
|
super(CrossStitchMix2, self).build(input_shape)
|
||||||
|
|
||||||
|
def call(self, xs):
|
||||||
|
x1, x2 = xs
|
||||||
|
out = (x1 * self.s, x2 * self.d)
|
||||||
|
out = K.concatenate(out, axis=-1)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def compute_output_shape(self, input_shape):
|
||||||
|
return (input_shape[0][0], input_shape[0][1] + input_shape[1][1])
|
||||||
|
|
||||||
|
|
||||||
|
def get_sluice_model(dropout, flow_features, window_size, domain_length, cnn_dims, kernel_size,
|
||||||
|
dense_dim, cnn) -> Model:
|
||||||
|
ipt_domains = Input(shape=(window_size, domain_length), name="ipt_domains")
|
||||||
|
ipt_flows = Input(shape=(window_size, flow_features), name="ipt_flows")
|
||||||
|
encoded = TimeDistributed(cnn, name="domain_cnn")(ipt_domains)
|
||||||
|
merged = keras.layers.concatenate([encoded, ipt_flows], -1)
|
||||||
|
y1 = Conv1D(cnn_dims,
|
||||||
|
kernel_size,
|
||||||
|
activation='relu', name="conv_server")(merged)
|
||||||
|
y1 = GlobalMaxPooling1D()(y1)
|
||||||
|
|
||||||
|
y2 = Conv1D(cnn_dims,
|
||||||
|
kernel_size,
|
||||||
|
activation='relu', name="conv_client")(merged)
|
||||||
|
y2 = GlobalMaxPooling1D()(y2)
|
||||||
|
|
||||||
|
c11 = CrossStitch2()([y1, y2])
|
||||||
|
c12 = CrossStitch2()([y1, y2])
|
||||||
|
|
||||||
|
y1 = Dropout(dropout)(c11)
|
||||||
|
y1 = Dense(dense_dim,
|
||||||
|
activation="relu",
|
||||||
|
name="dense_server")(y1)
|
||||||
|
|
||||||
|
y2 = Dropout(dropout)(c12)
|
||||||
|
y2 = Dense(dense_dim,
|
||||||
|
activation='relu',
|
||||||
|
name="dense_client")(y2)
|
||||||
|
|
||||||
|
c21 = CrossStitch2()([y1, y2])
|
||||||
|
c22 = CrossStitch2()([y1, y2])
|
||||||
|
|
||||||
|
beta1 = CrossStitchMix2()([c11, c21])
|
||||||
|
beta2 = CrossStitchMix2()([c12, c22])
|
||||||
|
|
||||||
|
out_server = Dense(1, activation="sigmoid", name="server")(beta1)
|
||||||
|
|
||||||
|
out_client = Dense(1, activation='sigmoid', name="client")(beta2)
|
||||||
|
|
||||||
|
return Model(ipt_domains, ipt_flows, out_client, out_server)
|
||||||
|
Loading…
Reference in New Issue
Block a user