Source code for tf_unet.unet

# tf_unet is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# tf_unet is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with tf_unet.  If not, see <>.

Created on Jul 28, 2016

author: jakeret
from __future__ import print_function, division, absolute_import, unicode_literals

import os
import shutil
import numpy as np
from collections import OrderedDict
import logging

import tensorflow as tf

from tf_unet import util
from tf_unet.layers import (weight_variable, weight_variable_devonc, bias_variable,
                            conv2d, deconv2d, max_pool, crop_and_concat, pixel_wise_softmax,

logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')

[docs]def create_conv_net(x, keep_prob, channels, n_class, layers=3, features_root=16, filter_size=3, pool_size=2, summaries=True): """ Creates a new convolutional unet for the given parametrization. :param x: input tensor, shape [?,nx,ny,channels] :param keep_prob: dropout probability tensor :param channels: number of channels in the input image :param n_class: number of output labels :param layers: number of layers in the net :param features_root: number of features in the first layer :param filter_size: size of the convolution filter :param pool_size: size of the max pooling operation :param summaries: Flag if summaries should be created """ "Layers {layers}, features {features}, filter size {filter_size}x{filter_size}, pool size: {pool_size}x{pool_size}".format( layers=layers, features=features_root, filter_size=filter_size, pool_size=pool_size)) # Placeholder for the input image with tf.name_scope("preprocessing"): nx = tf.shape(x)[1] ny = tf.shape(x)[2] x_image = tf.reshape(x, tf.stack([-1, nx, ny, channels])) in_node = x_image batch_size = tf.shape(x_image)[0] weights = [] biases = [] convs = [] pools = OrderedDict() deconv = OrderedDict() dw_h_convs = OrderedDict() up_h_convs = OrderedDict() in_size = 1000 size = in_size # down layers for layer in range(0, layers): with tf.name_scope("down_conv_{}".format(str(layer))): features = 2 ** layer * features_root stddev = np.sqrt(2 / (filter_size ** 2 * features)) if layer == 0: w1 = weight_variable([filter_size, filter_size, channels, features], stddev, name="w1") else: w1 = weight_variable([filter_size, filter_size, features // 2, features], stddev, name="w1") w2 = weight_variable([filter_size, filter_size, features, features], stddev, name="w2") b1 = bias_variable([features], name="b1") b2 = bias_variable([features], name="b2") conv1 = conv2d(in_node, w1, b1, keep_prob) tmp_h_conv = tf.nn.relu(conv1) conv2 = conv2d(tmp_h_conv, w2, b2, keep_prob) dw_h_convs[layer] = tf.nn.relu(conv2) weights.append((w1, w2)) biases.append((b1, b2)) convs.append((conv1, conv2)) size -= 2 * 2 * (filter_size // 2) # valid conv if layer < layers - 1: pools[layer] = max_pool(dw_h_convs[layer], pool_size) in_node = pools[layer] size /= pool_size in_node = dw_h_convs[layers - 1] # up layers for layer in range(layers - 2, -1, -1): with tf.name_scope("up_conv_{}".format(str(layer))): features = 2 ** (layer + 1) * features_root stddev = np.sqrt(2 / (filter_size ** 2 * features)) wd = weight_variable_devonc([pool_size, pool_size, features // 2, features], stddev, name="wd") bd = bias_variable([features // 2], name="bd") h_deconv = tf.nn.relu(deconv2d(in_node, wd, pool_size) + bd) h_deconv_concat = crop_and_concat(dw_h_convs[layer], h_deconv) deconv[layer] = h_deconv_concat w1 = weight_variable([filter_size, filter_size, features, features // 2], stddev, name="w1") w2 = weight_variable([filter_size, filter_size, features // 2, features // 2], stddev, name="w2") b1 = bias_variable([features // 2], name="b1") b2 = bias_variable([features // 2], name="b2") conv1 = conv2d(h_deconv_concat, w1, b1, keep_prob) h_conv = tf.nn.relu(conv1) conv2 = conv2d(h_conv, w2, b2, keep_prob) in_node = tf.nn.relu(conv2) up_h_convs[layer] = in_node weights.append((w1, w2)) biases.append((b1, b2)) convs.append((conv1, conv2)) size *= pool_size size -= 2 * 2 * (filter_size // 2) # valid conv # Output Map with tf.name_scope("output_map"): weight = weight_variable([1, 1, features_root, n_class], stddev) bias = bias_variable([n_class], name="bias") conv = conv2d(in_node, weight, bias, tf.constant(1.0)) output_map = tf.nn.relu(conv) up_h_convs["out"] = output_map if summaries: with tf.name_scope("summaries"): for i, (c1, c2) in enumerate(convs): tf.summary.image('summary_conv_%02d_01' % i, get_image_summary(c1)) tf.summary.image('summary_conv_%02d_02' % i, get_image_summary(c2)) for k in pools.keys(): tf.summary.image('summary_pool_%02d' % k, get_image_summary(pools[k])) for k in deconv.keys(): tf.summary.image('summary_deconv_concat_%02d' % k, get_image_summary(deconv[k])) for k in dw_h_convs.keys(): tf.summary.histogram("dw_convolution_%02d" % k + '/activations', dw_h_convs[k]) for k in up_h_convs.keys(): tf.summary.histogram("up_convolution_%s" % k + '/activations', up_h_convs[k]) variables = [] for w1, w2 in weights: variables.append(w1) variables.append(w2) for b1, b2 in biases: variables.append(b1) variables.append(b2) return output_map, variables, int(in_size - size)
[docs]class Unet(object): """ A unet implementation :param channels: number of channels in the input image :param n_class: number of output labels :param cost: (optional) name of the cost function. Default is 'cross_entropy' :param cost_kwargs: (optional) kwargs passed to the cost function. See Unet._get_cost for more options """ def __init__(self, channels, n_class, cost="cross_entropy", cost_kwargs={}, **kwargs): tf.reset_default_graph() self.n_class = n_class self.summaries = kwargs.get("summaries", True) self.x = tf.placeholder("float", shape=[None, None, None, channels], name="x") self.y = tf.placeholder("float", shape=[None, None, None, n_class], name="y") self.keep_prob = tf.placeholder(tf.float32, name="dropout_probability") # dropout (keep probability) logits, self.variables, self.offset = create_conv_net(self.x, self.keep_prob, channels, n_class, **kwargs) self.cost = self._get_cost(logits, cost, cost_kwargs) self.gradients_node = tf.gradients(self.cost, self.variables) with tf.name_scope("cross_entropy"): self.cross_entropy = cross_entropy(tf.reshape(self.y, [-1, n_class]), tf.reshape(pixel_wise_softmax(logits), [-1, n_class])) with tf.name_scope("results"): self.predicter = pixel_wise_softmax(logits) self.correct_pred = tf.equal(tf.argmax(self.predicter, 3), tf.argmax(self.y, 3)) self.accuracy = tf.reduce_mean(tf.cast(self.correct_pred, tf.float32)) def _get_cost(self, logits, cost_name, cost_kwargs): """ Constructs the cost function, either cross_entropy, weighted cross_entropy or dice_coefficient. Optional arguments are: class_weights: weights for the different classes in case of multi-class imbalance regularizer: power of the L2 regularizers added to the loss function """ with tf.name_scope("cost"): flat_logits = tf.reshape(logits, [-1, self.n_class]) flat_labels = tf.reshape(self.y, [-1, self.n_class]) if cost_name == "cross_entropy": class_weights = cost_kwargs.pop("class_weights", None) if class_weights is not None: class_weights = tf.constant(np.array(class_weights, dtype=np.float32)) weight_map = tf.multiply(flat_labels, class_weights) weight_map = tf.reduce_sum(weight_map, axis=1) loss_map = tf.nn.softmax_cross_entropy_with_logits_v2(logits=flat_logits, labels=flat_labels) weighted_loss = tf.multiply(loss_map, weight_map) loss = tf.reduce_mean(weighted_loss) else: loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=flat_logits, labels=flat_labels)) elif cost_name == "dice_coefficient": eps = 1e-5 prediction = pixel_wise_softmax(logits) intersection = tf.reduce_sum(prediction * self.y) union = eps + tf.reduce_sum(prediction) + tf.reduce_sum(self.y) loss = -(2 * intersection / (union)) else: raise ValueError("Unknown cost function: " % cost_name) regularizer = cost_kwargs.pop("regularizer", None) if regularizer is not None: regularizers = sum([tf.nn.l2_loss(variable) for variable in self.variables]) loss += (regularizer * regularizers) return loss
[docs] def predict(self, model_path, x_test): """ Uses the model to create a prediction for the given data :param model_path: path to the model checkpoint to restore :param x_test: Data to predict on. Shape [n, nx, ny, channels] :returns prediction: The unet prediction Shape [n, px, py, labels] (px=nx-self.offset/2) """ init = tf.global_variables_initializer() with tf.Session() as sess: # Initialize variables # Restore model weights from previously saved model self.restore(sess, model_path) y_dummy = np.empty((x_test.shape[0], x_test.shape[1], x_test.shape[2], self.n_class)) prediction =, feed_dict={self.x: x_test, self.y: y_dummy, self.keep_prob: 1.}) return prediction
[docs] def save(self, sess, model_path): """ Saves the current session to a checkpoint :param sess: current session :param model_path: path to file system location """ saver = tf.train.Saver() save_path =, model_path) return save_path
[docs] def restore(self, sess, model_path): """ Restores a session from a checkpoint :param sess: current session instance :param model_path: path to file system checkpoint location """ saver = tf.train.Saver() saver.restore(sess, model_path)"Model restored from file: %s" % model_path)
[docs]class Trainer(object): """ Trains a unet instance :param net: the unet instance to train :param batch_size: size of training batch :param verification_batch_size: size of verification batch :param norm_grads: (optional) true if normalized gradients should be added to the summaries :param optimizer: (optional) name of the optimizer to use (momentum or adam) :param opt_kwargs: (optional) kwargs passed to the learning rate (momentum opt) and to the optimizer """ def __init__(self, net, batch_size=1, verification_batch_size = 4, norm_grads=False, optimizer="momentum", opt_kwargs={}): = net self.batch_size = batch_size self.verification_batch_size = verification_batch_size self.norm_grads = norm_grads self.optimizer = optimizer self.opt_kwargs = opt_kwargs def _get_optimizer(self, training_iters, global_step): if self.optimizer == "momentum": learning_rate = self.opt_kwargs.pop("learning_rate", 0.2) decay_rate = self.opt_kwargs.pop("decay_rate", 0.95) momentum = self.opt_kwargs.pop("momentum", 0.2) self.learning_rate_node = tf.train.exponential_decay(learning_rate=learning_rate, global_step=global_step, decay_steps=training_iters, decay_rate=decay_rate, staircase=True) optimizer = tf.train.MomentumOptimizer(learning_rate=self.learning_rate_node, momentum=momentum, **self.opt_kwargs).minimize(, global_step=global_step) elif self.optimizer == "adam": learning_rate = self.opt_kwargs.pop("learning_rate", 0.001) self.learning_rate_node = tf.Variable(learning_rate, name="learning_rate") optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate_node, **self.opt_kwargs).minimize(, global_step=global_step) return optimizer def _initialize(self, training_iters, output_path, restore, prediction_path): global_step = tf.Variable(0, name="global_step") self.norm_gradients_node = tf.Variable(tf.constant(0.0, shape=[len(]), name="norm_gradients") if and self.norm_grads: tf.summary.histogram('norm_grads', self.norm_gradients_node) tf.summary.scalar('loss', tf.summary.scalar('cross_entropy', tf.summary.scalar('accuracy', self.optimizer = self._get_optimizer(training_iters, global_step) tf.summary.scalar('learning_rate', self.learning_rate_node) self.summary_op = tf.summary.merge_all() init = tf.global_variables_initializer() self.prediction_path = prediction_path abs_prediction_path = os.path.abspath(self.prediction_path) output_path = os.path.abspath(output_path) if not restore:"Removing '{:}'".format(abs_prediction_path)) shutil.rmtree(abs_prediction_path, ignore_errors=True)"Removing '{:}'".format(output_path)) shutil.rmtree(output_path, ignore_errors=True) if not os.path.exists(abs_prediction_path):"Allocating '{:}'".format(abs_prediction_path)) os.makedirs(abs_prediction_path) if not os.path.exists(output_path):"Allocating '{:}'".format(output_path)) os.makedirs(output_path) return init
[docs] def train(self, data_provider, output_path, training_iters=10, epochs=100, dropout=0.75, display_step=1, restore=False, write_graph=False, prediction_path='prediction'): """ Lauches the training process :param data_provider: callable returning training and verification data :param output_path: path where to store checkpoints :param training_iters: number of training mini batch iteration :param epochs: number of epochs :param dropout: dropout probability :param display_step: number of steps till outputting stats :param restore: Flag if previous model should be restored :param write_graph: Flag if the computation graph should be written as protobuf file to the output path :param prediction_path: path where to save predictions on each epoch """ save_path = os.path.join(output_path, "model.ckpt") if epochs == 0: return save_path init = self._initialize(training_iters, output_path, restore, prediction_path) with tf.Session() as sess: if write_graph: tf.train.write_graph(sess.graph_def, output_path, "graph.pb", False) if restore: ckpt = tf.train.get_checkpoint_state(output_path) if ckpt and ckpt.model_checkpoint_path:, ckpt.model_checkpoint_path) test_x, test_y = data_provider(self.verification_batch_size) pred_shape = self.store_prediction(sess, test_x, test_y, "_init") summary_writer = tf.summary.FileWriter(output_path, graph=sess.graph)"Start optimization") avg_gradients = None for epoch in range(epochs): total_loss = 0 for step in range((epoch * training_iters), ((epoch + 1) * training_iters)): batch_x, batch_y = data_provider(self.batch_size) # Run optimization op (backprop) _, loss, lr, gradients = (self.optimizer,, self.learning_rate_node,, feed_dict={ batch_x, util.crop_to_shape(batch_y, pred_shape), dropout}) if and self.norm_grads: avg_gradients = _update_avg_gradients(avg_gradients, gradients, step) norm_gradients = [np.linalg.norm(gradient) for gradient in avg_gradients] self.norm_gradients_node.assign(norm_gradients).eval() if step % display_step == 0: self.output_minibatch_stats(sess, summary_writer, step, batch_x, util.crop_to_shape(batch_y, pred_shape)) total_loss += loss self.output_epoch_stats(epoch, total_loss, training_iters, lr) self.store_prediction(sess, test_x, test_y, "epoch_%s" % epoch) save_path =, save_path)"Optimization Finished!") return save_path
[docs] def store_prediction(self, sess, batch_x, batch_y, name): prediction =, feed_dict={ batch_x, batch_y, 1.}) pred_shape = prediction.shape loss =, feed_dict={ batch_x, util.crop_to_shape(batch_y, pred_shape), 1.})"Verification error= {:.1f}%, loss= {:.4f}".format(error_rate(prediction, util.crop_to_shape(batch_y, prediction.shape)), loss)) img = util.combine_img_prediction(batch_x, batch_y, prediction) util.save_image(img, "%s/%s.jpg" % (self.prediction_path, name)) return pred_shape
[docs] def output_epoch_stats(self, epoch, total_loss, training_iters, lr): "Epoch {:}, Average loss: {:.4f}, learning rate: {:.4f}".format(epoch, (total_loss / training_iters), lr))
[docs] def output_minibatch_stats(self, sess, summary_writer, step, batch_x, batch_y): # Calculate batch loss and accuracy summary_str, loss, acc, predictions =[self.summary_op,,,], feed_dict={ batch_x, batch_y, 1.}) summary_writer.add_summary(summary_str, step) summary_writer.flush() "Iter {:}, Minibatch Loss= {:.4f}, Training Accuracy= {:.4f}, Minibatch error= {:.1f}%".format(step, loss, acc, error_rate( predictions, batch_y)))
def _update_avg_gradients(avg_gradients, gradients, step): if avg_gradients is None: avg_gradients = [np.zeros_like(gradient) for gradient in gradients] for i in range(len(gradients)): avg_gradients[i] = (avg_gradients[i] * (1.0 - (1.0 / (step + 1)))) + (gradients[i] / (step + 1)) return avg_gradients
[docs]def error_rate(predictions, labels): """ Return the error rate based on dense predictions and 1-hot labels. """ return 100.0 - ( 100.0 * np.sum(np.argmax(predictions, 3) == np.argmax(labels, 3)) / (predictions.shape[0] * predictions.shape[1] * predictions.shape[2]))
[docs]def get_image_summary(img, idx=0): """ Make an image summary for 4d tensor image with index idx """ V = tf.slice(img, (0, 0, 0, idx), (1, -1, -1, 1)) V -= tf.reduce_min(V) V /= tf.reduce_max(V) V *= 255 img_w = tf.shape(img)[1] img_h = tf.shape(img)[2] V = tf.reshape(V, tf.stack((img_w, img_h, 1))) V = tf.transpose(V, (2, 0, 1)) V = tf.reshape(V, tf.stack((-1, img_w, img_h, 1))) return V