diff --git a/models/networks.py b/models/networks.py index 517399d..ac4de42 100644 --- a/models/networks.py +++ b/models/networks.py @@ -2,9 +2,11 @@ from collections import namedtuple import keras import keras.backend as K +import numpy as np from keras.engine import Input, Model as KerasModel from keras.engine.topology import Layer from keras.layers import Conv1D, Dense, Dropout, Embedding, GlobalAveragePooling1D, GlobalMaxPooling1D, TimeDistributed +from keras.regularizers import Regularizer import dataset @@ -72,7 +74,8 @@ def get_inter_model(dropout, flow_features, window_size, domain_length, cnn_dims activation="relu", name="dense_server")(merged) out_server = Dense(1, activation="sigmoid", name="server")(y) - merged = keras.layers.concatenate([merged, y], -1) + merged = keras.layers.concatenate([merged, + y], -1) # CNN processing a small slides of flow windows y = Conv1D(cnn_dims, kernel_size, @@ -155,7 +158,6 @@ class CrossStitch2(Layer): def call(self, xs): x1, x2 = xs out = x1 * self.s + x2 * self.d - print("==>", x1, x2, out) return out def compute_output_shape(self, input_shape): @@ -180,14 +182,31 @@ class CrossStitchMix2(Layer): def call(self, xs): x1, x2 = xs - out = (x1 * self.s, x2 * self.d) - out = K.concatenate(out, axis=-1) + out = K.concatenate((x1 * self.s, x2 * self.d), axis=-1) return out def compute_output_shape(self, input_shape): return (input_shape[0][0], input_shape[0][1] + input_shape[1][1]) +class L21(Regularizer): + """Regularizer for L21 regularization. + Found at: https://bitbucket.org/ispamm/group-lasso-for-neural-networks-tensorflow-keras + # Arguments + C: Float; L21 regularization factor. + """ + + def __init__(self, C=0.): + self.C = K.cast_to_floatx(C) + + def __call__(self, x): + const_coeff = np.sqrt(K.int_shape(x)[1]) + return self.C * const_coeff * K.sum(K.sqrt(K.sum(K.square(x), axis=1))) + + def get_config(self): + return {'C': float(self.C)} + + def get_sluice_model(dropout, flow_features, window_size, domain_length, cnn_dims, kernel_size, dense_dim, cnn) -> Model: ipt_domains = Input(shape=(window_size, domain_length), name="ipt_domains")