Skip to content
This repository was archived by the owner on Jan 2, 2021. It is now read-only.

Commit f868514

Browse files
committed
Improve code for simply applying super-resolution.
1 parent 30534c6 commit f868514

File tree

2 files changed

+50
-54
lines changed

2 files changed

+50
-54
lines changed

README.rst

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
Neural Enhance
22
==============
33

4-
**Example #1** — China Town: `view comparison <http://5.9.70.47:4141/w/3b3c8054-9d00-11e6-9558-c86000be451f/view>`_ in 24-bit HD, `original photo <https://flic.kr/p/gnxcXH>`_ CC-BY-SA @cyalex.
4+
.. image:: docs/OldStation_example.gif
55

6-
.. image:: docs/Chinatown_example.gif
6+
**Example #1** — Old Station: `view comparison <http://5.9.70.47:4141/w/0f5177f4-9ce6-11e6-992c-c86000be451f/view>`_ in 24-bit HD, `original photo <https://flic.kr/p/oYhbBv>`_ CC-BY-SA @siv-athens.
7+
8+
----
79

8-
`As seen on TV! <https://www.youtube.com/watch?v=LhF_56SxrGk>`_ What if you could increase the resolution of your photos using technology from CSI laboratories? Thanks to deep learning and ``#NeuralEnhance``, it's now possible to train a neural network to zoom in to your images at 2x or even 4x. You'll get even better results by increasing the number of neurons or using specialized training images (e.g. faces).
10+
`As seen on TV! <https://www.youtube.com/watch?v=LhF_56SxrGk>`_ What if you could increase the resolution of your photos using technology from CSI laboratories? Thanks to deep learning and ``#NeuralEnhance``, it's now possible to train a neural network to zoom in to your images at 2x or even 4x. You'll get even better results by increasing the number of neurons or training with a dataset similar to your low resolution image.
911

1012
The catch? The neural network is hallucinating details based on its training from example images. It's not reconstructing your photo exactly as it would have been if it was HD. That's only possible in Holywood — but using deep learning as "Creative AI" works and its just as cool! Here's how you can get started...
1113

@@ -58,10 +60,10 @@ The default is to use ``--device=cpu``, if you have NVIDIA card setup with CUDA
5860
--smoothness-weight=5e4 --adversary-weight=2e2 \
5961
--generator-start=1 --discriminator-start=0 --adversarial-start=1
6062
61-
**Example #2** — Bank Lobby: `view comparison <http://5.9.70.47:4141/w/38d10880-9ce6-11e6-becb-c86000be451f/view>`_ in 24-bit HD, `original photo <https://flic.kr/p/6a8cwm>`_ CC-BY-SA @benarent.
62-
6363
.. image:: docs/BankLobby_example.gif
6464

65+
**Example #2** — Bank Lobby: `view comparison <http://5.9.70.47:4141/w/38d10880-9ce6-11e6-becb-c86000be451f/view>`_ in 24-bit HD, `original photo <https://flic.kr/p/6a8cwm>`_ CC-BY-SA @benarent.
66+
6567
2. Installation & Setup
6668
=======================
6769

@@ -100,7 +102,7 @@ After this, you should have ``pillow``, ``theano`` and ``lasagne`` installed in
100102
3. Background & Research
101103
========================
102104

103-
This code uses a combination of techniques from the following papers, as well as some minor improvements yet to be documented:
105+
This code uses a combination of techniques from the following papers, as well as some minor improvements yet to be documented (watch this repository for updates):
104106

105107
1. `Perceptual Losses for Real-Time Style Transfer and Super-Resolution <http://arxiv.org/abs/1603.08155>`_
106108
2. `Real-Time Super-Resolution Using Efficient Sub-Pixel Convolution <https://arxiv.org/abs/1609.05158>`_
@@ -142,10 +144,9 @@ It seems your terminal is misconfigured and not compatible with the way Python t
142144

143145
**FIX:** ``export LC_ALL=en_US.UTF-8``
144146

147+
.. image:: docs/Chinatown_example.gif
145148

146-
.. image:: docs/OldStation_example.gif
147-
148-
**Example #3** — Old Station: `view comparison <http://5.9.70.47:4141/w/0f5177f4-9ce6-11e6-992c-c86000be451f/view>`_ in 24-bit HD, `original photo <https://flic.kr/p/oYhbBv>`_ CC-BY-SA @siv-athens.
149+
**Example #3** — China Town: `view comparison <http://5.9.70.47:4141/w/3b3c8054-9d00-11e6-9558-c86000be451f/view>`_ in 24-bit HD, `original photo <https://flic.kr/p/gnxcXH>`_ CC-BY-SA @cyalex.
149150

