shamangary's Blog

Analysis of CNN Architecture. Using STN+PNNET, Residual Network as Examples.

| Comments

Introduction

在使用Torch7來編寫你自己的CNN架構時,我們需要先了解一下每個單位所負責的事情,就CNN的結構來說,大家通常只知道weight filter還有feature map之類的東西,如何排序整個架構在paper上是可以用箭頭和長方形來做high level的解說,但你在寫程式總不能這樣寫吧(好啦或許有些程式可以,但是我們可愛的Torch7不行),所以我們要從小到大再從大到小來看一下如何拆解和理解整個架構,我們用的範例是STN(Spatial Transformaer Network, https://github.com/qassemoquab/stnbhwd)+PNNET(PN-Net, https://github.com/vbalnt/pnnet),並在最後我們討論一下如何寫residual network (https://github.com/gcr/torch-residual-networks)

Problem definition

我們現在想處理的問題是local descriptor,也就是

input=32x32 patch --> CNN --> local descriptor=128 dim real-valued vector

本篇的重點是如何建立該CNN的model,一旦建立好了該model,你便可以使用該model去做例如Siamese network的training,本篇先不著墨如何寫training的lost function code,而先著重分析該model的每個部份為何如此寫。

對於local descriptor來說,在[arXiv16] LIFT 中明確指出有了STN的加持,在train你的siamese network時可以為你的descriptor加上一層transformation的保護,進而讓你產生出來的descriptor有對transformation更好的抵抗能力,是故本篇使用STN作為前導,在model的一開始先矯正input patch的transformation,並使用[arXiv16] PNNET作為local descriptor的network,藉以產生128-dim的向量。

STN+PNNET Model

nn.Sequential {
  [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> (7) -> (8) -> (9) -> output]
  (1): nn.Sequential {
    [input -> (1) -> (2) -> (3) -> output]
    (1): nn.ConcatTable {
      input
        |`-> (1): nn.Sequential {
        |      [input -> (1) -> (2) -> output]
        |      (1): nn.Identity
        |      (2): nn.Transpose
        |    }
        |`-> (2): nn.Sequential {
        |      [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> (7) -> (8) -> (9) -> (10) -> (11) -> (12) -> output]
        |      (1): cudnn.SpatialMaxPooling(2x2, 2,2)
        |      (2): cudnn.SpatialConvolution(1 -> 20, 5x5)
        |      (3): cudnn.ReLU
        |      (4): cudnn.SpatialMaxPooling(2x2, 2,2)
        |      (5): cudnn.SpatialConvolution(20 -> 20, 5x5)
        |      (6): cudnn.ReLU
        |      (7): nn.View(80)
        |      (8): nn.Linear(80 -> 20)
        |      (9): cudnn.ReLU
        |      (10): nn.Linear(20 -> 6)
        |      (11): nn.View(2, 3)
        |      (12): nn.AffineGridGeneratorBHWD
        |    }
         ... -> output
    }
    (2): nn.BilinearSamplerBHWD
    (3): nn.Transpose
  }
  (2): cudnn.SpatialConvolution(1 -> 32, 7x7)
  (3): cudnn.Tanh
  (4): cudnn.SpatialMaxPooling(2x2, 2,2)
  (5): cudnn.SpatialConvolution(32 -> 64, 6x6)
  (6): cudnn.Tanh
  (7): nn.View(4096)
  (8): nn.Linear(4096 -> 128)
  (9): cudnn.Tanh
}

如果你第一次接觸會覺得很混亂,但是概念上我們可以把上表表示成下面的形式

nn.Sequential {
  (1): nn.Sequential {
    **STN**
  }
  **PNNET**
}

基本上就是STN先,然後再做PNNET的local descriptor。

STN (Spatial Transformer Network)

同樣我們先把上面STN複雜的形式改成概念上的說明,如下表所示

  (1): nn.Sequential {
    [input -> (1) -> (2) -> (3) -> output]
    (1): nn.ConcatTable {
      input
        |`-> (1): first branch is there to transpose inputs to BHWD, for the bilinear sampler
        |    
        |`-> (2): second branch is the localization network
        |    
         ... -> output
    }
    (2): nn.BilinearSamplerBHWD
    (3): nn.Transpose
  }

首先第一個需要注意的地方就是,這裡使用了nn.ConcatTable,我們先看一下Torch7對他的定義
https://github.com/torch/nn/blob/master/doc/table.md

# module = nn.ConcatTable()
                  +-----------+
             +----> {member1, |
+-------+    |    |           |
| input +----+---->  member2, |
+-------+    |    |           |
   or        +---->  member3} |
 {input}          +-----------+

也就是你定義的不同member都會拿到同樣的input,而ouput也是table,會直接作為下一個module的input輸入過去。

我們在STN中所看到的input就是我們的32x32的image patch,而在標準的STN中有兩個member

input --> 32x32 image patch

member1--> 把input轉成BHWD形式

member2--> 找出transformation的grid

// Bilinear sampling is done in BHWD (coalescing is not obvious in BDHW)
B:batch, D:dimension, H: height, W:width

一旦我們得到了grid還有input的BHWD形式,我們就可以用這兩者產生轉換後用bilinear產生的output也是32x32的patch,但是是經過STN轉正過的。

特別有用的是,在**nn.BilinearSamplerBHWD**這層接收到的input是一個table,
也就是上面提到的兩個東西,但是ouput卻是一個而已,也就是說,
這種寫法對於設計CNN model是非常有用的,必須要記起來。
具體來說如何去寫,就要參考BilinearSamplerBHWD.lua裡面的東西了。

Others

既然我們都看到了nn.ConcatTable,我們也順便看一下其他支援的形式

# module = nn.ParallelTable()
+----------+         +-----------+
| {input1, +---------> {member1, |
|          |         |           |
|  input2, +--------->  member2, |
|          |         |           |
|  input3} +--------->  member3} |
+----------+         +-----------+

其他還有太多種就去官網看吧 --> https://github.com/torch/nn/blob/master/doc/table.md

Residual network

類似於nn.ConcatTable的寫法,我們特別來參考一下residual network要怎麼寫才對
https://github.com/gcr/torch-residual-networks/blob/master/residual-layers.lua

# Concept
               Input
                 |
         ,-------+-----.
   Downsampling      3x3 convolution+dimensionality reduction
        |               |
        v               v
   Zero-padding      3x3 convolution
        |               |
        `-----( Add )---'
                 |
              Output
              
# How to add them up? (Only show the important part)

   local skip = input

   -- Add them together
   net = cudnn.SpatialBatchNormalization(nOutChannels)(net)
   net = nn.CAddTable(){net, skip}
   net = cudnn.ReLU(true)(net)

其中的關鍵就是nn.CAddTable()這個方法,懂了之後想要怎麼設計都行。

至於怎麼插入這樣的東西到你的model中去,可以參考 https://github.com/gcr/torch-residual-networks/blob/master/train-cifar.lua,他把它直接寫成了一個layer可以很方便的插入

model = addResidualLayer2(model, 32, 64, 2)

由此可知,實行elementwise的運算在Torch7中也是有支援的,我們特別將他們列出來因為特別有用

# CMath Modules perform element-wise operations on a table of Tensors:

CAddTable: addition of input Tensors;
CSubTable: substraction of input Tensors;
CMulTable: multiplication of input Tensors;
CDivTable: division of input Tensors;

Conclusion

本篇文章中我們討論了如何去看還有如何去寫你所想要的CNN結構,之前想寫residual的概念但是不太懂程式上怎麼搞,概念上懂了之後我們就能迅速的在Torch7上面開發了。

Comments

comments powered by Disqus