shamangary's Blog

Code explanation in center loss github

| Comments

由於我之前看了蘇大大的這篇文章 https://kexue.fm/archives/4493/
覺得寫得很棒,又剛好想要做center loss,所以就做了一個github,
https://github.com/shamangary/Keras-MNIST-center-loss-with-visualization
除了蘇大大的center loss概念外,我加入了MNIST的訓練過程,
還有視覺化來確認是否有真正聚合特徵點到所謂的中心特徵,
算是成為了我keras的入門,可是好像有人寄信問我說覺得我寫的看不懂,
所以特此解釋一下。

1. 可以解釋一下概念嗎?

所謂的損失函數,在keras中的最終目的是要training,
所以在model.compile過後可以被視為損失函數,

由於center loss這東西除了softmax之外又加了所謂的center和特徵之間的距離項,
所以我定義的model_centerloss.compile的那行會有兩種loss,

第一個是categorical_crossentropy,也就是一般classfication的loss,
第二的應該是center loss了吧,可是為何要在loss的地方寫

lambda y_true,y_pred: y_pred

我猜你是在疑惑這個

我一開始學keras的時候也看不懂蘇大大這裡寫的概念,其實這是說keras的損失函數定義,
必然要有y_true和y_pred,就是keras他定死的,
那麼這裡寫lambda y_true,y_pred: y_pred的意思就是一個匿名函數,
就是不管你輸入了什麼ground-truth(y_true),我都回傳你y_pred,
那又是為什麼這樣寫?因為前面的輸入是l2_loss的層,這層的輸出就直接是要最小化的目標了,
所以有別於一般的categorical_crossentropy,
在l2_loss的fit,由於keras定義你必須要輸入對應標籤對(y_true),
但是實際上你又不需要(因為l2_loss直接是想要的損失函數項),
所以我們才用了lambda y_true,y_pred: y_pred,
並且在fit的過程中會定義隨機但是剛好維度符合的y值作為一個假的y_true,

random_y_train = np.random.rand(x_train.shape[0],1)
random_y_test = np.random.rand(x_test.shape[0],1)

這部分牽扯到keras對損失函數的定義問題,所以剛開始看可能會看不懂。
而兩個損失函數的結合有權重項loss_weights=[1,lambda_c],
最後在訓練的過程中,如果你不要center loss,當然是直接用model.fit就好了,
相反如果你需要center loss,那也是用model_centerloss.fit就好了,
因為model_centerloss同時包含了有softmax和center loss,並且已經用權重項組合起來了。

2. L2 loss怎麼看?感覺好怪?

在center loss裡面有個一行,定義中心和特徵之間的距離:

l2_loss = Lambda(lambda x: K.sum(K.square(x[0]-x[1][:,0]),1,keepdims=True),name='l2_loss')([ip1,centers])

乍看之下會有點難懂,我把它展開來看應該會清楚一點

def l2_loss_fun(x):
    print(x)
    ip1 = x[0]
    centers = x[1][:,0]
    print(ip1.shape)
    print(centers.shape)
    sq_dis = K.square(ip1-centers)
    print(sq_dis.shape)
    l2_loss_temp = K.sum(sq_dis,1,keepdims=True)
    print(l2_loss_temp.shape)
    return l2_loss_temp

l2_loss = Lambda(l2_loss_fun)([ip1, centers])

上面這個展開和原本的一行式子是完全等價的,
只是我們可以在函數裡比較好看到各種特徵的維度,
執行程式後回傳

[<tf.Tensor 'ip1/add:0' shape=(?, 2) dtype=float32>, <tf.Tensor 'embedding_1/Gather:0' shape=(?, 1, 2) dtype=float32>]
(?, 2)
(?, 2)
(?, 2)
(?, 1)

首先第一個回傳值:

[<tf.Tensor 'ip1/add:0' shape=(?, 2) dtype=float32>, <tf.Tensor 'embedding_1/Gather:0' shape=(?, 1, 2) dtype=float32>]

輸入的x有兩個東西,ip1和center,其維度分別為(?,2)和(?,1,2),
為了使其相減,我們用x[0]和x[1][:,0]重新命名他們,
可以看到印出來的維度是

(?, 2)
(?, 2)

接著我們算距離,先相減再平方,回傳維度為

(?, 2)

接著我們加起來,回傳維度為

(?, 1)

這個地方我就沒有開根號了,如果要開也可以就再加一行就好。

Comments

comments powered by Disqus