diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..65531ca
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..12cf37d
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/workspace.xml b/.idea/workspace.xml
new file mode 100644
index 0000000..c24ecf8
--- /dev/null
+++ b/.idea/workspace.xml
@@ -0,0 +1,79 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 1562733251523
+
+
+ 1562733251523
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/README.md b/README.md
index 4800bd2..18e890d 100755
--- a/README.md
+++ b/README.md
@@ -11,16 +11,15 @@ Tensorflow implementation for reproducing main results in the paper [StackGAN: T
### Dependencies
-python 2.7
+python 3.6+
-[TensorFlow 0.12](https://www.tensorflow.org/get_started/os_setup)
+[TensorFlow 1.13+](https://www.tensorflow.org/get_started/os_setup)
[Optional] [Torch](http://torch.ch/docs/getting-started.html#_) is needed, if use the pre-trained char-CNN-RNN text encoder.
[Optional] [skip-thought](https://github.com/ryankiros/skip-thoughts) is needed, if use the skip-thought text encoder.
In addition, please add the project folder to PYTHONPATH and `pip install` the following packages:
-- `prettytensor`
- `progressbar`
- `python-dateutil`
- `easydict`
@@ -32,7 +31,12 @@ In addition, please add the project folder to PYTHONPATH and `pip install` the f
**Data**
1. Download our preprocessed char-CNN-RNN text embeddings for [birds](https://drive.google.com/open?id=0B3y_msrWZaXLT1BZdVdycDY5TEE) and [flowers](https://drive.google.com/open?id=0B3y_msrWZaXLaUc0UXpmcnhaVmM) and save them to `Data/`.
+
- [Optional] Follow the instructions [reedscot/icml2016](https://github.com/reedscot/icml2016) to download the pretrained char-CNN-RNN text encoders and extract text embeddings.
+
+ - [Optional] Download our preprocessed skip-thoughts text embeddings for [birds](https://drive.google.com/open?id=10jlSsU3g2ywDFXgUmn2Dh_UJCkQectzy) and save them to `Data/`.
+
+
2. Download the [birds](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html) and [flowers](http://www.robots.ox.ac.uk/~vgg/data/flowers/102/) image data. Extract them to `Data/birds/` and `Data/flowers/`, respectively.
3. Preprocess images.
- For birds: `python misc/preprocess_birds.py`
@@ -51,9 +55,9 @@ In addition, please add the project folder to PYTHONPATH and `pip install` the f
**Pretrained Model**
-- [StackGAN for birds](https://drive.google.com/open?id=0B3y_msrWZaXLNUNKa3BaRjAyTzQ) trained from char-CNN-RNN text embeddings. Download and save it to `models/`.
-- [StackGAN for flowers](https://drive.google.com/open?id=0B3y_msrWZaXLX01FMC1JQW9vaFk) trained from char-CNN-RNN text embeddings. Download and save it to `models/`.
-- [StackGAN for birds](https://drive.google.com/open?id=0B3y_msrWZaXLZVNRNFg4d055Q1E) trained from skip-thought text embeddings. Download and save it to `models/` (Just used the same setting as the char-CNN-RNN. We assume better results can be achieved by playing with the hyper-parameters).
+- [StackGAN for birds](https://drive.google.com/open?id=1O1JHIoYO3h_qB5o27Td8KklvuLgTgpdV) trained from char-CNN-RNN text embeddings. Download and save it to `models/`.
+- [StackGAN for flowers]() trained from char-CNN-RNN text embeddings. Download and save it to `models/`.
+- [StackGAN for birds]() trained from skip-thought text embeddings. Download and save it to `models/` (Just used the same setting as the char-CNN-RNN. We assume better results can be achieved by playing with the hyper-parameters).
@@ -96,6 +100,12 @@ booktitle = {{ICCV}},
- [StackGAN++: Realistic Image Synthesis with Stacked Generative Adversarial Networks](https://arxiv.org/abs/1710.10916)
- [AttnGAN: Fine-Grained Text to Image Generation with Attentional Generative Adversarial Networks](https://arxiv.org/abs/1711.10485) [[supplementary]](https://1drv.ms/b/s!Aj4exx_cRA4ghK5-kUG-EqH7hgknUA) [[code]](https://github.com/taoxugit/AttnGAN)
+**Future**
+
+[Fashion Expansion](https://github.com/1o0ko/StackGAN-v1-TensorFlow)
+
+[Fashion Dataset](https://github.com/ayushidalmia/awesome-fashion-ai#datasets)
+
**References**
- Generative Adversarial Text-to-Image Synthesis [Paper](https://arxiv.org/abs/1605.05396) [Code](https://github.com/reedscot/icml2016)
diff --git a/demo/birds_demo.sh b/demo/birds_demo.sh
old mode 100644
new mode 100755
index 66ff9ab..162cec8
--- a/demo/birds_demo.sh
+++ b/demo/birds_demo.sh
@@ -1,3 +1,4 @@
+#!/usr/bin/env bash
#
# Extract text embeddings from the encoder
#
@@ -15,7 +16,7 @@ th demo/get_embedding.lua
#
# Generate image from text embeddings
#
-python demo/demo.py \
+python3 demo/demo.py \
--cfg demo/cfg/birds-demo.yml \
--gpu ${GPU} \
--caption_path ${CAPTION_PATH}.t7
diff --git a/demo/birds_skip_thought_demo.py b/demo/birds_skip_thought_demo.py
index cddb21f..f25303a 100644
--- a/demo/birds_skip_thought_demo.py
+++ b/demo/birds_skip_thought_demo.py
@@ -1,30 +1,29 @@
from __future__ import division
from __future__ import print_function
-import prettytensor as pt
import tensorflow as tf
import numpy as np
-import scipy.misc
+import imageio
import os
import argparse
from PIL import Image, ImageDraw, ImageFont
-from misc.config import cfg, cfg_from_file
-from misc.utils import mkdir_p
-from misc import skipthoughts
-from stageII.model import CondGAN
+import sys
+sys.path.append('misc')
+sys.path.append('stageII')
+
+import skipthoughts
+from config import cfg, cfg_from_file
+from utils import mkdir_p
+from model import CondGAN
+from skimage.transform import resize
def parse_args():
parser = argparse.ArgumentParser(description='Train a GAN network')
- parser.add_argument('--cfg', dest='cfg_file',
- help='optional config file',
- default=None, type=str)
- parser.add_argument('--gpu', dest='gpu_id',
- help='GPU device id to use [0]',
- default=-1, type=int)
- parser.add_argument('--caption_path', type=str, default=None,
- help='Path to the file with text sentences')
+ parser.add_argument('--cfg', dest='cfg_file', help='optional config file', default=None, type=str)
+ parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]', default=-1, type=int)
+ parser.add_argument('--caption_path', type=str, default=None, help='Path to the file with text sentences')
# if len(sys.argv) == 1:
# parser.print_help()
# sys.exit(1)
@@ -49,21 +48,17 @@ def sample_encoded_context(embeddings, model, bAugmentation=True):
def build_model(sess, embedding_dim, batch_size):
- model = CondGAN(
- lr_imsize=cfg.TEST.LR_IMSIZE,
- hr_lr_ratio=int(cfg.TEST.HR_IMSIZE/cfg.TEST.LR_IMSIZE))
-
- embeddings = tf.placeholder(
- tf.float32, [batch_size, embedding_dim],
- name='conditional_embeddings')
- with pt.defaults_scope(phase=pt.Phase.test):
- with tf.variable_scope("g_net"):
- c = sample_encoded_context(embeddings, model)
- z = tf.random_normal([batch_size, cfg.Z_DIM])
- fake_images = model.get_generator(tf.concat(1, [c, z]))
- with tf.variable_scope("hr_g_net"):
- hr_c = sample_encoded_context(embeddings, model)
- hr_fake_images = model.hr_get_generator(fake_images, hr_c)
+ model = CondGAN(lr_imsize=cfg.TEST.LR_IMSIZE, hr_lr_ratio=int(cfg.TEST.HR_IMSIZE/cfg.TEST.LR_IMSIZE))
+
+ embeddings = tf.placeholder(tf.float32, [batch_size, embedding_dim], name='conditional_embeddings')
+
+ with tf.variable_scope("g_net"):
+ c = sample_encoded_context(embeddings, model)
+ z = tf.random_normal([batch_size, cfg.Z_DIM])
+ fake_images = model.get_generator(tf.concat([c, z], 1,), False)
+ with tf.variable_scope("hr_g_net"):
+ hr_c = sample_encoded_context(embeddings, model)
+ hr_fake_images = model.hr_get_generator(fake_images, hr_c, False)
ckt_path = cfg.TEST.PRETRAINED_MODEL
if ckt_path.find('.ckpt') != -1:
@@ -101,9 +96,7 @@ def drawCaption(img, caption):
return img_txt
-def save_super_images(sample_batchs, hr_sample_batchs,
- captions_batch, batch_size,
- startID, save_dir):
+def save_super_images(sample_batchs, hr_sample_batchs, captions_batch, batch_size, startID, save_dir):
if not os.path.isdir(save_dir):
print('Make a new folder: ', save_dir)
mkdir_p(save_dir)
@@ -119,7 +112,7 @@ def save_super_images(sample_batchs, hr_sample_batchs,
lr_img = sample_batchs[i][j]
hr_img = hr_sample_batchs[i][j]
hr_img = (hr_img + 1.0) * 127.5
- re_sample = scipy.misc.imresize(lr_img, hr_img.shape[:2])
+ re_sample = resize(lr_img, hr_img.shape[:2])
row1.append(re_sample)
row2.append(hr_img)
row1 = np.concatenate(row1, axis=1)
@@ -134,27 +127,23 @@ def save_super_images(sample_batchs, hr_sample_batchs,
lr_img = sample_batchs[i][j]
hr_img = hr_sample_batchs[i][j]
hr_img = (hr_img + 1.0) * 127.5
- re_sample = scipy.misc.imresize(lr_img, hr_img.shape[:2])
+ re_sample = resize(lr_img, hr_img.shape[:2])
row1.append(re_sample)
row2.append(hr_img)
row1 = np.concatenate(row1, axis=1)
row2 = np.concatenate(row2, axis=1)
super_row = np.concatenate([row1, row2], axis=0)
superimage2 = np.zeros_like(superimage)
- superimage2[:super_row.shape[0],
- :super_row.shape[1],
- :super_row.shape[2]] = super_row
+ superimage2[:super_row.shape[0], :super_row.shape[1], :super_row.shape[2]] = super_row
mid_padding = np.zeros((64, superimage.shape[1], 3))
- superimage =\
- np.concatenate([superimage, mid_padding, superimage2], axis=0)
+ superimage = np.concatenate([superimage, mid_padding, superimage2], axis=0)
top_padding = np.zeros((128, superimage.shape[1], 3))
- superimage =\
- np.concatenate([top_padding, superimage], axis=0)
+ superimage = np.concatenate([top_padding, superimage], axis=0)
fullpath = '%s/sentence%d.jpg' % (save_dir, startID + j)
superimage = drawCaption(np.uint8(superimage), captions_batch[j])
- scipy.misc.imsave(fullpath, superimage)
+ imageio.imsave(fullpath, superimage)
if __name__ == "__main__":
@@ -188,8 +177,8 @@ def save_super_images(sample_batchs, hr_sample_batchs,
config = tf.ConfigProto(allow_soft_placement=True)
with tf.Session(config=config) as sess:
with tf.device("/gpu:%d" % cfg.GPU_ID):
- embeddings_holder, fake_images_opt, hr_fake_images_opt =\
- build_model(sess, embeddings.shape[-1], batch_size)
+ embeddings_holder, fake_images_opt, hr_fake_images_opt = build_model(sess, embeddings.shape[-1],
+ batch_size)
count = 0
while count < num_embeddings:
@@ -205,19 +194,14 @@ def save_super_images(sample_batchs, hr_sample_batchs,
# Generate up to 16 images for each sentence with
# randomness from noise z and conditioning augmentation.
for i in range(np.minimum(16, cfg.TEST.NUM_COPY)):
- hr_samples, samples =\
- sess.run([hr_fake_images_opt, fake_images_opt],
- {embeddings_holder: embeddings_batch})
+ hr_samples, samples = sess.run([hr_fake_images_opt, fake_images_opt],
+ {embeddings_holder: embeddings_batch})
samples_batchs.append(samples)
hr_samples_batchs.append(hr_samples)
- save_super_images(samples_batchs,
- hr_samples_batchs,
- captions_batch,
- batch_size,
- count, save_dir)
+ save_super_images(samples_batchs, hr_samples_batchs, captions_batch, batch_size, count, save_dir)
count += batch_size
print('Finish generating samples for %d sentences:' % num_embeddings)
print('Example sentences:')
- for i in xrange(np.minimum(10, num_embeddings)):
+ for i in range(np.minimum(10, num_embeddings)):
print('Sentence %d: %s' % (i, captions_list[i]))
diff --git a/demo/cfg/birds-demo.yml b/demo/cfg/birds-demo.yml
index 8526652..96b5286 100644
--- a/demo/cfg/birds-demo.yml
+++ b/demo/cfg/birds-demo.yml
@@ -5,7 +5,7 @@ GPU_ID: 0
Z_DIM: 100
TEST:
- PRETRAINED_MODEL: './models/birds_model_164000.ckpt'
+ PRETRAINED_MODEL: './models/stageII/model_330000.ckpt'
BATCH_SIZE: 64
NUM_COPY: 8
diff --git a/demo/cfg/birds-eval.yml b/demo/cfg/birds-eval.yml
index 78ba936..1a44393 100644
--- a/demo/cfg/birds-eval.yml
+++ b/demo/cfg/birds-eval.yml
@@ -7,7 +7,7 @@ Z_DIM: 100
TRAIN:
FLAG: False
- PRETRAINED_MODEL: './models/birds_model_164000.ckpt'
+ PRETRAINED_MODEL: './models/stageII/model_330000.ckpt'
BATCH_SIZE: 64
NUM_COPY: 8
diff --git a/demo/cfg/birds-skip-thought-demo.yml b/demo/cfg/birds-skip-thought-demo.yml
index e346428..1c00129 100644
--- a/demo/cfg/birds-skip-thought-demo.yml
+++ b/demo/cfg/birds-skip-thought-demo.yml
@@ -6,7 +6,7 @@ Z_DIM: 100
TEST:
CAPTION_PATH: './Data/birds/example_captions.txt'
- PRETRAINED_MODEL: './models/birds_skip_thought_model_164000.ckpt'
+ PRETRAINED_MODEL: './models/stageII/model_330000.ckpt'
BATCH_SIZE: 64
NUM_COPY: 8
diff --git a/demo/demo.py b/demo/demo.py
index 6f21a72..a4bc9d3 100644
--- a/demo/demo.py
+++ b/demo/demo.py
@@ -1,31 +1,30 @@
from __future__ import division
from __future__ import print_function
-import prettytensor as pt
import tensorflow as tf
import numpy as np
-import scipy.misc
+import imageio
import os
import argparse
import torchfile
from PIL import Image, ImageDraw, ImageFont
import re
-from misc.config import cfg, cfg_from_file
-from misc.utils import mkdir_p
-from stageII.model import CondGAN
+import sys
+sys.path.append('misc')
+sys.path.append('stageII')
+
+from config import cfg, cfg_from_file
+from utils import mkdir_p, caption_convert
+from model import CondGAN
+from skimage.transform import resize
def parse_args():
parser = argparse.ArgumentParser(description='Train a GAN network')
- parser.add_argument('--cfg', dest='cfg_file',
- help='optional config file',
- default=None, type=str)
- parser.add_argument('--gpu', dest='gpu_id',
- help='GPU device id to use [0]',
- default=-1, type=int)
- parser.add_argument('--caption_path', type=str, default=None,
- help='Path to the file with text sentences')
+ parser.add_argument('--cfg', dest='cfg_file', help='optional config file', default=None, type=str)
+ parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]', default=-1, type=int)
+ parser.add_argument('--caption_path', type=str, default=None, help='Path to the file with text sentences')
# if len(sys.argv) == 1:
# parser.print_help()
# sys.exit(1)
@@ -50,26 +49,22 @@ def sample_encoded_context(embeddings, model, bAugmentation=True):
def build_model(sess, embedding_dim, batch_size):
- model = CondGAN(
- lr_imsize=cfg.TEST.LR_IMSIZE,
- hr_lr_ratio=int(cfg.TEST.HR_IMSIZE/cfg.TEST.LR_IMSIZE))
-
- embeddings = tf.placeholder(
- tf.float32, [batch_size, embedding_dim],
- name='conditional_embeddings')
- with pt.defaults_scope(phase=pt.Phase.test):
- with tf.variable_scope("g_net"):
- c = sample_encoded_context(embeddings, model)
- z = tf.random_normal([batch_size, cfg.Z_DIM])
- fake_images = model.get_generator(tf.concat(1, [c, z]))
- with tf.variable_scope("hr_g_net"):
- hr_c = sample_encoded_context(embeddings, model)
- hr_fake_images = model.hr_get_generator(fake_images, hr_c)
+ model = CondGAN(lr_imsize=cfg.TEST.LR_IMSIZE, hr_lr_ratio=int(cfg.TEST.HR_IMSIZE/cfg.TEST.LR_IMSIZE))
+
+ embeddings = tf.placeholder(tf.float32, [batch_size, embedding_dim], name='conditional_embeddings')
+
+ with tf.variable_scope("g_net"):
+ c = sample_encoded_context(embeddings, model)
+ z = tf.random_normal([batch_size, cfg.Z_DIM])
+ fake_images = model.get_generator(tf.concat([c, z], 1), False)
+ with tf.variable_scope("hr_g_net"):
+ hr_c = sample_encoded_context(embeddings, model)
+ hr_fake_images = model.hr_get_generator(fake_images, hr_c, False)
ckt_path = cfg.TEST.PRETRAINED_MODEL
if ckt_path.find('.ckpt') != -1:
print("Reading model parameters from %s" % ckt_path)
- saver = tf.train.Saver(tf.all_variables())
+ saver = tf.train.Saver(tf.global_variables())
saver.restore(sess, ckt_path)
else:
print("Input a valid model path.")
@@ -77,6 +72,7 @@ def build_model(sess, embedding_dim, batch_size):
def drawCaption(img, caption):
+ caption = caption_convert(caption)
img_txt = Image.fromarray(img)
# get a font
fnt = ImageFont.truetype('Pillow/Tests/fonts/FreeMono.ttf', 50)
@@ -102,9 +98,7 @@ def drawCaption(img, caption):
return img_txt
-def save_super_images(sample_batchs, hr_sample_batchs,
- captions_batch, batch_size,
- startID, save_dir):
+def save_super_images(sample_batchs, hr_sample_batchs, captions_batch, batch_size, startID, save_dir):
if not os.path.isdir(save_dir):
print('Make a new folder: ', save_dir)
mkdir_p(save_dir)
@@ -112,7 +106,7 @@ def save_super_images(sample_batchs, hr_sample_batchs,
# Save up to 16 samples for each text embedding/sentence
img_shape = hr_sample_batchs[0][0].shape
for j in range(batch_size):
- if not re.search('[a-zA-Z]+', captions_batch[j]):
+ if not re.search(b'[a-zA-Z]+', captions_batch[j]):
continue
padding = np.zeros(img_shape)
@@ -121,9 +115,10 @@ def save_super_images(sample_batchs, hr_sample_batchs,
# First row with up to 8 samples
for i in range(np.minimum(8, len(sample_batchs))):
lr_img = sample_batchs[i][j]
+ lr_img = (lr_img + 1.0) * 127.5
hr_img = hr_sample_batchs[i][j]
hr_img = (hr_img + 1.0) * 127.5
- re_sample = scipy.misc.imresize(lr_img, hr_img.shape[:2])
+ re_sample = resize(lr_img, hr_img.shape[:2])
row1.append(re_sample)
row2.append(hr_img)
row1 = np.concatenate(row1, axis=1)
@@ -136,29 +131,26 @@ def save_super_images(sample_batchs, hr_sample_batchs,
row2 = [padding]
for i in range(8, len(sample_batchs)):
lr_img = sample_batchs[i][j]
+ lr_img = (lr_img + 1.0) * 127.5
hr_img = hr_sample_batchs[i][j]
hr_img = (hr_img + 1.0) * 127.5
- re_sample = scipy.misc.imresize(lr_img, hr_img.shape[:2])
+ re_sample = resize(lr_img, hr_img.shape[:2])
row1.append(re_sample)
row2.append(hr_img)
row1 = np.concatenate(row1, axis=1)
row2 = np.concatenate(row2, axis=1)
super_row = np.concatenate([row1, row2], axis=0)
superimage2 = np.zeros_like(superimage)
- superimage2[:super_row.shape[0],
- :super_row.shape[1],
- :super_row.shape[2]] = super_row
+ superimage2[:super_row.shape[0], :super_row.shape[1], :super_row.shape[2]] = super_row
mid_padding = np.zeros((64, superimage.shape[1], 3))
- superimage =\
- np.concatenate([superimage, mid_padding, superimage2], axis=0)
+ superimage = np.concatenate([superimage, mid_padding, superimage2], axis=0)
top_padding = np.zeros((128, superimage.shape[1], 3))
- superimage =\
- np.concatenate([top_padding, superimage], axis=0)
+ superimage = np.concatenate([top_padding, superimage], axis=0)
fullpath = '%s/sentence%d.jpg' % (save_dir, startID + j)
superimage = drawCaption(np.uint8(superimage), captions_batch[j])
- scipy.misc.imsave(fullpath, superimage)
+ imageio.imwrite(fullpath, superimage)
if __name__ == "__main__":
@@ -188,8 +180,8 @@ def save_super_images(sample_batchs, hr_sample_batchs,
config = tf.ConfigProto(allow_soft_placement=True)
with tf.Session(config=config) as sess:
with tf.device("/gpu:%d" % cfg.GPU_ID):
- embeddings_holder, fake_images_opt, hr_fake_images_opt =\
- build_model(sess, embeddings.shape[-1], batch_size)
+ embeddings_holder, fake_images_opt, hr_fake_images_opt = build_model(sess, embeddings.shape[-1],
+ batch_size)
count = 0
while count < num_embeddings:
@@ -205,19 +197,14 @@ def save_super_images(sample_batchs, hr_sample_batchs,
# Generate up to 16 images for each sentence with
# randomness from noise z and conditioning augmentation.
for i in range(np.minimum(16, cfg.TEST.NUM_COPY)):
- hr_samples, samples =\
- sess.run([hr_fake_images_opt, fake_images_opt],
- {embeddings_holder: embeddings_batch})
+ hr_samples, samples = sess.run([hr_fake_images_opt, fake_images_opt],
+ {embeddings_holder: embeddings_batch})
samples_batchs.append(samples)
hr_samples_batchs.append(hr_samples)
- save_super_images(samples_batchs,
- hr_samples_batchs,
- captions_batch,
- batch_size,
- count, save_dir)
+ save_super_images(samples_batchs, hr_samples_batchs, captions_batch, batch_size, count, save_dir)
count += batch_size
print('Finish generating samples for %d sentences:' % num_embeddings)
print('Example sentences:')
- for i in xrange(np.minimum(10, num_embeddings)):
- print('Sentence %d: %s' % (i, captions_list[i]))
+ for i in range(np.minimum(10, num_embeddings)):
+ print('Sentence %d: %s' % (i, caption_convert(captions_list[i])))
diff --git a/demo/flowers_demo.sh b/demo/flowers_demo.sh
index 287699b..28d854b 100644
--- a/demo/flowers_demo.sh
+++ b/demo/flowers_demo.sh
@@ -1,3 +1,4 @@
+#!/usr/bin/env bash
#
# Extract text embeddings from the encoder
#
@@ -16,7 +17,7 @@ th demo/get_embedding.lua
#
# Generate image from text embeddings
#
-python demo/demo.py \
+python3 demo/demo.py \
--cfg demo/cfg/flowers-demo.yml \
--gpu ${GPU} \
--caption_path ${CAPTION_PATH}.t7
diff --git a/misc/config.py b/misc/config.py
index 3ff777b..0653da3 100644
--- a/misc/config.py
+++ b/misc/config.py
@@ -1,7 +1,6 @@
from __future__ import division
from __future__ import print_function
-import os.path as osp
import numpy as np
from easydict import EasyDict as edict
@@ -48,6 +47,7 @@
__C.TRAIN.COEFF = edict()
__C.TRAIN.COEFF.KL = 2.0
+# For Stage II training
__C.TRAIN.FINETUNE_LR = False
__C.TRAIN.FT_LR_RETIO = 0.1
@@ -66,9 +66,9 @@ def _merge_a_into_b(a, b):
if type(a) is not edict:
return
- for k, v in a.iteritems():
+ for k, v in a.items():
# a must specify keys that are in b
- if not b.has_key(k):
+ if k not in b:
raise KeyError('{} is not a valid config key'.format(k))
# the types must match, too
diff --git a/misc/custom_ops.py b/misc/custom_ops.py
index 11b48e8..458edd8 100644
--- a/misc/custom_ops.py
+++ b/misc/custom_ops.py
@@ -5,132 +5,102 @@
from __future__ import division
from __future__ import print_function
-import prettytensor as pt
-from tensorflow.python.training import moving_averages
import tensorflow as tf
-from prettytensor.pretty_tensor_class import Phase
import numpy as np
-class conv_batch_norm(pt.VarStoreMethod):
- """Code modification of:
- http://stackoverflow.com/questions/33949786/how-could-i-use-batch-normalization-in-tensorflow
- and
- https://github.com/tensorflow/models/blob/master/inception/inception/slim/ops.py"""
-
- def __call__(self, input_layer, epsilon=1e-5, decay=0.9, name="batch_norm",
- in_dim=None, phase=Phase.train):
- shape = input_layer.shape
- shp = in_dim or shape[-1]
- with tf.variable_scope(name) as scope:
- self.mean = self.variable('mean', [shp], init=tf.constant_initializer(0.), train=False)
- self.variance = self.variable('variance', [shp], init=tf.constant_initializer(1.0), train=False)
-
- self.gamma = self.variable("gamma", [shp], init=tf.random_normal_initializer(1., 0.02))
- self.beta = self.variable("beta", [shp], init=tf.constant_initializer(0.))
-
- if phase == Phase.train:
- mean, variance = tf.nn.moments(input_layer.tensor, [0, 1, 2])
- mean.set_shape((shp,))
- variance.set_shape((shp,))
-
- update_moving_mean = moving_averages.assign_moving_average(self.mean, mean, decay)
- update_moving_variance = moving_averages.assign_moving_average(self.variance, variance, decay)
-
- with tf.control_dependencies([update_moving_mean, update_moving_variance]):
- normalized_x = tf.nn.batch_norm_with_global_normalization(
- input_layer.tensor, mean, variance, self.beta, self.gamma, epsilon,
- scale_after_normalization=True)
- else:
- normalized_x = tf.nn.batch_norm_with_global_normalization(
- input_layer.tensor, self.mean, self.variance,
- self.beta, self.gamma, epsilon,
- scale_after_normalization=True)
- return input_layer.with_tensor(normalized_x, parameters=self.vars)
-
-
-pt.Register(assign_defaults=('phase'))(conv_batch_norm)
-
-
-@pt.Register(assign_defaults=('phase'))
-class fc_batch_norm(conv_batch_norm):
- def __call__(self, input_layer, *args, **kwargs):
- ori_shape = input_layer.shape
- if ori_shape[0] is None:
- ori_shape[0] = -1
- new_shape = [ori_shape[0], 1, 1, ori_shape[1]]
- x = tf.reshape(input_layer.tensor, new_shape)
- normalized_x = super(self.__class__, self).__call__(input_layer.with_tensor(x), *args, **kwargs) # input_layer)
- return normalized_x.reshape(ori_shape)
-
-
-def leaky_rectify(x, leakiness=0.01):
- assert leakiness <= 1
- ret = tf.maximum(x, leakiness * x)
- # import ipdb; ipdb.set_trace()
- return ret
-
-
-@pt.Register
-class custom_conv2d(pt.VarStoreMethod):
- def __call__(self, input_layer, output_dim,
- k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, in_dim=None, padding='SAME',
- name="conv2d"):
- with tf.variable_scope(name):
- w = self.variable('w', [k_h, k_w, in_dim or input_layer.shape[-1], output_dim],
- init=tf.truncated_normal_initializer(stddev=stddev))
- conv = tf.nn.conv2d(input_layer.tensor, w, strides=[1, d_h, d_w, 1], padding=padding)
-
- # biases = self.variable('biases', [output_dim], init=tf.constant_initializer(0.0))
- # import ipdb; ipdb.set_trace()
- # return input_layer.with_tensor(tf.nn.bias_add(conv, biases), parameters=self.vars)
- return input_layer.with_tensor(conv, parameters=self.vars)
-
-
-@pt.Register
-class custom_deconv2d(pt.VarStoreMethod):
- def __call__(self, input_layer, output_shape,
- k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02,
- name="deconv2d"):
- output_shape[0] = input_layer.shape[0]
- ts_output_shape = tf.pack(output_shape)
- with tf.variable_scope(name):
- # filter : [height, width, output_channels, in_channels]
- w = self.variable('w', [k_h, k_w, output_shape[-1], input_layer.shape[-1]],
- init=tf.random_normal_initializer(stddev=stddev))
-
- try:
- deconv = tf.nn.conv2d_transpose(input_layer, w,
- output_shape=ts_output_shape,
- strides=[1, d_h, d_w, 1])
-
- # Support for versions of TensorFlow before 0.7.0
- except AttributeError:
- deconv = tf.nn.deconv2d(input_layer, w, output_shape=ts_output_shape,
- strides=[1, d_h, d_w, 1])
-
- # biases = self.variable('biases', [output_shape[-1]], init=tf.constant_initializer(0.0))
- # deconv = tf.reshape(tf.nn.bias_add(deconv, biases), [-1] + output_shape[1:])
- deconv = tf.reshape(deconv, [-1] + output_shape[1:])
-
- return deconv
-
-
-@pt.Register
-class custom_fully_connected(pt.VarStoreMethod):
- def __call__(self, input_layer, output_size, scope=None, in_dim=None, stddev=0.02, bias_start=0.0):
- shape = input_layer.shape
- input_ = input_layer.tensor
- try:
- if len(shape) == 4:
- input_ = tf.reshape(input_, tf.pack([tf.shape(input_)[0], np.prod(shape[1:])]))
- input_.set_shape([None, np.prod(shape[1:])])
- shape = input_.get_shape().as_list()
-
- with tf.variable_scope(scope or "Linear"):
- matrix = self.variable("Matrix", [in_dim or shape[1], output_size], dt=tf.float32,
- init=tf.random_normal_initializer(stddev=stddev))
- bias = self.variable("bias", [output_size], init=tf.constant_initializer(bias_start))
- return input_layer.with_tensor(tf.matmul(input_, matrix) + bias, parameters=self.vars)
- except Exception:
- import ipdb; ipdb.set_trace()
+def fc(inputs, num_out, name, activation_fn=None, reuse=None):
+ shape = inputs.get_shape()
+ if len(shape) == 4:
+ inputs = tf.reshape(inputs, tf.stack([tf.shape(inputs)[0], np.prod(shape[1:])]))
+ inputs.set_shape([None, np.prod(shape[1:])])
+
+ w_init = tf.random_normal_initializer(stddev=0.02)
+
+ return tf.contrib.layers.fully_connected(inputs, num_out, activation_fn=activation_fn, weights_initializer=w_init,
+ reuse=reuse, scope=name)
+
+
+def concat(inputs, axis):
+ return tf.concat(values=inputs, axis=axis)
+
+
+def conv_batch_normalization(inputs, name, epsilon=1e-5, is_training=True, activation_fn=None, reuse=None):
+ return tf.contrib.layers.batch_norm(inputs, decay=0.9, center=True, scale=True, epsilon=epsilon,
+ activation_fn=activation_fn,
+ param_initializers={'beta': tf.constant_initializer(0.),
+ 'gamma': tf.random_normal_initializer(1., 0.02)},
+ reuse=reuse, is_training=is_training, scope=name)
+
+
+def fc_batch_normalization(inputs, name, epsilon=1e-5, is_training=True, activation_fn=None, reuse=None):
+ ori_shape = inputs.get_shape()
+ if ori_shape[0] is None:
+ ori_shape = -1
+ new_shape = [ori_shape[0], 1, 1, ori_shape[1]]
+ x = tf.reshape(inputs, new_shape)
+ normalized_x = conv_batch_normalization(x, name, epsilon=epsilon, is_training=is_training,
+ activation_fn=activation_fn, reuse=reuse)
+ return tf.reshape(normalized_x, ori_shape)
+
+
+def reshape(inputs, shape, name):
+ return tf.reshape(inputs, shape, name)
+
+
+def Conv2d(inputs, k_h, k_w, c_o, s_h, s_w, name, activation_fn=None, reuse=None, padding='SAME', biased=False):
+ c_i = inputs.get_shape()[-1]
+ w_init = tf.random_normal_initializer(stddev=0.02)
+
+ convolve = lambda i, k: tf.nn.conv2d(i, k, [1, s_h, s_w, 1], padding=padding)
+ with tf.variable_scope(name, reuse=reuse) as scope:
+ kernel = tf.get_variable(name='weights', shape=[k_h, k_w, c_i, c_o], initializer=w_init)
+ output = convolve(inputs, kernel)
+
+ if biased:
+ biases = tf.get_variable(name='biases', shape=[c_o])
+ output = tf.nn.bias_add(output, biases)
+ if activation_fn is not None:
+ output = activation_fn(output, name=scope.name)
+
+ return output
+
+
+def Deconv2d(inputs, output_shape, name, k_h, k_w, s_h=2, s_w=2, reuse=None, activation_fn=None, biased=False):
+ output_shape[0] = inputs.get_shape()[0]
+ ts_output_shape = tf.stack(output_shape)
+ w_init = tf.random_normal_initializer(stddev=0.02)
+
+ deconvolve = lambda i, k: tf.nn.conv2d_transpose(i, k, output_shape=ts_output_shape, strides=[1, s_h, s_w, 1])
+ with tf.variable_scope(name, reuse=reuse) as scope:
+ kernel = tf.get_variable(name='weights', shape=[k_h, k_w, output_shape[-1], inputs.get_shape()[-1]],
+ initializer=w_init)
+ output = deconvolve(inputs, kernel)
+
+ if biased:
+ biases = tf.get_variable(name='biases', shape=[output_shape[-1]])
+ output = tf.nn.bias_add(output, biases)
+ if activation_fn is not None:
+ output = activation_fn(output, name=scope.name)
+
+ deconv = tf.reshape(output, [-1] + output_shape[1:])
+
+ return deconv
+
+
+def add(inputs, name):
+ return tf.add_n(inputs, name=name)
+
+
+def UpSample(inputs, size, method, align_corners, name):
+ return tf.image.resize_images(inputs, size, method, align_corners)
+
+
+def flatten(inputs, name):
+ input_shape = inputs.get_shape()
+ dim = 1
+ for d in input_shape[1:].as_list():
+ dim *= d
+ inputs = tf.reshape(inputs, [-1, dim])
+
+ return inputs
diff --git a/misc/datasets.py b/misc/datasets.py
index 624e6af..cd7bcee 100644
--- a/misc/datasets.py
+++ b/misc/datasets.py
@@ -1,16 +1,13 @@
from __future__ import division
from __future__ import print_function
-
import numpy as np
import pickle
import random
class Dataset(object):
- def __init__(self, images, imsize, embeddings=None,
- filenames=None, workdir=None,
- labels=None, aug_flag=True,
+ def __init__(self, images, imsize, embeddings=None, filenames=None, workdir=None, labels=None, aug_flag=True,
class_id=None, class_range=None):
self._images = images
self._embeddings = embeddings
@@ -59,8 +56,7 @@ def readCaptions(self, filenames, class_id):
if name.find('jpg/') != -1: # flowers dataset
class_name = 'class_%05d/' % class_id
name = name.replace('jpg/', class_name)
- cap_path = '%s/text_c10/%s.txt' %\
- (self.workdir, name)
+ cap_path = '%s/text_c10/%s.txt' % (self.workdir, name)
with open(cap_path, "r") as f:
captions = f.read().split('\n')
captions = [cap for cap in captions if len(cap) > 0]
@@ -68,14 +64,13 @@ def readCaptions(self, filenames, class_id):
def transform(self, images):
if self._aug_flag:
- transformed_images =\
- np.zeros([images.shape[0], self._imsize, self._imsize, 3])
+ transformed_images = np.zeros([images.shape[0], self._imsize, self._imsize, 3])
ori_size = images.shape[1]
for i in range(images.shape[0]):
h1 = np.floor((ori_size - self._imsize) * np.random.random())
w1 = np.floor((ori_size - self._imsize) * np.random.random())
- cropped_image =\
- images[i][w1: w1 + self._imsize, h1: h1 + self._imsize, :]
+ cropped_image = images[i][int(w1): int(w1 + self._imsize), int(h1): int(h1 + self._imsize), :]
+ # cropped_image = images[i][w1: w1 + self._imsize, h1: h1 + self._imsize, :]
if random.random() > 0.5:
transformed_images[i] = np.fliplr(cropped_image)
else:
@@ -93,12 +88,10 @@ def sample_embeddings(self, embeddings, filenames, class_id, sample_num):
sampled_embeddings = []
sampled_captions = []
for i in range(batch_size):
- randix = np.random.choice(embedding_num,
- sample_num, replace=False)
+ randix = np.random.choice(embedding_num, sample_num, replace=False)
if sample_num == 1:
randix = int(randix)
- captions = self.readCaptions(filenames[i],
- class_id[i])
+ captions = self.readCaptions(filenames[i], class_id[i])
sampled_captions.append(captions[randix])
sampled_embeddings.append(embeddings[i, randix, :])
else:
@@ -128,11 +121,8 @@ def next_batch(self, batch_size, window):
current_ids = self._perm[start:end]
fake_ids = np.random.randint(self._num_examples, size=batch_size)
- collision_flag =\
- (self._class_id[current_ids] == self._class_id[fake_ids])
- fake_ids[collision_flag] =\
- (fake_ids[collision_flag] +
- np.random.randint(100, 200)) % self._num_examples
+ collision_flag = (self._class_id[current_ids] == self._class_id[fake_ids])
+ fake_ids[collision_flag] = (fake_ids[collision_flag] + np.random.randint(100, 200)) % self._num_examples
sampled_images = self._images[current_ids]
sampled_wrong_images = self._images[fake_ids, :, :, :]
@@ -148,9 +138,8 @@ def next_batch(self, batch_size, window):
if self._embeddings is not None:
filenames = [self._filenames[i] for i in current_ids]
class_id = [self._class_id[i] for i in current_ids]
- sampled_embeddings, sampled_captions = \
- self.sample_embeddings(self._embeddings[current_ids],
- filenames, class_id, window)
+ sampled_embeddings, sampled_captions = self.sample_embeddings(self._embeddings[current_ids], filenames,
+ class_id, window)
ret_list.append(sampled_embeddings)
ret_list.append(sampled_captions)
else:
@@ -185,8 +174,7 @@ def next_batch_test(self, batch_size, start, max_captions):
sampled_filenames = self._filenames[start:end]
sampled_class_id = self._class_id[start:end]
for i in range(len(sampled_filenames)):
- captions = self.readCaptions(sampled_filenames[i],
- sampled_class_id[i])
+ captions = self.readCaptions(sampled_filenames[i], sampled_class_id[i])
# print(captions)
sampled_captions.append(captions)
@@ -194,8 +182,7 @@ def next_batch_test(self, batch_size, start, max_captions):
batch = sampled_embeddings[:, i, :]
sampled_embeddings_batchs.append(np.squeeze(batch))
- return [sampled_images, sampled_embeddings_batchs,
- self._saveIDs[start:end], sampled_captions]
+ return [sampled_images, sampled_embeddings_batchs, self._saveIDs[start:end], sampled_captions]
class TextDataset(object):
@@ -207,8 +194,7 @@ def __init__(self, workdir, embedding_type, hr_lr_ratio):
elif self.hr_lr_ratio == 4:
self.image_filename = '/304images.pickle'
- self.image_shape = [lr_imsize * self.hr_lr_ratio,
- lr_imsize * self.hr_lr_ratio, 3]
+ self.image_shape = [lr_imsize * self.hr_lr_ratio, lr_imsize * self.hr_lr_ratio, 3]
self.image_dim = self.image_shape[0] * self.image_shape[1] * 3
self.embedding_shape = None
self.train = None
@@ -226,7 +212,7 @@ def get_data(self, pickle_path, aug_flag=True):
print('images: ', images.shape)
with open(pickle_path + self.embedding_filename, 'rb') as f:
- embeddings = pickle.load(f)
+ embeddings = pickle.load(f, encoding="latin-1")
embeddings = np.array(embeddings)
self.embedding_shape = [embeddings.shape[-1]]
print('embeddings: ', embeddings.shape)
@@ -234,8 +220,6 @@ def get_data(self, pickle_path, aug_flag=True):
list_filenames = pickle.load(f)
print('list_filenames: ', len(list_filenames), list_filenames[0])
with open(pickle_path + '/class_info.pickle', 'rb') as f:
- class_id = pickle.load(f)
+ class_id = pickle.load(f, encoding="latin-1")
- return Dataset(images, self.image_shape[0], embeddings,
- list_filenames, self.workdir, None,
- aug_flag, class_id)
+ return Dataset(images, self.image_shape[0], embeddings, list_filenames, self.workdir, None, aug_flag, class_id)
diff --git a/misc/preprocess_birds.py b/misc/preprocess_birds.py
index c93736a..717f703 100644
--- a/misc/preprocess_birds.py
+++ b/misc/preprocess_birds.py
@@ -7,10 +7,12 @@
import numpy as np
import os
import pickle
-from misc.utils import get_image
-import scipy.misc
+from utils import get_image
import pandas as pd
+# > Python3
+from skimage.transform import resize
+
# from glob import glob
# TODO: 1. current label is temporary, need to change according to real label
@@ -43,7 +45,7 @@ def load_bbox(data_dir):
#
filename_bbox = {img_file[:-4]: [] for img_file in filenames}
numImgs = len(filenames)
- for i in xrange(0, numImgs):
+ for i in range(numImgs):
# bbox = [x-left, y-top, width, height]
bbox = df_bounding_boxes.iloc[i][1:].tolist()
@@ -64,7 +66,8 @@ def save_data_list(inpath, outpath, filenames, filename_bbox):
img = get_image(f_name, LOAD_SIZE, is_crop=True, bbox=bbox)
img = img.astype('uint8')
hr_images.append(img)
- lr_img = scipy.misc.imresize(img, [lr_size, lr_size], 'bicubic')
+ lr_img = resize(img, [lr_size, lr_size], order=3)
+ #lr_img = scipy.misc.imresize(img, [lr_size, lr_size], 'bicubic')
lr_images.append(lr_img)
cnt += 1
if cnt % 100 == 0:
diff --git a/misc/preprocess_flowers.py b/misc/preprocess_flowers.py
index e6f9789..b273c4e 100644
--- a/misc/preprocess_flowers.py
+++ b/misc/preprocess_flowers.py
@@ -7,9 +7,12 @@
import numpy as np
import os
import pickle
-from misc.utils import get_image
+from utils import get_image
import scipy.misc
+# > Python3
+from skimage.transform import resize
+
# from glob import glob
# TODO: 1. current label is temporary, need to change according to real label
@@ -39,7 +42,7 @@ def save_data_list(inpath, outpath, filenames):
img = get_image(f_name, LOAD_SIZE, is_crop=False)
img = img.astype('uint8')
hr_images.append(img)
- lr_img = scipy.misc.imresize(img, [lr_size, lr_size], 'bicubic')
+ lr_img = resize (img, [lr_size, lr_size], order=3)
lr_images.append(lr_img)
cnt += 1
if cnt % 100 == 0:
diff --git a/misc/skipthoughts.py b/misc/skipthoughts.py
index d8be946..6a00f45 100644
--- a/misc/skipthoughts.py
+++ b/misc/skipthoughts.py
@@ -2,14 +2,14 @@
Skip-thought vectors
https://github.com/ryankiros/skip-thoughts
'''
-import os
+# import os
+import warnings
import theano
import theano.tensor as tensor
-import cPickle as pkl
+import pickle as pkl
import numpy
-import copy
import nltk
from collections import OrderedDict, defaultdict
@@ -34,10 +34,10 @@ def load_model():
Load the model with saved tables
"""
# Load model options
- print 'Loading model parameters...'
- with open('%s.pkl'%path_to_umodel, 'rb') as f:
+ print('Loading model parameters...')
+ with open('%s.pkl' % path_to_umodel, 'rb') as f:
uoptions = pkl.load(f)
- with open('%s.pkl'%path_to_bmodel, 'rb') as f:
+ with open('%s.pkl' % path_to_bmodel, 'rb') as f:
boptions = pkl.load(f)
# Load parameters
@@ -49,19 +49,19 @@ def load_model():
btparams = init_tparams(bparams)
# Extractor functions
- print 'Compiling encoders...'
+ print('Compiling encoders...')
embedding, x_mask, ctxw2v = build_encoder(utparams, uoptions)
f_w2v = theano.function([embedding, x_mask], ctxw2v, name='f_w2v')
embedding, x_mask, ctxw2v = build_encoder_bi(btparams, boptions)
f_w2v2 = theano.function([embedding, x_mask], ctxw2v, name='f_w2v2')
# Tables
- print 'Loading tables...'
+ print('Loading tables...')
utable, btable = load_tables()
# Store everything we need in a dictionary
- print 'Packing up...'
- model = {}
+ print('Packing up...')
+ model = dict()
model['uoptions'] = uoptions
model['boptions'] = boptions
model['utable'] = utable
@@ -96,7 +96,7 @@ def encode(model, X, use_norm=True, verbose=True, batch_size=128, use_eos=False)
X = preprocess(X)
# word dictionary and init
- d = defaultdict(lambda : 0)
+ d = defaultdict(lambda: 0)
for w in model['utable'].keys():
d[w] = 1
ufeatures = numpy.zeros((len(X), model['uoptions']['dim']), dtype='float32')
@@ -105,14 +105,14 @@ def encode(model, X, use_norm=True, verbose=True, batch_size=128, use_eos=False)
# length dictionary
ds = defaultdict(list)
captions = [s.split() for s in X]
- for i,s in enumerate(captions):
+ for i, s in enumerate(captions):
ds[len(s)].append(i)
# Get features. This encodes by length, in order to avoid wasting computation
for k in ds.keys():
if verbose:
- print k
- numbatches = len(ds[k]) / batch_size + 1
+ print(k)
+ numbatches = len(ds[k]) // batch_size + 1
for minibatch in range(numbatches):
caps = ds[k][minibatch::numbatches]
@@ -126,20 +126,20 @@ def encode(model, X, use_norm=True, verbose=True, batch_size=128, use_eos=False)
caption = captions[c]
for j in range(len(caption)):
if d[caption[j]] > 0:
- uembedding[j,ind] = model['utable'][caption[j]]
- bembedding[j,ind] = model['btable'][caption[j]]
+ uembedding[j, ind] = model['utable'][caption[j]]
+ bembedding[j, ind] = model['btable'][caption[j]]
else:
- uembedding[j,ind] = model['utable']['UNK']
- bembedding[j,ind] = model['btable']['UNK']
+ uembedding[j, ind] = model['utable']['UNK']
+ bembedding[j, ind] = model['btable']['UNK']
if use_eos:
- uembedding[-1,ind] = model['utable']['']
- bembedding[-1,ind] = model['btable']['']
+ uembedding[-1, ind] = model['utable']['']
+ bembedding[-1, ind] = model['btable']['']
if use_eos:
- uff = model['f_w2v'](uembedding, numpy.ones((len(caption)+1,len(caps)), dtype='float32'))
- bff = model['f_w2v2'](bembedding, numpy.ones((len(caption)+1,len(caps)), dtype='float32'))
+ uff = model['f_w2v'](uembedding, numpy.ones((len(caption)+1, len(caps)), dtype='float32'))
+ bff = model['f_w2v2'](bembedding, numpy.ones((len(caption)+1, len(caps)), dtype='float32'))
else:
- uff = model['f_w2v'](uembedding, numpy.ones((len(caption),len(caps)), dtype='float32'))
- bff = model['f_w2v2'](bembedding, numpy.ones((len(caption),len(caps)), dtype='float32'))
+ uff = model['f_w2v'](uembedding, numpy.ones((len(caption), len(caps)), dtype='float32'))
+ bff = model['f_w2v2'](bembedding, numpy.ones((len(caption), len(caps)), dtype='float32'))
if use_norm:
for j in range(len(uff)):
uff[j] /= norm(uff[j])
@@ -180,10 +180,10 @@ def nn(model, text, vectors, query, k=5):
scores = numpy.dot(qf, vectors.T).flatten()
sorted_args = numpy.argsort(scores)[::-1]
sentences = [text[a] for a in sorted_args[:k]]
- print 'QUERY: ' + query
- print 'NEAREST: '
+ print('QUERY: ' + query)
+ print('NEAREST: ')
for i, s in enumerate(sentences):
- print s, sorted_args[i]
+ print(s, sorted_args[i])
def word_features(table):
@@ -207,17 +207,17 @@ def nn_words(table, wordvecs, query, k=10):
scores = numpy.dot(qf, wordvecs.T).flatten()
sorted_args = numpy.argsort(scores)[::-1]
words = [keys[a] for a in sorted_args[:k]]
- print 'QUERY: ' + query
- print 'NEAREST: '
+ print('QUERY: ' + query)
+ print('NEAREST: ')
for i, w in enumerate(words):
- print w
+ print(w)
def _p(pp, name):
"""
make prefix-appended name
"""
- return '%s_%s'%(pp, name)
+ return '%s_%s' % (pp, name)
def init_tparams(params):
@@ -225,7 +225,7 @@ def init_tparams(params):
initialize Theano shared variables according to the initial parameters
"""
tparams = OrderedDict()
- for kk, pp in params.iteritems():
+ for kk, pp in params.items():
tparams[kk] = theano.shared(params[kk], name=kk)
return tparams
@@ -235,9 +235,9 @@ def load_params(path, params):
load parameters
"""
pp = numpy.load(path)
- for kk, vv in params.iteritems():
+ for kk, vv in params.items():
if kk not in pp:
- warnings.warn('%s is not in the archive'%kk)
+ warnings.warn('%s is not in the archive' % kk)
continue
params[kk] = pp[kk]
return params
@@ -246,6 +246,7 @@ def load_params(path, params):
# layers: 'name': ('parameter initializer', 'feedforward')
layers = {'gru': ('param_init_gru', 'gru_layer')}
+
def get_layer(name):
fns = layers[name]
return (eval(fns[0]), eval(fns[1]))
@@ -261,8 +262,8 @@ def init_params(options):
params['Wemb'] = norm_weight(options['n_words_src'], options['dim_word'])
# encoder: GRU
- params = get_layer(options['encoder'])[0](options, params, prefix='encoder',
- nin=options['dim_word'], dim=options['dim'])
+ params = get_layer(options['encoder'])[0](options, params, prefix='encoder', nin=options['dim_word'],
+ dim=options['dim'])
return params
@@ -276,10 +277,10 @@ def init_params_bi(options):
params['Wemb'] = norm_weight(options['n_words_src'], options['dim_word'])
# encoder: GRU
- params = get_layer(options['encoder'])[0](options, params, prefix='encoder',
- nin=options['dim_word'], dim=options['dim'])
- params = get_layer(options['encoder'])[0](options, params, prefix='encoder_r',
- nin=options['dim_word'], dim=options['dim'])
+ params = get_layer(options['encoder'])[0](options, params, prefix='encoder', nin=options['dim_word'],
+ dim=options['dim'])
+ params = get_layer(options['encoder'])[0](options, params, prefix='encoder_r', nin=options['dim_word'],
+ dim=options['dim'])
return params
@@ -292,9 +293,7 @@ def build_encoder(tparams, options):
x_mask = tensor.matrix('x_mask', dtype='float32')
# encoder
- proj = get_layer(options['encoder'])[1](tparams, embedding, options,
- prefix='encoder',
- mask=x_mask)
+ proj = get_layer(options['encoder'])[1](tparams, embedding, options, prefix='encoder', mask=x_mask)
ctx = proj[0][-1]
return embedding, x_mask, ctx
@@ -311,12 +310,8 @@ def build_encoder_bi(tparams, options):
xr_mask = x_mask[::-1]
# encoder
- proj = get_layer(options['encoder'])[1](tparams, embedding, options,
- prefix='encoder',
- mask=x_mask)
- projr = get_layer(options['encoder'])[1](tparams, embeddingr, options,
- prefix='encoder_r',
- mask=xr_mask)
+ proj = get_layer(options['encoder'])[1](tparams, embedding, options, prefix='encoder', mask=x_mask)
+ projr = get_layer(options['encoder'])[1](tparams, embeddingr, options, prefix='encoder_r', mask=xr_mask)
ctx = tensor.concatenate([proj[0][-1], projr[0][-1]], axis=1)
@@ -330,10 +325,10 @@ def ortho_weight(ndim):
return u.astype('float32')
-def norm_weight(nin,nout=None, scale=0.1, ortho=True):
- if nout == None:
+def norm_weight(nin, nout=None, scale=0.1, ortho=True):
+ if nout is None:
nout = nin
- if nout == nin and ortho:
+ if nout is nin and ortho:
W = ortho_weight(nin)
else:
W = numpy.random.uniform(low=-scale, high=scale, size=(nin, nout))
@@ -344,23 +339,21 @@ def param_init_gru(options, params, prefix='gru', nin=None, dim=None):
"""
parameter init for GRU
"""
- if nin == None:
+ if nin is None:
nin = options['dim_proj']
- if dim == None:
+ if dim is None:
dim = options['dim_proj']
- W = numpy.concatenate([norm_weight(nin,dim),
- norm_weight(nin,dim)], axis=1)
- params[_p(prefix,'W')] = W
- params[_p(prefix,'b')] = numpy.zeros((2 * dim,)).astype('float32')
- U = numpy.concatenate([ortho_weight(dim),
- ortho_weight(dim)], axis=1)
- params[_p(prefix,'U')] = U
+ W = numpy.concatenate([norm_weight(nin, dim), norm_weight(nin, dim)], axis=1)
+ params[_p(prefix, 'W')] = W
+ params[_p(prefix, 'b')] = numpy.zeros((2 * dim,)).astype('float32')
+ U = numpy.concatenate([ortho_weight(dim), ortho_weight(dim)], axis=1)
+ params[_p(prefix, 'U')] = U
Wx = norm_weight(nin, dim)
- params[_p(prefix,'Wx')] = Wx
+ params[_p(prefix, 'Wx')] = Wx
Ux = ortho_weight(dim)
- params[_p(prefix,'Ux')] = Ux
- params[_p(prefix,'bx')] = numpy.zeros((dim,)).astype('float32')
+ params[_p(prefix, 'Ux')] = Ux
+ params[_p(prefix, 'bx')] = numpy.zeros((dim,)).astype('float32')
return params
@@ -375,9 +368,9 @@ def gru_layer(tparams, state_below, options, prefix='gru', mask=None, **kwargs):
else:
n_samples = 1
- dim = tparams[_p(prefix,'Ux')].shape[1]
+ dim = tparams[_p(prefix, 'Ux')].shape[1]
- if mask == None:
+ if mask is None:
mask = tensor.alloc(1., state_below.shape[0], 1)
def _slice(_x, n, dim):
@@ -404,21 +397,18 @@ def _step_slice(m_, x_, xx_, h_, U, Ux):
h = tensor.tanh(preactx)
h = u * h_ + (1. - u) * h
- h = m_[:,None] * h + (1. - m_)[:,None] * h_
+ h = m_[:, None] * h + (1. - m_)[:, None] * h_
return h
seqs = [mask, state_below_, state_belowx]
_step = _step_slice
- rval, updates = theano.scan(_step,
- sequences=seqs,
- outputs_info = [tensor.alloc(0., n_samples, dim)],
- non_sequences = [tparams[_p(prefix, 'U')],
- tparams[_p(prefix, 'Ux')]],
+ rval, updates = theano.scan(_step, sequences=seqs, outputs_info=[tensor.alloc(0., n_samples, dim)],
+ non_sequences=[tparams[_p(prefix, 'U')], tparams[_p(prefix, 'Ux')]],
name=_p(prefix, '_layers'),
n_steps=nsteps,
profile=profile,
strict=True)
- rval = [rval]
+ rval=[rval]
return rval
diff --git a/misc/tf_upgrade.py b/misc/tf_upgrade.py
new file mode 100644
index 0000000..e0a8dcd
--- /dev/null
+++ b/misc/tf_upgrade.py
@@ -0,0 +1,255 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Upgrader for Python scripts from pre-1.0 TensorFlow to 1.0 TensorFlow."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+
+from tensorflow.tools.compatibility import ast_edits
+
+
+class TFAPIChangeSpec(ast_edits.APIChangeSpec):
+ """List of maps that describe what changed in the API."""
+
+ def __init__(self):
+ # Maps from a function name to a dictionary that describes how to
+ # map from an old argument keyword to the new argument keyword.
+ self.function_keyword_renames = {
+ "tf.batch_matmul": {
+ "adj_x": "adjoint_a",
+ "adj_y": "adjoint_b",
+ },
+ "tf.count_nonzero": {
+ "reduction_indices": "axis"
+ },
+ "tf.reduce_all": {
+ "reduction_indices": "axis"
+ },
+ "tf.reduce_any": {
+ "reduction_indices": "axis"
+ },
+ "tf.reduce_max": {
+ "reduction_indices": "axis"
+ },
+ "tf.reduce_mean": {
+ "reduction_indices": "axis"
+ },
+ "tf.reduce_min": {
+ "reduction_indices": "axis"
+ },
+ "tf.reduce_prod": {
+ "reduction_indices": "axis"
+ },
+ "tf.reduce_sum": {
+ "reduction_indices": "axis"
+ },
+ "tf.reduce_logsumexp": {
+ "reduction_indices": "axis"
+ },
+ "tf.expand_dims": {
+ "dim": "axis"
+ },
+ "tf.argmax": {
+ "dimension": "axis"
+ },
+ "tf.argmin": {
+ "dimension": "axis"
+ },
+ "tf.reduce_join": {
+ "reduction_indices": "axis"
+ },
+ "tf.sparse_concat": {
+ "concat_dim": "axis"
+ },
+ "tf.sparse_split": {
+ "split_dim": "axis"
+ },
+ "tf.sparse_reduce_sum": {
+ "reduction_axes": "axis"
+ },
+ "tf.reverse_sequence": {
+ "seq_dim": "seq_axis",
+ "batch_dim": "batch_axis"
+ },
+ "tf.sparse_reduce_sum_sparse": {
+ "reduction_axes": "axis"
+ },
+ "tf.squeeze": {
+ "squeeze_dims": "axis"
+ },
+ "tf.split": {
+ "split_dim": "axis",
+ "num_split": "num_or_size_splits"
+ },
+ "tf.concat": {
+ "concat_dim": "axis"
+ },
+ }
+
+ # Mapping from function to the new name of the function
+ self.symbol_renames = {
+ "tf.inv": "tf.reciprocal",
+ "tf.contrib.deprecated.scalar_summary": "tf.summary.scalar",
+ "tf.contrib.deprecated.histogram_summary": "tf.summary.histogram",
+ "tf.listdiff": "tf.setdiff1d",
+ "tf.list_diff": "tf.setdiff1d",
+ "tf.mul": "tf.multiply",
+ "tf.neg": "tf.negative",
+ "tf.sub": "tf.subtract",
+ "tf.train.SummaryWriter": "tf.summary.FileWriter",
+ "tf.scalar_summary": "tf.summary.scalar",
+ "tf.histogram_summary": "tf.summary.histogram",
+ "tf.audio_summary": "tf.summary.audio",
+ "tf.image_summary": "tf.summary.image",
+ "tf.merge_summary": "tf.summary.merge",
+ "tf.merge_all_summaries": "tf.summary.merge_all",
+ "tf.image.per_image_whitening": "tf.image.per_image_standardization",
+ "tf.all_variables": "tf.global_variables",
+ "tf.VARIABLES": "tf.GLOBAL_VARIABLES",
+ "tf.initialize_all_variables": "tf.global_variables_initializer",
+ "tf.initialize_variables": "tf.variables_initializer",
+ "tf.initialize_local_variables": "tf.local_variables_initializer",
+ "tf.batch_matrix_diag": "tf.matrix_diag",
+ "tf.batch_band_part": "tf.band_part",
+ "tf.batch_set_diag": "tf.set_diag",
+ "tf.batch_matrix_transpose": "tf.matrix_transpose",
+ "tf.batch_matrix_determinant": "tf.matrix_determinant",
+ "tf.batch_matrix_inverse": "tf.matrix_inverse",
+ "tf.batch_cholesky": "tf.cholesky",
+ "tf.batch_cholesky_solve": "tf.cholesky_solve",
+ "tf.batch_matrix_solve": "tf.matrix_solve",
+ "tf.batch_matrix_triangular_solve": "tf.matrix_triangular_solve",
+ "tf.batch_matrix_solve_ls": "tf.matrix_solve_ls",
+ "tf.batch_self_adjoint_eig": "tf.self_adjoint_eig",
+ "tf.batch_self_adjoint_eigvals": "tf.self_adjoint_eigvals",
+ "tf.batch_svd": "tf.svd",
+ "tf.batch_fft": "tf.fft",
+ "tf.batch_ifft": "tf.ifft",
+ "tf.batch_fft2d": "tf.fft2d",
+ "tf.batch_ifft2d": "tf.ifft2d",
+ "tf.batch_fft3d": "tf.fft3d",
+ "tf.batch_ifft3d": "tf.ifft3d",
+ "tf.select": "tf.where",
+ "tf.complex_abs": "tf.abs",
+ "tf.batch_matmul": "tf.matmul",
+ "tf.pack": "tf.stack",
+ "tf.unpack": "tf.unstack",
+ "tf.op_scope": "tf.name_scope",
+ }
+
+ self.change_to_function = {
+ "tf.ones_initializer",
+ "tf.zeros_initializer",
+ }
+
+ # Functions that were reordered should be changed to the new keyword args
+ # for safety, if positional arguments are used. If you have reversed the
+ # positional arguments yourself, this could do the wrong thing.
+ self.function_reorders = {
+ "tf.split": ["axis", "num_or_size_splits", "value", "name"],
+ "tf.sparse_split": ["axis", "num_or_size_splits", "value", "name"],
+ "tf.concat": ["concat_dim", "values", "name"],
+ "tf.svd": ["tensor", "compute_uv", "full_matrices", "name"],
+ "tf.nn.softmax_cross_entropy_with_logits": [
+ "logits", "labels", "dim", "name"
+ ],
+ "tf.nn.sparse_softmax_cross_entropy_with_logits": [
+ "logits", "labels", "name"
+ ],
+ "tf.nn.sigmoid_cross_entropy_with_logits": ["logits", "labels", "name"],
+ "tf.op_scope": ["values", "name", "default_name"],
+ }
+
+ # Warnings that should be printed if corresponding functions are used.
+ self.function_warnings = {
+ "tf.reverse": (
+ ast_edits.ERROR,
+ "tf.reverse has had its argument semantics changed "
+ "significantly. The converter cannot detect this reliably, so "
+ "you need to inspect this usage manually.\n"),
+ }
+
+ self.module_deprecations = {}
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ description="""Convert a TensorFlow Python file to 1.0
+
+Simple usage:
+ tf_convert.py --infile foo.py --outfile bar.py
+ tf_convert.py --intree ~/code/old --outtree ~/code/new
+""")
+
+ parser.add_argument(
+ "--infile",
+ dest="input_file",
+ help="If converting a single file, the name of the file "
+ "to convert")
+ parser.add_argument(
+ "--outfile",
+ dest="output_file",
+ help="If converting a single file, the output filename.")
+ parser.add_argument(
+ "--intree",
+ dest="input_tree",
+ help="If converting a whole tree of files, the directory "
+ "to read from (relative or absolute).")
+ parser.add_argument(
+ "--outtree",
+ dest="output_tree",
+ help="If converting a whole tree of files, the output "
+ "directory (relative or absolute).")
+ parser.add_argument(
+ "--copyotherfiles",
+ dest="copy_other_files",
+ help=("If converting a whole tree of files, whether to "
+ "copy the other files."),
+ type=bool,
+ default=False)
+ parser.add_argument(
+ "--reportfile",
+ dest="report_filename",
+ help=("The name of the file where the report log is "
+ "stored."
+ "(default: %(default)s)"),
+ default="report.txt")
+ args = parser.parse_args()
+
+ upgrade = ast_edits.ASTCodeUpgrader(TFAPIChangeSpec())
+ report_text = None
+ report_filename = args.report_filename
+ files_processed = 0
+ if args.input_file:
+ files_processed, report_text, errors = upgrade.process_file(args.input_file, args.output_file)
+ files_processed = 1
+ elif args.input_tree:
+ files_processed, report_text, errors = upgrade.process_tree(args.input_tree, args.output_tree,
+ args.copy_other_files)
+ else:
+ parser.print_help()
+ if report_text:
+ open(report_filename, "w").write(report_text)
+ print("TensorFlow 1.0 Upgrade Script")
+ print("-----------------------------")
+ print("Converted %d files\n" % files_processed)
+ print("Detected %d errors that require attention" % len(errors))
+ print("-" * 80)
+ print("\n".join(errors))
+ print("\nMake sure to read the detailed log %r\n" % report_filename)
diff --git a/misc/utils.py b/misc/utils.py
index 961954d..540dc72 100644
--- a/misc/utils.py
+++ b/misc/utils.py
@@ -6,9 +6,11 @@
from __future__ import print_function
import numpy as np
-import scipy.misc
import os
import errno
+import imageio
+
+from skimage.transform import resize
def get_image(image_path, image_size, is_crop=False, bbox=None):
@@ -43,14 +45,12 @@ def transform(image, image_size, is_crop, bbox):
image = colorize(image)
if is_crop:
image = custom_crop(image, bbox)
- #
- transformed_image =\
- scipy.misc.imresize(image, [image_size, image_size], 'bicubic')
- return np.array(transformed_image)
+ transformed_image = resize(image, [image_size, image_size], order=3)
+ return transformed_image
def imread(path):
- img = scipy.misc.imread(path)
+ img = imageio.imread(path)
if len(img.shape) == 0:
raise ValueError(path + " got loaded as a dimensionless array!")
return img.astype(np.float)
@@ -65,6 +65,16 @@ def colorize(img):
return img
+def convert_to_uint8(img):
+ img = (img + 1.) * (255 / 2.)
+ img = img.astype(np.uint8)
+ return img
+
+
+def caption_convert(caption):
+ return caption.decode("utf-8")
+
+
def mkdir_p(path):
try:
os.makedirs(path)
diff --git a/models/README.md b/models/README.md
index a38141b..71ad6cb 100644
--- a/models/README.md
+++ b/models/README.md
@@ -1,9 +1,9 @@
**Pretrained StackGAN Models**
-- [StackGAN for birds]() trained from char-CNN-RNN text embeddings. Download and save it to `models/`.
-- [StackGAN for flowers](https://drive.google.com/open?id=0B3y_msrWZaXLX01FMC1JQW9vaFk) trained from char-CNN-RNN text embeddings. Download and save it to `models/`.
-- [StackGAN for birds](https://drive.google.com/open?id=0B3y_msrWZaXLZVNRNFg4d055Q1E) trained from skip-thought text embeddings. Download and save it to `models/` (Just use the same setting as the char-CNN-RNN, we assume better results can be achieved by playing with the hyper-parameters).
+- [StackGAN for birds](https://drive.google.com/open?id=1O1JHIoYO3h_qB5o27Td8KklvuLgTgpdV) trained from char-CNN-RNN text embeddings. Download and save it to `models/`.
+- [StackGAN for flowers]() trained from char-CNN-RNN text embeddings. Download and save it to `models/`.
+- [StackGAN for birds](https://drive.google.com/open?id=1O1JHIoYO3h_qB5o27Td8KklvuLgTgpdV) trained from skip-thought text embeddings. Download and save it to `models/` (Just use the same setting as the char-CNN-RNN, we assume better results can be achieved by playing with the hyper-parameters).
**char-CNN-RNN text encoders**
- [Download](https://drive.google.com/file/d/0B0ywwgffWnLLZUt0UmQ1LU1oWlU/view) the char-CNN-RNN text encoder for flowers to `models/text_encoder/`.
-- [Download](https://drive.google.com/file/d/0B0ywwgffWnLLU0F3UHA3NzFTNEE/view) the char-CNN-RNN text encoder for birds to `models/text_encoder/`.
+- [Download](https://drive.google.com/open?id=1a11TUAQKrHxRWnzWBTLpK9FkZdZqhKlT) the char-CNN-RNN text encoder for birds to `models/text_encoder/`.
diff --git a/stageI/cfg/birds.yml b/stageI/cfg/birds.yml
index af02386..7537341 100644
--- a/stageI/cfg/birds.yml
+++ b/stageI/cfg/birds.yml
@@ -6,8 +6,8 @@ GPU_ID: 0
Z_DIM: 100
TRAIN:
- FLAG: True
- PRETRAINED_MODEL: ''
+ FLAG: False # True
+ PRETRAINED_MODEL: './ckt_logs/birds/stageI_2019_07_10_09_33_08/model_82000.ckpt' # ''
BATCH_SIZE: 64
NUM_COPY: 4
MAX_EPOCH: 600
diff --git a/stageI/model.py b/stageI/model.py
index 15d8d9c..20e2652 100644
--- a/stageI/model.py
+++ b/stageI/model.py
@@ -1,11 +1,13 @@
from __future__ import division
from __future__ import print_function
-import prettytensor as pt
import tensorflow as tf
-import misc.custom_ops
-from misc.custom_ops import leaky_rectify
-from misc.config import cfg
+import sys
+
+sys.path.append('misc')
+
+from custom_ops import fc, conv_batch_normalization, fc_batch_normalization, reshape, Conv2d, Deconv2d, UpSample, add
+from config import cfg
class CondGAN(object):
@@ -17,208 +19,171 @@ def __init__(self, image_shape):
self.df_dim = cfg.GAN.DF_DIM
self.ef_dim = cfg.GAN.EMBEDDING_DIM
- self.image_shape = image_shape
self.s = image_shape[0]
- self.s2, self.s4, self.s8, self.s16 =\
- int(self.s / 2), int(self.s / 4), int(self.s / 8), int(self.s / 16)
-
- # Since D is only used during training, we build a template
- # for safe reuse the variables during computing loss for fake/real/wrong images
- # We do not do this for G,
- # because batch_norm needs different options for training and testing
- if cfg.GAN.NETWORK_TYPE == "default":
- with tf.variable_scope("d_net"):
- self.d_encode_img_template = self.d_encode_image()
- self.d_context_template = self.context_embedding()
- self.discriminator_template = self.discriminator()
- elif cfg.GAN.NETWORK_TYPE == "simple":
- with tf.variable_scope("d_net"):
- self.d_encode_img_template = self.d_encode_image_simple()
- self.d_context_template = self.context_embedding()
- self.discriminator_template = self.discriminator()
- else:
- raise NotImplementedError
+ self.s2, self.s4, self.s8, self.s16 = int(self.s / 2), int(self.s / 4), int(self.s / 8), int(self.s / 16)
# g-net
def generate_condition(self, c_var):
- conditions =\
- (pt.wrap(c_var).
- flatten().
- custom_fully_connected(self.ef_dim * 2).
- apply(leaky_rectify, leakiness=0.2))
+ conditions = fc(c_var, self.ef_dim * 2, 'gen_cond/fc', activation_fn=tf.nn.leaky_relu)
mean = conditions[:, :self.ef_dim]
log_sigma = conditions[:, self.ef_dim:]
return [mean, log_sigma]
- def generator(self, z_var):
- node1_0 =\
- (pt.wrap(z_var).
- flatten().
- custom_fully_connected(self.s16 * self.s16 * self.gf_dim * 8).
- fc_batch_norm().
- reshape([-1, self.s16, self.s16, self.gf_dim * 8]))
- node1_1 = \
- (node1_0.
- custom_conv2d(self.gf_dim * 2, k_h=1, k_w=1, d_h=1, d_w=1).
- conv_batch_norm().
- apply(tf.nn.relu).
- custom_conv2d(self.gf_dim * 2, k_h=3, k_w=3, d_h=1, d_w=1).
- conv_batch_norm().
- apply(tf.nn.relu).
- custom_conv2d(self.gf_dim * 8, k_h=3, k_w=3, d_h=1, d_w=1).
- conv_batch_norm())
- node1 = \
- (node1_0.
- apply(tf.add, node1_1).
- apply(tf.nn.relu))
-
- node2_0 = \
- (node1.
- # custom_deconv2d([0, self.s8, self.s8, self.gf_dim * 4], k_h=4, k_w=4).
- apply(tf.image.resize_nearest_neighbor, [self.s8, self.s8]).
- custom_conv2d(self.gf_dim * 4, k_h=3, k_w=3, d_h=1, d_w=1).
- conv_batch_norm())
- node2_1 = \
- (node2_0.
- custom_conv2d(self.gf_dim * 1, k_h=1, k_w=1, d_h=1, d_w=1).
- conv_batch_norm().
- apply(tf.nn.relu).
- custom_conv2d(self.gf_dim * 1, k_h=3, k_w=3, d_h=1, d_w=1).
- conv_batch_norm().
- apply(tf.nn.relu).
- custom_conv2d(self.gf_dim * 4, k_h=3, k_w=3, d_h=1, d_w=1).
- conv_batch_norm())
- node2 = \
- (node2_0.
- apply(tf.add, node2_1).
- apply(tf.nn.relu))
-
- output_tensor = \
- (node2.
- # custom_deconv2d([0, self.s4, self.s4, self.gf_dim * 2], k_h=4, k_w=4).
- apply(tf.image.resize_nearest_neighbor, [self.s4, self.s4]).
- custom_conv2d(self.gf_dim * 2, k_h=3, k_w=3, d_h=1, d_w=1).
- conv_batch_norm().
- apply(tf.nn.relu).
- # custom_deconv2d([0, self.s2, self.s2, self.gf_dim], k_h=4, k_w=4).
- apply(tf.image.resize_nearest_neighbor, [self.s2, self.s2]).
- custom_conv2d(self.gf_dim, k_h=3, k_w=3, d_h=1, d_w=1).
- conv_batch_norm().
- apply(tf.nn.relu).
- # custom_deconv2d([0] + list(self.image_shape), k_h=4, k_w=4).
- apply(tf.image.resize_nearest_neighbor, [self.s, self.s]).
- custom_conv2d(3, k_h=3, k_w=3, d_h=1, d_w=1).
- apply(tf.nn.tanh))
+ def generator(self, z_var, training=True):
+ node1_0 = fc(z_var, self.s16 * self.s16 * self.gf_dim * 8, 'g_n1.0/fc')
+ node1_0 = fc_batch_normalization(node1_0, 'g_n1.0/batch_norm')
+ node1_0 = reshape(node1_0, [-1, self.s16, self.s16, self.gf_dim * 8], name='g_n1.0/reshape')
+
+ node1_1 = Conv2d(node1_0, 1, 1, self.gf_dim * 2, 1, 1, name='g_n1.1/conv2d')
+ node1_1 = conv_batch_normalization(node1_1, 'g_n1.1/batch_norm_1', activation_fn=tf.nn.relu,
+ is_training=training)
+ node1_1 = Conv2d(node1_1, 3, 3, self.gf_dim * 2, 1, 1, name='g_n1.1/conv2d2')
+ node1_1 = conv_batch_normalization(node1_1, 'g_n1.1/batch_norm_2', activation_fn=tf.nn.relu,
+ is_training=training)
+ node1_1 = Conv2d(node1_1, 3, 3, self.gf_dim * 8, 1, 1, name='g_n1.1/conv2d3')
+ node1_1 = conv_batch_normalization(node1_1, 'g_n1.1/batch_norm_3', activation_fn=tf.nn.relu,
+ is_training=training)
+
+ node1 = add([node1_0, node1_1], name='g_n1_res/add')
+ node1_output = tf.nn.relu(node1)
+
+ node2_0 = UpSample(node1_output, size=[self.s8, self.s8], method=1, align_corners=False, name='g_n2.0/upsample')
+ node2_0 = Conv2d(node2_0, 3, 3, self.gf_dim * 4, 1, 1, name='g_n2.0/conv2d')
+ node2_0 = conv_batch_normalization(node2_0, 'g_n2.0/batch_norm', is_training=training)
+
+ node2_1 = Conv2d(node2_0, 1, 1, self.gf_dim * 1, 1, 1, name='g_n2.1/conv2d')
+ node2_1 = conv_batch_normalization(node2_1, 'g_n2.1/batch_norm', activation_fn=tf.nn.relu, is_training=training)
+ node2_1 = Conv2d(node2_1, 3, 3, self.gf_dim * 1, 1, 1, name='g_n2.1/conv2d2')
+ node2_1 = conv_batch_normalization(node2_1, 'g_n2.1/batch_norm2', activation_fn=tf.nn.relu,
+ is_training=training)
+ node2_1 = Conv2d(node2_1, 3, 3, self.gf_dim * 4, 1, 1, name='g_n2.1/conv2d3')
+ node2_1 = conv_batch_normalization(node2_1, 'g_n2.1/batch_norm3', is_training=training)
+
+ node2 = add([node2_0, node2_1], name='g_n2_res/add')
+ node2_output = tf.nn.relu(node2)
+
+ output_tensor = UpSample(node2_output, size=[self.s4, self.s4], method=1, align_corners=False,
+ name='g_OT/upsample')
+ output_tensor = Conv2d(output_tensor, 3, 3, self.gf_dim * 2, 1, 1, name='g_OT/conv2d')
+ output_tensor = conv_batch_normalization(output_tensor, 'g_OT/batch_norm', activation_fn=tf.nn.relu,
+ is_training=training)
+ output_tensor = UpSample(output_tensor, size=[self.s2, self.s2], method=1, align_corners=False,
+ name='g_OT/upsample2')
+ output_tensor = Conv2d(output_tensor, 3, 3, self.gf_dim, 1, 1, name='g_OT/conv2d2')
+ output_tensor = conv_batch_normalization(output_tensor, 'g_OT/batch_norm2', activation_fn=tf.nn.relu,
+ is_training=training)
+ output_tensor = UpSample(output_tensor, size=[self.s, self.s], method=1, align_corners=False,
+ name='g_OT/upsample3')
+ output_tensor = Conv2d(output_tensor, 3, 3, 3, 1, 1, activation_fn=tf.nn.tanh, name='g_OT/conv2d3')
return output_tensor
- def generator_simple(self, z_var):
- output_tensor =\
- (pt.wrap(z_var).
- flatten().
- custom_fully_connected(self.s16 * self.s16 * self.gf_dim * 8).
- reshape([-1, self.s16, self.s16, self.gf_dim * 8]).
- conv_batch_norm().
- apply(tf.nn.relu).
- custom_deconv2d([0, self.s8, self.s8, self.gf_dim * 4], k_h=4, k_w=4).
- # apply(tf.image.resize_nearest_neighbor, [self.s8, self.s8]).
- # custom_conv2d(self.gf_dim * 4, k_h=3, k_w=3, d_h=1, d_w=1).
- conv_batch_norm().
- apply(tf.nn.relu).
- custom_deconv2d([0, self.s4, self.s4, self.gf_dim * 2], k_h=4, k_w=4).
- # apply(tf.image.resize_nearest_neighbor, [self.s4, self.s4]).
- # custom_conv2d(self.gf_dim * 2, k_h=3, k_w=3, d_h=1, d_w=1).
- conv_batch_norm().
- apply(tf.nn.relu).
- custom_deconv2d([0, self.s2, self.s2, self.gf_dim], k_h=4, k_w=4).
- # apply(tf.image.resize_nearest_neighbor, [self.s2, self.s2]).
- # custom_conv2d(self.gf_dim, k_h=3, k_w=3, d_h=1, d_w=1).
- conv_batch_norm().
- apply(tf.nn.relu).
- custom_deconv2d([0] + list(self.image_shape), k_h=4, k_w=4).
- # apply(tf.image.resize_nearest_neighbor, [self.s, self.s]).
- # custom_conv2d(3, k_h=3, k_w=3, d_h=1, d_w=1).
- apply(tf.nn.tanh))
+ def generator_simple(self, z_var, training=True):
+ output_tensor = fc(z_var, self.s16 * self.s16 * self.gf_dim * 8, 'g_simple_OT/fc')
+ output_tensor = reshape(output_tensor, [-1, self.s16, self.s16, self.gf_dim * 8], name='g_simple_OT/reshape')
+ output_tensor = conv_batch_normalization(output_tensor, 'g_simple_OT/batch_norm', activation_fn=tf.nn.relu,
+ is_training=training)
+ output_tensor = Deconv2d(output_tensor, [0, self.s8, self.s8, self.gf_dim * 4], name='g_simple_OT/deconv2d',
+ k_h=4, k_w=4)
+ output_tensor = conv_batch_normalization(output_tensor, 'g_simple_OT/batch_norm2', activation_fn=tf.nn.relu,
+ is_training=training)
+ output_tensor = Deconv2d(output_tensor, [0, self.s4, self.s4, self.gf_dim * 2], name='g_simple_OT/deconv2d2',
+ k_h=4, k_w=4)
+ output_tensor = conv_batch_normalization(output_tensor, 'g_simple_OT/batch_norm3', activation_fn=tf.nn.relu,
+ is_training=training)
+ output_tensor = Deconv2d(output_tensor, [0, self.s2, self.s2, self.gf_dim], name='g_simple_OT/deconv2d3',
+ k_h=4, k_w=4)
+ output_tensor = conv_batch_normalization(output_tensor, 'g_simple_OT/batch_norm4', activation_fn=tf.nn.relu,
+ is_training=training)
+ output_tensor = Deconv2d(output_tensor, [0] + list(self.image_shape), name='g_simple_OT/deconv2d4',
+ k_h=4, k_w=4, activation_fn=tf.nn.tanh)
+
return output_tensor
- def get_generator(self, z_var):
+
+ def get_generator(self, z_var, is_training):
if cfg.GAN.NETWORK_TYPE == "default":
- return self.generator(z_var)
+ return self.generator(z_var, training=is_training)
elif cfg.GAN.NETWORK_TYPE == "simple":
- return self.generator_simple(z_var)
+ return self.generator_simple(z_var, training=is_training)
else:
raise NotImplementedError
# d-net
- def context_embedding(self):
- template = (pt.template("input").
- custom_fully_connected(self.ef_dim).
- apply(leaky_rectify, leakiness=0.2))
+ def context_embedding(self, inputs=None, if_reuse=None):
+ template = fc(inputs, self.ef_dim, 'd_embedd/fc', activation_fn=tf.nn.leaky_relu, reuse=if_reuse)
return template
- def d_encode_image(self):
- node1_0 = \
- (pt.template("input").
- custom_conv2d(self.df_dim, k_h=4, k_w=4).
- apply(leaky_rectify, leakiness=0.2).
- custom_conv2d(self.df_dim * 2, k_h=4, k_w=4).
- conv_batch_norm().
- apply(leaky_rectify, leakiness=0.2).
- custom_conv2d(self.df_dim * 4, k_h=4, k_w=4).
- conv_batch_norm().
- custom_conv2d(self.df_dim * 8, k_h=4, k_w=4).
- conv_batch_norm())
- node1_1 = \
- (node1_0.
- custom_conv2d(self.df_dim * 2, k_h=1, k_w=1, d_h=1, d_w=1).
- conv_batch_norm().
- apply(leaky_rectify, leakiness=0.2).
- custom_conv2d(self.df_dim * 2, k_h=3, k_w=3, d_h=1, d_w=1).
- conv_batch_norm().
- apply(leaky_rectify, leakiness=0.2).
- custom_conv2d(self.df_dim * 8, k_h=3, k_w=3, d_h=1, d_w=1).
- conv_batch_norm())
-
- node1 = \
- (node1_0.
- apply(tf.add, node1_1).
- apply(leaky_rectify, leakiness=0.2))
+ def d_encode_image(self, training=True, inputs=None, if_reuse=None):
+ node1_0 = Conv2d(inputs, 4, 4, self.df_dim, 2, 2, name='d_n1.0/conv2d', activation_fn=tf.nn.leaky_relu,
+ reuse=if_reuse)
+ node1_0 = Conv2d(node1_0, 4, 4, self.df_dim * 2, 2, 2, name='d_n1.0/conv2d2', reuse=if_reuse)
+ node1_0 = conv_batch_normalization(node1_0, 'd_n1.0/batch_norm', is_training=training,
+ activation_fn=tf.nn.leaky_relu, reuse=if_reuse)
+ node1_0 = Conv2d(node1_0, 4, 4, self.df_dim * 4, 2, 2, name='d_n1.0/conv2d3', reuse=if_reuse)
+ node1_0 = conv_batch_normalization(node1_0, 'd_n1.0/batch_norm2', is_training=training, reuse=if_reuse)
+ node1_0 = Conv2d(node1_0, 4, 4, self.df_dim * 8, 2, 2, name='d_n1.0/conv2d4', reuse=if_reuse)
+ node1_0 = conv_batch_normalization(node1_0, 'd_n1.0/batch_norm3', is_training=training, reuse=if_reuse)
+
+ node1_1 = Conv2d(node1_0, 1, 1, self.df_dim * 2, 1, 1, name='d_n1.1/conv2d', reuse=if_reuse)
+ node1_1 = conv_batch_normalization(node1_1, 'd_n1.1/batch_norm', is_training=training,
+ activation_fn=tf.nn.leaky_relu, reuse=if_reuse)
+ node1_1 = Conv2d(node1_1, 3, 3, self.df_dim * 2, 1, 1, name='d_n1.1/conv2d2', reuse=if_reuse)
+ node1_1 = conv_batch_normalization(node1_1, 'd_n1.1/batch_norm2', is_training=training,
+ activation_fn=tf.nn.leaky_relu, reuse=if_reuse)
+ node1_1 = Conv2d(node1_1, 3, 3, self.df_dim * 8, 1, 1, name='d_n1.1/conv2d3', reuse=if_reuse)
+ node1_1 = conv_batch_normalization(node1_1, 'd_n1.1/batch_norm3', is_training=training, reuse=if_reuse)
+
+ node1 = add([node1_0, node1_1], name='d_n1_res/add')
+ node1 = tf.nn.leaky_relu(node1)
return node1
- def d_encode_image_simple(self):
- template = \
- (pt.template("input").
- custom_conv2d(self.df_dim, k_h=4, k_w=4).
- apply(leaky_rectify, leakiness=0.2).
- custom_conv2d(self.df_dim * 2, k_h=4, k_w=4).
- conv_batch_norm().
- apply(leaky_rectify, leakiness=0.2).
- custom_conv2d(self.df_dim * 4, k_h=4, k_w=4).
- conv_batch_norm().
- apply(leaky_rectify, leakiness=0.2).
- custom_conv2d(self.df_dim * 8, k_h=4, k_w=4).
- conv_batch_norm().
- apply(leaky_rectify, leakiness=0.2))
+ def d_encode_image_simple(self, training=True, inputs=None, if_reuse=None):
+ template = Conv2d(inputs, 4, 4, self.df_dim, 2, 2, activation_fn=tf.nn.leaky_relu, name='d_template/conv2d',
+ reuse=if_reuse)
+ template = Conv2d(template, 4, 4, self.df_dim * 2, 2, 2, name='d_template/conv2d2', reuse=if_reuse)
+ template = conv_batch_normalization(template, 'd_template/batch_norm', is_training=training,
+ activation_fn=tf.nn.leaky_relu, reuse=if_reuse)
+ template = Conv2d(template, 4, 4, self.df_dim * 4, 2, 2, name='d_template/conv2d3', reuse=if_reuse)
+ template = conv_batch_normalization(template, 'd_template/batch_norm2', is_training=training,
+ activation_fn=tf.nn.leaky_relu, reuse=if_reuse)
+ template = Conv2d(template, 4, 4, self.df_dim * 8, 2, 2, name='d_template/conv2d4', reuse=if_reuse)
+ template = conv_batch_normalization(template, 'd_template/batch_norm3', is_training=training,
+ activation_fn=tf.nn.leaky_relu, reuse=if_reuse)
return template
- def discriminator(self):
- template = \
- (pt.template("input"). # 128*9*4*4
- custom_conv2d(self.df_dim * 8, k_h=1, k_w=1, d_h=1, d_w=1). # 128*8*4*4
- conv_batch_norm().
- apply(leaky_rectify, leakiness=0.2).
- # custom_fully_connected(1))
- custom_conv2d(1, k_h=self.s16, k_w=self.s16, d_h=self.s16, d_w=self.s16))
+ def discriminator(self, training=True, inputs=None, if_reuse=None):
+ template = Conv2d(inputs, 1, 1, self.df_dim * 8, 1, 1, name='d_template/conv2d', reuse=if_reuse)
+ template = conv_batch_normalization(template, 'd_template/batch_norm', is_training=training,
+ activation_fn=tf.nn.leaky_relu, reuse=if_reuse)
+ template = Conv2d(template, self.s16, self.s16, 1, self.s16, self.s16, name='d_template/conv2d2',
+ reuse=if_reuse)
return template
- def get_discriminator(self, x_var, c_var):
- x_code = self.d_encode_img_template.construct(input=x_var)
- c_code = self.d_context_template.construct(input=c_var)
- c_code = tf.expand_dims(tf.expand_dims(c_code, 1), 1)
- c_code = tf.tile(c_code, [1, self.s16, self.s16, 1])
+ # Since D is only used during training, we build a template
+ # for safe reuse the variables during computing loss for fake/real/wrong images
+ # We do not do this for G,
+ # because batch_norm needs different options for training and testing
+ def get_discriminator(self, x_var, c_var, is_training, no_reuse=None):
+ if cfg.GAN.NETWORK_TYPE == "default":
+ x_code = self.d_encode_image(training=is_training, inputs=x_var, if_reuse=no_reuse)
+ c_code = self.context_embedding(inputs=c_var, if_reuse=no_reuse)
+ c_code = tf.expand_dims(tf.expand_dims(c_code, 1), 1)
+ c_code = tf.tile(c_code, [1, self.s16, self.s16, 1])
+ x_c_code = tf.concat([x_code, c_code], 3)
+
+ return self.discriminator(training=is_training, inputs=x_c_code, if_reuse=no_reuse)
+
+ elif cfg.GAN.NETWORK_TYPE == "simple":
+ x_code = self.d_encode_image_simple(training=is_training, inputs=x_var, if_reuse=no_reuse)
+ c_code = self.context_embedding(inputs=c_var, if_reuse=no_reuse)
+ c_code = tf.expand_dims(tf.expand_dims(c_code, 1), 1)
+ c_code = tf.tile(c_code, [1, self.s16, self.s16, 1])
+ x_c_code = tf.concat([x_code, c_code], 3)
- x_c_code = tf.concat(3, [x_code, c_code])
- return self.discriminator_template.construct(input=x_c_code)
+ return self.discriminator(training=is_training, inputs=x_c_code, if_reuse=no_reuse)
+ else:
+ raise NotImplementedError
diff --git a/stageI/run_exp.py b/stageI/run_exp.py
index 8535fea..53d0b8e 100644
--- a/stageI/run_exp.py
+++ b/stageI/run_exp.py
@@ -1,27 +1,26 @@
from __future__ import division
from __future__ import print_function
-import dateutil
import dateutil.tz
import datetime
import argparse
import pprint
-from misc.datasets import TextDataset
-from stageI.model import CondGAN
-from stageI.trainer import CondGANTrainer
-from misc.utils import mkdir_p
-from misc.config import cfg, cfg_from_file
+import sys
+sys.path.append('misc')
+sys.path.append('stageI')
+
+from datasets import TextDataset
+from utils import mkdir_p
+from config import cfg, cfg_from_file
+from model import CondGAN
+from trainer import CondGANTrainer
def parse_args():
parser = argparse.ArgumentParser(description='Train a GAN network')
- parser.add_argument('--cfg', dest='cfg_file',
- help='optional config file',
- default=None, type=str)
- parser.add_argument('--gpu', dest='gpu_id',
- help='GPU device id to use [0]',
- default=-1, type=int)
+ parser.add_argument('--cfg', dest='cfg_file', help='optional config file', default=None, type=str)
+ parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]', default=-1, type=int)
# if len(sys.argv) == 1:
# parser.print_help()
# sys.exit(1)
@@ -48,22 +47,16 @@ def parse_args():
filename_train = '%s/train' % (datadir)
dataset.train = dataset.get_data(filename_train)
- ckt_logs_dir = "ckt_logs/%s/%s_%s" % \
- (cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp)
+ ckt_logs_dir = "ckt_logs/%s/%s_%s" % (cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp)
mkdir_p(ckt_logs_dir)
else:
s_tmp = cfg.TRAIN.PRETRAINED_MODEL
ckt_logs_dir = s_tmp[:s_tmp.find('.ckpt')]
- model = CondGAN(
- image_shape=dataset.image_shape
- )
+ model = CondGAN(image_shape=dataset.image_shape)
+
+ algo = CondGANTrainer(model=model, dataset=dataset, ckt_logs_dir=ckt_logs_dir)
- algo = CondGANTrainer(
- model=model,
- dataset=dataset,
- ckt_logs_dir=ckt_logs_dir
- )
if cfg.TRAIN.FLAG:
algo.train()
else:
diff --git a/stageI/trainer.py b/stageI/trainer.py
index 001666a..981bc2e 100644
--- a/stageI/trainer.py
+++ b/stageI/trainer.py
@@ -1,18 +1,18 @@
from __future__ import division
from __future__ import print_function
+from six.moves import range
+from progressbar import ETA, Bar, Percentage, ProgressBar
-import prettytensor as pt
import tensorflow as tf
import numpy as np
-import scipy.misc
import os
-import sys
-from six.moves import range
-from progressbar import ETA, Bar, Percentage, ProgressBar
+import imageio
+import sys
+sys.path.append('misc')
-from misc.config import cfg
-from misc.utils import mkdir_p
+from config import cfg
+from utils import mkdir_p
TINY = 1e-8
@@ -26,12 +26,7 @@ def KL_loss(mu, log_sigma):
class CondGANTrainer(object):
- def __init__(self,
- model,
- dataset=None,
- exp_name="model",
- ckt_logs_dir="ckt_logs",
- ):
+ def __init__(self, model, dataset=None, exp_name="model", ckt_logs_dir="ckt_logs",):
"""
:type model: RegularizedGAN
"""
@@ -48,28 +43,18 @@ def __init__(self,
self.log_vars = []
+ tf.reset_default_graph()
+
def build_placeholder(self):
- '''Helper function for init_opt'''
- self.images = tf.placeholder(
- tf.float32, [self.batch_size] + self.dataset.image_shape,
- name='real_images')
- self.wrong_images = tf.placeholder(
- tf.float32, [self.batch_size] + self.dataset.image_shape,
- name='wrong_images'
- )
- self.embeddings = tf.placeholder(
- tf.float32, [self.batch_size] + self.dataset.embedding_shape,
- name='conditional_embeddings'
- )
-
- self.generator_lr = tf.placeholder(
- tf.float32, [],
- name='generator_learning_rate'
- )
- self.discriminator_lr = tf.placeholder(
- tf.float32, [],
- name='discriminator_learning_rate'
- )
+ ''' Helper function for init_opt '''
+ self.images = tf.placeholder(tf.float32, [self.batch_size] + self.dataset.image_shape, name='real_images')
+ self.wrong_images = tf.placeholder(tf.float32, [self.batch_size] + self.dataset.image_shape,
+ name='wrong_images')
+ self.embeddings = tf.placeholder(tf.float32, [self.batch_size] + self.dataset.embedding_shape,
+ name='conditional_embeddings')
+
+ self.generator_lr = tf.placeholder(tf.float32, [], name='generator_learning_rate')
+ self.discriminator_lr = tf.placeholder(tf.float32, [], name='discriminator_learning_rate')
def sample_encoded_context(self, embeddings):
'''Helper function for init_opt'''
@@ -91,21 +76,19 @@ def sample_encoded_context(self, embeddings):
def init_opt(self):
self.build_placeholder()
- with pt.defaults_scope(phase=pt.Phase.train):
- with tf.variable_scope("g_net"):
- # ####get output from G network################################
- c, kl_loss = self.sample_encoded_context(self.embeddings)
- z = tf.random_normal([self.batch_size, cfg.Z_DIM])
- self.log_vars.append(("hist_c", c))
- self.log_vars.append(("hist_z", z))
- fake_images = self.model.get_generator(tf.concat(1, [c, z]))
+ with tf.variable_scope("g_net"): # For training
+ # ####get output from G network################################
+ c, kl_loss = self.sample_encoded_context(self.embeddings)
+ z = tf.random_normal([self.batch_size, cfg.Z_DIM])
+ self.log_vars.append(("hist_c", c))
+ self.log_vars.append(("hist_z", z))
+ fake_images = self.model.get_generator(tf.concat([c, z], 1), True) # set training to be True
+
+ with tf.variable_scope("d_net"): # For training
# ####get discriminator_loss and generator_loss ###################
- discriminator_loss, generator_loss =\
- self.compute_losses(self.images,
- self.wrong_images,
- fake_images,
- self.embeddings)
+ discriminator_loss, generator_loss = self.compute_losses(self.images, self.wrong_images, fake_images,
+ self.embeddings)
generator_loss += kl_loss
self.log_vars.append(("g_loss_kl_loss", kl_loss))
self.log_vars.append(("g_loss", generator_loss))
@@ -116,11 +99,11 @@ def init_opt(self):
# #######define self.g_sum, self.d_sum,....########################
self.define_summaries()
- with pt.defaults_scope(phase=pt.Phase.test):
- with tf.variable_scope("g_net", reuse=True):
- self.sampler()
- self.visualization(cfg.TRAIN.NUM_COPY)
- print("success")
+ with tf.variable_scope("g_net", reuse=True): # For testing
+ self.sampler()
+ self.visualization(cfg.TRAIN.NUM_COPY)
+ print("success")
+
def sampler(self):
c, _ = self.sample_encoded_context(self.embeddings)
@@ -128,62 +111,52 @@ def sampler(self):
z = tf.zeros([self.batch_size, cfg.Z_DIM]) # Expect similar BGs
else:
z = tf.random_normal([self.batch_size, cfg.Z_DIM])
- self.fake_images = self.model.get_generator(tf.concat(1, [c, z]))
+ self.fake_images = self.model.get_generator(tf.concat([c, z], 1), False) # for testing
def compute_losses(self, images, wrong_images, fake_images, embeddings):
- real_logit = self.model.get_discriminator(images, embeddings)
- wrong_logit = self.model.get_discriminator(wrong_images, embeddings)
- fake_logit = self.model.get_discriminator(fake_images, embeddings)
+ real_logit = self.model.get_discriminator(images, embeddings, True)
+ # Reuse the weights
+ wrong_logit = self.model.get_discriminator(wrong_images, embeddings, True, no_reuse=tf.AUTO_REUSE)
+ fake_logit = self.model.get_discriminator(fake_images, embeddings, True, no_reuse=tf.AUTO_REUSE)
- real_d_loss =\
- tf.nn.sigmoid_cross_entropy_with_logits(real_logit,
- tf.ones_like(real_logit))
+ real_d_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=real_logit, labels=tf.ones_like(real_logit))
real_d_loss = tf.reduce_mean(real_d_loss)
- wrong_d_loss =\
- tf.nn.sigmoid_cross_entropy_with_logits(wrong_logit,
- tf.zeros_like(wrong_logit))
+ wrong_d_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=wrong_logit, labels=tf.zeros_like(wrong_logit))
wrong_d_loss = tf.reduce_mean(wrong_d_loss)
- fake_d_loss =\
- tf.nn.sigmoid_cross_entropy_with_logits(fake_logit,
- tf.zeros_like(fake_logit))
+ fake_d_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_logit, labels=tf.zeros_like(fake_logit))
fake_d_loss = tf.reduce_mean(fake_d_loss)
if cfg.TRAIN.B_WRONG:
- discriminator_loss =\
- real_d_loss + (wrong_d_loss + fake_d_loss) / 2.
+ discriminator_loss = real_d_loss + (wrong_d_loss + fake_d_loss) / 2.
self.log_vars.append(("d_loss_wrong", wrong_d_loss))
else:
discriminator_loss = real_d_loss + fake_d_loss
self.log_vars.append(("d_loss_real", real_d_loss))
self.log_vars.append(("d_loss_fake", fake_d_loss))
- generator_loss = \
- tf.nn.sigmoid_cross_entropy_with_logits(fake_logit,
- tf.ones_like(fake_logit))
+ generator_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_logit, labels=tf.ones_like(fake_logit))
generator_loss = tf.reduce_mean(generator_loss)
return discriminator_loss, generator_loss
def prepare_trainer(self, generator_loss, discriminator_loss):
- '''Helper function for init_opt'''
+ ''' Helper function for init_opt '''
all_vars = tf.trainable_variables()
- g_vars = [var for var in all_vars if
- var.name.startswith('g_')]
- d_vars = [var for var in all_vars if
- var.name.startswith('d_')]
-
- generator_opt = tf.train.AdamOptimizer(self.generator_lr,
- beta1=0.5)
- self.generator_trainer =\
- pt.apply_optimizer(generator_opt,
- losses=[generator_loss],
- var_list=g_vars)
- discriminator_opt = tf.train.AdamOptimizer(self.discriminator_lr,
- beta1=0.5)
- self.discriminator_trainer =\
- pt.apply_optimizer(discriminator_opt,
- losses=[discriminator_loss],
- var_list=d_vars)
+ g_vars = [var for var in all_vars if var.name.startswith('g_')]
+ d_vars = [var for var in all_vars if var.name.startswith('d_')]
+
+ # Update the trainable variables
+ update_ops_D = [var for var in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if var.name.startswith('d_')]
+ update_ops_G = [var for var in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if var.name.startswith('g_')]
+
+ with tf.control_dependencies(update_ops_G): # Update the moving mean and variance from the batch normalization
+ generator_opt = tf.train.AdamOptimizer(self.generator_lr, beta1=0.5)
+ self.generator_trainer = generator_opt.minimize(generator_loss, var_list=g_vars)
+
+ with tf.control_dependencies(update_ops_D): # Update the moving mean and variance from the batch normalization
+ discriminator_opt = tf.train.AdamOptimizer(self.discriminator_lr, beta1=0.5)
+ self.discriminator_trainer = discriminator_opt.minimize(discriminator_loss, var_list=d_vars)
+
self.log_vars.append(("g_learning_rate", self.generator_lr))
self.log_vars.append(("d_learning_rate", self.discriminator_lr))
@@ -192,15 +165,15 @@ def define_summaries(self):
all_sum = {'g': [], 'd': [], 'hist': []}
for k, v in self.log_vars:
if k.startswith('g'):
- all_sum['g'].append(tf.scalar_summary(k, v))
+ all_sum['g'].append(tf.summary.scalar(k, v))
elif k.startswith('d'):
- all_sum['d'].append(tf.scalar_summary(k, v))
+ all_sum['d'].append(tf.summary.scalar(k, v))
elif k.startswith('hist'):
- all_sum['hist'].append(tf.histogram_summary(k, v))
+ all_sum['hist'].append(tf.summary.histogram(k, v))
- self.g_sum = tf.merge_summary(all_sum['g'])
- self.d_sum = tf.merge_summary(all_sum['d'])
- self.hist_sum = tf.merge_summary(all_sum['hist'])
+ self.g_sum = tf.summary.merge(all_sum['g'])
+ self.d_sum = tf.summary.merge(all_sum['d'])
+ self.hist_sum = tf.summary.merge(all_sum['hist'])
def visualize_one_superimage(self, img_var, images, rows, filename):
stacked_img = []
@@ -210,22 +183,18 @@ def visualize_one_superimage(self, img_var, images, rows, filename):
for col in range(rows):
row_img.append(img_var[row * rows + col, :, :, :])
# each rows is 1realimage +10_fakeimage
- stacked_img.append(tf.concat(1, row_img))
- imgs = tf.expand_dims(tf.concat(0, stacked_img), 0)
- current_img_summary = tf.image_summary(filename, imgs)
+ stacked_img.append(tf.concat(row_img, 1))
+ imgs = tf.expand_dims(tf.concat(stacked_img, 0), 0)
+ current_img_summary = tf.summary.image(filename, imgs)
return current_img_summary, imgs
def visualization(self, n):
- fake_sum_train, superimage_train = \
- self.visualize_one_superimage(self.fake_images[:n * n],
- self.images[:n * n],
- n, "train")
- fake_sum_test, superimage_test = \
- self.visualize_one_superimage(self.fake_images[n * n:2 * n * n],
- self.images[n * n:2 * n * n],
- n, "test")
- self.superimages = tf.concat(0, [superimage_train, superimage_test])
- self.image_summary = tf.merge_summary([fake_sum_train, fake_sum_test])
+ fake_sum_train, superimage_train = self.visualize_one_superimage(self.fake_images[:n * n], self.images[:n * n],
+ n, "train")
+ fake_sum_test, superimage_test = self.visualize_one_superimage(self.fake_images[n * n:2 * n * n],
+ self.images[n * n:2 * n * n], n, "test")
+ self.superimages = tf.concat([superimage_train, superimage_test], 0)
+ self.image_summary = tf.summary.merge([fake_sum_train, fake_sum_test])
def preprocess(self, x, n):
# make sure every row with n column have the same embeddings
@@ -235,33 +204,29 @@ def preprocess(self, x, n):
return x
def epoch_sum_images(self, sess, n):
- images_train, _, embeddings_train, captions_train, _ =\
- self.dataset.train.next_batch(n * n, cfg.TRAIN.NUM_EMBEDDING)
+ images_train, _, embeddings_train, captions_train, _ = self.dataset.train.next_batch(n * n,
+ cfg.TRAIN.NUM_EMBEDDING)
images_train = self.preprocess(images_train, n)
embeddings_train = self.preprocess(embeddings_train, n)
- images_test, _, embeddings_test, captions_test, _ = \
- self.dataset.test.next_batch(n * n, 1)
+ images_test, _, embeddings_test, captions_test, _ = self.dataset.test.next_batch(n * n, 1)
images_test = self.preprocess(images_test, n)
embeddings_test = self.preprocess(embeddings_test, n)
images = np.concatenate([images_train, images_test], axis=0)
- embeddings =\
- np.concatenate([embeddings_train, embeddings_test], axis=0)
+ embeddings = np.concatenate([embeddings_train, embeddings_test], axis=0)
if self.batch_size > 2 * n * n:
- images_pad, _, embeddings_pad, _, _ =\
- self.dataset.test.next_batch(self.batch_size - 2 * n * n, 1)
+ images_pad, _, embeddings_pad, _, _ = self.dataset.test.next_batch(self.batch_size - 2 * n * n, 1)
images = np.concatenate([images, images_pad], axis=0)
embeddings = np.concatenate([embeddings, embeddings_pad], axis=0)
feed_dict = {self.images: images,
self.embeddings: embeddings}
- gen_samples, img_summary =\
- sess.run([self.superimages, self.image_summary], feed_dict)
+ gen_samples, img_summary = sess.run([self.superimages, self.image_summary], feed_dict)
# save images generated for train and test captions
- scipy.misc.imsave('%s/train.jpg' % (self.log_dir), gen_samples[0])
- scipy.misc.imsave('%s/test.jpg' % (self.log_dir), gen_samples[1])
+ imageio.imwrite('%s/train.jpg' % (self.log_dir), gen_samples[0])
+ imageio.imwrite('%s/test.jpg' % (self.log_dir), gen_samples[1])
# pfi_train = open(self.log_dir + "/train.txt", "w")
pfi_test = open(self.log_dir + "/test.txt", "w")
@@ -278,11 +243,11 @@ def epoch_sum_images(self, sess, n):
def build_model(self, sess):
self.init_opt()
- sess.run(tf.initialize_all_variables())
+ sess.run(tf.global_variables_initializer())
if len(self.model_path) > 0:
print("Reading model parameters from %s" % self.model_path)
- restore_vars = tf.all_variables()
+ restore_vars = tf.global_variables()
# all_vars = tf.all_variables()
# restore_vars = [var for var in all_vars if
# var.name.startswith('g_') or
@@ -301,15 +266,14 @@ def build_model(self, sess):
def train(self):
config = tf.ConfigProto(allow_soft_placement=True)
+ config.gpu_options.per_process_gpu_memory_fraction = 0.7
with tf.Session(config=config) as sess:
with tf.device("/gpu:%d" % cfg.GPU_ID):
counter = self.build_model(sess)
- saver = tf.train.Saver(tf.all_variables(),
- keep_checkpoint_every_n_hours=2)
+ saver = tf.train.Saver(tf.global_variables(), keep_checkpoint_every_n_hours=2)
# summary_op = tf.merge_all_summaries()
- summary_writer = tf.train.SummaryWriter(self.log_dir,
- sess.graph)
+ summary_writer = tf.summary.FileWriter(self.log_dir, sess.graph)
keys = ["d_loss", "g_loss"]
log_vars = []
@@ -327,10 +291,8 @@ def train(self):
updates_per_epoch = int(number_example / self.batch_size)
epoch_start = int(counter / updates_per_epoch)
for epoch in range(epoch_start, self.max_epoch):
- widgets = ["epoch #%d|" % epoch,
- Percentage(), Bar(), ETA()]
- pbar = ProgressBar(maxval=updates_per_epoch,
- widgets=widgets)
+ widgets = ["epoch #%d|" % epoch, Percentage(), Bar(), ETA()]
+ pbar = ProgressBar(maxval=updates_per_epoch, widgets=widgets)
pbar.start()
if epoch % lr_decay_step == 0 and epoch != 0:
@@ -341,8 +303,7 @@ def train(self):
for i in range(updates_per_epoch):
pbar.update(i)
# training d
- images, wrong_images, embeddings, _, _ =\
- self.dataset.train.next_batch(self.batch_size,
+ images, wrong_images, embeddings, _, _ = self.dataset.train.next_batch(self.batch_size,
num_embedding)
feed_dict = {self.images: images,
self.wrong_images: wrong_images,
@@ -351,28 +312,19 @@ def train(self):
self.discriminator_lr: discriminator_lr
}
# train d
- feed_out = [self.discriminator_trainer,
- self.d_sum,
- self.hist_sum,
- log_vars]
- _, d_sum, hist_sum, log_vals = sess.run(feed_out,
- feed_dict)
+ feed_out = [self.discriminator_trainer, self.d_sum, self.hist_sum, log_vars]
+ _, d_sum, hist_sum, log_vals = sess.run(feed_out, feed_dict)
summary_writer.add_summary(d_sum, counter)
summary_writer.add_summary(hist_sum, counter)
all_log_vals.append(log_vals)
# train g
- feed_out = [self.generator_trainer,
- self.g_sum]
- _, g_sum = sess.run(feed_out,
- feed_dict)
+ feed_out = [self.generator_trainer, self.g_sum]
+ _, g_sum = sess.run(feed_out, feed_dict)
summary_writer.add_summary(g_sum, counter)
# save checkpoint
counter += 1
if counter % self.snapshot_interval == 0:
- snapshot_path = "%s/%s_%s.ckpt" %\
- (self.checkpoint_dir,
- self.exp_name,
- str(counter))
+ snapshot_path = "%s/%s_%s.ckpt" % (self.checkpoint_dir, self.exp_name, str(counter))
fn = saver.save(sess, snapshot_path)
print("Model saved in file: %s" % fn)
@@ -385,21 +337,17 @@ def train(self):
dic_logs[k] = v
# print(k, v)
- log_line = "; ".join("%s: %s" %
- (str(k), str(dic_logs[k]))
- for k in dic_logs)
+ log_line = "; ".join("%s: %s" % (str(k), str(dic_logs[k])) for k in dic_logs)
print("Epoch %d | " % (epoch) + log_line)
sys.stdout.flush()
if np.any(np.isnan(avg_log_vals)):
raise ValueError("NaN detected!")
- def save_super_images(self, images, sample_batchs, filenames,
- sentenceID, save_dir, subset):
+ def save_super_images(self, images, sample_batchs, filenames, sentenceID, save_dir, subset):
# batch_size samples for each embedding
numSamples = len(sample_batchs)
for j in range(len(filenames)):
- s_tmp = '%s-1real-%dsamples/%s/%s' %\
- (save_dir, numSamples, subset, filenames[j])
+ s_tmp = '%s-1real-%dsamples/%s/%s' % (save_dir, numSamples, subset, filenames[j])
folder = s_tmp[:s_tmp.rfind('/')]
if not os.path.isdir(folder):
print('Make a new folder: ', folder)
@@ -411,42 +359,40 @@ def save_super_images(self, images, sample_batchs, filenames,
superimage = np.concatenate(superimage, axis=1)
fullpath = '%s_sentence%d.jpg' % (s_tmp, sentenceID)
- scipy.misc.imsave(fullpath, superimage)
+ imageio.imwrite(fullpath, superimage)
def eval_one_dataset(self, sess, dataset, save_dir, subset='train'):
count = 0
print('num_examples:', dataset._num_examples)
while count < dataset._num_examples:
start = count % dataset._num_examples
- images, embeddings_batchs, filenames, _ =\
- dataset.next_batch_test(self.batch_size, start, 1)
+ images, embeddings_batchs, filenames, _ = dataset.next_batch_test(self.batch_size, start, 1)
print('count = ', count, 'start = ', start)
for i in range(len(embeddings_batchs)):
samples_batchs = []
# Generate up to 16 images for each sentence,
# with randomness from noise z and conditioning augmentation.
for j in range(np.minimum(16, cfg.TRAIN.NUM_COPY)):
- samples = sess.run(self.fake_images,
- {self.embeddings: embeddings_batchs[i]})
+ samples = sess.run(self.fake_images, {self.embeddings: embeddings_batchs[i]})
samples_batchs.append(samples)
- self.save_super_images(images, samples_batchs,
- filenames, i, save_dir,
- subset)
+ self.save_super_images(images, samples_batchs, filenames, i, save_dir, subset)
count += self.batch_size
def evaluate(self):
config = tf.ConfigProto(allow_soft_placement=True)
+ config.gpu_options.per_process_gpu_memory_fraction = 0.7
with tf.Session(config=config) as sess:
with tf.device("/gpu:%d" % cfg.GPU_ID):
if self.model_path.find('.ckpt') != -1:
self.init_opt()
print("Reading model parameters from %s" % self.model_path)
- saver = tf.train.Saver(tf.all_variables())
+ saver = tf.train.Saver(tf.global_variables())
+ print(tf.global_variables())
saver.restore(sess, self.model_path)
- # self.eval_one_dataset(sess, self.dataset.train,
- # self.log_dir, subset='train')
- self.eval_one_dataset(sess, self.dataset.test,
- self.log_dir, subset='test')
+
+ # self.eval_one_dataset(sess, self.dataset.train, self.log_dir, subset='train')
+
+ self.eval_one_dataset(sess, self.dataset.test, self.log_dir, subset='test')
else:
print("Input a valid model path.")
diff --git a/stageII/__init__.py b/stageII/__init__.py
index f78a8b1..008827b 100644
--- a/stageII/__init__.py
+++ b/stageII/__init__.py
@@ -1,2 +1,3 @@
from __future__ import division
from __future__ import print_function
+
diff --git a/stageII/cfg/birds.yml b/stageII/cfg/birds.yml
index 4e3ce8d..e0c805f 100644
--- a/stageII/cfg/birds.yml
+++ b/stageII/cfg/birds.yml
@@ -7,9 +7,9 @@ Z_DIM: 100
TRAIN:
FLAG: True
- PRETRAINED_MODEL: './ckt_logs/birds/stageI/model_82000.ckpt'
+ PRETRAINED_MODEL: './models/stageI/model_82000.ckpt'
PRETRAINED_EPOCH: 600
- BATCH_SIZE: 64
+ BATCH_SIZE: 64 # 32 (if you do not have enough space)
NUM_COPY: 4
MAX_EPOCH: 1200
SNAPSHOT_INTERVAL: 2000
@@ -19,6 +19,7 @@ TRAIN:
NUM_EMBEDDING: 4
COEFF:
KL: 2.0
+ FINETUNE_LR: True
GAN:
EMBEDDING_DIM: 128
diff --git a/stageII/model.py b/stageII/model.py
index 28aee30..a4bddfa 100644
--- a/stageII/model.py
+++ b/stageII/model.py
@@ -1,14 +1,12 @@
from __future__ import division
from __future__ import print_function
-import prettytensor as pt
import tensorflow as tf
-import misc.custom_ops
-from misc.custom_ops import leaky_rectify
-from misc.config import cfg
+import sys
+sys.path.append('misc')
-# TODO: Does template.constrct() really shared the computation
-# when multipel times of construct are done
+from custom_ops import fc, conv_batch_normalization, fc_batch_normalization, reshape, Conv2d, UpSample, add
+from config import cfg
class CondGAN(object):
@@ -22,300 +20,279 @@ def __init__(self, lr_imsize, hr_lr_ratio):
self.s = lr_imsize
print('lr_imsize: ', lr_imsize)
- self.s2, self.s4, self.s8, self.s16 = \
- int(self.s / 2), int(self.s / 4), int(self.s / 8), int(self.s / 16)
- if cfg.GAN.NETWORK_TYPE == "default":
- with tf.variable_scope("d_net"):
- self.d_context_template = self.context_embedding()
- self.d_image_template = self.d_encode_image()
- self.d_discriminator_template = self.discriminator()
-
- with tf.variable_scope("hr_d_net"):
- self.hr_d_context_template = self.context_embedding()
- self.hr_d_image_template = self.hr_d_encode_image()
- self.hr_discriminator_template = self.discriminator()
- else:
- raise NotImplementedError
+ self.s2, self.s4, self.s8, self.s16 = int(self.s / 2), int(self.s / 4), int(self.s / 8), int(self.s / 16)
# conditioning augmentation structure for text embedding
# are shared by g and hr_g
# g and hr_g build this structure separately and do not share parameters
def generate_condition(self, c_var):
- conditions =\
- (pt.wrap(c_var).
- flatten().
- custom_fully_connected(self.ef_dim * 2).
- apply(leaky_rectify, leakiness=0.2))
+ conditions = fc(c_var, self.ef_dim * 2, 'gen_cond/fc', activation_fn=tf.nn.leaky_relu)
mean = conditions[:, :self.ef_dim]
log_sigma = conditions[:, self.ef_dim:]
return [mean, log_sigma]
# stage I generator (g)
- def generator(self, z_var):
- node1_0 =\
- (pt.wrap(z_var).
- flatten().
- custom_fully_connected(self.s16 * self.s16 * self.gf_dim * 8).
- fc_batch_norm().
- reshape([-1, self.s16, self.s16, self.gf_dim * 8]))
- node1_1 = \
- (node1_0.
- custom_conv2d(self.gf_dim * 2, k_h=1, k_w=1, d_h=1, d_w=1).
- conv_batch_norm().
- apply(tf.nn.relu).
- custom_conv2d(self.gf_dim * 2, k_h=3, k_w=3, d_h=1, d_w=1).
- conv_batch_norm().
- apply(tf.nn.relu).
- custom_conv2d(self.gf_dim * 8, k_h=3, k_w=3, d_h=1, d_w=1).
- conv_batch_norm())
- node1 = \
- (node1_0.
- apply(tf.add, node1_1).
- apply(tf.nn.relu))
-
- node2_0 = \
- (node1.
- # custom_deconv2d([0, self.s8, self.s8, self.gf_dim * 4], k_h=4, k_w=4).
- apply(tf.image.resize_nearest_neighbor, [self.s8, self.s8]).
- custom_conv2d(self.gf_dim * 4, k_h=3, k_w=3, d_h=1, d_w=1).
- conv_batch_norm())
- node2_1 = \
- (node2_0.
- custom_conv2d(self.gf_dim * 1, k_h=1, k_w=1, d_h=1, d_w=1).
- conv_batch_norm().
- apply(tf.nn.relu).
- custom_conv2d(self.gf_dim * 1, k_h=3, k_w=3, d_h=1, d_w=1).
- conv_batch_norm().
- apply(tf.nn.relu).
- custom_conv2d(self.gf_dim * 4, k_h=3, k_w=3, d_h=1, d_w=1).
- conv_batch_norm())
- node2 = \
- (node2_0.
- apply(tf.add, node2_1).
- apply(tf.nn.relu))
-
- output_tensor = \
- (node2.
- # custom_deconv2d([0, self.s4, self.s4, self.gf_dim * 2], k_h=4, k_w=4).
- apply(tf.image.resize_nearest_neighbor, [self.s4, self.s4]).
- custom_conv2d(self.gf_dim * 2, k_h=3, k_w=3, d_h=1, d_w=1).
- conv_batch_norm().
- apply(tf.nn.relu).
- # custom_deconv2d([0, self.s2, self.s2, self.gf_dim], k_h=4, k_w=4).
- apply(tf.image.resize_nearest_neighbor, [self.s2, self.s2]).
- custom_conv2d(self.gf_dim, k_h=3, k_w=3, d_h=1, d_w=1).
- conv_batch_norm().
- apply(tf.nn.relu).
- # custom_deconv2d([0] + list(self.image_shape), k_h=4, k_w=4).
- apply(tf.image.resize_nearest_neighbor, [self.s, self.s]).
- custom_conv2d(3, k_h=3, k_w=3, d_h=1, d_w=1).
- apply(tf.nn.tanh))
+ def generator(self, z_var, training=True):
+ node1_0 = fc(z_var, self.s16 * self.s16 * self.gf_dim * 8, 'g_n1.0/fc')
+ node1_0 = fc_batch_normalization(node1_0, 'g_n1.0/batch_norm')
+ node1_0 = reshape(node1_0, [-1, self.s16, self.s16, self.gf_dim * 8], name='g_n1.0/reshape')
+
+ node1_1 = Conv2d(node1_0, 1, 1, self.gf_dim * 2, 1, 1, name='g_n1.1/conv2d')
+ node1_1 = conv_batch_normalization(node1_1, 'g_n1.1/batch_norm_1', activation_fn=tf.nn.relu,
+ is_training=training)
+ node1_1 = Conv2d(node1_1, 3, 3, self.gf_dim * 2, 1, 1, name='g_n1.1/conv2d2')
+ node1_1 = conv_batch_normalization(node1_1, 'g_n1.1/batch_norm_2', activation_fn=tf.nn.relu,
+ is_training=training)
+ node1_1 = Conv2d(node1_1, 3, 3, self.gf_dim * 8, 1, 1, name='g_n1.1/conv2d3')
+ node1_1 = conv_batch_normalization(node1_1, 'g_n1.1/batch_norm_3', activation_fn=tf.nn.relu,
+ is_training=training)
+
+ node1 = add([node1_0, node1_1], name='g_n1_res/add')
+ node1_output = tf.nn.relu(node1)
+
+ node2_0 = UpSample(node1_output, size=[self.s8, self.s8], method=1, align_corners=False, name='g_n2.0/upsample')
+ node2_0 = Conv2d(node2_0, 3, 3, self.gf_dim * 4, 1, 1, name='g_n2.0/conv2d')
+ node2_0 = conv_batch_normalization(node2_0, 'g_n2.0/batch_norm', is_training=training)
+
+ node2_1 = Conv2d(node2_0, 1, 1, self.gf_dim * 1, 1, 1, name='g_n2.1/conv2d')
+ node2_1 = conv_batch_normalization(node2_1, 'g_n2.1/batch_norm', activation_fn=tf.nn.relu,
+ is_training=training)
+ node2_1 = Conv2d(node2_1, 3, 3, self.gf_dim * 1, 1, 1, name='g_n2.1/conv2d2')
+ node2_1 = conv_batch_normalization(node2_1, 'g_n2.1/batch_norm2', activation_fn=tf.nn.relu,
+ is_training=training)
+ node2_1 = Conv2d(node2_1, 3, 3, self.gf_dim * 4, 1, 1, name='g_n2.1/conv2d3')
+ node2_1 = conv_batch_normalization(node2_1, 'g_n2.1/batch_norm3', is_training=training)
+
+ node2 = add([node2_0, node2_1], name='g_n2_res/add')
+ node2_output = tf.nn.relu(node2)
+
+ output_tensor = UpSample(node2_output, size=[self.s4, self.s4], method=1, align_corners=False,
+ name='g_OT/upsample')
+ output_tensor = Conv2d(output_tensor, 3, 3, self.gf_dim * 2, 1, 1, name='g_OT/conv2d')
+ output_tensor = conv_batch_normalization(output_tensor, 'g_OT/batch_norm', activation_fn=tf.nn.relu,
+ is_training=training)
+ output_tensor = UpSample(output_tensor, size=[self.s2, self.s2], method=1, align_corners=False,
+ name='g_OT/upsample2')
+ output_tensor = Conv2d(output_tensor, 3, 3, self.gf_dim, 1, 1, name='g_OT/conv2d2')
+ output_tensor = conv_batch_normalization(output_tensor, 'g_OT/batch_norm2', activation_fn=tf.nn.relu,
+ is_training=training)
+ output_tensor = UpSample(output_tensor, size=[self.s, self.s], method=1, align_corners=False,
+ name='g_OT/upsample3')
+ output_tensor = Conv2d(output_tensor, 3, 3, 3, 1, 1, activation_fn=tf.nn.tanh, name='g_OT/conv2d3')
+
return output_tensor
- def get_generator(self, z_var):
+ def get_generator(self, z_var, is_training):
if cfg.GAN.NETWORK_TYPE == "default":
- return self.generator(z_var)
+ return self.generator(z_var, training=is_training)
else:
raise NotImplementedError
# stage II generator (hr_g)
- def residual_block(self, x_c_code):
- node0_0 = pt.wrap(x_c_code) # -->s4 * s4 * gf_dim * 4
- node0_1 = \
- (pt.wrap(x_c_code). # -->s4 * s4 * gf_dim * 4
- custom_conv2d(self.gf_dim * 4, k_h=3, k_w=3, d_h=1, d_w=1).
- conv_batch_norm().
- apply(tf.nn.relu).
- custom_conv2d(self.gf_dim * 4, k_h=3, k_w=3, d_h=1, d_w=1).
- conv_batch_norm())
- output_tensor = \
- (node0_0.
- apply(tf.add, node0_1).
- apply(tf.nn.relu))
+ def residual_block(self, x_c_code, name, training=True):
+ node0_0 = x_c_code # -->s4 * s4 * gf_dim * 4
+
+ node0_1 = Conv2d(x_c_code, 3, 3, self.gf_dim * 4, 1, 1, name=name+'/conv2d')
+ node0_1 = conv_batch_normalization(node0_1, name+'/batch_norm', is_training=training,
+ activation_fn=tf.nn.relu)
+ node0_1 = Conv2d(node0_1, 3, 3, self.gf_dim * 4, 1, 1, name=name+'/conv2d2')
+ node0_1 = conv_batch_normalization(node0_1, name+'/batch_norm2', is_training=training)
+
+ output_tensor = add([node0_0, node0_1], name='resid_block/add')
+ output_tensor = tf.nn.relu(output_tensor)
+
return output_tensor
- def hr_g_encode_image(self, x_var):
- output_tensor = \
- (pt.wrap(x_var). # -->s * s * 3
- custom_conv2d(self.gf_dim, k_h=3, k_w=3, d_h=1, d_w=1). # s * s * gf_dim
- apply(tf.nn.relu).
- custom_conv2d(self.gf_dim * 2, k_h=4, k_w=4). # s2 * s2 * gf_dim * 2
- conv_batch_norm().
- apply(tf.nn.relu).
- custom_conv2d(self.gf_dim * 4, k_h=4, k_w=4). # s4 * s4 * gf_dim * 4
- conv_batch_norm().
- apply(tf.nn.relu))
+ def hr_g_encode_image(self, x_var, training=True): # input: x_var --> s * s * 3
+ # s * s * gf_dim
+ output_tensor = Conv2d(x_var, 3, 3, self.gf_dim, 1, 1, activation_fn=tf.nn.relu, name='hr_g_OT/conv2d')
+
+ # s2 * s2 * gf_dim * 2
+ output_tensor = Conv2d(output_tensor, 4, 4, self.gf_dim * 2, 2, 2, name='hr_g_OT/conv2d2')
+ output_tensor = conv_batch_normalization(output_tensor, 'hr_g_OT/batch_norm', is_training=training,
+ activation_fn=tf.nn.relu)
+ # s4 * s4 * gf_dim * 4
+ output_tensor = Conv2d(output_tensor, 4, 4, self.gf_dim * 4, 2, 2, name='hr_g_OT/conv2d3')
+ output_tensor = conv_batch_normalization(output_tensor, 'hr_g_OT/batch_norm2', is_training=training,
+ activation_fn=tf.nn.relu)
return output_tensor
- def hr_g_joint_img_text(self, x_c_code):
- output_tensor = \
- (pt.wrap(x_c_code). # -->s4 * s4 * (ef_dim+gf_dim*4)
- custom_conv2d(self.gf_dim * 4, k_h=3, k_w=3, d_h=1, d_w=1). # s4 * s4 * gf_dim * 4
- conv_batch_norm().
- apply(tf.nn.relu))
+ def hr_g_joint_img_text(self, x_c_code, training=True): # input: x_code: -->s4 * s4 * (ef_dim+gf_dim*4)
+ # s4 * s4 * gf_dim * 4
+ output_tensor = Conv2d(x_c_code, 3, 3, self.gf_dim * 4, 1, 1, name='hr_g_joint_OT/conv2d')
+ output_tensor = conv_batch_normalization(output_tensor, 'hr_g_joint_OT/batch_norm', is_training=training,
+ activation_fn=tf.nn.relu)
return output_tensor
- def hr_generator(self, x_c_code):
- output_tensor = \
- (pt.wrap(x_c_code). # -->s4 * s4 * gf_dim*4
- # custom_deconv2d([0, self.s2, self.s2, self.gf_dim * 2], k_h=4, k_w=4). # -->s2 * s2 * gf_dim*2
- apply(tf.image.resize_nearest_neighbor, [self.s2, self.s2]).
- custom_conv2d(self.gf_dim * 2, k_h=3, k_w=3, d_h=1, d_w=1).
- conv_batch_norm().
- apply(tf.nn.relu).
- # custom_deconv2d([0, self.s, self.s, self.gf_dim], k_h=4, k_w=4). # -->s * s * gf_dim
- apply(tf.image.resize_nearest_neighbor, [self.s, self.s]).
- custom_conv2d(self.gf_dim, k_h=3, k_w=3, d_h=1, d_w=1).
- conv_batch_norm().
- apply(tf.nn.relu).
- # custom_deconv2d([0, self.s * 2, self.s * 2, self.gf_dim // 2], k_h=4, k_w=4). # -->2s * 2s * gf_dim/2
- apply(tf.image.resize_nearest_neighbor, [self.s * 2, self.s * 2]).
- custom_conv2d(self.gf_dim // 2, k_h=3, k_w=3, d_h=1, d_w=1).
- conv_batch_norm().
- apply(tf.nn.relu).
- # custom_deconv2d([0, self.s * 4, self.s * 4, self.gf_dim // 4], k_h=4, k_w=4). # -->4s * 4s * gf_dim//4
- apply(tf.image.resize_nearest_neighbor, [self.s * 4, self.s * 4]).
- custom_conv2d(self.gf_dim // 4, k_h=3, k_w=3, d_h=1, d_w=1).
- conv_batch_norm().
- apply(tf.nn.relu).
- custom_conv2d(3, k_h=3, k_w=3, d_h=1, d_w=1). # -->4s * 4s * 3
- apply(tf.nn.tanh))
+ def hr_generator(self, x_c_code, training=True): # Input: x_c_code -->s4 * s4 * gf_dim*4
+ output_tensor = UpSample(x_c_code, size=[self.s2, self.s2], method=1, align_corners=False,
+ name='hr_gen/upsample')
+ output_tensor = Conv2d(output_tensor, 3, 3, self.gf_dim * 2, 1, 1, name='hr_gen/conv2d')
+ output_tensor = conv_batch_normalization(output_tensor, 'hr_gen/batch_norm', is_training=training,
+ activation_fn=tf.nn.relu)
+ output_tensor = UpSample(output_tensor, size=[self.s, self.s], method=1, align_corners=False,
+ name='hr_gen/upsample2')
+ output_tensor = Conv2d(output_tensor, 3, 3, self.gf_dim, 1, 1, name='hr_gen/conv2d2')
+ output_tensor = conv_batch_normalization(output_tensor, 'hr_gen/batch_norm2', is_training=training,
+ activation_fn=tf.nn.relu)
+ output_tensor = UpSample(output_tensor, size=[self.s * 2, self.s * 2], method=1, align_corners=False,
+ name='hr_gen/upsample3')
+ output_tensor = Conv2d(output_tensor, 3, 3, self.gf_dim//2, 1, 1, name='hr_gen/conv2d3')
+ output_tensor = conv_batch_normalization(output_tensor, 'hr_gen/batch_norm3', is_training=training,
+ activation_fn=tf.nn.relu)
+ output_tensor = UpSample(output_tensor, size=[self.s * 4, self.s * 4], method=1, align_corners=False,
+ name='hr_gen/upsample3')
+ output_tensor = Conv2d(output_tensor, 3, 3, self.gf_dim//4, 1, 1, name='hr_gen/conv2d4')
+ output_tensor = conv_batch_normalization(output_tensor, 'hr_gen/batch_norm4', is_training=training,
+ activation_fn=tf.nn.relu)
+ # -->4s * 4s * 3
+ output_tensor = Conv2d(output_tensor, 3, 3, 3, 1, 1, name='hr_gen/conv2d5', activation_fn=tf.nn.tanh)
return output_tensor
- def hr_get_generator(self, x_var, c_code):
+ def hr_get_generator(self, x_var, c_code, is_training):
if cfg.GAN.NETWORK_TYPE == "default":
# image x_var: self.s * self.s *3
- x_code = self.hr_g_encode_image(x_var) # -->s4 * s4 * gf_dim * 4
+ x_code = self.hr_g_encode_image(x_var, training=is_training) # -->s4 * s4 * gf_dim * 4
# text c_code: ef_dim
c_code = tf.expand_dims(tf.expand_dims(c_code, 1), 1)
c_code = tf.tile(c_code, [1, self.s4, self.s4, 1])
# combine both --> s4 * s4 * (ef_dim+gf_dim*4)
- x_c_code = tf.concat(3, [x_code, c_code])
+ x_c_code = tf.concat([x_code, c_code], 3)
# Joint learning from text and image -->s4 * s4 * gf_dim * 4
node0 = self.hr_g_joint_img_text(x_c_code)
- node1 = self.residual_block(node0)
- node2 = self.residual_block(node1)
- node3 = self.residual_block(node2)
- node4 = self.residual_block(node3)
+ node1 = self.residual_block(node0, 'node1_resid_block', training=is_training)
+ node2 = self.residual_block(node1, 'node2_resid_block', training=is_training)
+ node3 = self.residual_block(node2, 'node3_resid_block', training=is_training)
+ node4 = self.residual_block(node3, 'node4_resid_block', training=is_training)
# Up-sampling
- return self.hr_generator(node4) # -->4s * 4s * 3
+ return self.hr_generator(node4, training=is_training) # -->4s * 4s * 3
else:
raise NotImplementedError
# structure shared by d and hr_d
# d and hr_d build this structure separately and do not share parameters
- def context_embedding(self):
- template = (pt.template("input").
- custom_fully_connected(self.ef_dim).
- apply(leaky_rectify, leakiness=0.2))
+ def context_embedding(self, inputs=None, if_reuse=None):
+ template = fc(inputs, self.ef_dim, 'd_embedd/fc', activation_fn=tf.nn.leaky_relu, reuse=if_reuse)
+
return template
- def discriminator(self):
- template = \
- (pt.template("input"). # s16 * s16 * 128*9
- custom_conv2d(self.df_dim * 8, k_h=1, k_w=1, d_h=1, d_w=1). # s16 * s16 * 128*8
- conv_batch_norm().
- apply(leaky_rectify, leakiness=0.2).
- # custom_fully_connected(1))
- custom_conv2d(1, k_h=self.s16, k_w=self.s16, d_h=self.s16, d_w=self.s16))
+ def discriminator(self, training=True, inputs=None, if_reuse=None):
+ template = Conv2d(inputs, 1, 1, self.df_dim * 8, 1, 1, name='d_template/conv2d', reuse=if_reuse)
+ template = conv_batch_normalization(template, 'd_template/batch_norm', is_training=training,
+ activation_fn=tf.nn.leaky_relu, reuse=if_reuse)
+ template = Conv2d(template, self.s16, self.s16, 1, self.s16, self.s16, name='d_template/conv2d2',
+ reuse=if_reuse)
return template
# d-net
- def d_encode_image(self):
- node1_0 = \
- (pt.template("input"). # s * s * 3
- custom_conv2d(self.df_dim, k_h=4, k_w=4). # s2 * s2 * df_dim
- apply(leaky_rectify, leakiness=0.2).
- custom_conv2d(self.df_dim * 2, k_h=4, k_w=4). # s4 * s4 * df_dim*2
- conv_batch_norm().
- apply(leaky_rectify, leakiness=0.2).
- custom_conv2d(self.df_dim * 4, k_h=4, k_w=4). # s8 * s8 * df_dim*4
- conv_batch_norm().
- custom_conv2d(self.df_dim * 8, k_h=4, k_w=4). # s16 * s16 * df_dim*8
- conv_batch_norm())
- node1_1 = \
- (node1_0.
- custom_conv2d(self.df_dim * 2, k_h=1, k_w=1, d_h=1, d_w=1).
- conv_batch_norm().
- apply(leaky_rectify, leakiness=0.2).
- custom_conv2d(self.df_dim * 2, k_h=3, k_w=3, d_h=1, d_w=1).
- conv_batch_norm().
- apply(leaky_rectify, leakiness=0.2).
- custom_conv2d(self.df_dim * 8, k_h=3, k_w=3, d_h=1, d_w=1).
- conv_batch_norm())
-
- node1 = \
- (node1_0.
- apply(tf.add, node1_1).
- apply(leaky_rectify, leakiness=0.2))
+ def d_encode_image(self, inputs=None, training=True, if_reuse=None):
+ # input: s * s * 3
+ node1_0 = Conv2d(inputs, 4, 4, self.df_dim, 2, 2, activation_fn=tf.nn.leaky_relu, name='d_n1.0/conv2d',
+ reuse=if_reuse) # s2 * s2 * df_dim
+
+ # s4 * s4 * df_dim*2
+ node1_0 = Conv2d(node1_0, 4, 4, self.df_dim * 2, 2, 2, name='d_n1.0/conv2d2', reuse=if_reuse)
+ node1_0 = conv_batch_normalization(node1_0, 'd_n1.0/batch_norm', is_training=training,
+ activation_fn=tf.nn.leaky_relu, reuse=if_reuse)
+ # s8 * s8 * df_dim*4
+ node1_0 = Conv2d(node1_0, 4, 4, self.df_dim * 4, 2, 2, name='d_n1.0/conv2d3', reuse=if_reuse)
+ node1_0 = conv_batch_normalization(node1_0, 'd_n1.0/batch_norm2', is_training=training, reuse=if_reuse)
+ # s16 * s16 * df_dim*8
+ node1_0 = Conv2d(node1_0, 4, 4, self.df_dim * 8, 2, 2, name='d_n1.0/conv2d4', reuse=if_reuse)
+ node1_0 = conv_batch_normalization(node1_0, 'd_n1.0/batch_norm3', is_training=training, reuse=if_reuse)
+
+ node1_1 = Conv2d(node1_0, 1, 1, self.df_dim * 2, 1, 1, name='d_n1.1/conv2d', reuse=if_reuse)
+ node1_1 = conv_batch_normalization(node1_1, 'd_n1.1/batch_norm', is_training=training,
+ activation_fn=tf.nn.leaky_relu, reuse=if_reuse)
+ node1_1 = Conv2d(node1_1, 3, 3, self.df_dim * 2, 1, 1, name='d_n1.1/conv2d2', reuse=if_reuse)
+ node1_1 = conv_batch_normalization(node1_1, 'd_n1.1/batch_norm2', is_training=training,
+ activation_fn=tf.nn.leaky_relu, reuse=if_reuse)
+ node1_1 = Conv2d(node1_1, 3, 3, self.df_dim * 8, 1, 1, name='d_n1.1/conv2d3', reuse=if_reuse)
+ node1_1 = conv_batch_normalization(node1_1, 'd_n1.1/batch_norm3', is_training=training, reuse=if_reuse)
+
+ node1 = add([node1_0, node1_1], name='d_n1/add')
+ node1 = tf.nn.leaky_relu(node1)
return node1
- def get_discriminator(self, x_var, c_var):
- x_code = self.d_image_template.construct(input=x_var) # s16 * s16 * df_dim*8
+ def get_discriminator(self, x_var, c_var, is_training, no_reuse=None):
+ if cfg.GAN.NETWORK_TYPE == "default":
+ x_code = self.d_encode_image(training=is_training, inputs=x_var, if_reuse=no_reuse) # s16 * s16 * df_dim*8
- c_code = self.d_context_template.construct(input=c_var)
- c_code = tf.expand_dims(tf.expand_dims(c_code, 1), 1)
- c_code = tf.tile(c_code, [1, self.s16, self.s16, 1]) # s16 * s16 * ef_dim
+ c_code = self.context_embedding(inputs=c_var, if_reuse=no_reuse)
+ c_code = tf.expand_dims(tf.expand_dims(c_code, 1), 1)
+ c_code = tf.tile(c_code, [1, self.s16, self.s16, 1]) # s16 * s16 * ef_dim
- x_c_code = tf.concat(3, [x_code, c_code])
- return self.d_discriminator_template.construct(input=x_c_code)
+ x_c_code = tf.concat([x_code, c_code], 3)
+ return self.discriminator(training=is_training, inputs=x_c_code, if_reuse=no_reuse)
+ else:
+ raise NotImplementedError
# hr_d_net
- def hr_d_encode_image(self):
- node1_0 = \
- (pt.template("input"). # 4s * 4s * 3
- custom_conv2d(self.df_dim, k_h=4, k_w=4). # 2s * 2s * df_dim
- apply(leaky_rectify, leakiness=0.2).
- custom_conv2d(self.df_dim * 2, k_h=4, k_w=4). # s * s * df_dim*2
- conv_batch_norm().
- apply(leaky_rectify, leakiness=0.2).
- custom_conv2d(self.df_dim * 4, k_h=4, k_w=4). # s2 * s2 * df_dim*4
- conv_batch_norm().
- apply(leaky_rectify, leakiness=0.2).
- custom_conv2d(self.df_dim * 8, k_h=4, k_w=4). # s4 * s4 * df_dim*8
- conv_batch_norm().
- apply(leaky_rectify, leakiness=0.2).
- custom_conv2d(self.df_dim * 16, k_h=4, k_w=4). # s8 * s8 * df_dim*16
- conv_batch_norm().
- apply(leaky_rectify, leakiness=0.2).
- custom_conv2d(self.df_dim * 32, k_h=4, k_w=4). # s16 * s16 * df_dim*32
- conv_batch_norm().
- apply(leaky_rectify, leakiness=0.2).
- custom_conv2d(self.df_dim * 16, k_h=1, k_w=1, d_h=1, d_w=1). # s16 * s16 * df_dim*16
- conv_batch_norm().
- apply(leaky_rectify, leakiness=0.2).
- custom_conv2d(self.df_dim * 8, k_h=1, k_w=1, d_h=1, d_w=1). # s16 * s16 * df_dim*8
- conv_batch_norm())
- node1_1 = \
- (node1_0.
- custom_conv2d(self.df_dim * 2, k_h=1, k_w=1, d_h=1, d_w=1).
- conv_batch_norm().
- apply(leaky_rectify, leakiness=0.2).
- custom_conv2d(self.df_dim * 2, k_h=3, k_w=3, d_h=1, d_w=1).
- conv_batch_norm().
- apply(leaky_rectify, leakiness=0.2).
- custom_conv2d(self.df_dim * 8, k_h=3, k_w=3, d_h=1, d_w=1).
- conv_batch_norm())
-
- node1 = \
- (node1_0.
- apply(tf.add, node1_1).
- apply(leaky_rectify, leakiness=0.2))
+ def hr_d_encode_image(self, inputs=None, training=True, if_reuse=None):
+ # input: 4s * 4s * 3
+ node1_0 = Conv2d(inputs, 4, 4, self.df_dim, 2, 2, activation_fn=tf.nn.leaky_relu,
+ name='hr_d_encode_n1.0/conv2d1', reuse=if_reuse) # 2s * 2s * df_dim
+
+ # s * s * df_dim*2
+ node1_0 = Conv2d(node1_0, 4, 4, self.df_dim * 2, 2, 2, name='hr_d_encode_n1.0/conv2d2', reuse=if_reuse)
+ node1_0 = conv_batch_normalization(node1_0, 'hr_d_encode_n1.0/batch_norm', is_training=training,
+ activation_fn=tf.nn.leaky_relu, reuse=if_reuse)
+ # s2 * s2 * df_dim*4
+ node1_0 = Conv2d(node1_0, 4, 4, self.df_dim * 4, 2, 2, name='hr_d_encode_n1.0/conv2d3', reuse=if_reuse)
+ node1_0 = conv_batch_normalization(node1_0, 'hr_d_encode_n1.0/batch_norm2', is_training=training,
+ activation_fn=tf.nn.leaky_relu, reuse=if_reuse)
+ # s4 * s4 * df_dim*8
+ node1_0 = Conv2d(node1_0, 4, 4, self.df_dim * 8, 2, 2, name='hr_d_encode_n1.0/conv2d4', reuse=if_reuse)
+ node1_0 = conv_batch_normalization(node1_0, 'hr_d_encode_n1.0/batch_norm3', is_training=training,
+ activation_fn=tf.nn.leaky_relu, reuse=if_reuse)
+ # s8 * s8 * df_dim*16
+ node1_0 = Conv2d(node1_0, 4, 4, self.df_dim * 16, 2, 2, name='hr_d_encode_n1.0/conv2d5', reuse=if_reuse)
+ node1_0 = conv_batch_normalization(node1_0, 'hr_d_encode_n1.0/batch_norm4', is_training=training,
+ activation_fn=tf.nn.leaky_relu, reuse=if_reuse)
+ # s16 * s16 * df_dim*32
+ node1_0 = Conv2d(node1_0, 4, 4, self.df_dim * 32, 2, 2, name='hr_d_encode_n1.0/conv2d6', reuse=if_reuse)
+ node1_0 = conv_batch_normalization(node1_0, 'hr_d_encode_n1.0/batch_norm5', is_training=training,
+ activation_fn=tf.nn.leaky_relu, reuse=if_reuse)
+ # s16 * s16 * df_dim*16
+ node1_0 = Conv2d(node1_0, 1, 1, self.df_dim * 16, 1, 1, name='hr_d_encode_n1.0/conv2d7', reuse=if_reuse)
+ node1_0 = conv_batch_normalization(node1_0, 'hr_d_encode_n1.0/batch_norm6', is_training=training,
+ activation_fn=tf.nn.leaky_relu, reuse=if_reuse)
+ # s16 * s16 * df_dim*8
+ node1_0 = Conv2d(node1_0, 1, 1, self.df_dim * 8, 1, 1, name='hr_d_encode_n1.0/conv2d8', reuse=if_reuse)
+ node1_0 = conv_batch_normalization(node1_0, 'hr_d_encode_n1.0/batch_norm7', is_training=training,
+ reuse=if_reuse)
+
+ node1_1 = Conv2d(node1_0, 1, 1, self.df_dim * 2, 1, 1, name='hr_d_encode_n1.1/conv2d', reuse=if_reuse)
+ node1_1 = conv_batch_normalization(node1_1, 'hr_d_encode_n1.1/batch_norm', is_training=training,
+ activation_fn=tf.nn.leaky_relu, reuse=if_reuse)
+ node1_1 = Conv2d(node1_1, 3, 3, self.df_dim * 2, 1, 1, name='hr_d_encode_n1.1/conv2d2', reuse=if_reuse)
+ node1_1 = conv_batch_normalization(node1_1, 'hr_d_encode_n1.1/batch_norm2', is_training=training,
+ activation_fn=tf.nn.leaky_relu, reuse=if_reuse)
+ node1_1 = Conv2d(node1_1, 3, 3, self.df_dim * 8, 1, 1, name='hr_d_encode_n1.1/conv2d3', reuse=if_reuse)
+ node1_1 = conv_batch_normalization(node1_1, 'hr_d_encode_n1.1/batch_norm3', is_training=training,
+ reuse=if_reuse)
+
+ node1 = add([node1_0, node1_1], name='hr_d_encode_n1/add')
+ node1 = tf.nn.leaky_relu(node1)
return node1
- def hr_get_discriminator(self, x_var, c_var):
- x_code = self.hr_d_image_template.construct(input=x_var) # s16 * s16 * df_dim*8
+ def hr_get_discriminator(self, x_var, c_var, is_training, no_reuse=None):
+ if cfg.GAN.NETWORK_TYPE == "default":
+ # s16 * s16 * df_dim*8
+ x_code = self.hr_d_encode_image(training=is_training, inputs=x_var, if_reuse=no_reuse)
- c_code = self.hr_d_context_template.construct(input=c_var)
- c_code = tf.expand_dims(tf.expand_dims(c_code, 1), 1)
- c_code = tf.tile(c_code, [1, self.s16, self.s16, 1]) # s16 * s16 * ef_dim
+ c_code = self.context_embedding(inputs=c_var, if_reuse=no_reuse)
+ c_code = tf.expand_dims(tf.expand_dims(c_code, 1), 1)
+ c_code = tf.tile(c_code, [1, self.s16, self.s16, 1]) # s16 * s16 * ef_dim
- x_c_code = tf.concat(3, [x_code, c_code])
- return self.hr_discriminator_template.construct(input=x_c_code)
+ x_c_code = tf.concat([x_code, c_code], 3)
+ return self.discriminator(training=is_training, inputs=x_c_code, if_reuse=no_reuse)
+ else:
+ raise NotImplementedError
diff --git a/stageII/run_exp.py b/stageII/run_exp.py
index 1ab9a7d..1285112 100644
--- a/stageII/run_exp.py
+++ b/stageII/run_exp.py
@@ -1,28 +1,26 @@
from __future__ import division
from __future__ import print_function
-import tensorflow as tf
-import dateutil
import dateutil.tz
import datetime
import argparse
import pprint
-from misc.datasets import TextDataset
-from stageII.model import CondGAN
-from stageII.trainer import CondGANTrainer
-from misc.utils import mkdir_p
-from misc.config import cfg, cfg_from_file
+import sys
+sys.path.append('misc')
+sys.path.append('stageII')
+
+from datasets import TextDataset
+from utils import mkdir_p
+from config import cfg, cfg_from_file
+from model import CondGAN
+from trainer import CondGANTrainer
def parse_args():
parser = argparse.ArgumentParser(description='Train a GAN network')
- parser.add_argument('--cfg', dest='cfg_file',
- help='optional config file',
- default=None, type=str)
- parser.add_argument('--gpu', dest='gpu_id',
- help='GPU device id to use [0]',
- default=-1, type=int)
+ parser.add_argument('--cfg', dest='cfg_file', help='optional config file', default=None, type=str)
+ parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]', default=-1, type=int)
# if len(sys.argv) == 1:
# parser.print_help()
# sys.exit(1)
@@ -49,23 +47,14 @@ def parse_args():
if cfg.TRAIN.FLAG:
filename_train = '%s/train' % (datadir)
dataset.train = dataset.get_data(filename_train)
- ckt_logs_dir = "ckt_logs/%s/%s_%s" % \
- (cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp)
+ ckt_logs_dir = "ckt_logs/%s/%s_%s" % (cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp)
mkdir_p(ckt_logs_dir)
else:
s_tmp = cfg.TRAIN.PRETRAINED_MODEL
ckt_logs_dir = s_tmp[:s_tmp.find('.ckpt')]
+ model = CondGAN(lr_imsize=int(dataset.image_shape[0] / dataset.hr_lr_ratio), hr_lr_ratio=dataset.hr_lr_ratio)
- model = CondGAN(
- lr_imsize=int(dataset.image_shape[0] / dataset.hr_lr_ratio),
- hr_lr_ratio=dataset.hr_lr_ratio
- )
-
- algo = CondGANTrainer(
- model=model,
- dataset=dataset,
- ckt_logs_dir=ckt_logs_dir
- )
+ algo = CondGANTrainer(model=model, dataset=dataset, ckt_logs_dir=ckt_logs_dir)
if cfg.TRAIN.FLAG:
algo.train()
diff --git a/stageII/trainer.py b/stageII/trainer.py
index 457405c..3899f01 100644
--- a/stageII/trainer.py
+++ b/stageII/trainer.py
@@ -1,19 +1,20 @@
from __future__ import division
from __future__ import print_function
-import prettytensor as pt
import tensorflow as tf
import numpy as np
-import scipy.misc
+import imageio
import os
-import sys
from six.moves import range
from progressbar import ETA, Bar, Percentage, ProgressBar
from PIL import Image, ImageDraw, ImageFont
+import sys
+sys.path.append('misc')
-from misc.config import cfg
-from misc.utils import mkdir_p
+from config import cfg
+from utils import mkdir_p
+from skimage.transform import resize
TINY = 1e-8
@@ -27,12 +28,7 @@ def KL_loss(mu, log_sigma):
class CondGANTrainer(object):
- def __init__(self,
- model,
- dataset=None,
- exp_name="model",
- ckt_logs_dir="ckt_logs",
- ):
+ def __init__(self, model, dataset=None, exp_name="model", ckt_logs_dir="ckt_logs"):
"""
:type model: RegularizedGAN
"""
@@ -51,39 +47,24 @@ def __init__(self,
self.hr_image_shape = self.dataset.image_shape
ratio = self.dataset.hr_lr_ratio
- self.lr_image_shape = [int(self.hr_image_shape[0] / ratio),
- int(self.hr_image_shape[1] / ratio),
+ self.lr_image_shape = [int(self.hr_image_shape[0] / ratio), int(self.hr_image_shape[1] / ratio),
self.hr_image_shape[2]]
print('hr_image_shape', self.hr_image_shape)
print('lr_image_shape', self.lr_image_shape)
def build_placeholder(self):
'''Helper function for init_opt'''
- self.hr_images = tf.placeholder(
- tf.float32, [self.batch_size] + self.hr_image_shape,
- name='real_hr_images')
- self.hr_wrong_images = tf.placeholder(
- tf.float32, [self.batch_size] + self.hr_image_shape,
- name='wrong_hr_images'
- )
- self.embeddings = tf.placeholder(
- tf.float32, [self.batch_size] + self.dataset.embedding_shape,
- name='conditional_embeddings'
- )
-
- self.generator_lr = tf.placeholder(
- tf.float32, [],
- name='generator_learning_rate'
- )
- self.discriminator_lr = tf.placeholder(
- tf.float32, [],
- name='discriminator_learning_rate'
- )
+ self.hr_images = tf.placeholder(tf.float32, [self.batch_size] + self.hr_image_shape, name='real_hr_images')
+ self.hr_wrong_images = tf.placeholder(tf.float32, [self.batch_size] + self.hr_image_shape,
+ name='wrong_hr_images')
+ self.embeddings = tf.placeholder(tf.float32, [self.batch_size] + self.dataset.embedding_shape,
+ name='conditional_embeddings')
+
+ self.generator_lr = tf.placeholder(tf.float32, [], name='generator_learning_rate')
+ self.discriminator_lr = tf.placeholder(tf.float32, [], name='discriminator_learning_rate')
#
- self.images = tf.image.resize_bilinear(self.hr_images,
- self.lr_image_shape[:2])
- self.wrong_images = tf.image.resize_bilinear(self.hr_wrong_images,
- self.lr_image_shape[:2])
+ self.images = tf.image.resize_bilinear(self.hr_images, self.lr_image_shape[:2])
+ self.wrong_images = tf.image.resize_bilinear(self.hr_wrong_images, self.lr_image_shape[:2])
def sample_encoded_context(self, embeddings):
'''Helper function for init_opt'''
@@ -107,95 +88,73 @@ def sample_encoded_context(self, embeddings):
def init_opt(self):
self.build_placeholder()
- with pt.defaults_scope(phase=pt.Phase.train):
- # ####get output from G network####################################
- with tf.variable_scope("g_net"):
- c, kl_loss = self.sample_encoded_context(self.embeddings)
- z = tf.random_normal([self.batch_size, cfg.Z_DIM])
- self.log_vars.append(("hist_c", c))
- self.log_vars.append(("hist_z", z))
- fake_images = self.model.get_generator(tf.concat(1, [c, z]))
-
+ # ####get output from G network####################################
+ with tf.variable_scope("g_net"): # For training
+ c, kl_loss = self.sample_encoded_context(self.embeddings)
+ z = tf.random_normal([self.batch_size, cfg.Z_DIM])
+ self.log_vars.append(("hist_c", c))
+ self.log_vars.append(("hist_z", z))
+ fake_images = self.model.get_generator(tf.concat([c, z], 1), True)
+ with tf.variable_scope("d_net"): # For training
# ####get discriminator_loss and generator_loss ###################
- discriminator_loss, generator_loss =\
- self.compute_losses(self.images,
- self.wrong_images,
- fake_images,
- self.embeddings,
- flag='lr')
+ discriminator_loss, generator_loss = self.compute_losses(self.images, self.wrong_images, fake_images,
+ self.embeddings, flag='lr')
generator_loss += kl_loss
self.log_vars.append(("g_loss_kl_loss", kl_loss))
self.log_vars.append(("g_loss", generator_loss))
self.log_vars.append(("d_loss", discriminator_loss))
- # #### For hr_g and hr_d #########################################
- with tf.variable_scope("hr_g_net"):
- hr_c, hr_kl_loss = self.sample_encoded_context(self.embeddings)
- self.log_vars.append(("hist_hr_c", hr_c))
- hr_fake_images = self.model.hr_get_generator(fake_images, hr_c)
+ # #### For hr_g and hr_d #########################################
+ with tf.variable_scope("hr_g_net"): # For training
+ hr_c, hr_kl_loss = self.sample_encoded_context(self.embeddings)
+ self.log_vars.append(("hist_hr_c", hr_c))
+ hr_fake_images = self.model.hr_get_generator(fake_images, hr_c, True)
+
+ with tf.variable_scope("hr_d_net"): # For training
# get losses
- hr_discriminator_loss, hr_generator_loss =\
- self.compute_losses(self.hr_images,
- self.hr_wrong_images,
- hr_fake_images,
- self.embeddings,
- flag='hr')
+ hr_discriminator_loss, hr_generator_loss = self.compute_losses(self.hr_images, self.hr_wrong_images,
+ hr_fake_images, self.embeddings, flag='hr')
hr_generator_loss += hr_kl_loss
self.log_vars.append(("hr_g_loss", hr_generator_loss))
self.log_vars.append(("hr_d_loss", hr_discriminator_loss))
# #######define self.g_sum, self.d_sum,....########################
- self.prepare_trainer(discriminator_loss, generator_loss,
- hr_discriminator_loss, hr_generator_loss)
+ self.prepare_trainer(discriminator_loss, generator_loss, hr_discriminator_loss, hr_generator_loss)
self.define_summaries()
- with pt.defaults_scope(phase=pt.Phase.test):
- self.sampler()
- self.visualization(cfg.TRAIN.NUM_COPY)
- print("success")
+ self.sampler()
+ self.visualization(cfg.TRAIN.NUM_COPY)
+ print("success")
def sampler(self):
- with tf.variable_scope("g_net", reuse=True):
+ with tf.variable_scope("g_net", reuse=True): # For testing
c, _ = self.sample_encoded_context(self.embeddings)
z = tf.random_normal([self.batch_size, cfg.Z_DIM])
- self.fake_images = self.model.get_generator(tf.concat(1, [c, z]))
- with tf.variable_scope("hr_g_net", reuse=True):
+ self.fake_images = self.model.get_generator(tf.concat([c, z], 1), False)
+ with tf.variable_scope("hr_g_net", reuse=True): # For testing
hr_c, _ = self.sample_encoded_context(self.embeddings)
- self.hr_fake_images =\
- self.model.hr_get_generator(self.fake_images, hr_c)
+ self.hr_fake_images = self.model.hr_get_generator(self.fake_images, hr_c, False)
- def compute_losses(self, images, wrong_images,
- fake_images, embeddings, flag='lr'):
+ def compute_losses(self, images, wrong_images, fake_images, embeddings, flag='lr'):
if flag == 'lr':
- real_logit =\
- self.model.get_discriminator(images, embeddings)
- wrong_logit =\
- self.model.get_discriminator(wrong_images, embeddings)
- fake_logit =\
- self.model.get_discriminator(fake_images, embeddings)
+ real_logit = self.model.get_discriminator(images, embeddings, True)
+ # Reuse the weights
+ wrong_logit = self.model.get_discriminator(wrong_images, embeddings, True, no_reuse=tf.AUTO_REUSE)
+ fake_logit = self.model.get_discriminator(fake_images, embeddings, True, no_reuse=tf.AUTO_REUSE)
else:
- real_logit =\
- self.model.hr_get_discriminator(images, embeddings)
- wrong_logit =\
- self.model.hr_get_discriminator(wrong_images, embeddings)
- fake_logit =\
- self.model.hr_get_discriminator(fake_images, embeddings)
-
- real_d_loss =\
- tf.nn.sigmoid_cross_entropy_with_logits(real_logit,
- tf.ones_like(real_logit))
+ real_logit = self.model.hr_get_discriminator(images, embeddings, True)
+ # Reuse the weights
+ wrong_logit = self.model.hr_get_discriminator(wrong_images, embeddings, True, no_reuse=tf.AUTO_REUSE)
+ fake_logit = self.model.hr_get_discriminator(fake_images, embeddings, True, no_reuse=tf.AUTO_REUSE)
+
+ real_d_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=real_logit, labels=tf.ones_like(real_logit))
real_d_loss = tf.reduce_mean(real_d_loss)
- wrong_d_loss =\
- tf.nn.sigmoid_cross_entropy_with_logits(wrong_logit,
- tf.zeros_like(wrong_logit))
+ wrong_d_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=wrong_logit, labels=tf.zeros_like(wrong_logit))
wrong_d_loss = tf.reduce_mean(wrong_d_loss)
- fake_d_loss =\
- tf.nn.sigmoid_cross_entropy_with_logits(fake_logit,
- tf.zeros_like(fake_logit))
+ fake_d_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_logit, labels=tf.zeros_like(fake_logit))
fake_d_loss = tf.reduce_mean(fake_d_loss)
if cfg.TRAIN.B_WRONG:
- discriminator_loss =\
- real_d_loss + (wrong_d_loss + fake_d_loss) / 2.
+ discriminator_loss = real_d_loss + (wrong_d_loss + fake_d_loss) / 2.
else:
discriminator_loss = real_d_loss + fake_d_loss
if flag == 'lr':
@@ -209,9 +168,7 @@ def compute_losses(self, images, wrong_images,
if cfg.TRAIN.B_WRONG:
self.log_vars.append(("hr_d_loss_wrong", wrong_d_loss))
- generator_loss = \
- tf.nn.sigmoid_cross_entropy_with_logits(fake_logit,
- tf.ones_like(fake_logit))
+ generator_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_logit, labels=tf.ones_like(fake_logit))
generator_loss = tf.reduce_mean(generator_loss)
if flag == 'lr':
self.log_vars.append(("g_loss_fake", generator_loss))
@@ -223,37 +180,25 @@ def compute_losses(self, images, wrong_images,
def define_one_trainer(self, loss, learning_rate, key_word):
'''Helper function for init_opt'''
all_vars = tf.trainable_variables()
- tarin_vars = [var for var in all_vars if
- var.name.startswith(key_word)]
+ tarin_vars = [var for var in all_vars if var.name.startswith(key_word)]
+
+ # Update the specific weights
+ update_ops_vars = [var for var in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if var.name.startswith(key_word)]
+ # Only update the moving mean and variance (from the batch normalization)
+ with tf.control_dependencies(update_ops_vars):
+ opt = tf.train.AdamOptimizer(learning_rate, beta1=0.5)
+ trainer = opt.minimize(loss, var_list=tarin_vars)
- opt = tf.train.AdamOptimizer(learning_rate, beta1=0.5)
- trainer = pt.apply_optimizer(opt, losses=[loss], var_list=tarin_vars)
return trainer
- def prepare_trainer(self, discriminator_loss, generator_loss,
- hr_discriminator_loss, hr_generator_loss):
+ def prepare_trainer(self, discriminator_loss, generator_loss, hr_discriminator_loss, hr_generator_loss):
ft_lr_retio = cfg.TRAIN.FT_LR_RETIO
- self.discriminator_trainer =\
- self.define_one_trainer(discriminator_loss,
- self.discriminator_lr * ft_lr_retio,
- 'd_')
- self.generator_trainer =\
- self.define_one_trainer(generator_loss,
- self.generator_lr * ft_lr_retio,
- 'g_')
- self.hr_discriminator_trainer =\
- self.define_one_trainer(hr_discriminator_loss,
- self.discriminator_lr,
- 'hr_d_')
- self.hr_generator_trainer =\
- self.define_one_trainer(hr_generator_loss,
- self.generator_lr,
- 'hr_g_')
-
- self.ft_generator_trainer = \
- self.define_one_trainer(hr_generator_loss,
- self.generator_lr * cfg.TRAIN.FT_LR_RETIO,
- 'g_')
+ self.discriminator_trainer = self.define_one_trainer(discriminator_loss, self.discriminator_lr * ft_lr_retio, 'd_')
+ self.generator_trainer = self.define_one_trainer(generator_loss, self.generator_lr * ft_lr_retio, 'g_')
+ self.hr_discriminator_trainer = self.define_one_trainer(hr_discriminator_loss, self.discriminator_lr, 'hr_d_')
+ self.hr_generator_trainer = self.define_one_trainer(hr_generator_loss, self.generator_lr, 'hr_g_')
+
+ self.ft_generator_trainer = self.define_one_trainer(hr_generator_loss, self.generator_lr * cfg.TRAIN.FT_LR_RETIO, 'g_')
self.log_vars.append(("hr_d_learning_rate", self.discriminator_lr))
self.log_vars.append(("hr_g_learning_rate", self.generator_lr))
@@ -263,21 +208,21 @@ def define_summaries(self):
all_sum = {'g': [], 'd': [], 'hr_g': [], 'hr_d': [], 'hist': []}
for k, v in self.log_vars:
if k.startswith('g'):
- all_sum['g'].append(tf.scalar_summary(k, v))
+ all_sum['g'].append(tf.summary.scalar(k, v))
elif k.startswith('d'):
- all_sum['d'].append(tf.scalar_summary(k, v))
+ all_sum['d'].append(tf.summary.scalar(k, v))
elif k.startswith('hr_g'):
- all_sum['hr_g'].append(tf.scalar_summary(k, v))
+ all_sum['hr_g'].append(tf.summary.scalar(k, v))
elif k.startswith('hr_d'):
- all_sum['hr_d'].append(tf.scalar_summary(k, v))
+ all_sum['hr_d'].append(tf.summary.scalar(k, v))
elif k.startswith('hist'):
- all_sum['hist'].append(tf.histogram_summary(k, v))
+ all_sum['hist'].append(tf.summary.histogram(k, v))
- self.g_sum = tf.merge_summary(all_sum['g'])
- self.d_sum = tf.merge_summary(all_sum['d'])
- self.hr_g_sum = tf.merge_summary(all_sum['hr_g'])
- self.hr_d_sum = tf.merge_summary(all_sum['hr_d'])
- self.hist_sum = tf.merge_summary(all_sum['hist'])
+ self.g_sum = tf.summary.merge(all_sum['g'])
+ self.d_sum = tf.summary.merge(all_sum['d'])
+ self.hr_g_sum = tf.summary.merge(all_sum['hr_g'])
+ self.hr_d_sum = tf.summary.merge(all_sum['hr_d'])
+ self.hist_sum = tf.summary.merge(all_sum['hist'])
def visualize_one_superimage(self, img_var, images, rows, filename):
stacked_img = []
@@ -287,35 +232,27 @@ def visualize_one_superimage(self, img_var, images, rows, filename):
for col in range(rows):
row_img.append(img_var[row * rows + col, :, :, :])
# each rows is 1realimage +10_fakeimage
- stacked_img.append(tf.concat(1, row_img))
- imgs = tf.expand_dims(tf.concat(0, stacked_img), 0)
- current_img_summary = tf.image_summary(filename, imgs)
+ stacked_img.append(tf.concat(row_img, 1))
+ imgs = tf.expand_dims(tf.concat(stacked_img, 0), 0)
+ current_img_summary = tf.summary.image(filename, imgs)
return current_img_summary, imgs
def visualization(self, n):
- fake_sum_train, superimage_train =\
- self.visualize_one_superimage(self.fake_images[:n * n],
- self.images[:n * n],
- n, "train")
- fake_sum_test, superimage_test =\
- self.visualize_one_superimage(self.fake_images[n * n:2 * n * n],
- self.images[n * n:2 * n * n],
- n, "test")
- self.superimages = tf.concat(0, [superimage_train, superimage_test])
- self.image_summary = tf.merge_summary([fake_sum_train, fake_sum_test])
-
- hr_fake_sum_train, hr_superimage_train =\
- self.visualize_one_superimage(self.hr_fake_images[:n * n],
- self.hr_images[:n * n, :, :, :],
- n, "hr_train")
- hr_fake_sum_test, hr_superimage_test =\
- self.visualize_one_superimage(self.hr_fake_images[n * n:2 * n * n],
- self.hr_images[n * n:2 * n * n],
- n, "hr_test")
- self.hr_superimages =\
- tf.concat(0, [hr_superimage_train, hr_superimage_test])
- self.hr_image_summary =\
- tf.merge_summary([hr_fake_sum_train, hr_fake_sum_test])
+ fake_sum_train, superimage_train = self.visualize_one_superimage(self.fake_images[:n * n], self.images[:n * n],
+ n, "train")
+ fake_sum_test, superimage_test = self.visualize_one_superimage(self.fake_images[n * n:2 * n * n],
+ self.images[n * n:2 * n * n], n, "test")
+ self.superimages = tf.concat([superimage_train, superimage_test], 0)
+ self.image_summary = tf.summary.merge([fake_sum_train, fake_sum_test])
+
+ hr_fake_sum_train, hr_superimage_train = self.visualize_one_superimage(self.hr_fake_images[:n * n],
+ self.hr_images[:n * n, :, :, :], n,
+ "hr_train")
+ hr_fake_sum_test, hr_superimage_test = self.visualize_one_superimage(self.hr_fake_images[n * n:2 * n * n],
+ self.hr_images[n * n:2 * n * n], n,
+ "hr_test")
+ self.hr_superimages = tf.concat([hr_superimage_train, hr_superimage_test], 0)
+ self.hr_image_summary = tf.summary.merge([hr_fake_sum_train, hr_fake_sum_test])
def preprocess(self, x, n):
# make sure every row with n column have the same embeddings
@@ -325,43 +262,33 @@ def preprocess(self, x, n):
return x
def epoch_sum_images(self, sess, n):
- images_train, _, embeddings_train, captions_train, _ =\
- self.dataset.train.next_batch(n * n, cfg.TRAIN.NUM_EMBEDDING)
+ images_train, _, embeddings_train, captions_train, _ = self.dataset.train.next_batch(n * n,
+ cfg.TRAIN.NUM_EMBEDDING)
images_train = self.preprocess(images_train, n)
embeddings_train = self.preprocess(embeddings_train, n)
- images_test, _, embeddings_test, captions_test, _ =\
- self.dataset.test.next_batch(n * n, 1)
+ images_test, _, embeddings_test, captions_test, _ = self.dataset.test.next_batch(n * n, 1)
images_test = self.preprocess(images_test, n)
embeddings_test = self.preprocess(embeddings_test, n)
images = np.concatenate([images_train, images_test], axis=0)
- embeddings =\
- np.concatenate([embeddings_train, embeddings_test], axis=0)
+ embeddings = np.concatenate([embeddings_train, embeddings_test], axis=0)
if self.batch_size > 2 * n * n:
- images_pad, _, embeddings_pad, _, _ =\
- self.dataset.test.next_batch(self.batch_size - 2 * n * n, 1)
+ images_pad, _, embeddings_pad, _, _ = self.dataset.test.next_batch(self.batch_size - 2 * n * n, 1)
images = np.concatenate([images, images_pad], axis=0)
embeddings = np.concatenate([embeddings, embeddings_pad], axis=0)
- feed_out = [self.superimages, self.image_summary,
- self.hr_superimages, self.hr_image_summary]
- feed_dict = {self.hr_images: images,
- self.embeddings: embeddings}
- gen_samples, img_summary, hr_gen_samples, hr_img_summary =\
- sess.run(feed_out, feed_dict)
+ feed_out = [self.superimages, self.image_summary, self.hr_superimages, self.hr_image_summary]
+ feed_dict = {self.hr_images: images, self.embeddings: embeddings}
+ gen_samples, img_summary, hr_gen_samples, hr_img_summary = sess.run(feed_out, feed_dict)
# save images generated for train and test captions
- scipy.misc.imsave('%s/lr_fake_train.jpg' %
- (self.log_dir), gen_samples[0])
- scipy.misc.imsave('%s/lr_fake_test.jpg' %
- (self.log_dir), gen_samples[1])
+ imageio.imwrite('%s/lr_fake_train.jpg' % (self.log_dir), gen_samples[0])
+ imageio.imwrite('%s/lr_fake_test.jpg' % (self.log_dir), gen_samples[1])
#
- scipy.misc.imsave('%s/hr_fake_train.jpg' %
- (self.log_dir), hr_gen_samples[0])
- scipy.misc.imsave('%s/hr_fake_test.jpg' %
- (self.log_dir), hr_gen_samples[1])
+ imageio.imwrite('%s/hr_fake_train.jpg' % (self.log_dir), hr_gen_samples[0])
+ imageio.imwrite('%s/hr_fake_test.jpg' % (self.log_dir), hr_gen_samples[1])
# pfi_train = open(self.log_dir + "/train.txt", "w")
pfi_test = open(self.log_dir + "/test.txt", "w")
@@ -378,17 +305,11 @@ def epoch_sum_images(self, sess, n):
def build_model(self, sess):
self.init_opt()
-
- sess.run(tf.initialize_all_variables())
+ sess.run(tf.global_variables_initializer())
if len(self.model_path) > 0:
print("Reading model parameters from %s" % self.model_path)
all_vars = tf.trainable_variables()
- # all_vars = tf.all_variables()
- restore_vars = []
- for var in all_vars:
- if var.name.startswith('g_') or var.name.startswith('d_'):
- restore_vars.append(var)
- # print(var.name)
+ restore_vars = [var for var in all_vars if var.name.startswith('g_') or var.name.startswith('d_')]
saver = tf.train.Saver(restore_vars)
saver.restore(sess, self.model_path)
@@ -401,56 +322,46 @@ def build_model(self, sess):
counter = 0
return counter
- def train_one_step(self, generator_lr,
- discriminator_lr,
- counter, summary_writer, log_vars, sess):
+ def train_one_step(self, generator_lr, discriminator_lr, counter, summary_writer, log_vars, sess):
# training d
- hr_images, hr_wrong_images, embeddings, _, _ =\
- self.dataset.train.next_batch(self.batch_size,
- cfg.TRAIN.NUM_EMBEDDING)
+ hr_images, hr_wrong_images, embeddings, _, _ = self.dataset.train.next_batch(self.batch_size,
+ cfg.TRAIN.NUM_EMBEDDING)
feed_dict = {self.hr_images: hr_images,
self.hr_wrong_images: hr_wrong_images,
self.embeddings: embeddings,
self.generator_lr: generator_lr,
- self.discriminator_lr: discriminator_lr
- }
+ self.discriminator_lr: discriminator_lr}
if cfg.TRAIN.FINETUNE_LR:
# train d1
- feed_out_d = [self.hr_discriminator_trainer,
- self.hr_d_sum,
- log_vars,
- self.hist_sum]
+ feed_out_d = [self.hr_discriminator_trainer, self.hr_d_sum, log_vars, self.hist_sum]
ret_list = sess.run(feed_out_d, feed_dict)
summary_writer.add_summary(ret_list[1], counter)
log_vals = ret_list[2]
summary_writer.add_summary(ret_list[3], counter)
+
# train g1 and finetune g0 with the loss of g1
- feed_out_g = [self.hr_generator_trainer,
- self.ft_generator_trainer,
- self.hr_g_sum]
+ feed_out_g = [self.hr_generator_trainer, self.ft_generator_trainer, self.hr_g_sum]
_, _, hr_g_sum = sess.run(feed_out_g, feed_dict)
summary_writer.add_summary(hr_g_sum, counter)
+
# finetune d0 with the loss of d0
feed_out_d = [self.discriminator_trainer, self.d_sum]
_, d_sum = sess.run(feed_out_d, feed_dict)
summary_writer.add_summary(d_sum, counter)
+
# finetune g0 with the loss of g0
feed_out_g = [self.generator_trainer, self.g_sum]
_, g_sum = sess.run(feed_out_g, feed_dict)
summary_writer.add_summary(g_sum, counter)
else:
# train d1
- feed_out_d = [self.hr_discriminator_trainer,
- self.hr_d_sum,
- log_vars,
- self.hist_sum]
+ feed_out_d = [self.hr_discriminator_trainer, self.hr_d_sum, log_vars, self.hist_sum]
ret_list = sess.run(feed_out_d, feed_dict)
summary_writer.add_summary(ret_list[1], counter)
log_vals = ret_list[2]
summary_writer.add_summary(ret_list[3], counter)
# train g1
- feed_out_g = [self.hr_generator_trainer,
- self.hr_g_sum]
+ feed_out_g = [self.hr_generator_trainer, self.hr_g_sum]
_, hr_g_sum = sess.run(feed_out_g, feed_dict)
summary_writer.add_summary(hr_g_sum, counter)
@@ -461,12 +372,10 @@ def train(self):
with tf.Session(config=config) as sess:
with tf.device("/gpu:%d" % cfg.GPU_ID):
counter = self.build_model(sess)
- saver = tf.train.Saver(tf.all_variables(),
- keep_checkpoint_every_n_hours=5)
+ saver = tf.train.Saver(tf.global_variables(), keep_checkpoint_every_n_hours=5)
# summary_op = tf.merge_all_summaries()
- summary_writer = tf.train.SummaryWriter(self.log_dir,
- sess.graph)
+ summary_writer = tf.summary.FileWriter(self.log_dir, sess.graph)
if cfg.TRAIN.FINETUNE_LR:
keys = ["hr_d_loss", "hr_g_loss", "d_loss", "g_loss"]
@@ -487,10 +396,8 @@ def train(self):
decay_start = cfg.TRAIN.PRETRAINED_EPOCH
epoch_start = int(counter / updates_per_epoch)
for epoch in range(epoch_start, self.max_epoch):
- widgets = ["epoch #%d|" % epoch,
- Percentage(), Bar(), ETA()]
- pbar = ProgressBar(maxval=updates_per_epoch,
- widgets=widgets)
+ widgets = ["epoch #%d|" % epoch, Percentage(), Bar(), ETA()]
+ pbar = ProgressBar(maxval=updates_per_epoch, widgets=widgets)
pbar.start()
if epoch % lr_decay_step == 0 and epoch > decay_start:
@@ -500,23 +407,17 @@ def train(self):
all_log_vals = []
for i in range(updates_per_epoch):
pbar.update(i)
- log_vals = self.train_one_step(generator_lr,
- discriminator_lr,
- counter, summary_writer,
+ log_vals = self.train_one_step(generator_lr, discriminator_lr, counter, summary_writer,
log_vars, sess)
all_log_vals.append(log_vals)
# save checkpoint
counter += 1
if counter % self.snapshot_interval == 0:
- snapshot_path = "%s/%s_%s.ckpt" %\
- (self.checkpoint_dir,
- self.exp_name,
- str(counter))
+ snapshot_path = "%s/%s_%s.ckpt" % (self.checkpoint_dir, self.exp_name, str(counter))
fn = saver.save(sess, snapshot_path)
print("Model saved in file: %s" % fn)
- img_summary, img_summary2 =\
- self.epoch_sum_images(sess, cfg.TRAIN.NUM_COPY)
+ img_summary, img_summary2 = self.epoch_sum_images(sess, cfg.TRAIN.NUM_COPY)
summary_writer.add_summary(img_summary, counter)
summary_writer.add_summary(img_summary2, counter)
@@ -526,9 +427,7 @@ def train(self):
dic_logs[k] = v
# print(k, v)
- log_line = "; ".join("%s: %s" %
- (str(k), str(dic_logs[k]))
- for k in dic_logs)
+ log_line = "; ".join("%s: %s" % (str(k), str(dic_logs[k])) for k in dic_logs)
print("Epoch %d | " % (epoch) + log_line)
sys.stdout.flush()
if np.any(np.isnan(avg_log_vals)):
@@ -559,15 +458,13 @@ def drawCaption(self, img, caption):
return img_txt
- def save_super_images(self, images, sample_batchs, hr_sample_batchs,
- savenames, captions_batchs,
- sentenceID, save_dir, subset):
+ def save_super_images(self, images, sample_batchs, hr_sample_batchs, savenames, captions_batchs, sentenceID,
+ save_dir, subset):
# batch_size samples for each embedding
# Up to 16 samples for each text embedding/sentence
numSamples = len(sample_batchs)
for j in range(len(savenames)):
- s_tmp = '%s-1real-%dsamples/%s/%s' %\
- (save_dir, numSamples, subset, savenames[j])
+ s_tmp = '%s-1real-%dsamples/%s/%s' % (save_dir, numSamples, subset, savenames[j])
folder = s_tmp[:s_tmp.rfind('/')]
if not os.path.isdir(folder):
print('Make a new folder: ', folder)
@@ -583,9 +480,10 @@ def save_super_images(self, images, sample_batchs, hr_sample_batchs,
row2 = [padding0, real_img, padding]
for i in range(np.minimum(8, numSamples)):
lr_img = sample_batchs[i][j]
+ lr_img = (lr_img + 1.0) * 127.5
hr_img = hr_sample_batchs[i][j]
hr_img = (hr_img + 1.0) * 127.5
- re_sample = scipy.misc.imresize(lr_img, hr_img.shape[:2])
+ re_sample = resize(lr_img, hr_img.shape[:2])
row1.append(re_sample)
row2.append(hr_img)
row1 = np.concatenate(row1, axis=1)
@@ -598,38 +496,34 @@ def save_super_images(self, images, sample_batchs, hr_sample_batchs,
row2 = [padding0, real_img, padding]
for i in range(8, len(sample_batchs)):
lr_img = sample_batchs[i][j]
+ lr_img = (lr_img + 1.0) * 127.5
hr_img = hr_sample_batchs[i][j]
hr_img = (hr_img + 1.0) * 127.5
- re_sample = scipy.misc.imresize(lr_img, hr_img.shape[:2])
+ re_sample = resize(lr_img, hr_img.shape[:2])
row1.append(re_sample)
row2.append(hr_img)
row1 = np.concatenate(row1, axis=1)
row2 = np.concatenate(row2, axis=1)
super_row = np.concatenate([row1, row2], axis=0)
superimage2 = np.zeros_like(superimage)
- superimage2[:super_row.shape[0],
- :super_row.shape[1],
- :super_row.shape[2]] = super_row
+ superimage2[:super_row.shape[0], :super_row.shape[1], :super_row.shape[2]] = super_row
mid_padding = np.zeros((64, superimage.shape[1], 3))
- superimage = np.concatenate([superimage, mid_padding,
- superimage2], axis=0)
+ superimage = np.concatenate([superimage, mid_padding, superimage2], axis=0)
top_padding = np.zeros((128, superimage.shape[1], 3))
- superimage =\
- np.concatenate([top_padding, superimage], axis=0)
+ superimage = np.concatenate([top_padding, superimage], axis=0)
captions = captions_batchs[j][sentenceID]
fullpath = '%s_sentence%d.jpg' % (s_tmp, sentenceID)
superimage = self.drawCaption(np.uint8(superimage), captions)
- scipy.misc.imsave(fullpath, superimage)
+ imageio.imwrite(fullpath, superimage)
def eval_one_dataset(self, sess, dataset, save_dir, subset='train'):
count = 0
print('num_examples:', dataset._num_examples)
while count < dataset._num_examples:
start = count % dataset._num_examples
- images, embeddings_batchs, savenames, captions_batchs =\
- dataset.next_batch_test(self.batch_size, start, 1)
+ images, embeddings_batchs, savenames, captions_batchs = dataset.next_batch_test(self.batch_size, start, 1)
print('count = ', count, 'start = ', start)
# the i-th sentence/caption
@@ -640,15 +534,12 @@ def eval_one_dataset(self, sess, dataset, save_dir, subset='train'):
# with randomness from noise z and conditioning augmentation.
numSamples = np.minimum(16, cfg.TRAIN.NUM_COPY)
for j in range(numSamples):
- hr_samples, samples =\
- sess.run([self.hr_fake_images, self.fake_images],
- {self.embeddings: embeddings_batchs[i]})
+ hr_samples, samples = sess.run([self.hr_fake_images, self.fake_images],
+ {self.embeddings: embeddings_batchs[i]})
samples_batchs.append(samples)
hr_samples_batchs.append(hr_samples)
- self.save_super_images(images, samples_batchs,
- hr_samples_batchs,
- savenames, captions_batchs,
- i, save_dir, subset)
+ self.save_super_images(images, samples_batchs, hr_samples_batchs, savenames, captions_batchs, i,
+ save_dir, subset)
count += self.batch_size
@@ -659,11 +550,10 @@ def evaluate(self):
if self.model_path.find('.ckpt') != -1:
self.init_opt()
print("Reading model parameters from %s" % self.model_path)
- saver = tf.train.Saver(tf.all_variables())
+ saver = tf.train.Saver(tf.global_variables())
saver.restore(sess, self.model_path)
# self.eval_one_dataset(sess, self.dataset.train,
# self.log_dir, subset='train')
- self.eval_one_dataset(sess, self.dataset.test,
- self.log_dir, subset='test')
+ self.eval_one_dataset(sess, self.dataset.test, self.log_dir, subset='test')
else:
print("Input a valid model path.")