Source code for tf_unet.util

# tf_unet is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# 
# tf_unet is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
# 
# You should have received a copy of the GNU General Public License
# along with tf_unet.  If not, see <http://www.gnu.org/licenses/>.


'''
Created on Aug 10, 2016

author: jakeret
'''
from __future__ import print_function, division, absolute_import, unicode_literals

import os

import numpy as np
from PIL import Image

[docs]def plot_prediction(x_test, y_test, prediction, save=False): import matplotlib import matplotlib.pyplot as plt test_size = x_test.shape[0] fig, ax = plt.subplots(test_size, 3, figsize=(12,12), sharey=True, sharex=True) x_test = crop_to_shape(x_test, prediction.shape) y_test = crop_to_shape(y_test, prediction.shape) ax = np.atleast_2d(ax) for i in range(test_size): cax = ax[i, 0].imshow(x_test[i]) plt.colorbar(cax, ax=ax[i,0]) cax = ax[i, 1].imshow(y_test[i, ..., 1]) plt.colorbar(cax, ax=ax[i,1]) pred = prediction[i, ..., 1] pred -= np.amin(pred) pred /= np.amax(pred) cax = ax[i, 2].imshow(pred) plt.colorbar(cax, ax=ax[i,2]) if i==0: ax[i, 0].set_title("x") ax[i, 1].set_title("y") ax[i, 2].set_title("pred") fig.tight_layout() if save: fig.savefig(save) else: fig.show() plt.show()
[docs]def to_rgb(img): """ Converts the given array into a RGB image. If the number of channels is not 3 the array is tiled such that it has 3 channels. Finally, the values are rescaled to [0,255) :param img: the array to convert [nx, ny, channels] :returns img: the rgb image [nx, ny, 3] """ img = np.atleast_3d(img) channels = img.shape[2] if channels < 3: img = np.tile(img, 3) img[np.isnan(img)] = 0 img -= np.amin(img) img /= np.amax(img) img *= 255 return img
[docs]def crop_to_shape(data, shape): """ Crops the array to the given image shape by removing the border (expects a tensor of shape [batches, nx, ny, channels]. :param data: the array to crop :param shape: the target shape """ diff_nx = (data.shape[1] - shape[1]) diff_ny = (data.shape[2] - shape[2]) offset_nx_left = diff_nx // 2 offset_nx_right = diff_nx - offset_nx_left offset_ny_left = diff_ny // 2 offset_ny_right = diff_ny - offset_ny_left cropped = data[:, offset_nx_left:(-offset_nx_right), offset_ny_left:(-offset_ny_right)] assert cropped.shape[1] == shape[1] assert cropped.shape[2] == shape[2] return cropped
[docs]def combine_img_prediction(data, gt, pred): """ Combines the data, grouth thruth and the prediction into one rgb image :param data: the data tensor :param gt: the ground thruth tensor :param pred: the prediction tensor :returns img: the concatenated rgb image """ ny = pred.shape[2] ch = data.shape[3] img = np.concatenate((to_rgb(crop_to_shape(data, pred.shape).reshape(-1, ny, ch)), to_rgb(crop_to_shape(gt[..., 1], pred.shape).reshape(-1, ny, 1)), to_rgb(pred[..., 1].reshape(-1, ny, 1))), axis=1) return img
[docs]def save_image(img, path): """ Writes the image to disk :param img: the rgb image to save :param path: the target path """ Image.fromarray(img.round().astype(np.uint8)).save(path, 'JPEG', dpi=[300,300], quality=90)
[docs]def create_training_path(output_path, prefix="run_"): """ Enumerates a new path using the prefix under the given output_path :param output_path: the root path :param prefix: (optional) defaults to `run_` :return: the generated path as string in form `output_path`/`prefix_` + `<number>` """ idx = 0 path = os.path.join(output_path, "{:}{:03d}".format(prefix, idx)) while os.path.exists(path): idx += 1 path = os.path.join(output_path, "{:}{:03d}".format(prefix, idx)) return path