# tf_unet is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# tf_unet is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with tf_unet. If not, see <http://www.gnu.org/licenses/>.
'''
Created on Aug 10, 2016
author: jakeret
'''
from __future__ import print_function, division, absolute_import, unicode_literals
import os
import numpy as np
from PIL import Image
[docs]def plot_prediction(x_test, y_test, prediction, save=False):
import matplotlib
import matplotlib.pyplot as plt
test_size = x_test.shape[0]
fig, ax = plt.subplots(test_size, 3, figsize=(12,12), sharey=True, sharex=True)
x_test = crop_to_shape(x_test, prediction.shape)
y_test = crop_to_shape(y_test, prediction.shape)
ax = np.atleast_2d(ax)
for i in range(test_size):
cax = ax[i, 0].imshow(x_test[i])
plt.colorbar(cax, ax=ax[i,0])
cax = ax[i, 1].imshow(y_test[i, ..., 1])
plt.colorbar(cax, ax=ax[i,1])
pred = prediction[i, ..., 1]
pred -= np.amin(pred)
pred /= np.amax(pred)
cax = ax[i, 2].imshow(pred)
plt.colorbar(cax, ax=ax[i,2])
if i==0:
ax[i, 0].set_title("x")
ax[i, 1].set_title("y")
ax[i, 2].set_title("pred")
fig.tight_layout()
if save:
fig.savefig(save)
else:
fig.show()
plt.show()
[docs]def to_rgb(img):
"""
Converts the given array into a RGB image. If the number of channels is not
3 the array is tiled such that it has 3 channels. Finally, the values are
rescaled to [0,255)
:param img: the array to convert [nx, ny, channels]
:returns img: the rgb image [nx, ny, 3]
"""
img = np.atleast_3d(img)
channels = img.shape[2]
if channels < 3:
img = np.tile(img, 3)
img[np.isnan(img)] = 0
img -= np.amin(img)
if np.amax(img) != 0:
img /= np.amax(img)
img *= 255
return img
[docs]def crop_to_shape(data, shape):
"""
Crops the array to the given image shape by removing the border (expects a tensor of shape [batches, nx, ny, channels].
:param data: the array to crop
:param shape: the target shape
"""
diff_nx = (data.shape[1] - shape[1])
diff_ny = (data.shape[2] - shape[2])
offset_nx_left = diff_nx // 2
offset_nx_right = diff_nx - offset_nx_left
offset_ny_left = diff_ny // 2
offset_ny_right = diff_ny - offset_ny_left
cropped = data[:, offset_nx_left:(-offset_nx_right), offset_ny_left:(-offset_ny_right)]
assert cropped.shape[1] == shape[1]
assert cropped.shape[2] == shape[2]
return cropped
[docs]def expand_to_shape(data, shape, border=0):
"""
Expands the array to the given image shape by padding it with a border (expects a tensor of shape [batches, nx, ny, channels].
:param data: the array to expand
:param shape: the target shape
"""
diff_nx = shape[1] - data.shape[1]
diff_ny = shape[2] - data.shape[2]
offset_nx_left = diff_nx // 2
offset_nx_right = diff_nx - offset_nx_left
offset_ny_left = diff_ny // 2
offset_ny_right = diff_ny - offset_ny_left
expanded = np.full(shape, border, dtype=np.float32)
expanded[:, offset_nx_left:(-offset_nx_right), offset_ny_left:(-offset_ny_right)] = data
return expanded
[docs]def combine_img_prediction(data, gt, pred):
"""
Combines the data, grouth thruth and the prediction into one rgb image
:param data: the data tensor
:param gt: the ground thruth tensor
:param pred: the prediction tensor
:returns img: the concatenated rgb image
"""
ny = pred.shape[2]
ch = data.shape[3]
img = np.concatenate((to_rgb(crop_to_shape(data, pred.shape).reshape(-1, ny, ch)),
to_rgb(crop_to_shape(gt[..., 1], pred.shape).reshape(-1, ny, 1)),
to_rgb(pred[..., 1].reshape(-1, ny, 1))), axis=1)
return img
[docs]def save_image(img, path):
"""
Writes the image to disk
:param img: the rgb image to save
:param path: the target path
"""
Image.fromarray(img.round().astype(np.uint8)).save(path, 'JPEG', dpi=[300,300], quality=90)
[docs]def create_training_path(output_path, prefix="run_"):
"""
Enumerates a new path using the prefix under the given output_path
:param output_path: the root path
:param prefix: (optional) defaults to `run_`
:return: the generated path as string in form `output_path`/`prefix_` + `<number>`
"""
idx = 0
path = os.path.join(output_path, "{:}{:03d}".format(prefix, idx))
while os.path.exists(path):
idx += 1
path = os.path.join(output_path, "{:}{:03d}".format(prefix, idx))
return path