150151
----
151152

enhance.py

Lines changed: 40 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
1515
#
1616

17+
__version__ = '0.1'
18+
1719
import os
1820
import sys
1921
import bz2
@@ -34,7 +36,7 @@
3436
add_arg = parser.add_argument
3537
add_arg('files', nargs='*', default=[])
3638
add_arg('--scales', default=2, type=int, help='How many times to perform 2x upsampling.')
37-
add_arg('--model', default='ne%ix.pkl.bz2', type=str, help='Name of the neural network to load/save.')
39+
add_arg('--model', default='medium', type=str, help='Name of the neural network to load/save.')
3840
add_arg('--train', default=False, action='store_true', help='Learn new or fine-tune a neural network.')
3941
add_arg('--batch-resolution', default=192, type=int, help='Resolution of images in training batch.')
4042
add_arg('--batch-size', default=15, type=int, help='Number of images per training batch.')
@@ -46,11 +48,11 @@
4648
add_arg('--learning-period', default=50, type=int, help='How often to decay the learning rate.')
4749
add_arg('--learning-decay', default=0.5, type=float, help='How much to decay the learning rate.')
4850
add_arg('--generator-filters', default=[64], nargs='+', type=int, help='Number of convolution units in network.')
49-
add_arg('--generator-blocks', default=12, type=int, help='Number of residual blocks per iteration.')
50-
add_arg('--generator-iters', default=1, type=int, help='Number of iterations in total.')
51+
add_arg('--generator-blocks', default=4, type=int, help='Number of residual blocks per iteration.')
5152
add_arg('--generator-residual', default=2, type=int, help='Number of layers in a residual block.')
5253
add_arg('--perceptual-layer', default='conv2_2', type=str, help='Which VGG layer to use as loss component.')
5354
add_arg('--perceptual-weight', default=1e0, type=float, help='Weight for VGG-layer perceptual loss.')
55+
add_arg('--discriminator-size', default=32, type=int, help='Multiplier for number of filters in D.')
5456
add_arg('--smoothness-weight', default=2e5, type=float, help='Weight of the total-variation loss.')
5557
add_arg('--adversary-weight', default=1e2, type=float, help='Weight of adversarial loss compoment.')
5658
add_arg('--generator-start', default=0, type=int, help='Epoch count to start training generator.')
@@ -102,7 +104,7 @@ def extend(lst): return itertools.chain(lst, itertools.repeat(lst[-1]))
102104
# Numeric Computing (GPU)
103105
import theano
104106
import theano.tensor as T
105-
import theano.tensor.nnet.neighbours
107+
T.nnet.softminus = lambda x: x - T.nnet.softplus(x)
106108

107109
# Support ansi colors in Windows too.
108110
if sys.platform == 'win32':
@@ -233,23 +235,8 @@ def last_layer(self):
233235
return list(self.network.values())[-1]
234236

235237
def make_layer(self, name, input, units, filter_size=(3,3), stride=(1,1), pad=(1,1), alpha=0.25):
236-
orig = '1.'+''.join(name.split('.')[1:])
237-
if orig+'x' in self.network:
238-
print('reused', orig, 'for', name)
239-
l = self.network[orig +'x']
240-
extra = {'W': l.W, 'b': l.b}
241-
else:
242-
extra = {}
243-
244-
conv = ConvLayer(input, units, filter_size=filter_size, stride=stride, pad=pad, nonlinearity=None, **extra)
245-
246-
alpha = lasagne.init.Constant(alpha)
247-
if orig +'>' in self.network:
248-
l = self.network[orig +'>']
249-
alpha = l.alpha
250-
251-
prelu = lasagne.layers.ParametricRectifierLayer(conv, alpha=alpha)
252-
238+
conv = ConvLayer(input, units, filter_size=filter_size, stride=stride, pad=pad, nonlinearity=None)
239+
prelu = lasagne.layers.ParametricRectifierLayer(conv, alpha=lasagne.init.Constant(alpha))
253240
self.network[name+'x'] = conv
254241
self.network[name+'>'] = prelu
255242
return prelu
@@ -267,20 +254,15 @@ def setup_generator(self, input, config):
267254
self.make_layer('iter.0-B', self.last_layer(), units, filter_size=(5,5), pad=(2,2))
268255
self.network['iter.0'] = self.last_layer()
269256

