· 6 years ago · Jan 16, 2020, 11:58 PM
1#!/usr/bin/env python
2# coding: utf-8
3
4# # Neural Style Transfer with tf.keras
5#
6# <table class="tfo-notebook-buttons" align="left">
7# <td>
8# <a target="_blank" href="https://colab.research.google.com/github/tensorflow/models/blob/master/research/nst_blogpost/4_Neural_Style_Transfer_with_Eager_Execution.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
9# </td>
10# <td>
11# <a target="_blank" href="https://github.com/tensorflow/models/blob/master/research/nst_blogpost/4_Neural_Style_Transfer_with_Eager_Execution.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
12# </td>
13# </table>
14
15# ## Overview
16#
17# In this tutorial, we will learn how to use deep learning to compose images in the style of another image (ever wish you could paint like Picasso or Van Gogh?). This is known as **neural style transfer**! This is a technique outlined in [Leon A. Gatys' paper, A Neural Algorithm of Artistic Style](https://arxiv.org/abs/1508.06576), which is a great read, and you should definitely check it out.
18#
19# But, what is neural style transfer?
20#
21# Neural style transfer is an optimization technique used to take three images, a **content** image, a **style reference** image (such as an artwork by a famous painter), and the **input** image you want to style -- and blend them together such that the input image is transformed to look like the content image, but “painted” in the style of the style image.
22#
23#
24# For example, let’s take an image of this turtle and Katsushika Hokusai's *The Great Wave off Kanagawa*:
25#
26# <img src="https://github.com/tensorflow/models/blob/master/research/nst_blogpost/Green_Sea_Turtle_grazing_seagrass.jpg?raw=1" alt="Drawing" style="width: 200px;"/>
27# <img src="https://github.com/tensorflow/models/blob/master/research/nst_blogpost/The_Great_Wave_off_Kanagawa.jpg?raw=1" alt="Drawing" style="width: 200px;"/>
28#
29# [Image of Green Sea Turtle](https://commons.wikimedia.org/wiki/File:Green_Sea_Turtle_grazing_seagrass.jpg)
30# -By P.Lindgren [CC BY-SA 3.0 (https://creativecommons.org/licenses/by-sa/3.0)], from Wikimedia Common
31#
32#
33# Now how would it look like if Hokusai decided to paint the picture of this Turtle exclusively with this style? Something like this?
34#
35# <img src="https://github.com/tensorflow/models/blob/master/research/nst_blogpost/wave_turtle.png?raw=1" alt="Drawing" style="width: 500px;"/>
36#
37# Is this magic or just deep learning? Fortunately, this doesn’t involve any witchcraft: style transfer is a fun and interesting technique that showcases the capabilities and internal representations of neural networks.
38#
39# The principle of neural style transfer is to define two distance functions, one that describes how different the content of two images are , $L_{content}$, and one that describes the difference between two images in terms of their style, $L_{style}$. Then, given three images, a desired style image, a desired content image, and the input image (initialized with the content image), we try to transform the input image to minimize the content distance with the content image and its style distance with the style image.
40# In summary, we’ll take the base input image, a content image that we want to match, and the style image that we want to match. We’ll transform the base input image by minimizing the content and style distances (losses) with backpropagation, creating an image that matches the content of the content image and the style of the style image.
41#
42# ### Specific concepts that will be covered:
43# In the process, we will build practical experience and develop intuition around the following concepts
44#
45# * **Eager Execution** - use TensorFlow's imperative programming environment that evaluates operations immediately
46# * [Learn more about eager execution](https://www.tensorflow.org/programmers_guide/eager)
47# * [See it in action](https://www.tensorflow.org/get_started/eager)
48# * ** Using [Functional API](https://keras.io/getting-started/functional-api-guide/) to define a model** - we'll build a subset of our model that will give us access to the necessary intermediate activations using the Functional API
49# * **Leveraging feature maps of a pretrained model** - Learn how to use pretrained models and their feature maps
50# * **Create custom training loops** - we'll examine how to set up an optimizer to minimize a given loss with respect to input parameters
51#
52# ### We will follow the general steps to perform style transfer:
53#
54# 1. Visualize data
55# 2. Basic Preprocessing/preparing our data
56# 3. Set up loss functions
57# 4. Create model
58# 5. Optimize for loss function
59#
60# **Audience:** This post is geared towards intermediate users who are comfortable with basic machine learning concepts. To get the most out of this post, you should:
61# * Read [Gatys' paper](https://arxiv.org/abs/1508.06576) - we'll explain along the way, but the paper will provide a more thorough understanding of the task
62# * [Understand reducing loss with gradient descent](https://developers.google.com/machine-learning/crash-course/reducing-loss/gradient-descent)
63#
64# **Time Estimated**: 30 min
65#
66
67# ## Setup
68#
69# ### Download Images
70
71# In[ ]:
72
73
74import os
75img_dir = '/tmp/nst'
76if not os.path.exists(img_dir):
77 os.makedirs(img_dir)
78!wget --quiet -P /tmp/nst/ https://upload.wikimedia.org/wikipedia/commons/d/d7/Green_Sea_Turtle_grazing_seagrass.jpg
79!wget --quiet -P /tmp/nst/ https://upload.wikimedia.org/wikipedia/commons/0/0a/The_Great_Wave_off_Kanagawa.jpg
80!wget --quiet -P /tmp/nst/ https://upload.wikimedia.org/wikipedia/commons/b/b4/Vassily_Kandinsky%2C_1913_-_Composition_7.jpg
81!wget --quiet -P /tmp/nst/ https://upload.wikimedia.org/wikipedia/commons/0/00/Tuebingen_Neckarfront.jpg
82!wget --quiet -P /tmp/nst/ https://upload.wikimedia.org/wikipedia/commons/6/68/Pillars_of_creation_2014_HST_WFC3-UVIS_full-res_denoised.jpg
83!wget --quiet -P /tmp/nst/ https://upload.wikimedia.org/wikipedia/commons/thumb/e/ea/Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg/1024px-Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg
84
85# ### Import and configure modules
86
87# In[ ]:
88
89
90import matplotlib.pyplot as plt
91import matplotlib as mpl
92mpl.rcParams['figure.figsize'] = (10,10)
93mpl.rcParams['axes.grid'] = False
94
95import numpy as np
96from PIL import Image
97import time
98import functools
99
100# In[ ]:
101
102
103import tensorflow as tf
104
105from tensorflow.python.keras.preprocessing import image as kp_image
106from tensorflow.python.keras import models
107from tensorflow.python.keras import losses
108from tensorflow.python.keras import layers
109from tensorflow.python.keras import backend as K
110
111# We’ll begin by enabling [eager execution](https://www.tensorflow.org/guide/eager). Eager execution allows us to work through this technique in the clearest and most readable way.
112
113# In[ ]:
114
115
116tf.enable_eager_execution()
117print("Eager execution: {}".format(tf.executing_eagerly()))
118
119# In[ ]:
120
121
122# Set up some global values here
123content_path = '/tmp/nst/Green_Sea_Turtle_grazing_seagrass.jpg'
124style_path = '/tmp/nst/The_Great_Wave_off_Kanagawa.jpg'
125
126# ## Visualize the input
127
128# In[ ]:
129
130
131def load_img(path_to_img):
132 max_dim = 512
133 img = Image.open(path_to_img)
134 long = max(img.size)
135 scale = max_dim/long
136 img = img.resize((round(img.size[0]*scale), round(img.size[1]*scale)), Image.ANTIALIAS)
137
138 img = kp_image.img_to_array(img)
139
140 # We need to broadcast the image array such that it has a batch dimension
141 img = np.expand_dims(img, axis=0)
142 return img
143
144# In[ ]:
145
146
147def imshow(img, title=None):
148 # Remove the batch dimension
149 out = np.squeeze(img, axis=0)
150 # Normalize for display
151 out = out.astype('uint8')
152 plt.imshow(out)
153 if title is not None:
154 plt.title(title)
155 plt.imshow(out)
156
157# These are input content and style images. We hope to "create" an image with the content of our content image, but with the style of the style image.
158
159# In[ ]:
160
161
162plt.figure(figsize=(10,10))
163
164content = load_img(content_path).astype('uint8')
165style = load_img(style_path).astype('uint8')
166
167plt.subplot(1, 2, 1)
168imshow(content, 'Content Image')
169
170plt.subplot(1, 2, 2)
171imshow(style, 'Style Image')
172plt.show()
173
174# ## Prepare the data
175# Let's create methods that will allow us to load and preprocess our images easily. We perform the same preprocessing process as are expected according to the VGG training process. VGG networks are trained on image with each channel normalized by `mean = [103.939, 116.779, 123.68]`and with channels BGR.
176
177# In[ ]:
178
179
180def load_and_process_img(path_to_img):
181 img = load_img(path_to_img)
182 img = tf.keras.applications.vgg19.preprocess_input(img)
183 return img
184
185# In order to view the outputs of our optimization, we are required to perform the inverse preprocessing step. Furthermore, since our optimized image may take its values anywhere between $- \infty$ and $\infty$, we must clip to maintain our values from within the 0-255 range.
186
187# In[ ]:
188
189
190def deprocess_img(processed_img):
191 x = processed_img.copy()
192 if len(x.shape) == 4:
193 x = np.squeeze(x, 0)
194 assert len(x.shape) == 3, ("Input to deprocess image must be an image of "
195 "dimension [1, height, width, channel] or [height, width, channel]")
196 if len(x.shape) != 3:
197 raise ValueError("Invalid input to deprocessing image")
198
199 # perform the inverse of the preprocessiing step
200 x[:, :, 0] += 103.939
201 x[:, :, 1] += 116.779
202 x[:, :, 2] += 123.68
203 x = x[:, :, ::-1]
204
205 x = np.clip(x, 0, 255).astype('uint8')
206 return x
207
208# ### Define content and style representations
209# In order to get both the content and style representations of our image, we will look at some intermediate layers within our model. As we go deeper into the model, these intermediate layers represent higher and higher order features. In this case, we are using the network architecture VGG19, a pretrained image classification network. These intermediate layers are necessary to define the representation of content and style from our images. For an input image, we will try to match the corresponding style and content target representations at these intermediate layers.
210#
211# #### Why intermediate layers?
212#
213# You may be wondering why these intermediate outputs within our pretrained image classification network allow us to define style and content representations. At a high level, this phenomenon can be explained by the fact that in order for a network to perform image classification (which our network has been trained to do), it must understand the image. This involves taking the raw image as input pixels and building an internal representation through transformations that turn the raw image pixels into a complex understanding of the features present within the image. This is also partly why convolutional neural networks are able to generalize well: they’re able to capture the invariances and defining features within classes (e.g., cats vs. dogs) that are agnostic to background noise and other nuisances. Thus, somewhere between where the raw image is fed in and the classification label is output, the model serves as a complex feature extractor; hence by accessing intermediate layers, we’re able to describe the content and style of input images.
214#
215#
216# Specifically we’ll pull out these intermediate layers from our network:
217#
218
219# In[ ]:
220
221
222# Content layer where will pull our feature maps
223content_layers = ['block5_conv2']
224
225# Style layer we are interested in
226style_layers = ['block1_conv1',
227 'block2_conv1',
228 'block3_conv1',
229 'block4_conv1',
230 'block5_conv1'
231 ]
232
233num_content_layers = len(content_layers)
234num_style_layers = len(style_layers)
235
236# ## Build the Model
237# In this case, we load [VGG19](https://keras.io/applications/#vgg19), and feed in our input tensor to the model. This will allow us to extract the feature maps (and subsequently the content and style representations) of the content, style, and generated images.
238#
239# We use VGG19, as suggested in the paper. In addition, since VGG19 is a relatively simple model (compared with ResNet, Inception, etc) the feature maps actually work better for style transfer.
240
241# In order to access the intermediate layers corresponding to our style and content feature maps, we get the corresponding outputs and using the Keras [**Functional API**](https://keras.io/getting-started/functional-api-guide/), we define our model with the desired output activations.
242#
243# With the Functional API defining a model simply involves defining the input and output:
244#
245# `model = Model(inputs, outputs)`
246
247# In[ ]:
248
249
250def get_model():
251 """ Creates our model with access to intermediate layers.
252
253 This function will load the VGG19 model and access the intermediate layers.
254 These layers will then be used to create a new model that will take input image
255 and return the outputs from these intermediate layers from the VGG model.
256
257 Returns:
258 returns a keras model that takes image inputs and outputs the style and
259 content intermediate layers.
260 """
261 # Load our model. We load pretrained VGG, trained on imagenet data
262 vgg = tf.keras.applications.vgg19.VGG19(include_top=False, weights='imagenet')
263 vgg.trainable = False
264 # Get output layers corresponding to style and content layers
265 style_outputs = [vgg.get_layer(name).output for name in style_layers]
266 content_outputs = [vgg.get_layer(name).output for name in content_layers]
267 model_outputs = style_outputs + content_outputs
268 # Build model
269 return models.Model(vgg.input, model_outputs)
270
271# In the above code snippet, we’ll load our pretrained image classification network. Then we grab the layers of interest as we defined earlier. Then we define a Model by setting the model’s inputs to an image and the outputs to the outputs of the style and content layers. In other words, we created a model that will take an input image and output the content and style intermediate layers!
272#
273
274# ## Define and create our loss functions (content and style distances)
275
276# ### Content Loss
277
278# Our content loss definition is actually quite simple. We’ll pass the network both the desired content image and our base input image. This will return the intermediate layer outputs (from the layers defined above) from our model. Then we simply take the euclidean distance between the two intermediate representations of those images.
279#
280# More formally, content loss is a function that describes the distance of content from our output image $x$ and our content image, $p$. Let $C_{nn}$ be a pre-trained deep convolutional neural network. Again, in this case we use [VGG19](https://keras.io/applications/#vgg19). Let $X$ be any image, then $C_{nn}(X)$ is the network fed by X. Let $F^l_{ij}(x) \in C_{nn}(x)$ and $P^l_{ij}(p) \in C_{nn}(p)$ describe the respective intermediate feature representation of the network with inputs $x$ and $p$ at layer $l$. Then we describe the content distance (loss) formally as: $$L^l_{content}(p, x) = \sum_{i, j} (F^l_{ij}(x) - P^l_{ij}(p))^2$$
281#
282# We perform backpropagation in the usual way such that we minimize this content loss. We thus change the initial image until it generates a similar response in a certain layer (defined in content_layer) as the original content image.
283#
284# This can be implemented quite simply. Again it will take as input the feature maps at a layer L in a network fed by x, our input image, and p, our content image, and return the content distance.
285#
286#
287
288# ### Computing content loss
289# We will actually add our content losses at each desired layer. This way, each iteration when we feed our input image through the model (which in eager is simply `model(input_image)`!) all the content losses through the model will be properly compute and because we are executing eagerly, all the gradients will be computed.
290
291# In[ ]:
292
293
294def get_content_loss(base_content, target):
295 return tf.reduce_mean(tf.square(base_content - target))
296
297# ## Style Loss
298
299# Computing style loss is a bit more involved, but follows the same principle, this time feeding our network the base input image and the style image. However, instead of comparing the raw intermediate outputs of the base input image and the style image, we instead compare the Gram matrices of the two outputs.
300#
301# Mathematically, we describe the style loss of the base input image, $x$, and the style image, $a$, as the distance between the style representation (the gram matrices) of these images. We describe the style representation of an image as the correlation between different filter responses given by the Gram matrix $G^l$, where $G^l_{ij}$ is the inner product between the vectorized feature map $i$ and $j$ in layer $l$. We can see that $G^l_{ij}$ generated over the feature map for a given image represents the correlation between feature maps $i$ and $j$.
302#
303# To generate a style for our base input image, we perform gradient descent from the content image to transform it into an image that matches the style representation of the original image. We do so by minimizing the mean squared distance between the feature correlation map of the style image and the input image. The contribution of each layer to the total style loss is described by
304# $$E_l = \frac{1}{4N_l^2M_l^2} \sum_{i,j}(G^l_{ij} - A^l_{ij})^2$$
305#
306# where $G^l_{ij}$ and $A^l_{ij}$ are the respective style representation in layer $l$ of $x$ and $a$. $N_l$ describes the number of feature maps, each of size $M_l = height * width$. Thus, the total style loss across each layer is
307# $$L_{style}(a, x) = \sum_{l \in L} w_l E_l$$
308# where we weight the contribution of each layer's loss by some factor $w_l$. In our case, we weight each layer equally ($w_l =\frac{1}{|L|}$)
309
310# ### Computing style loss
311# Again, we implement our loss as a distance metric .
312
313# In[ ]:
314
315
316def gram_matrix(input_tensor):
317 # We make the image channels first
318 channels = int(input_tensor.shape[-1])
319 a = tf.reshape(input_tensor, [-1, channels])
320 n = tf.shape(a)[0]
321 gram = tf.matmul(a, a, transpose_a=True)
322 return gram / tf.cast(n, tf.float32)
323
324def get_style_loss(base_style, gram_target):
325 """Expects two images of dimension h, w, c"""
326 # height, width, num filters of each layer
327 # We scale the loss at a given layer by the size of the feature map and the number of filters
328 height, width, channels = base_style.get_shape().as_list()
329 gram_style = gram_matrix(base_style)
330
331 return tf.reduce_mean(tf.square(gram_style - gram_target))# / (4. * (channels ** 2) * (width * height) ** 2)
332
333# ## Apply style transfer to our images
334#
335
336# ### Run Gradient Descent
337# If you aren't familiar with gradient descent/backpropagation or need a refresher, you should definitely check out this [awesome resource](https://developers.google.com/machine-learning/crash-course/reducing-loss/gradient-descent).
338#
339# In this case, we use the [Adam](https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Adam)* optimizer in order to minimize our loss. We iteratively update our output image such that it minimizes our loss: we don't update the weights associated with our network, but instead we train our input image to minimize loss. In order to do this, we must know how we calculate our loss and gradients.
340#
341# \* Note that L-BFGS, which if you are familiar with this algorithm is recommended, isn’t used in this tutorial because a primary motivation behind this tutorial was to illustrate best practices with eager execution, and, by using Adam, we can demonstrate the autograd/gradient tape functionality with custom training loops.
342#
343
344# We’ll define a little helper function that will load our content and style image, feed them forward through our network, which will then output the content and style feature representations from our model.
345
346# In[ ]:
347
348
349def get_feature_representations(model, content_path, style_path):
350 """Helper function to compute our content and style feature representations.
351
352 This function will simply load and preprocess both the content and style
353 images from their path. Then it will feed them through the network to obtain
354 the outputs of the intermediate layers.
355
356 Arguments:
357 model: The model that we are using.
358 content_path: The path to the content image.
359 style_path: The path to the style image
360
361 Returns:
362 returns the style features and the content features.
363 """
364 # Load our images in
365 content_image = load_and_process_img(content_path)
366 style_image = load_and_process_img(style_path)
367
368 # batch compute content and style features
369 style_outputs = model(style_image)
370 content_outputs = model(content_image)
371
372
373 # Get the style and content feature representations from our model
374 style_features = [style_layer[0] for style_layer in style_outputs[:num_style_layers]]
375 content_features = [content_layer[0] for content_layer in content_outputs[num_style_layers:]]
376 return style_features, content_features
377
378# ### Computing the loss and gradients
379# Here we use [**tf.GradientTape**](https://www.tensorflow.org/programmers_guide/eager#computing_gradients) to compute the gradient. It allows us to take advantage of the automatic differentiation available by tracing operations for computing the gradient later. It records the operations during the forward pass and then is able to compute the gradient of our loss function with respect to our input image for the backwards pass.
380
381# In[ ]:
382
383
384def compute_loss(model, loss_weights, init_image, gram_style_features, content_features):
385 """This function will compute the loss total loss.
386
387 Arguments:
388 model: The model that will give us access to the intermediate layers
389 loss_weights: The weights of each contribution of each loss function.
390 (style weight, content weight, and total variation weight)
391 init_image: Our initial base image. This image is what we are updating with
392 our optimization process. We apply the gradients wrt the loss we are
393 calculating to this image.
394 gram_style_features: Precomputed gram matrices corresponding to the
395 defined style layers of interest.
396 content_features: Precomputed outputs from defined content layers of
397 interest.
398
399 Returns:
400 returns the total loss, style loss, content loss, and total variational loss
401 """
402 style_weight, content_weight = loss_weights
403
404 # Feed our init image through our model. This will give us the content and
405 # style representations at our desired layers. Since we're using eager
406 # our model is callable just like any other function!
407 model_outputs = model(init_image)
408
409 style_output_features = model_outputs[:num_style_layers]
410 content_output_features = model_outputs[num_style_layers:]
411
412 style_score = 0
413 content_score = 0
414
415 # Accumulate style losses from all layers
416 # Here, we equally weight each contribution of each loss layer
417 weight_per_style_layer = 1.0 / float(num_style_layers)
418 for target_style, comb_style in zip(gram_style_features, style_output_features):
419 style_score += weight_per_style_layer * get_style_loss(comb_style[0], target_style)
420
421 # Accumulate content losses from all layers
422 weight_per_content_layer = 1.0 / float(num_content_layers)
423 for target_content, comb_content in zip(content_features, content_output_features):
424 content_score += weight_per_content_layer* get_content_loss(comb_content[0], target_content)
425
426 style_score *= style_weight
427 content_score *= content_weight
428
429 # Get total loss
430 loss = style_score + content_score
431 return loss, style_score, content_score
432
433# Then computing the gradients is easy:
434
435# In[ ]:
436
437
438def compute_grads(cfg):
439 with tf.GradientTape() as tape:
440 all_loss = compute_loss(**cfg)
441 # Compute gradients wrt input image
442 total_loss = all_loss[0]
443 return tape.gradient(total_loss, cfg['init_image']), all_loss
444
445# ### Optimization loop
446
447# In[ ]:
448
449
450import IPython.display
451
452def run_style_transfer(content_path,
453 style_path,
454 num_iterations=1000,
455 content_weight=1e3,
456 style_weight=1e-2):
457 # We don't need to (or want to) train any layers of our model, so we set their
458 # trainable to false.
459 model = get_model()
460 for layer in model.layers:
461 layer.trainable = False
462
463 # Get the style and content feature representations (from our specified intermediate layers)
464 style_features, content_features = get_feature_representations(model, content_path, style_path)
465 gram_style_features = [gram_matrix(style_feature) for style_feature in style_features]
466
467 # Set initial image
468 init_image = load_and_process_img(content_path)
469 init_image = tf.Variable(init_image, dtype=tf.float32)
470 # Create our optimizer
471 opt = tf.train.AdamOptimizer(learning_rate=5, beta1=0.99, epsilon=1e-1)
472
473 # For displaying intermediate images
474 iter_count = 1
475
476 # Store our best result
477 best_loss, best_img = float('inf'), None
478
479 # Create a nice config
480 loss_weights = (style_weight, content_weight)
481 cfg = {
482 'model': model,
483 'loss_weights': loss_weights,
484 'init_image': init_image,
485 'gram_style_features': gram_style_features,
486 'content_features': content_features
487 }
488
489 # For displaying
490 num_rows = 2
491 num_cols = 5
492 display_interval = num_iterations/(num_rows*num_cols)
493 start_time = time.time()
494 global_start = time.time()
495
496 norm_means = np.array([103.939, 116.779, 123.68])
497 min_vals = -norm_means
498 max_vals = 255 - norm_means
499
500 imgs = []
501 for i in range(num_iterations):
502 grads, all_loss = compute_grads(cfg)
503 loss, style_score, content_score = all_loss
504 opt.apply_gradients([(grads, init_image)])
505 clipped = tf.clip_by_value(init_image, min_vals, max_vals)
506 init_image.assign(clipped)
507 end_time = time.time()
508
509 if loss < best_loss:
510 # Update best loss and best image from total loss.
511 best_loss = loss
512 best_img = deprocess_img(init_image.numpy())
513
514 if i % display_interval== 0:
515 start_time = time.time()
516
517 # Use the .numpy() method to get the concrete numpy array
518 plot_img = init_image.numpy()
519 plot_img = deprocess_img(plot_img)
520 imgs.append(plot_img)
521 IPython.display.clear_output(wait=True)
522 IPython.display.display_png(Image.fromarray(plot_img))
523 print('Iteration: {}'.format(i))
524 print('Total loss: {:.4e}, '
525 'style loss: {:.4e}, '
526 'content loss: {:.4e}, '
527 'time: {:.4f}s'.format(loss, style_score, content_score, time.time() - start_time))
528 print('Total time: {:.4f}s'.format(time.time() - global_start))
529 IPython.display.clear_output(wait=True)
530 plt.figure(figsize=(14,4))
531 for i,img in enumerate(imgs):
532 plt.subplot(num_rows,num_cols,i+1)
533 plt.imshow(img)
534 plt.xticks([])
535 plt.yticks([])
536
537 return best_img, best_loss
538
539# In[ ]:
540
541
542best, best_loss = run_style_transfer(content_path,
543 style_path, num_iterations=1000)
544
545# In[ ]:
546
547
548Image.fromarray(best)
549
550# To download the image from Colab uncomment the following code:
551
552# In[ ]:
553
554
555#from google.colab import files
556#files.download('wave_turtle.png')
557
558# ## Visualize outputs
559# We "deprocess" the output image in order to remove the processing that was applied to it.
560
561# In[ ]:
562
563
564def show_results(best_img, content_path, style_path, show_large_final=True):
565 plt.figure(figsize=(10, 5))
566 content = load_img(content_path)
567 style = load_img(style_path)
568
569 plt.subplot(1, 2, 1)
570 imshow(content, 'Content Image')
571
572 plt.subplot(1, 2, 2)
573 imshow(style, 'Style Image')
574
575 if show_large_final:
576 plt.figure(figsize=(10, 10))
577
578 plt.imshow(best_img)
579 plt.title('Output Image')
580 plt.show()
581
582# In[ ]:
583
584
585show_results(best, content_path, style_path)
586
587# ## Try it on other images
588# Image of Tuebingen
589#
590# Photo By: Andreas Praefcke [GFDL (http://www.gnu.org/copyleft/fdl.html) or CC BY 3.0 (https://creativecommons.org/licenses/by/3.0)], from Wikimedia Commons
591
592# ### Starry night + Tuebingen
593
594# In[ ]:
595
596
597best_starry_night, best_loss = run_style_transfer('/tmp/nst/Tuebingen_Neckarfront.jpg',
598 '/tmp/nst/1024px-Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg')
599
600# In[ ]:
601
602
603show_results(best_starry_night, '/tmp/nst/Tuebingen_Neckarfront.jpg',
604 '/tmp/nst/1024px-Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg')
605
606# ### Pillars of Creation + Tuebingen
607
608# In[ ]:
609
610
611best_poc_tubingen, best_loss = run_style_transfer('/tmp/nst/Tuebingen_Neckarfront.jpg',
612 '/tmp/nst/Pillars_of_creation_2014_HST_WFC3-UVIS_full-res_denoised.jpg')
613
614# In[ ]:
615
616
617show_results(best_poc_tubingen,
618 '/tmp/nst/Tuebingen_Neckarfront.jpg',
619 '/tmp/nst/Pillars_of_creation_2014_HST_WFC3-UVIS_full-res_denoised.jpg')
620
621# ### Kandinsky Composition 7 + Tuebingen
622
623# In[ ]:
624
625
626best_kandinsky_tubingen, best_loss = run_style_transfer('/tmp/nst/Tuebingen_Neckarfront.jpg',
627 '/tmp/nst/Vassily_Kandinsky,_1913_-_Composition_7.jpg')
628
629# In[ ]:
630
631
632show_results(best_kandinsky_tubingen,
633 '/tmp/nst/Tuebingen_Neckarfront.jpg',
634 '/tmp/nst/Vassily_Kandinsky,_1913_-_Composition_7.jpg')
635
636# ### Pillars of Creation + Sea Turtle
637
638# In[ ]:
639
640
641best_poc_turtle, best_loss = run_style_transfer('/tmp/nst/Green_Sea_Turtle_grazing_seagrass.jpg',
642 '/tmp/nst/Pillars_of_creation_2014_HST_WFC3-UVIS_full-res_denoised.jpg')
643
644# In[ ]:
645
646
647show_results(best_poc_turtle,
648 '/tmp/nst/Green_Sea_Turtle_grazing_seagrass.jpg',
649 '/tmp/nst/Pillars_of_creation_2014_HST_WFC3-UVIS_full-res_denoised.jpg')
650
651# ## Key Takeaways
652#
653# ### What we covered:
654#
655# * We built several different loss functions and used backpropagation to transform our input image in order to minimize these losses
656# * In order to do this we had to load in a **pretrained model** and use its learned feature maps to describe the content and style representation of our images.
657# * Our main loss functions were primarily computing the distance in terms of these different representations
658# * We implemented this with a custom model and **eager execution**
659# * We built our custom model with the Functional API
660# * Eager execution allows us to dynamically work with tensors, using a natural python control flow
661# * We manipulated tensors directly, which makes debugging and working with tensors easier.
662# * We iteratively updated our image by applying our optimizers update rules using **tf.gradient**. The optimizer minimized a given loss with respect to our input image.
663
664#
665# **[Image of Tuebingen](https://commons.wikimedia.org/wiki/File:Tuebingen_Neckarfront.jpg)**
666# Photo By: Andreas Praefcke [GFDL (http://www.gnu.org/copyleft/fdl.html) or CC BY 3.0 (https://creativecommons.org/licenses/by/3.0)], from Wikimedia Commons
667#
668# **[Image of Green Sea Turtle](https://commons.wikimedia.org/wiki/File:Green_Sea_Turtle_grazing_seagrass.jpg)**
669# By P.Lindgren [CC BY-SA 3.0 (https://creativecommons.org/licenses/by-sa/3.0)], from Wikimedia Commons
670#
671#
672
673# In[ ]: