diff --git a/README.md b/README.md index 979f0ca1..ad3a7ca4 100644 --- a/README.md +++ b/README.md @@ -11,15 +11,15 @@ Pytorch implementation for reproducing StackGAN_v2 results in the paper [StackGA - +Note: Code has been updated for Python3 usage. Thank you [David Stap](https://github.com/davidstap/AttnGAN) for your help upgrading the original StackGAN-v2 file. Also, sometimes during training my computer randomly shut down. I think was because the GPU was pulling in too much power, but be aware of this. ### Dependencies -python 2.7 +python 3.6+ -Pytorch +Pytorch 1.1.0+ In addition, please add the project folder to PYTHONPATH and `pip install` the following packages: -- `tensorboard` +- `tensorboardX` - `python-dateutil` - `easydict` - `pandas` @@ -56,11 +56,11 @@ In addition, please add the project folder to PYTHONPATH and `pip install` the f **Pretrained Model** -- [StackGAN-v2 for bird](https://drive.google.com/open?id=1s5Yf3nFiXx0lltMFOiJWB6s1LP24RcwH). Download and save it to `models/` (The [inception score](https://github.com/hanzhanggit/StackGAN-inception-model) for this Model is 4.04±0.05) -- [StackGAN-v2 for dog](https://drive.google.com/open?id=1zcwYfvhsKqb8svQDecTbx_mdYy3TG3F0). Download and save it to `models/` (The [inception score](https://github.com/openai/improved-gan/tree/master/inception_score) for this Model is 9.55±0.11) -- [StackGAN-v2 for cat](https://drive.google.com/open?id=1yPX62c-eCLCNxpziGX9qF_V6Verom3v9). Download and save it to `models/` -- [StackGAN-v2 for bedroom](https://drive.google.com/open?id=1Kqowg0ZLZbN1ek5N-YqEw9TlZeI3XV-K). Download and save it to `models/` -- [StackGAN-v2 for church](https://drive.google.com/open?id=13Pw4PZOkiAM5y_KoOwBzlXK9eQ2hHLfT). Download and save it to `models/` +- [StackGAN-v2 for bird](). Download and save it to `models/` (The [inception score](https://github.com/hanzhanggit/StackGAN-inception-model) for this Model is 4.04±0.05) +- [StackGAN-v2 for dog](). Download and save it to `models/` (The [inception score](https://github.com/openai/improved-gan/tree/master/inception_score) for this Model is 9.55±0.11) +- [StackGAN-v2 for cat](). Download and save it to `models/` +- [StackGAN-v2 for bedroom](). Download and save it to `models/` +- [StackGAN-v2 for church](). Download and save it to `models/` diff --git a/code/cfg/birds_3stages.yml b/code/cfg/birds_3stages.yml index 4a1322a1..8f82671f 100755 --- a/code/cfg/birds_3stages.yml +++ b/code/cfg/birds_3stages.yml @@ -2,7 +2,7 @@ CONFIG_NAME: '3stages' DATASET_NAME: 'birds' EMBEDDING_TYPE: 'cnn-rnn' -DATA_DIR: '../data/birds' +DATA_DIR: 'data/birds' GPU_ID: '0' WORKERS: 4 @@ -13,11 +13,11 @@ TREE: TRAIN: FLAG: True - NET_G: '' # '../output/birds_3stages/Model/netG_epoch_700.pth' - NET_D: '' # '../output/birds_3stages/Model/netD' - BATCH_SIZE: 24 + NET_G: '' # 'output/birds_3stages/Model/netG_epoch_700.pth' + NET_D: '' # 'output/birds_3stages/Model/netD' + BATCH_SIZE: 9 #24 MAX_EPOCH: 600 - SNAPSHOT_INTERVAL: 2000 + SNAPSHOT_INTERVAL: 1000 #2000 DISCRIMINATOR_LR: 0.0002 GENERATOR_LR: 0.0002 COEFF: diff --git a/code/datasets.py b/code/datasets.py index 04abc4cb..683c1dbf 100755 --- a/code/datasets.py +++ b/code/datasets.py @@ -3,27 +3,18 @@ from __future__ import print_function from __future__ import unicode_literals - -import torch.utils.data as data import torchvision.transforms as transforms -from PIL import Image -import PIL +import torch.utils.data as data import os -import os.path -import pickle import random import numpy as np import pandas as pd -from miscc.config import cfg - -import torch.utils.data as data -from PIL import Image -import os -import os.path import six -import string import sys -import torch + +from miscc.config import cfg +from PIL import Image + if sys.version_info[0] == 2: import cPickle as pickle else: @@ -37,8 +28,7 @@ def is_image_file(filename): return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) -def get_imgs(img_path, imsize, bbox=None, - transform=None, normalize=None): +def get_imgs(img_path, imsize, bbox=None, transform=None, normalize=None): img = Image.open(img_path).convert('RGB') width, height = img.size if bbox is not None: @@ -57,7 +47,7 @@ def get_imgs(img_path, imsize, bbox=None, ret = [] for i in range(cfg.TREE.BRANCH_NUM): if i < (cfg.TREE.BRANCH_NUM - 1): - re_img = transforms.Scale(imsize[i])(img) + re_img = transforms.Resize(imsize[i])(img) else: re_img = img ret.append(normalize(re_img)) @@ -71,7 +61,7 @@ def __init__(self, root, split_dir='train', custom_classes=None, root = os.path.join(root, split_dir) classes, class_to_idx = self.find_classes(root, custom_classes) imgs = self.make_dataset(classes, class_to_idx) - if len(imgs) == 0: + if imgs: raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) @@ -93,13 +83,13 @@ def __init__(self, root, split_dir='train', custom_classes=None, base_size = base_size * 2 print('num_classes', self.num_classes) - def find_classes(self, dir, custom_classes): + def find_classes(self, directory, custom_classes): classes = [] - for d in os.listdir(dir): + for d in os.listdir(directory): if os.path.isdir: if custom_classes is None or d in custom_classes: - classes.append(os.path.join(dir, d)) + classes.append(os.path.join(directory, d)) print('Valid classes: ', len(classes), classes) classes.sort() @@ -120,10 +110,7 @@ def make_dataset(self, classes, class_to_idx): def __getitem__(self, index): path, target = self.imgs[index] - imgs_list = get_imgs(path, self.imsize, - transform=self.transform, - normalize=self.norm) - + imgs_list = get_imgs(path, self.imsize, transform=self.transform, normalize=self.norm) return imgs_list def __len__(self): @@ -131,8 +118,7 @@ def __len__(self): class LSUNClass(data.Dataset): - def __init__(self, db_path, base_size=64, - transform=None, target_transform=None): + def __init__(self, db_path, base_size=64, transform=None, target_transform=None): import lmdb self.db_path = db_path self.env = lmdb.open(db_path, max_readers=1, readonly=True, lock=False, @@ -168,9 +154,7 @@ def __getitem__(self, index): buf = six.BytesIO() buf.write(imgbuf) buf.seek(0) - imgs = get_imgs(buf, self.imsize, - transform=self.transform, - normalize=self.norm) + imgs = get_imgs(buf, self.imsize, transform=self.transform, normalize=self.norm) return imgs def __len__(self): @@ -215,19 +199,15 @@ def __init__(self, data_dir, split='train', embedding_type='cnn-rnn', def load_bbox(self): data_dir = self.data_dir bbox_path = os.path.join(data_dir, 'CUB_200_2011/bounding_boxes.txt') - df_bounding_boxes = pd.read_csv(bbox_path, - delim_whitespace=True, - header=None).astype(int) + df_bounding_boxes = pd.read_csv(bbox_path, delim_whitespace=True, header=None).astype(int) # filepath = os.path.join(data_dir, 'CUB_200_2011/images.txt') - df_filenames = \ - pd.read_csv(filepath, delim_whitespace=True, header=None) + df_filenames = pd.read_csv(filepath, delim_whitespace=True, header=None) filenames = df_filenames[1].tolist() print('Total filenames: ', len(filenames), filenames[0]) # filename_bbox = {img_file[:-4]: [] for img_file in filenames} - numImgs = len(filenames) - for i in xrange(0, numImgs): + for i, item in enumerate(filenames): # this is the range of the number of images # bbox = [x-left, y-top, width, height] bbox = df_bounding_boxes.iloc[i][1:].tolist() @@ -240,14 +220,13 @@ def load_all_captions(self): def load_captions(caption_name): # self, cap_path = caption_name with open(cap_path, "r") as f: - captions = f.read().decode('utf8').split('\n') - captions = [cap.replace("\ufffd\ufffd", " ") - for cap in captions if len(cap) > 0] + captions = f.read().split('\n') + captions = [cap.replace("\ufffd\ufffd", " ") for cap in captions if len(cap) > 0] return captions caption_dict = {} for key in self.filenames: - caption_name = '%s/text/%s.txt' % (self.data_dir, key) + caption_name = '%s/text_c10/%s.txt' % (self.data_dir, key) captions = load_captions(caption_name) caption_dict[key] = captions return caption_dict @@ -261,7 +240,7 @@ def load_embedding(self, data_dir, embedding_type): embedding_filename = '/skip-thought-embeddings.pickle' with open(data_dir + embedding_filename, 'rb') as f: - embeddings = pickle.load(f) + embeddings = pickle.load(f, encoding="bytes") embeddings = np.array(embeddings) # embedding_shape = [embeddings.shape[-1]] print('embeddings: ', embeddings.shape) @@ -270,7 +249,7 @@ def load_embedding(self, data_dir, embedding_type): def load_class_id(self, data_dir, total_num): if os.path.isfile(data_dir + '/class_info.pickle'): with open(data_dir + '/class_info.pickle', 'rb') as f: - class_id = pickle.load(f) + class_id = pickle.load(f, encoding="bytes") else: class_id = np.arange(total_num) return class_id @@ -293,21 +272,18 @@ def prepair_training_pairs(self, index): # captions = self.captions[key] embeddings = self.embeddings[index, :, :] img_name = '%s/images/%s.jpg' % (data_dir, key) - imgs = get_imgs(img_name, self.imsize, - bbox, self.transform, normalize=self.norm) + imgs = get_imgs(img_name, self.imsize, bbox, self.transform, normalize=self.norm) wrong_ix = random.randint(0, len(self.filenames) - 1) - if(self.class_id[index] == self.class_id[wrong_ix]): + if self.class_id[index] == self.class_id[wrong_ix]: wrong_ix = random.randint(0, len(self.filenames) - 1) wrong_key = self.filenames[wrong_ix] if self.bbox is not None: wrong_bbox = self.bbox[wrong_key] else: wrong_bbox = None - wrong_img_name = '%s/images/%s.jpg' % \ - (data_dir, wrong_key) - wrong_imgs = get_imgs(wrong_img_name, self.imsize, - wrong_bbox, self.transform, normalize=self.norm) + wrong_img_name = '%s/images/%s.jpg' % (data_dir, wrong_key) + wrong_imgs = get_imgs(wrong_img_name, self.imsize, wrong_bbox, self.transform, normalize=self.norm) embedding_ix = random.randint(0, embeddings.shape[0] - 1) embedding = embeddings[embedding_ix, :] @@ -327,8 +303,7 @@ def prepair_test_pairs(self, index): # captions = self.captions[key] embeddings = self.embeddings[index, :, :] img_name = '%s/images/%s.jpg' % (data_dir, key) - imgs = get_imgs(img_name, self.imsize, - bbox, self.transform, normalize=self.norm) + imgs = get_imgs(img_name, self.imsize, bbox, self.transform, normalize=self.norm) if self.target_transform is not None: embeddings = self.target_transform(embeddings) diff --git a/code/main.py b/code/main.py index 3be8a85d..9f134ecf 100755 --- a/code/main.py +++ b/code/main.py @@ -1,4 +1,6 @@ from __future__ import print_function +from miscc.config import cfg, cfg_from_file + import torch import torchvision.transforms as transforms @@ -16,9 +18,6 @@ sys.path.append(dir_path) -from miscc.config import cfg, cfg_from_file - - # 19 classes --> 7 valid classes with 8,555 images DOG_LESS = ['n02084071', 'n01322604', 'n02112497', 'n02113335', 'n02111277', 'n02084732', 'n02111129', 'n02103406', 'n02112826', 'n02111626', @@ -95,8 +94,7 @@ def parse_args(): now = datetime.datetime.now(dateutil.tz.tzlocal()) timestamp = now.strftime('%Y_%m_%d_%H_%M_%S') - output_dir = '../output/%s_%s_%s' % \ - (cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp) + output_dir = '../output/%s_%s_%s' % (cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp) split_dir, bshuffle = 'train', True if not cfg.TRAIN.FLAG: @@ -107,7 +105,7 @@ def parse_args(): # Get data loader imsize = cfg.TREE.BASE_SIZE * (2 ** (cfg.TREE.BRANCH_NUM-1)) image_transform = transforms.Compose([ - transforms.Scale(int(imsize * 76 / 64)), + transforms.Resize(int(imsize * 76 / 64)), transforms.RandomCrop(imsize), transforms.RandomHorizontalFlip()]) if cfg.DATA_DIR.find('lsun') != -1: @@ -128,9 +126,8 @@ def parse_args(): transform=image_transform) assert dataset num_gpu = len(cfg.GPU_ID.split(',')) - dataloader = torch.utils.data.DataLoader( - dataset, batch_size=cfg.TRAIN.BATCH_SIZE * num_gpu, - drop_last=True, shuffle=bshuffle, num_workers=int(cfg.WORKERS)) + dataloader = torch.utils.data.DataLoader(dataset, batch_size=cfg.TRAIN.BATCH_SIZE * num_gpu, + drop_last=True, shuffle=bshuffle, num_workers=int(cfg.WORKERS)) # Define models and go to train/evaluate if not cfg.GAN.B_CONDITION: @@ -146,9 +143,8 @@ def parse_args(): algo.evaluate(split_dir) end_t = time.time() print('Total time for training:', end_t - start_t) - ''' Running time comparison for 10epoch with batch_size 24 on birds dataset - T(1gpu) = 1.383 T(2gpus) - - gpu 2: 2426.228544 -> 4min/epoch - - gpu 2 & 3: 1754.12295008 -> 2.9min/epoch - - gpu 3: 2514.02744293 - ''' + # Running time comparison for 10epoch with batch_size 24 on birds dataset + # T(1gpu) = 1.383 T(2gpus) + # - gpu 2: 2426.228544 -> 4min/epoch + # - gpu 2 & 3: 1754.12295008 -> 2.9min/epoch + # - gpu 3: 2514.02744293 diff --git a/code/miscc/config.py b/code/miscc/config.py index e3ab5fbf..b8dcec64 100755 --- a/code/miscc/config.py +++ b/code/miscc/config.py @@ -68,11 +68,11 @@ def _merge_a_into_b(a, b): options in b whenever they are also specified in a. """ if type(a) is not edict: - return + raise TypeError('{} is not a valid edict type'.format(a)) - for k, v in a.iteritems(): + for k, v in a.items(): # a must specify keys that are in b - if not b.has_key(k): + if k not in b: raise KeyError('{} is not a valid config key'.format(k)) # the types must match, too @@ -81,17 +81,14 @@ def _merge_a_into_b(a, b): if isinstance(b[k], np.ndarray): v = np.array(v, dtype=b[k].dtype) else: - raise ValueError(('Type mismatch ({} vs. {}) ' - 'for config key: {}').format(type(b[k]), - type(v), k)) + raise TypeError(('Type mismatch ({} vs. {}) for config key: {}'.format(type(b[k]), type(v), k))) # recursively merge dicts if type(v) is edict: try: _merge_a_into_b(a[k], b[k]) except: - print('Error under config key: {}'.format(k)) - raise + raise KeyError('Error under config key: {}'.format(k)) else: b[k] = v diff --git a/code/model.py b/code/model.py index a698bb1d..c0ec245b 100755 --- a/code/model.py +++ b/code/model.py @@ -1,12 +1,11 @@ - import torch import torch.nn as nn import torch.nn.parallel +import torch.utils.model_zoo as model_zoo + from miscc.config import cfg from torch.autograd import Variable -import torch.nn.functional as F from torchvision import models -import torch.utils.model_zoo as model_zoo # ############################## For Compute inception score ############################## @@ -18,8 +17,7 @@ def __init__(self): self.model = models.inception_v3() url = 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth' # print(next(model.parameters()).data) - state_dict = \ - model_zoo.load_url(url, map_location=lambda storage, loc: storage) + state_dict = model_zoo.load_url(url, map_location=lambda storage, loc: storage) self.model.load_state_dict(state_dict) for param in self.model.parameters(): param.requires_grad = False @@ -27,9 +25,9 @@ def __init__(self): # print(next(self.model.parameters()).data) # print(self.model) - def forward(self, input): + def forward(self, the_input): # [-1.0, 1.0] --> [0, 1.0] - x = input * 0.5 + 0.5 + x = the_input * 0.5 + 0.5 # mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225] # --> mean = 0, std = 1 x[:, 0] = (x[:, 0] - 0.485) / 0.229 @@ -37,10 +35,24 @@ def forward(self, input): x[:, 2] = (x[:, 2] - 0.406) / 0.225 # # --> fixed-size input: batch x 3 x 299 x 299 - x = nn.Upsample(size=(299, 299), mode='bilinear')(x) + x = nn.functional.interpolate(x, size=(299, 299), mode='bilinear', align_corners=False) + # x = nn.Upsample(size=(299, 299), mode='bilinear', align_corners=False)(x) # 299 x 299 x 3 x = self.model(x) - x = nn.Softmax()(x) + x = nn.Softmax(dim=1)(x) + return x + + +class Interpolate(nn.Module): + def __init__(self, scale_factor, mode, size=None): + super(Interpolate, self).__init__() + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.size = size + + def forward(self, x): + x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, size=self.size) return x @@ -50,22 +62,22 @@ def __init__(self): def forward(self, x): nc = x.size(1) - assert nc % 2 == 0, 'channels dont divide 2!' + assert nc % 2 == 0, 'channels do not divide 2!' nc = int(nc/2) - return x[:, :nc] * F.sigmoid(x[:, nc:]) + return x[:, :nc] * torch.sigmoid(x[:, nc:]) def conv3x3(in_planes, out_planes): - "3x3 convolution with padding" - return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, - padding=1, bias=False) + """ 3x3 convolution with padding """ + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False) # ############## G networks ################################################ # Upsale the spatial size by a factor of 2 def upBlock(in_planes, out_planes): block = nn.Sequential( - nn.Upsample(scale_factor=2, mode='nearest'), + Interpolate(scale_factor=2, mode='nearest'), + # nn.Upsample(scale_factor=2, mode='nearest'), conv3x3(in_planes, out_planes * 2), nn.BatchNorm2d(out_planes * 2), GLU() @@ -94,7 +106,6 @@ def __init__(self, channel_num): nn.BatchNorm2d(channel_num) ) - def forward(self, x): residual = x out = self.block(x) @@ -151,7 +162,6 @@ def define_module(self): nn.BatchNorm1d(ngf * 4 * 4 * 2), GLU()) - self.upsample1 = upBlock(ngf, ngf // 2) self.upsample2 = upBlock(ngf // 2, ngf // 4) self.upsample3 = upBlock(ngf // 4, ngf // 8) @@ -250,7 +260,7 @@ def define_module(self): if cfg.TREE.BRANCH_NUM > 2: self.h_net3 = NEXT_STAGE_G(self.gf_dim // 2) self.img_net3 = GET_IMAGE_G(self.gf_dim // 4) - if cfg.TREE.BRANCH_NUM > 3: # Recommended structure (mainly limited by GPU memory), and not test yet + if cfg.TREE.BRANCH_NUM > 3: # Recommended structure (mainly limited by GPU memory), and not test yet self.h_net4 = NEXT_STAGE_G(self.gf_dim // 4, num_residual=1) self.img_net4 = GET_IMAGE_G(self.gf_dim // 8) if cfg.TREE.BRANCH_NUM > 4: @@ -338,15 +348,11 @@ def define_module(self): efg = self.ef_dim self.img_code_s16 = encode_image_by_16times(ndf) - self.logits = nn.Sequential( - nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4), - nn.Sigmoid()) + self.logits = nn.Sequential(nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4), nn.Sigmoid()) if cfg.GAN.B_CONDITION: self.jointConv = Block3x3_leakRelu(ndf * 8 + efg, ndf * 8) - self.uncond_logits = nn.Sequential( - nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4), - nn.Sigmoid()) + self.uncond_logits = nn.Sequential(nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4), nn.Sigmoid()) def forward(self, x_var, c_code=None): x_code = self.img_code_s16(x_var) @@ -365,8 +371,7 @@ def forward(self, x_var, c_code=None): if cfg.GAN.B_CONDITION: out_uncond = self.uncond_logits(x_code) return [output.view(-1), out_uncond.view(-1)] - else: - return [output.view(-1)] + return [output.view(-1)] # For 128 x 128 images @@ -384,15 +389,11 @@ def define_module(self): self.img_code_s32 = downBlock(ndf * 8, ndf * 16) self.img_code_s32_1 = Block3x3_leakRelu(ndf * 16, ndf * 8) - self.logits = nn.Sequential( - nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4), - nn.Sigmoid()) + self.logits = nn.Sequential(nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4), nn.Sigmoid()) if cfg.GAN.B_CONDITION: self.jointConv = Block3x3_leakRelu(ndf * 8 + efg, ndf * 8) - self.uncond_logits = nn.Sequential( - nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4), - nn.Sigmoid()) + self.uncond_logits = nn.Sequential(nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4), nn.Sigmoid()) def forward(self, x_var, c_code=None): x_code = self.img_code_s16(x_var) @@ -413,8 +414,7 @@ def forward(self, x_var, c_code=None): if cfg.GAN.B_CONDITION: out_uncond = self.uncond_logits(x_code) return [output.view(-1), out_uncond.view(-1)] - else: - return [output.view(-1)] + return [output.view(-1)] # For 256 x 256 images @@ -440,9 +440,7 @@ def define_module(self): if cfg.GAN.B_CONDITION: self.jointConv = Block3x3_leakRelu(ndf * 8 + efg, ndf * 8) - self.uncond_logits = nn.Sequential( - nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4), - nn.Sigmoid()) + self.uncond_logits = nn.Sequential(nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4), nn.Sigmoid()) def forward(self, x_var, c_code=None): x_code = self.img_code_s16(x_var) @@ -465,8 +463,7 @@ def forward(self, x_var, c_code=None): if cfg.GAN.B_CONDITION: out_uncond = self.uncond_logits(x_code) return [output.view(-1), out_uncond.view(-1)] - else: - return [output.view(-1)] + return [output.view(-1)] # For 512 x 512 images: Recommended structure, not test yet @@ -488,15 +485,11 @@ def define_module(self): self.img_code_s128_2 = Block3x3_leakRelu(ndf * 32, ndf * 16) self.img_code_s128_3 = Block3x3_leakRelu(ndf * 16, ndf * 8) - self.logits = nn.Sequential( - nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4), - nn.Sigmoid()) + self.logits = nn.Sequential(nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4), nn.Sigmoid()) if cfg.GAN.B_CONDITION: self.jointConv = Block3x3_leakRelu(ndf * 8 + efg, ndf * 8) - self.uncond_logits = nn.Sequential( - nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4), - nn.Sigmoid()) + self.uncond_logits = nn.Sequential(nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4), nn.Sigmoid()) def forward(self, x_var, c_code=None): x_code = self.img_code_s16(x_var) @@ -521,8 +514,7 @@ def forward(self, x_var, c_code=None): if cfg.GAN.B_CONDITION: out_uncond = self.uncond_logits(x_code) return [output.view(-1), out_uncond.view(-1)] - else: - return [output.view(-1)] + return [output.view(-1)] # For 1024 x 1024 images: Recommended structure, not test yet @@ -546,15 +538,11 @@ def define_module(self): self.img_code_s256_3 = Block3x3_leakRelu(ndf * 32, ndf * 16) self.img_code_s256_4 = Block3x3_leakRelu(ndf * 16, ndf * 8) - self.logits = nn.Sequential( - nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4), - nn.Sigmoid()) + self.logits = nn.Sequential(nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4), nn.Sigmoid()) if cfg.GAN.B_CONDITION: self.jointConv = Block3x3_leakRelu(ndf * 8 + efg, ndf * 8) - self.uncond_logits = nn.Sequential( - nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4), - nn.Sigmoid()) + self.uncond_logits = nn.Sequential(nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4), nn.Sigmoid()) def forward(self, x_var, c_code=None): x_code = self.img_code_s16(x_var) @@ -581,5 +569,4 @@ def forward(self, x_var, c_code=None): if cfg.GAN.B_CONDITION: out_uncond = self.uncond_logits(x_code) return [output.view(-1), out_uncond.view(-1)] - else: - return [output.view(-1)] + return [output.view(-1)] diff --git a/code/trainer.py b/code/trainer.py index 497288e8..af096319 100755 --- a/code/trainer.py +++ b/code/trainer.py @@ -4,22 +4,23 @@ import torch.backends.cudnn as cudnn import torch import torch.nn as nn -from torch.autograd import Variable import torch.optim as optim import torchvision.utils as vutils import numpy as np import os import time + +from torch.autograd import Variable from PIL import Image, ImageFont, ImageDraw from copy import deepcopy from miscc.config import cfg from miscc.utils import mkdir_p +from model import G_NET, D_NET64, D_NET128, D_NET256, D_NET512, D_NET1024, INCEPTION_V3 + +from tensorboardX import FileWriter, summary -from tensorboard import summary -from tensorboard import FileWriter -from model import G_NET, D_NET64, D_NET128, D_NET256, D_NET512, D_NET1024, INCEPTION_V3 @@ -56,12 +57,12 @@ def KL_loss(mu, logvar): def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: - nn.init.orthogonal(m.weight.data, 1.0) + nn.init.orthogonal_(m.weight.data, 1.0) elif classname.find('BatchNorm') != -1: m.weight.data.normal_(1.0, 0.02) m.bias.data.fill_(0) elif classname.find('Linear') != -1: - nn.init.orthogonal(m.weight.data, 1.0) + nn.init.orthogonal_(m.weight.data, 1.0) if m.bias is not None: m.bias.data.fill_(0.0) @@ -83,8 +84,7 @@ def compute_inception_score(predictions, num_splits=1): istart = i * predictions.shape[0] // num_splits iend = (i + 1) * predictions.shape[0] // num_splits part = predictions[istart:iend, :] - kl = part * \ - (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0))) + kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0))) kl = np.mean(np.sum(kl, 1)) scores.append(np.exp(kl)) return np.mean(scores), np.std(scores) @@ -161,58 +161,45 @@ def define_optimizers(netG, netsD): optimizersD = [] num_Ds = len(netsD) for i in range(num_Ds): - opt = optim.Adam(netsD[i].parameters(), - lr=cfg.TRAIN.DISCRIMINATOR_LR, - betas=(0.5, 0.999)) + opt = optim.Adam(netsD[i].parameters(), lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999)) optimizersD.append(opt) # G_opt_paras = [] # for p in netG.parameters(): # if p.requires_grad: # G_opt_paras.append(p) - optimizerG = optim.Adam(netG.parameters(), - lr=cfg.TRAIN.GENERATOR_LR, - betas=(0.5, 0.999)) + optimizerG = optim.Adam(netG.parameters(), lr=cfg.TRAIN.GENERATOR_LR, betas=(0.5, 0.999)) return optimizerG, optimizersD def save_model(netG, avg_param_G, netsD, epoch, model_dir): load_params(netG, avg_param_G) - torch.save( - netG.state_dict(), - '%s/netG_%d.pth' % (model_dir, epoch)) + torch.save(netG.state_dict(), '%s/netG_%d.pth' % (model_dir, epoch)) for i in range(len(netsD)): netD = netsD[i] - torch.save( - netD.state_dict(), - '%s/netD%d.pth' % (model_dir, i)) + torch.save(netD.state_dict(), '%s/netD%d.pth' % (model_dir, i)) print('Save G/Ds models.') -def save_img_results(imgs_tcpu, fake_imgs, num_imgs, - count, image_dir, summary_writer): +def save_img_results(imgs_tcpu, fake_imgs, num_imgs, count, image_dir, summary_writer): num = cfg.TRAIN.VIS_COUNT # The range of real_img (i.e., self.imgs_tcpu[i][0:num]) # is changed to [0, 1] by function vutils.save_image real_img = imgs_tcpu[-1][0:num] - vutils.save_image( - real_img, '%s/real_samples.png' % (image_dir), - normalize=True) + vutils.save_image(real_img, '%s/real_samples.png' % (image_dir), normalize=True) real_img_set = vutils.make_grid(real_img).numpy() real_img_set = np.transpose(real_img_set, (1, 2, 0)) real_img_set = real_img_set * 255 real_img_set = real_img_set.astype(np.uint8) - sup_real_img = summary.image('real_img', real_img_set) + sup_real_img = summary.image('real_img', real_img_set, dataformats='HWC') summary_writer.add_summary(sup_real_img, count) for i in range(num_imgs): fake_img = fake_imgs[i][0:num] # The range of fake_img.data (i.e., self.fake_imgs[i][0:num]) # is still [-1. 1]... - vutils.save_image( - fake_img.data, '%s/count_%09d_fake_samples%d.png' % - (image_dir, count, i), normalize=True) + vutils.save_image(fake_img.data, '%s/count_%09d_fake_samples%d.png' % (image_dir, count, i), normalize=True) fake_img_set = vutils.make_grid(fake_img.data).cpu().numpy() @@ -220,7 +207,7 @@ def save_img_results(imgs_tcpu, fake_imgs, num_imgs, fake_img_set = (fake_img_set + 1) * 255 / 2 fake_img_set = fake_img_set.astype(np.uint8) - sup_fake_img = summary.image('fake_img%d' % i, fake_img_set) + sup_fake_img = summary.image('fake_img%d' % i, fake_img_set, dataformats='HWC') summary_writer.add_summary(sup_fake_img, count) summary_writer.flush() @@ -237,6 +224,7 @@ def __init__(self, output_dir, data_loader, imsize): mkdir_p(self.log_dir) self.summary_writer = FileWriter(self.log_dir) + s_gpus = cfg.GPU_ID.split(',') self.gpus = [int(ix) for ix in s_gpus] self.num_gpus = len(self.gpus) @@ -287,7 +275,7 @@ def train_Dnet(self, idx, count): optD.step() # log if flag == 0: - summary_D = summary.scalar('D_loss%d' % idx, errD.data[0]) + summary_D = summary.scalar('D_loss%d' % idx, errD.item()) self.summary_writer.add_summary(summary_D, count) return errD @@ -306,37 +294,33 @@ def train_Gnet(self, count): # errG = self.stage_coeff[i] * errG errG_total = errG_total + errG if flag == 0: - summary_G = summary.scalar('G_loss%d' % i, errG.data[0]) + summary_G = summary.scalar('G_loss%d' % i, errG.item()) self.summary_writer.add_summary(summary_G, count) # Compute color preserve losses if cfg.TRAIN.COEFF.COLOR_LOSS > 0: if self.num_Ds > 1: mu1, covariance1 = compute_mean_covariance(self.fake_imgs[-1]) - mu2, covariance2 = \ - compute_mean_covariance(self.fake_imgs[-2].detach()) + mu2, covariance2 = compute_mean_covariance(self.fake_imgs[-2].detach()) like_mu2 = cfg.TRAIN.COEFF.COLOR_LOSS * nn.MSELoss()(mu1, mu2) - like_cov2 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * \ - nn.MSELoss()(covariance1, covariance2) + like_cov2 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * nn.MSELoss()(covariance1, covariance2) errG_total = errG_total + like_mu2 + like_cov2 if self.num_Ds > 2: mu1, covariance1 = compute_mean_covariance(self.fake_imgs[-2]) - mu2, covariance2 = \ - compute_mean_covariance(self.fake_imgs[-3].detach()) + mu2, covariance2 = compute_mean_covariance(self.fake_imgs[-3].detach()) like_mu1 = cfg.TRAIN.COEFF.COLOR_LOSS * nn.MSELoss()(mu1, mu2) - like_cov1 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * \ - nn.MSELoss()(covariance1, covariance2) + like_cov1 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * nn.MSELoss()(covariance1, covariance2) errG_total = errG_total + like_mu1 + like_cov1 if flag == 0: - sum_mu = summary.scalar('G_like_mu2', like_mu2.data[0]) + sum_mu = summary.scalar('G_like_mu2', like_mu2.item()) self.summary_writer.add_summary(sum_mu, count) - sum_cov = summary.scalar('G_like_cov2', like_cov2.data[0]) + sum_cov = summary.scalar('G_like_cov2', like_cov2.item()) self.summary_writer.add_summary(sum_cov, count) if self.num_Ds > 2: - sum_mu = summary.scalar('G_like_mu1', like_mu1.data[0]) + sum_mu = summary.scalar('G_like_mu1', like_mu1.item()) self.summary_writer.add_summary(sum_mu, count) - sum_cov = summary.scalar('G_like_cov1', like_cov1.data[0]) + sum_cov = summary.scalar('G_like_cov1', like_cov1.item()) self.summary_writer.add_summary(sum_cov, count) errG_total.backward() @@ -344,23 +328,18 @@ def train_Gnet(self, count): return errG_total def train(self): - self.netG, self.netsD, self.num_Ds,\ - self.inception_model, start_count = load_network(self.gpus) + self.netG, self.netsD, self.num_Ds, self.inception_model, start_count = load_network(self.gpus) avg_param_G = copy_G_params(self.netG) - self.optimizerG, self.optimizersD = \ - define_optimizers(self.netG, self.netsD) + self.optimizerG, self.optimizersD = define_optimizers(self.netG, self.netsD) self.criterion = nn.BCELoss() - self.real_labels = \ - Variable(torch.FloatTensor(self.batch_size).fill_(1)) - self.fake_labels = \ - Variable(torch.FloatTensor(self.batch_size).fill_(0)) + self.real_labels = Variable(torch.FloatTensor(self.batch_size).fill_(1)) + self.fake_labels = Variable(torch.FloatTensor(self.batch_size).fill_(0)) nz = cfg.GAN.Z_DIM noise = Variable(torch.FloatTensor(self.batch_size, nz)) - fixed_noise = \ - Variable(torch.FloatTensor(self.batch_size, nz).normal_(0, 1)) + fixed_noise = Variable(torch.FloatTensor(self.batch_size, nz).normal_(0, 1)) if cfg.CUDA: self.criterion.cuda() @@ -406,14 +385,14 @@ def train(self): predictions.append(pred.data.cpu().numpy()) if count % 100 == 0: - summary_D = summary.scalar('D_loss', errD_total.data[0]) - summary_G = summary.scalar('G_loss', errG_total.data[0]) + summary_D = summary.scalar('D_loss', errD_total.item()) + summary_G = summary.scalar('G_loss', errG_total.item()) self.summary_writer.add_summary(summary_D, count) self.summary_writer.add_summary(summary_G, count) if step == 0: print('''[%d/%d][%d/%d] Loss_D: %.2f Loss_G: %.2f''' % (epoch, self.max_epoch, step, self.num_batches, - errD_total.data[0], errG_total.data[0])) + errD_total.item(), errG_total.item())) count = count + 1 if count % cfg.TRAIN.SNAPSHOT_INTERVAL == 0: @@ -437,8 +416,7 @@ def train(self): m_incep = summary.scalar('Inception_mean', mean) self.summary_writer.add_summary(m_incep, count) # - mean_nlpp, std_nlpp = \ - negative_log_posterior_probability(predictions, 10) + mean_nlpp, std_nlpp = negative_log_posterior_probability(predictions, 10) m_nlpp = summary.scalar('NLPP_mean', mean_nlpp) self.summary_writer.add_summary(m_nlpp, count) # @@ -477,9 +455,7 @@ def evaluate(self, split_dir): netG = torch.nn.DataParallel(netG, device_ids=self.gpus) print(netG) # state_dict = torch.load(cfg.TRAIN.NET_G) - state_dict = \ - torch.load(cfg.TRAIN.NET_G, - map_location=lambda storage, loc: storage) + state_dict = torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load ', cfg.TRAIN.NET_G) @@ -507,7 +483,7 @@ def evaluate(self, split_dir): netG.eval() num_batches = int(cfg.TEST.SAMPLE_NUM / self.batch_size) cnt = 0 - for step in xrange(num_batches): + for step in range(num_batches): noise.data.normal_(0, 1) fake_imgs, _, _ = netG(noise) if cfg.TEST.B_EXAMPLE: @@ -584,12 +560,9 @@ def train_Dnet(self, idx, count): errD_wrong = criterion(wrong_logits[0], fake_labels) errD_fake = criterion(fake_logits[0], fake_labels) if len(real_logits) > 1 and cfg.TRAIN.COEFF.UNCOND_LOSS > 0: - errD_real_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \ - criterion(real_logits[1], real_labels) - errD_wrong_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \ - criterion(wrong_logits[1], real_labels) - errD_fake_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \ - criterion(fake_logits[1], fake_labels) + errD_real_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * criterion(real_logits[1], real_labels) + errD_wrong_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * criterion(wrong_logits[1], real_labels) + errD_fake_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * criterion(fake_logits[1], fake_labels) # errD_real = errD_real + errD_real_uncond errD_wrong = errD_wrong + errD_wrong_uncond @@ -604,7 +577,7 @@ def train_Dnet(self, idx, count): optD.step() # log if flag == 0: - summary_D = summary.scalar('D_loss%d' % idx, errD.data[0]) + summary_D = summary.scalar('D_loss%d' % idx, errD.item()) self.summary_writer.add_summary(summary_D, count) return errD @@ -619,41 +592,38 @@ def train_Gnet(self, count): outputs = self.netsD[i](self.fake_imgs[i], mu) errG = criterion(outputs[0], real_labels) if len(outputs) > 1 and cfg.TRAIN.COEFF.UNCOND_LOSS > 0: - errG_patch = cfg.TRAIN.COEFF.UNCOND_LOSS *\ - criterion(outputs[1], real_labels) + errG_patch = cfg.TRAIN.COEFF.UNCOND_LOSS * criterion(outputs[1], real_labels) errG = errG + errG_patch errG_total = errG_total + errG if flag == 0: - summary_D = summary.scalar('G_loss%d' % i, errG.data[0]) + summary_D = summary.scalar('G_loss%d' % i, errG.item()) self.summary_writer.add_summary(summary_D, count) # Compute color consistency losses if cfg.TRAIN.COEFF.COLOR_LOSS > 0: if self.num_Ds > 1: mu1, covariance1 = compute_mean_covariance(self.fake_imgs[-1]) - mu2, covariance2 = \ - compute_mean_covariance(self.fake_imgs[-2].detach()) + mu2, covariance2 = compute_mean_covariance(self.fake_imgs[-2].detach()) like_mu2 = cfg.TRAIN.COEFF.COLOR_LOSS * nn.MSELoss()(mu1, mu2) - like_cov2 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * \ - nn.MSELoss()(covariance1, covariance2) + like_cov2 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * nn.MSELoss()(covariance1, covariance2) errG_total = errG_total + like_mu2 + like_cov2 if flag == 0: - sum_mu = summary.scalar('G_like_mu2', like_mu2.data[0]) - self.summary_writer.add_summary(sum_mu, count) - sum_cov = summary.scalar('G_like_cov2', like_cov2.data[0]) - self.summary_writer.add_summary(sum_cov, count) + sum_mu = summary.scalar('G_like_mu2', like_mu2.item()) + self.summary_writer.add_summary(sum_mu, global_step=count) + + sum_cov = summary.scalar('G_like_cov2', like_cov2.item()) + self.summary_writer.add_summary(sum_cov, global_step=count) if self.num_Ds > 2: mu1, covariance1 = compute_mean_covariance(self.fake_imgs[-2]) - mu2, covariance2 = \ - compute_mean_covariance(self.fake_imgs[-3].detach()) + mu2, covariance2 = compute_mean_covariance(self.fake_imgs[-3].detach()) like_mu1 = cfg.TRAIN.COEFF.COLOR_LOSS * nn.MSELoss()(mu1, mu2) - like_cov1 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * \ - nn.MSELoss()(covariance1, covariance2) + like_cov1 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * nn.MSELoss()(covariance1, covariance2) errG_total = errG_total + like_mu1 + like_cov1 if flag == 0: - sum_mu = summary.scalar('G_like_mu1', like_mu1.data[0]) + sum_mu = summary.scalar('G_like_mu1', like_mu1.item()) self.summary_writer.add_summary(sum_mu, count) - sum_cov = summary.scalar('G_like_cov1', like_cov1.data[0]) + + sum_cov = summary.scalar('G_like_cov1', like_cov1.item()) self.summary_writer.add_summary(sum_cov, count) kl_loss = KL_loss(mu, logvar) * cfg.TRAIN.COEFF.KL @@ -663,27 +633,22 @@ def train_Gnet(self, count): return kl_loss, errG_total def train(self): - self.netG, self.netsD, self.num_Ds,\ - self.inception_model, start_count = load_network(self.gpus) + self.netG, self.netsD, self.num_Ds, self.inception_model, start_count = load_network(self.gpus) avg_param_G = copy_G_params(self.netG) - self.optimizerG, self.optimizersD = \ - define_optimizers(self.netG, self.netsD) + self.optimizerG, self.optimizersD = define_optimizers(self.netG, self.netsD) self.criterion = nn.BCELoss() - self.real_labels = \ - Variable(torch.FloatTensor(self.batch_size).fill_(1)) - self.fake_labels = \ - Variable(torch.FloatTensor(self.batch_size).fill_(0)) + self.real_labels = Variable(torch.FloatTensor(self.batch_size).fill_(1)) + self.fake_labels = Variable(torch.FloatTensor(self.batch_size).fill_(0)) self.gradient_one = torch.FloatTensor([1.0]) self.gradient_half = torch.FloatTensor([0.5]) nz = cfg.GAN.Z_DIM noise = Variable(torch.FloatTensor(self.batch_size, nz)) - fixed_noise = \ - Variable(torch.FloatTensor(self.batch_size, nz).normal_(0, 1)) + fixed_noise = Variable(torch.FloatTensor(self.batch_size, nz).normal_(0, 1)) if cfg.CUDA: self.criterion.cuda() @@ -703,15 +668,13 @@ def train(self): ####################################################### # (0) Prepare training data ###################################################### - self.imgs_tcpu, self.real_imgs, self.wrong_imgs, \ - self.txt_embedding = self.prepare_data(data) + self.imgs_tcpu, self.real_imgs, self.wrong_imgs, self.txt_embedding = self.prepare_data(data) ####################################################### # (1) Generate fake images ###################################################### noise.data.normal_(0, 1) - self.fake_imgs, self.mu, self.logvar = \ - self.netG(noise, self.txt_embedding) + self.fake_imgs, self.mu, self.logvar = self.netG(noise, self.txt_embedding) ####################################################### # (2) Update D network @@ -733,9 +696,9 @@ def train(self): predictions.append(pred.data.cpu().numpy()) if count % 100 == 0: - summary_D = summary.scalar('D_loss', errD_total.data[0]) - summary_G = summary.scalar('G_loss', errG_total.data[0]) - summary_KL = summary.scalar('KL_loss', kl_loss.data[0]) + summary_D = summary.scalar('D_loss', errD_total.item()) + summary_G = summary.scalar('G_loss', errG_total.item()) + summary_KL = summary.scalar('KL_loss', kl_loss.item()) self.summary_writer.add_summary(summary_D, count) self.summary_writer.add_summary(summary_G, count) self.summary_writer.add_summary(summary_KL, count) @@ -748,10 +711,9 @@ def train(self): backup_para = copy_G_params(self.netG) load_params(self.netG, avg_param_G) # - self.fake_imgs, _, _ = \ - self.netG(fixed_noise, self.txt_embedding) - save_img_results(self.imgs_tcpu, self.fake_imgs, self.num_Ds, - count, self.image_dir, self.summary_writer) + self.fake_imgs, _, _ = self.netG(fixed_noise, self.txt_embedding) + save_img_results(self.imgs_tcpu, self.fake_imgs, self.num_Ds, count, self.image_dir, + self.summary_writer) # load_params(self.netG, backup_para) @@ -763,8 +725,7 @@ def train(self): m_incep = summary.scalar('Inception_mean', mean) self.summary_writer.add_summary(m_incep, count) # - mean_nlpp, std_nlpp = \ - negative_log_posterior_probability(predictions, 10) + mean_nlpp, std_nlpp = negative_log_posterior_probability(predictions, 10) m_nlpp = summary.scalar('NLPP_mean', mean_nlpp) self.summary_writer.add_summary(m_nlpp, count) # @@ -775,19 +736,17 @@ def train(self): Loss_D: %.2f Loss_G: %.2f Loss_KL: %.2f Time: %.2fs ''' # D(real): %.4f D(wrong):%.4f D(fake) %.4f % (epoch, self.max_epoch, self.num_batches, - errD_total.data[0], errG_total.data[0], - kl_loss.data[0], end_t - start_t)) + errD_total.item(), errG_total.item(), + kl_loss.item(), end_t - start_t)) save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir) self.summary_writer.close() - def save_superimages(self, images_list, filenames, - save_dir, split_dir, imsize): + def save_superimages(self, images_list, filenames, save_dir, split_dir, imsize): batch_size = images_list[0].size(0) num_sentences = len(images_list) for i in range(batch_size): - s_tmp = '%s/super/%s/%s' %\ - (save_dir, split_dir, filenames[i]) + s_tmp = '%s/super/%s/%s' % (save_dir, split_dir, filenames[i]) folder = s_tmp[:s_tmp.rfind('/')] if not os.path.isdir(folder): print('Make a new folder: ', folder) @@ -805,11 +764,9 @@ def save_superimages(self, images_list, filenames, super_img = torch.cat(super_img, 0) vutils.save_image(super_img, savename, nrow=10, normalize=True) - def save_singleimages(self, images, filenames, - save_dir, split_dir, sentenceID, imsize): + def save_singleimages(self, images, filenames, save_dir, split_dir, sentenceID, imsize): for i in range(images.size(0)): - s_tmp = '%s/single_samples/%s/%s' %\ - (save_dir, split_dir, filenames[i]) + s_tmp = '%s/single_samples/%s/%s' % (save_dir, split_dir, filenames[i]) folder = s_tmp[:s_tmp.rfind('/')] if not os.path.isdir(folder): print('Make a new folder: ', folder) @@ -834,9 +791,7 @@ def evaluate(self, split_dir): netG = torch.nn.DataParallel(netG, device_ids=self.gpus) print(netG) # state_dict = torch.load(cfg.TRAIN.NET_G) - state_dict = \ - torch.load(cfg.TRAIN.NET_G, - map_location=lambda storage, loc: storage) + state_dict = torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load ', cfg.TRAIN.NET_G) @@ -877,8 +832,7 @@ def evaluate(self, split_dir): # fake_img_list.append(fake_imgs[1].data.cpu()) fake_img_list.append(fake_imgs[2].data.cpu()) else: - self.save_singleimages(fake_imgs[-1], filenames, - save_dir, split_dir, i, 256) + self.save_singleimages(fake_imgs[-1], filenames, save_dir, split_dir, i, 256) # self.save_singleimages(fake_imgs[-2], filenames, # save_dir, split_dir, i, 128) # self.save_singleimages(fake_imgs[-3], filenames, @@ -889,5 +843,4 @@ def evaluate(self, split_dir): # save_dir, split_dir, 64) # self.save_superimages(fake_img_list, filenames, # save_dir, split_dir, 128) - self.save_superimages(fake_img_list, filenames, - save_dir, split_dir, 256) \ No newline at end of file + self.save_superimages(fake_img_list, filenames, save_dir, split_dir, 256)