summaryrefslogtreecommitdiff
path: root/otsu2018.py
diff options
context:
space:
mode:
authorPaweł Redman <pawel.redman@gmail.com>2020-07-25 19:45:17 +0200
committerPaweł Redman <pawel.redman@gmail.com>2020-07-25 19:45:17 +0200
commitde339a8611548cfa3f9bd8bcdaa22afd2a6e062c (patch)
tree486e4e056e4f888eb18aae38c159b0fab32ac10a /otsu2018.py
parentf88a852f659ea3223ce69ac147912b1d346cf7be (diff)
A new class for saving and loading datasets.
Diffstat (limited to 'otsu2018.py')
-rw-r--r--otsu2018.py239
1 files changed, 197 insertions, 42 deletions
diff --git a/otsu2018.py b/otsu2018.py
index 6d15580..884cbcc 100644
--- a/otsu2018.py
+++ b/otsu2018.py
@@ -2,9 +2,10 @@ import numpy as np
import matplotlib.pyplot as plt
import time
-from colour import (SpectralDistribution, STANDARD_OBSERVER_CMFS,
- sd_ones, sd_to_XYZ, XYZ_to_xy)
+from colour import (SpectralShape, SpectralDistribution,
+ STANDARD_OBSERVER_CMFS, sd_ones, sd_to_XYZ, XYZ_to_xy)
from colour.plotting import plot_chromaticity_diagram_CIE1931
+from colour.utilities import as_float_array
def load_Otsu2018_spectra(path, every_nth=1):
@@ -129,7 +130,7 @@ class Colours:
return lesser, greater
-class treeError(Exception):
+class Otsu2018Error(Exception):
"""
Exception used for various errors originating from code in this file.
"""
@@ -172,7 +173,7 @@ class Node:
@property
def leaf(self):
"""
- Is this node a leaf? Tree leaves don't have any children and store
+ Is this node a leaf? Otsu2018Tree leaves don't have any children and store
instances of ``Colours``.
"""
@@ -343,7 +344,7 @@ class Node:
if (len(partition[0]) < self.tree.min_cluster_size
or len(partition[1]) < self.tree.min_cluster_size):
- raise treeError(
+ raise Otsu2018Error(
'partition created parts smaller than min_cluster_size')
lesser = Node(self.tree, partition[0])
@@ -382,7 +383,7 @@ class Node:
try:
new_error, partition = self.partition_error(axis)
- except treeError:
+ except Otsu2018Error:
continue
if new_error >= error:
@@ -395,7 +396,7 @@ class Node:
bar.close()
if self.best_partition is None:
- raise treeError('no partitions are possible')
+ raise Otsu2018Error('no partitions are possible')
return self.best_partition
@@ -448,7 +449,7 @@ class Node:
plt.plot(self.wl, recon.values, 'C%d--' % i)
-class Tree(Node):
+class Otsu2018Tree(Node):
"""
This is an extension of ``Node``. It's meant to represent the root of the
tree and contains information shared with all the nodes, such as cmfs
@@ -576,7 +577,7 @@ class Tree(Node):
try:
error, axis, partition = leaf.find_best_partition()
- except treeError as e:
+ except Otsu2018Error as e:
_print('Failed: %s' % e)
continue
@@ -606,7 +607,111 @@ class Tree(Node):
_print('Finished.')
- def write_python_dataset(self, path):
+ def _create_selector_array(self):
+ """
+ Create an array that describes how to select the appropriate cluster
+ for given *CIE xy* coordinates. See ``Otsu2018Dataset.select`` for
+ information about what the array looks like and how to use it.
+ """
+
+ rows = []
+ leaf_number = 0
+ symbol_table = {}
+
+ def add_rows(node):
+ nonlocal leaf_number
+
+ if node.leaf:
+ symbol_table[node] = leaf_number
+ leaf_number += 1
+ return
+
+ symbol_table[node] = -len(rows)
+ rows.append([node.partition_axis.direction,
+ node.partition_axis.origin,
+ node.children[0],
+ node.children[1]])
+
+ for child in node.children:
+ add_rows(child)
+
+ add_rows(self)
+
+ # Special case for trees with just the root
+ if len(rows) == 0:
+ return as_float_array([0., 0., 0., 0.])
+
+ for i, (_, _, symbol_1, symbol_2) in enumerate(rows):
+ rows[i][2] = symbol_table[symbol_1]
+ rows[i][3] = symbol_table[symbol_2]
+
+ return as_float_array(rows)
+
+ def to_dataset(self):
+ """
+ Create an ``Otsu2018Dataset`` based on information stored in this tree.
+ The object can then be saved to disk or used in reflectance recovery.
+
+ Returns
+ =======
+ Otsu2018Dataset
+ The dataset object.
+ """
+
+ basis_functions = [leaf.basis_functions for leaf in self.leaves]
+ means = [leaf.mean for leaf in self.leaves]
+ selector_array = self._create_selector_array()
+
+ return Otsu2018Dataset(self.shape,
+ basis_functions,
+ means,
+ selector_array)
+
+
+class Otsu2018Dataset:
+ """
+ Stores all the information needed for the *Otsu et al. (2018)* spectral
+ upsampling method. Datasets can be either generated and turned into
+ this form using ``Otsu2018Tree.to_dataset`` or loaded from disk.
+
+ Attributes
+ ==========
+ shape: SpectralShape
+ Shape of the spectral data.
+ basis_functions : ndarray(n, 3, m)
+ Three basis functions for every cluster.
+ means : ndarray(n, m)
+ Mean for every cluster.
+ selector_array : ndarray(k, 4)
+ Array describing how to select the appropriate cluster. See
+ ``Otsu2018Dataset.select`` for details.
+ """
+
+ def __init__(self,
+ shape=None,
+ basis_functions=None,
+ means=None,
+ selector_array=None):
+ self.shape = shape
+ self.basis_functions = basis_functions
+ self.means = means
+ self.selector_array = selector_array
+
+ def to_file(self, path):
+ """
+ Saves the dataset to an .npz file.
+ """
+
+ shape_array = as_float_array([self.shape.start, self.shape.end,
+ self.shape.interval])
+
+ np.savez(path,
+ shape=shape_array,
+ basis_functions=self.basis_functions,
+ means=self.means,
+ selector_array=self.selector_array)
+
+ def to_Python_file(self, path):
"""
Write the tree into a Python dataset compatible with Colour's
``colour.recovery.otsu2018`` code.
@@ -626,46 +731,96 @@ class Tree(Node):
fd.write('OTSU_2018_SPECTRAL_SHAPE = SpectralShape%s\n\n\n'
% self.shape)
- # Basis functions
-
- fd.write('OTSU_2018_BASIS_FUNCTIONS = [\n')
- for i, leaf in enumerate(self.leaves):
- for line in (repr(leaf.basis_functions) + ',').splitlines():
+ def write_array(name, array):
+ fd.write('%s = [\n' % name)
+ for line in (repr(array) + ',').splitlines():
fd.write(' %s\n' % line)
- fd.write(']\n\n\n')
+ fd.write(']\n\n\n')
- # Means
+ write_array('OTSU_2018_BASIS_FUNCTIONS', self.basis_functions)
+ write_array('OTSU_2018_MEANS', self.means)
+ write_array('OTSU_2018_SELECTOR_ARRAY', self.selector_array)
- fd.write('OTSU_2018_MEANS = [\n')
- for leaf in self.leaves:
- for line in (repr(leaf.mean) + ',').splitlines():
- fd.write(' %s\n' % line)
- fd.write(']\n\n\n')
+ def from_file(self, path):
+ """
+ Loads a dataset from an .npz file.
+
+ Parameters
+ ==========
+ path : unicode
+ Path to file.
+
+ Raises
+ ======
+ ValueError, KeyError
+ Raised when loading the file succeeded but it did not contain the
+ expected data.
+ """
+
+ npz = np.load(path, allow_pickle=False)
+ if not isinstance(npz, np.lib.npyio.NpzFile):
+ raise ValueError('the loaded file is not an .npz file')
+
+ start, end, interval = npz['shape']
+ self.shape = SpectralShape(start, end, interval)
+ self.basis_functions = npz['basis_functions']
+ self.means = npz['means']
+ self.selector_array = npz['selector_array']
+
+ n, three, m = self.basis_functions.shape
+ if (three != 3 or self.means.shape != (n, m)
+ or self.selector_array.shape[1] != 4):
+ raise ValueError('array shapes are not correct, the file could be '
+ 'corrupted or in a wrong format')
- # Cluster selection function
+ def select(self, xy):
+ """
+ Returns the cluster index appropriate for the given *CIE xy*
+ coordinates.
- fd.write('def select_cluster_Otsu2018(xy):\n')
- fd.write(' x, y = xy\n\n')
+ Parameters
+ ==========
+ ndarray : (2,)
+ *CIE xy* chromaticity coordinates.
- counter = 0
+ Returns
+ =======
+ int
+ Cluster index.
+ """
- def write_if(node, indent):
- nonlocal counter
+ i = 0
+ while True:
+ row = self.selector_array[i, :]
+ direction, origin, lesser_index, greater_index = row
- if node.leaf:
- fd.write(' ' * indent)
- fd.write('return %d # %s\n' % (counter, node))
- counter += 1
- return
+ if xy[int(direction)] <= origin:
+ index = int(lesser_index)
+ else:
+ index = int(greater_index)
+
+ if index < 0:
+ i = -index
+ else:
+ return index
- fd.write(' ' * indent)
- fd.write('if %s <= %s:\n'
- % ('xy'[node.partition_axis.direction],
- repr(node.partition_axis.origin)))
- write_if(node.children[0], indent + 1)
+ def cluster(self, xy):
+ """
+ Returns the basis functions and dataset mean for the given *CIE xy*
+ coordinates.
+
+ Parameters
+ ==========
+ ndarray : (2,)
+ *CIE xy* chromaticity coordinates.
- fd.write(' ' * indent)
- fd.write('else:\n')
- write_if(node.children[1], indent + 1)
+ Returns
+ =======
+ basis_functions : ndarray (3, n)
+ Three basis functions.
+ mean : ndarray (n,)
+ Dataset mean.
+ """
- write_if(self, 1)
+ index = self.select(xy)
+ return self.basis_functions[index, :, :], self.means[index, :]