add initial version of sluice network with alphas, betas, and soft share
This commit is contained in:
parent
349bc92a61
commit
f382d06eb5
@ -73,6 +73,19 @@ def get_models_by_params(params: dict):
|
||||
l1 = [0.001 * K.sum(K.abs(x - y)) for (x, y) in zip(conv_server, conv_client)]
|
||||
model.add_loss(l1)
|
||||
|
||||
dense_server = model.get_layer("dense_server").trainable_weights
|
||||
dense_client = model.get_layer("dense_client").trainable_weights
|
||||
l2 = [0.001 * K.sum(K.abs(x - y)) for (x, y) in zip(dense_server, dense_client)]
|
||||
model.add_loss(l2)
|
||||
elif network_type == "sluice":
|
||||
model = networks.get_sluice_model(0.25, flow_features, window_size, domain_length,
|
||||
filter_main, kernel_main, dense_dim, domain_cnn)
|
||||
model = create_model(model, model_output)
|
||||
conv_server = model.get_layer("conv_server").trainable_weights
|
||||
conv_client = model.get_layer("conv_client").trainable_weights
|
||||
l1 = [0.001 * K.sum(K.abs(x - y)) for (x, y) in zip(conv_server, conv_client)]
|
||||
model.add_loss(l1)
|
||||
|
||||
dense_server = model.get_layer("dense_server").trainable_weights
|
||||
dense_client = model.get_layer("dense_client").trainable_weights
|
||||
l2 = [0.001 * K.sum(K.abs(x - y)) for (x, y) in zip(dense_server, dense_client)]
|
||||
|
@ -1,7 +1,9 @@
|
||||
from collections import namedtuple
|
||||
|
||||
import keras
|
||||
import keras.backend as K
|
||||
from keras.engine import Input, Model as KerasModel
|
||||
from keras.engine.topology import Layer
|
||||
from keras.layers import Conv1D, Dense, Dropout, Embedding, GlobalAveragePooling1D, GlobalMaxPooling1D, TimeDistributed
|
||||
|
||||
import dataset
|
||||
@ -132,3 +134,97 @@ def get_long_model(dropout, flow_features, window_size, domain_length, cnn_dims,
|
||||
out_client = Dense(1, activation='sigmoid', name="client")(y)
|
||||
|
||||
return Model(ipt_domains, ipt_flows, out_client, out_server)
|
||||
|
||||
|
||||
class CrossStitch2(Layer):
|
||||
def __init__(self, **kwargs):
|
||||
super(CrossStitch2, self).__init__(**kwargs)
|
||||
|
||||
def build(self, input_shape):
|
||||
# Create a trainable weight variable for this layer.
|
||||
self.s = self.add_weight(name='cross-stitch-s',
|
||||
shape=(1,),
|
||||
initializer='uniform',
|
||||
trainable=True)
|
||||
self.d = self.add_weight(name='cross-stitch-d',
|
||||
shape=(1,),
|
||||
initializer='uniform',
|
||||
trainable=True)
|
||||
super(CrossStitch2, self).build(input_shape)
|
||||
|
||||
def call(self, xs):
|
||||
x1, x2 = xs
|
||||
out = x1 * self.s + x2 * self.d
|
||||
print("==>", x1, x2, out)
|
||||
return out
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
return input_shape[0]
|
||||
|
||||
|
||||
class CrossStitchMix2(Layer):
|
||||
def __init__(self, **kwargs):
|
||||
super(CrossStitchMix2, self).__init__(**kwargs)
|
||||
|
||||
def build(self, input_shape):
|
||||
# Create a trainable weight variable for this layer.
|
||||
self.s = self.add_weight(name='cross-stitch-s',
|
||||
shape=(1,),
|
||||
initializer='uniform',
|
||||
trainable=True)
|
||||
self.d = self.add_weight(name='cross-stitch-d',
|
||||
shape=(1,),
|
||||
initializer='uniform',
|
||||
trainable=True)
|
||||
super(CrossStitchMix2, self).build(input_shape)
|
||||
|
||||
def call(self, xs):
|
||||
x1, x2 = xs
|
||||
out = (x1 * self.s, x2 * self.d)
|
||||
out = K.concatenate(out, axis=-1)
|
||||
return out
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
return (input_shape[0][0], input_shape[0][1] + input_shape[1][1])
|
||||
|
||||
|
||||
def get_sluice_model(dropout, flow_features, window_size, domain_length, cnn_dims, kernel_size,
|
||||
dense_dim, cnn) -> Model:
|
||||
ipt_domains = Input(shape=(window_size, domain_length), name="ipt_domains")
|
||||
ipt_flows = Input(shape=(window_size, flow_features), name="ipt_flows")
|
||||
encoded = TimeDistributed(cnn, name="domain_cnn")(ipt_domains)
|
||||
merged = keras.layers.concatenate([encoded, ipt_flows], -1)
|
||||
y1 = Conv1D(cnn_dims,
|
||||
kernel_size,
|
||||
activation='relu', name="conv_server")(merged)
|
||||
y1 = GlobalMaxPooling1D()(y1)
|
||||
|
||||
y2 = Conv1D(cnn_dims,
|
||||
kernel_size,
|
||||
activation='relu', name="conv_client")(merged)
|
||||
y2 = GlobalMaxPooling1D()(y2)
|
||||
|
||||
c11 = CrossStitch2()([y1, y2])
|
||||
c12 = CrossStitch2()([y1, y2])
|
||||
|
||||
y1 = Dropout(dropout)(c11)
|
||||
y1 = Dense(dense_dim,
|
||||
activation="relu",
|
||||
name="dense_server")(y1)
|
||||
|
||||
y2 = Dropout(dropout)(c12)
|
||||
y2 = Dense(dense_dim,
|
||||
activation='relu',
|
||||
name="dense_client")(y2)
|
||||
|
||||
c21 = CrossStitch2()([y1, y2])
|
||||
c22 = CrossStitch2()([y1, y2])
|
||||
|
||||
beta1 = CrossStitchMix2()([c11, c21])
|
||||
beta2 = CrossStitchMix2()([c12, c22])
|
||||
|
||||
out_server = Dense(1, activation="sigmoid", name="server")(beta1)
|
||||
|
||||
out_client = Dense(1, activation='sigmoid', name="client")(beta2)
|
||||
|
||||
return Model(ipt_domains, ipt_flows, out_client, out_server)
|
||||
|
Loading…
Reference in New Issue
Block a user