shamangary's Blog

[Source code study] Rewrite StarGAN. From Pytorch to Keras. Part. 1

| Comments

由於之前有想要研究GAN來融合個神奇寶貝,但是之前試過DCGAN和WGAN其實覺得效果都很爛,然後StarGAN就橫空殺出,他們的DEMO讓我非常驚訝,又是multi-domain又是漂亮的影像輸出還有code,馬上列為必讀必寫的課題,接著就來改寫他們的pytorch到keras的形狀吧:D。

讀code的時候我喜歡從易看的地方下手,例如說雖然pytorch和keras都是python based的,但是你其實很難知道main.py需不需要改寫,所以其實用一般CNN的角度來看,model.py肯定是需要改寫的,那我們就從model.py開始出發。

[Eng. ver.] This is a note for rewriting StarGAN from pytorch to keras. I record my steps and thoughts in this note. If you find anything wrong or something you want to add, please contact me or comment below.

In this note, we first review the model.py first. Since every framework (pytorch, keras, ...) must require their own special coding method, it is pretty clear that this file contains pure pytorch, and we want to edit it first.

Goal

# Pytorch -> Keras

# Sequential model -> functional API

1. Conv2D/Conv2DTranspose

# Pytorch

# ref
  torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)

# source code
  layers.append(nn.Conv2d(3+c_dim, conv_dim, kernel_size=7, stride=1, padding=3, bias=False))
        

----------------------------------------------
|
| (rewrite the layer into keras)
|
v
----------------------------------------------

# Keras

# ref
  keras.layers.Conv2D(filters, kernel_size, strides=(1, 1), padding='valid', data_format=None, dilation_rate=(1, 1), activation=None, use_bias=True, kernel_initializer='glorot_uniform', bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None)

  keras.layers.ZeroPadding2D(padding=(1, 1), data_format=None)

# reimplementation

  x = ZeroPadding2D(padding=(3,3))(x)
  x = Conv2D(conv_dim ,(7,7), strides=(1,1), use_bias=False)(x)
  
  or
  
  x = Conv2D(conv_dim ,(7,7), strides=(1,1), padding='same', use_bias=False)(x)
  • There is no need to use input dimension of the layer in keras.

Conv2DTranspose

*** Somehow zero padding with Conv2DTranspose in Keras is not equal to nn.ConvTranspose2d in Pytorch. Use padding='same' directly!!! ***

# Pytorch
# source code
  layers.append(nn.ConvTranspose2d(curr_dim, curr_dim//2, kernel_size=4, stride=2, padding=1, bias=False))
----------------------------------------------
|
| (rewrite the layer into keras)
|
v
----------------------------------------------
# Keras
# reimplementation
    x = Conv2DTranspose(curr_dim//2 ,(4,4), strides=(2,2), padding='same', use_bias=False)(x)
            

2. Instance normalization

(Not directly supported by Keras. You need to install keras-contrib: https://github.com/keras-team/keras-contrib)

# Pytorch

# ref
  torch.nn.InstanceNorm2d(num_features, eps=1e-05, momentum=0.1, affine=False)
  affine: a boolean value that when set to ``True``, gives the layer learnable affine parameters. Default: ``False``

# source code
  layers.append(nn.InstanceNorm2d(conv_dim, affine=True))
        
----------------------------------------------
|
| (rewrite the layer into keras)
|
v
----------------------------------------------
# Keras

# ref
  https://github.com/keras-team/keras-contrib/blob/master/keras_contrib/layers/normalization.py

# reimplementation

  x = InstanceNormalization()(x)

3. ReLU/LeakyReLU

# Pytorch

# ref
  torch.nn.ReLU(inplace=False)

  torch.nn.LeakyReLU(negative_slope=0.01, inplace=False)

# source code
  layers.append(nn.ReLU(inplace=True))

  layers.append(nn.LeakyReLU(0.01, inplace=True))

----------------------------------------------
|
| (rewrite the layer into keras)
|
v
----------------------------------------------

# Keras

# ref
  relu(x, alpha=0.0, max_value=None)

  keras.layers.LeakyReLU(alpha=0.3)

# reimplementation
  x = Activation('relu')(x)

  x = LeakyReLU(alpha=0.01)(x)

4. unsqueeze and expand (in Pytorch)

# Pytorch

def forward(self, x, c):
        # replicate spatially and concatenate domain information
        c = c.unsqueeze(2).unsqueeze(3)
        c = c.expand(c.size(0), c.size(1), x.size(2), x.size(3))
        x = torch.cat([x, c], dim=1)
        return self.main(x)

----------------------------------------------
|
| (rewrite the layer into keras)
|
v
----------------------------------------------
# Keras

inputs = Input(shape=self._input_shape) # remember input image channel is (RGB+mask channel)

fake_c = Input(shape=(self.c_dim,))

def labels_to_maps(x, image_size):
    x_temp = K.expand_dims(x,1)
    x_temp = K.expand_dims(x_temp,2)
    x_temp = K.tile(x_temp, (1, image_size, image_size, 1))
    return x_temp

fake_c_2d = Lambda(labels_to_maps, arguments={'image_size': self.image_size})(fake_c)
x = Concatenate(axis=-1)([inputs, fake_c_2d])

Note that unsqueeze() in pytorch is corresponding to expand_dims() in keras.
And expand() in pytorch is corresponding to tile() in keras.

5. squeeze (in Pytorch)

# pytorch

Returns a tensor with all the dimensions of input of size 1 removed.
http://pytorch.org/docs/master/torch.html

# source code
    out_real.squeeze(), out_aux.squeeze()
    
----------------------------------------------
|
| (rewrite the layer into keras)
|
v
----------------------------------------------

# keras

def squeeze_all(x):
    x_temp = x
    delta_temp = 0
    for i in range(1,4):
    if x.shape[i]==1:
        x_temp = K.squeeze(x_temp,axis=(i-delta_temp))
        delta_temp = delta_temp + 1
    return x_temp
    
out_real = Lambda(squeeze_all)(out_real)
out_aux = Lambda(squeeze_all)(out_aux)

Note that the squeeze in tensorflow or keras is not equal to the squeeze in pytorch.
You must define your own custom layer to do that.

Check (Full code will be released soon)

KERAS_BACKEND=tensorflow python
# python code
from model_keras import Generator, Discriminator
import numpy as np

img_size = 128
g_model = Generator(img_size)()
d_model = Discriminator(img_size)()

m = np.random.rand(1,img_size,img_size,3)
fake_c = np.random.rand(1, 5)

pred_g = g_model.predict([m, fake_c])
pred_d = d_model.predict(m)


pred_g.shape
pred_d[0].shape
pred_d[1].shape

# output
>>> pred_g.shape
(1, 128, 128, 3)
>>> pred_d[0].shape
(1, 2, 2)
>>> pred_d[1].shape
(1, 5)

Comments

comments powered by Disqus