add group lasso regularizer impl

This commit is contained in:
René Knaebel 2017-11-30 09:34:45 +01:00
parent f382d06eb5
commit b3d646c9e7

View File

@ -2,9 +2,11 @@ from collections import namedtuple
import keras
import keras.backend as K
import numpy as np
from keras.engine import Input, Model as KerasModel
from keras.engine.topology import Layer
from keras.layers import Conv1D, Dense, Dropout, Embedding, GlobalAveragePooling1D, GlobalMaxPooling1D, TimeDistributed
from keras.regularizers import Regularizer
import dataset
@ -72,7 +74,8 @@ def get_inter_model(dropout, flow_features, window_size, domain_length, cnn_dims
activation="relu",
name="dense_server")(merged)
out_server = Dense(1, activation="sigmoid", name="server")(y)
merged = keras.layers.concatenate([merged, y], -1)
merged = keras.layers.concatenate([merged,
y], -1)
# CNN processing a small slides of flow windows
y = Conv1D(cnn_dims,
kernel_size,
@ -155,7 +158,6 @@ class CrossStitch2(Layer):
def call(self, xs):
x1, x2 = xs
out = x1 * self.s + x2 * self.d
print("==>", x1, x2, out)
return out
def compute_output_shape(self, input_shape):
@ -180,14 +182,31 @@ class CrossStitchMix2(Layer):
def call(self, xs):
x1, x2 = xs
out = (x1 * self.s, x2 * self.d)
out = K.concatenate(out, axis=-1)
out = K.concatenate((x1 * self.s, x2 * self.d), axis=-1)
return out
def compute_output_shape(self, input_shape):
return (input_shape[0][0], input_shape[0][1] + input_shape[1][1])
class L21(Regularizer):
"""Regularizer for L21 regularization.
Found at: https://bitbucket.org/ispamm/group-lasso-for-neural-networks-tensorflow-keras
# Arguments
C: Float; L21 regularization factor.
"""
def __init__(self, C=0.):
self.C = K.cast_to_floatx(C)
def __call__(self, x):
const_coeff = np.sqrt(K.int_shape(x)[1])
return self.C * const_coeff * K.sum(K.sqrt(K.sum(K.square(x), axis=1)))
def get_config(self):
return {'C': float(self.C)}
def get_sluice_model(dropout, flow_features, window_size, domain_length, cnn_dims, kernel_size,
dense_dim, cnn) -> Model:
ipt_domains = Input(shape=(window_size, domain_length), name="ipt_domains")