# tf_unet is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# tf_unet is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with tf_unet. If not, see <http://www.gnu.org/licenses/>.
'''
author: jakeret
'''
from __future__ import print_function, division, absolute_import, unicode_literals
import glob
import numpy as np
from PIL import Image
[docs]class BaseDataProvider(object):
"""
Abstract base class for DataProvider implementation. Subclasses have to
overwrite the `_next_data` method that load the next data and label array.
This implementation automatically clips the data with the given min/max and
normalizes the values to (0,1]. To change this behavoir the `_process_data`
method can be overwritten. To enable some post processing such as data
augmentation the `_post_process` method can be overwritten.
:param a_min: (optional) min value used for clipping
:param a_max: (optional) max value used for clipping
"""
channels = 1
n_class = 2
def __init__(self, a_min=None, a_max=None):
self.a_min = a_min if a_min is not None else -np.inf
self.a_max = a_max if a_min is not None else np.inf
def _load_data_and_label(self):
data, label = self._next_data()
train_data = self._process_data(data)
labels = self._process_labels(label)
train_data, labels = self._post_process(train_data, labels)
nx = train_data.shape[1]
ny = train_data.shape[0]
return train_data.reshape(1, ny, nx, self.channels), labels.reshape(1, ny, nx, self.n_class),
def _process_labels(self, label):
if self.n_class == 2:
nx = label.shape[1]
ny = label.shape[0]
labels = np.zeros((ny, nx, self.n_class), dtype=np.float32)
# It is the responsibility of the child class to make sure that the label
# is a boolean array, but we a chech here just in case.
if label.dtype != 'bool':
label = label.astype(np.bool)
labels[..., 1] = label
labels[..., 0] = ~label
return labels
return label
def _process_data(self, data):
# normalization
data = np.clip(np.fabs(data), self.a_min, self.a_max)
data -= np.amin(data)
if np.amax(data) != 0:
data /= np.amax(data)
return data
def _post_process(self, data, labels):
"""
Post processing hook that can be used for data augmentation
:param data: the data array
:param labels: the label array
"""
return data, labels
def __call__(self, n):
train_data, labels = self._load_data_and_label()
nx = train_data.shape[1]
ny = train_data.shape[2]
X = np.zeros((n, nx, ny, self.channels))
Y = np.zeros((n, nx, ny, self.n_class))
X[0] = train_data
Y[0] = labels
for i in range(1, n):
train_data, labels = self._load_data_and_label()
X[i] = train_data
Y[i] = labels
return X, Y
[docs]class SimpleDataProvider(BaseDataProvider):
"""
A simple data provider for numpy arrays.
Assumes that the data and label are numpy array with the dimensions
data `[n, X, Y, channels]`, label `[n, X, Y, classes]`. Where
`n` is the number of images, `X`, `Y` the size of the image.
:param data: data numpy array. Shape=[n, X, Y, channels]
:param label: label numpy array. Shape=[n, X, Y, classes]
:param a_min: (optional) min value used for clipping
:param a_max: (optional) max value used for clipping
"""
def __init__(self, data, label, a_min=None, a_max=None):
super(SimpleDataProvider, self).__init__(a_min, a_max)
self.data = data
self.label = label
self.file_count = data.shape[0]
self.n_class = label.shape[-1]
self.channels = data.shape[-1]
def _next_data(self):
idx = np.random.choice(self.file_count)
return self.data[idx], self.label[idx]
[docs]class ImageDataProvider(BaseDataProvider):
"""
Generic data provider for images, supports gray scale and colored images.
Assumes that the data images and label images are stored in the same folder
and that the labels have a different file suffix
e.g. 'train/fish_1.tif' and 'train/fish_1_mask.tif'
Number of pixels in x and y of the images and masks should be even.
Usage:
data_provider = ImageDataProvider("..fishes/train/*.tif")
:param search_path: a glob search pattern to find all data and label images
:param a_min: (optional) min value used for clipping
:param a_max: (optional) max value used for clipping
:param data_suffix: suffix pattern for the data images. Default '.tif'
:param mask_suffix: suffix pattern for the label images. Default '_mask.tif'
:param shuffle_data: if the order of the loaded file path should be randomized. Default 'True'
"""
def __init__(self, search_path, a_min=None, a_max=None, data_suffix=".tif", mask_suffix='_mask.tif', shuffle_data=True):
super(ImageDataProvider, self).__init__(a_min, a_max)
self.data_suffix = data_suffix
self.mask_suffix = mask_suffix
self.file_idx = -1
self.shuffle_data = shuffle_data
self.data_files = self._find_data_files(search_path)
if self.shuffle_data:
np.random.shuffle(self.data_files)
assert len(self.data_files) > 0, "No training files"
print("Number of files used: %s" % len(self.data_files))
image_path = self.data_files[0]
label_path = image_path.replace(self.data_suffix, self.mask_suffix)
img = self._load_file(image_path)
mask = self._load_file(label_path)
self.channels = 1 if len(img.shape) == 2 else img.shape[-1]
self.n_class = 2 if len(mask.shape) == 2 else mask.shape[-1]
print("Number of channels: %s"%self.channels)
print("Number of classes: %s"%self.n_class)
def _find_data_files(self, search_path):
all_files = glob.glob(search_path)
return [name for name in all_files if self.data_suffix in name and not self.mask_suffix in name]
def _load_file(self, path, dtype=np.float32):
return np.array(Image.open(path), dtype)
def _cylce_file(self):
self.file_idx += 1
if self.file_idx >= len(self.data_files):
self.file_idx = 0
if self.shuffle_data:
np.random.shuffle(self.data_files)
def _next_data(self):
self._cylce_file()
image_name = self.data_files[self.file_idx]
label_name = image_name.replace(self.data_suffix, self.mask_suffix)
img = self._load_file(image_name, np.float32)
label = self._load_file(label_name, np.bool)
return img,label