270-
for i in range(0, args.generator_iters):
271-
base = self.last_layer()
272-
for j in range(0, args.generator_blocks):
273-
self.make_block('%i.iter-%i'%(i+1, j), self.last_layer(), units)
274-
print('iter.%i-%i'%(i+1, j))
275-
# self.network['iter.%i'%(i+1)] = DropPathLayer([base, self.last_layer()])
257+
for i in range(0, args.generator_blocks):
258+
self.make_block('iter.%i'%(i+1), self.last_layer(), units)
276259

277260
for i in range(0, args.scales):
278261
u = next(units_iter)
279262
self.make_layer('scale%i.3'%i, self.last_layer(), u*4)
280263
self.network['scale%i.2'%i] = SubpixelReshuffleLayer(self.last_layer(), u, 2)
281264
self.make_layer('scale%i.1'%i, self.last_layer(), u)
282265

283-
284266
self.network['out'] = ConvLayer(self.last_layer(), 3, filter_size=(5,5), stride=(1,1), pad=(2,2),
285267
nonlinearity=lasagne.nonlinearities.tanh)
286268

@@ -314,15 +296,17 @@ def setup_perceptual(self, input):
314296
self.network['conv5_4'] = ConvLayer(self.network['conv5_3'], 512, 3, pad=1)
315297

316298
def setup_discriminator(self):
317-
self.make_layer('disc1.1', batch_norm(self.network['conv1_2']), 64, filter_size=(5,5), stride=(2,2), pad=(2,2))
318-
self.make_layer('disc1.2', self.last_layer(), 64, filter_size=(5,5), stride=(2,2), pad=(2,2))
319-
self.make_layer('disc2', batch_norm(self.network['conv2_2']), 128, filter_size=(5,5), stride=(2,2), pad=(2,2))
320-
self.make_layer('disc3', batch_norm(self.network['conv3_2']), 192, filter_size=(3,3), stride=(1,1), pad=(1,1))
299+
c = args.discriminator_size
300+
self.make_layer('disc1.1', batch_norm(self.network['conv1_2']), 1*c, filter_size=(5,5), stride=(2,2), pad=(2,2))
301+
self.make_layer('disc1.2', self.last_layer(), 1*c, filter_size=(5,5), stride=(2,2), pad=(2,2))
302+
self.make_layer('disc2', batch_norm(self.network['conv2_2']), 2*c, filter_size=(5,5), stride=(2,2), pad=(2,2))
303+
self.make_layer('disc3', batch_norm(self.network['conv3_2']), 3*c, filter_size=(3,3), stride=(1,1), pad=(1,1))
321304
hypercolumn = ConcatLayer([self.network['disc1.2>'], self.network['disc2>'], self.network['disc3>']])
322-
self.make_layer('disc4', hypercolumn, 192, filter_size=(3,3), stride=(1,1))
323-
self.make_layer('disc5', self.last_layer(), 96, filter_size=(3,3), stride=(1,1))
305+
self.make_layer('disc4', hypercolumn, 4*c, filter_size=(1,1), stride=(1,1), pad=(0,0))
306+
self.make_layer('disc5', self.last_layer(), 3*c, filter_size=(3,3), stride=(2,2))
307+
self.make_layer('disc6', self.last_layer(), 2*c, filter_size=(1,1), stride=(1,1), pad=(0,0))
324308
self.network['disc'] = batch_norm(ConvLayer(self.last_layer(), 1, filter_size=(1,1),
325-
nonlinearity=lasagne.nonlinearities.sigmoid))
309+
nonlinearity=lasagne.nonlinearities.linear))
326310

327311

328312
#------------------------------------------------------------------------------------------------------------------
@@ -351,20 +335,23 @@ def save_generator(self):
351335
def cast(p): return p.get_value().astype(np.float16)
352336
params = {k: [cast(p) for p in l.get_params()] for (k, l) in self.list_generator_layers()}
353337
config = {k: getattr(args, k) for k in ['generator_blocks', 'generator_residual', 'generator_filters']}
354-
filename = args.model % 2**args.scales
338+
filename = 'ne%ix-%s-%s.pkl.bz2' % (2**args.scales, args.model, __version__)
355339
pickle.dump((config, params), bz2.open(filename, 'wb'))
356340
print(' - Saved model as `{}` after training.'.format(filename))
357341

358342
def load_model(self):
359-
filename = args.model % 2**args.scales
360-
if not os.path.exists(filename): return {}, {}
343+
filename = 'ne%ix-%s-%s.pkl.bz2' % (2**args.scales, args.model, __version__)
344+
if not os.path.exists(filename):
345+
if args.train: return {}, {}
346+
error("Model file with pre-trained convolution layers not found. Download it here...",
347+
"https://github.com/alexjc/neural-enhance/releases/download/v%s/%s"%(__version__, filename))
361348
print(' - Loaded file `{}` with trained model.'.format(filename))
362349
return pickle.load(bz2.open(filename, 'rb'))
363350

364351
def load_generator(self, params):
365352
if len(params) == 0: return
366353
for k, l in self.list_generator_layers():
367-
assert k in params, "Couldn't find layer `%s` in loaded model.'"
354+
assert k in params, "Couldn't find layer `%s` in loaded model.'" % k
368355
assert len(l.get_params()) == len(params[k]), "Mismatch in types of layers."
369356
for p, v in zip(l.get_params(), params[k]):
370357
assert v.shape == p.get_value().shape, "Mismatch in number of parameters."
@@ -381,10 +368,10 @@ def loss_total_variation(self, x):
381368
return T.mean(((x[:,:,:-1,:-1] - x[:,:,1:,:-1])**2 + (x[:,:,:-1,:-1] - x[:,:,:-1,1:])**2)**1.25)
382369

383370
def loss_adversarial(self, d):
384-
return 1.0 - T.log(1E-6 + d[args.batch_size:]).mean()
371+
return T.mean(1.0 - T.nnet.softplus(d[args.batch_size:]))
385372

386373
def loss_discriminator(self, d):
387-
return T.mean(T.log(1E-6 + d[args.batch_size:]) + T.log(1E-6 + 1.0 - d[:args.batch_size]))
374+
return T.mean(T.nnet.softminus(d[args.batch_size:]) - T.nnet.softplus(d[:args.batch_size]))
388375

389376
def compile(self):
390377
# Helper function for rendering test images during training, or standalone non-training mode.
@@ -424,8 +411,12 @@ def compile(self):
424411
class NeuralEnhancer(object):
425412

426413
def __init__(self):
427-
print('{}Training {} epochs on random image sections with batch size {}.{}'\
428-
.format(ansi.BLUE_B, args.epochs, args.batch_size, ansi.BLUE))
414+
if args.train:
415+
print('{}Training {} epochs on random image sections with batch size {}.{}'\
416+
.format(ansi.BLUE_B, args.epochs, args.batch_size, ansi.BLUE))
417+
else:
418+
print('{}Enhancing {} image(s) specified on the command-line.{}'\
419+
.format(ansi.BLUE_B, len(args.files), ansi.BLUE))
429420

430421
self.thread = DataLoader()
431422
self.model = Model()
@@ -483,7 +474,7 @@ def train(self):
483474
print(' - generator {}'.format(' '.join(gen_info)))
484475

485476
real, fake = stats[:args.batch_size], stats[args.batch_size:]
486-
print(' - discriminator', real.mean(), len(np.where(real > 0.5)[0]), fake.mean(), len(np.where(fake < 0.5)[0]))
477+
print(' - discriminator', real.mean(), len(np.where(real > 0.5)[0]), fake.mean(), len(np.where(fake < -0.5)[0]))
487478
if epoch == args.adversarial_start-1:
488479
print(' - adversary mode: generator engaging discriminator.')
489480
self.model.adversary_weight.set_value(args.adversary_weight)
@@ -511,5 +502,9 @@ def process(self, image):
511502
enhancer.train()
512503

513504
for filename in args.files:
505+
print(filename)
514506
out = enhancer.process(scipy.ndimage.imread(filename, mode='RGB'))
515-
out.save(os.path.splitext(filename)[0]+'_enhanced.png')
507+
out.save(os.path.splitext(filename)[0]+'_ne%ix-%s.png'%(2**args.scales, args.model))
508+
509+
if args.files:
510+
print(ansi.ENDC)

0 commit comments

Comments
 (0)