summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--clustering.py11
-rw-r--r--demo.py51
-rw-r--r--otsu2018.py239
3 files changed, 204 insertions, 97 deletions
diff --git a/clustering.py b/clustering.py
index 527830b..6df6643 100644
--- a/clustering.py
+++ b/clustering.py
@@ -4,16 +4,16 @@ import tqdm
from colour import SpectralShape, COLOURCHECKER_SDS, sd_to_XYZ
-from otsu2018 import load_Otsu2018_spectra, Tree
+from otsu2018 import load_Otsu2018_spectra, Otsu2018Tree
if __name__ == '__main__':
print('Loading spectral data...')
- sds = load_Otsu2018_spectra('CommonData/spectrum_m.csv', every_nth=50)
+ sds = load_Otsu2018_spectra('CommonData/spectrum_m.csv', every_nth=7)
shape = SpectralShape(380, 730, 10)
print('Initializing the tree...')
- tree = Tree(sds, shape)
+ tree = Otsu2018Tree(sds, shape)
print('Clustering...')
before = tree.total_reconstruction_error()
@@ -25,7 +25,9 @@ if __name__ == '__main__':
print('Saving the dataset...')
os.makedirs('datasets', exist_ok=True)
- tree.write_python_dataset('datasets/otsu2018.py')
+ data = tree.to_dataset()
+ data.to_file('datasets/otsu2018.npz')
+ data.to_Python_file('datasets/otsu2018.py')
print('Plotting...')
tree.visualise()
@@ -50,3 +52,4 @@ if __name__ == '__main__':
break
plt.show()
+
diff --git a/demo.py b/demo.py
deleted file mode 100644
index 3c4b933..0000000
--- a/demo.py
+++ /dev/null
@@ -1,51 +0,0 @@
-import os
-import matplotlib.pyplot as plt
-
-from colour import SpectralShape, COLOURCHECKER_SDS, sd_to_XYZ
-
-from otsu2018 import load_Otsu2018_spectra, Tree
-
-
-if __name__ == '__main__':
- print('Loading spectral data...')
- sds = load_Otsu2018_spectra('CommonData/spectrum_m.csv', every_nth=50)
- shape = SpectralShape(380, 730, 10)
-
- print('Initializing the tree...')
- tree = Tree(sds, shape)
-
- print('Tree...')
- before = tree.total_reconstruction_error()
- tree.optimise()
- after = tree.total_reconstruction_error()
-
- print('Error before: %g' % before)
- print('Error after: %g' % after)
-
- print('Saving the dataset...')
- os.makedirs('datasets', exist_ok=True)
- tree.write_python_dataset('datasets/otsu2018.py')
-
- print('Plotting...')
- tree.visualise()
-
- plt.figure()
-
- examples = COLOURCHECKER_SDS['ColorChecker N Ohta'].items()
- for i, (name, sd) in enumerate(examples):
- plt.subplot(2, 3, 1 + i)
- plt.title(name)
-
- plt.plot(sd.wavelengths, sd.values, label='Original')
-
- XYZ = sd_to_XYZ(sd) / 100
- recovered_sd = tree.reconstruct(XYZ)
- plt.plot(recovered_sd.wavelengths, recovered_sd.values,
- label='Recovered')
-
- plt.legend()
-
- if i + 1 == 6:
- break
-
- plt.show()
diff --git a/otsu2018.py b/otsu2018.py
index 6d15580..884cbcc 100644
--- a/otsu2018.py
+++ b/otsu2018.py
@@ -2,9 +2,10 @@ import numpy as np
import matplotlib.pyplot as plt
import time
-from colour import (SpectralDistribution, STANDARD_OBSERVER_CMFS,
- sd_ones, sd_to_XYZ, XYZ_to_xy)
+from colour import (SpectralShape, SpectralDistribution,
+ STANDARD_OBSERVER_CMFS, sd_ones, sd_to_XYZ, XYZ_to_xy)
from colour.plotting import plot_chromaticity_diagram_CIE1931
+from colour.utilities import as_float_array
def load_Otsu2018_spectra(path, every_nth=1):
@@ -129,7 +130,7 @@ class Colours:
return lesser, greater
-class treeError(Exception):
+class Otsu2018Error(Exception):
"""
Exception used for various errors originating from code in this file.
"""
@@ -172,7 +173,7 @@ class Node:
@property
def leaf(self):
"""
- Is this node a leaf? Tree leaves don't have any children and store
+ Is this node a leaf? Otsu2018Tree leaves don't have any children and store
instances of ``Colours``.
"""
@@ -343,7 +344,7 @@ class Node:
if (len(partition[0]) < self.tree.min_cluster_size
or len(partition[1]) < self.tree.min_cluster_size):
- raise treeError(
+ raise Otsu2018Error(
'partition created parts smaller than min_cluster_size')
lesser = Node(self.tree, partition[0])
@@ -382,7 +383,7 @@ class Node:
try:
new_error, partition = self.partition_error(axis)
- except treeError:
+ except Otsu2018Error:
continue
if new_error >= error:
@@ -395,7 +396,7 @@ class Node:
bar.close()
if self.best_partition is None:
- raise treeError('no partitions are possible')
+ raise Otsu2018Error('no partitions are possible')
return self.best_partition
@@ -448,7 +449,7 @@ class Node:
plt.plot(self.wl, recon.values, 'C%d--' % i)
-class Tree(Node):
+class Otsu2018Tree(Node):
"""
This is an extension of ``Node``. It's meant to represent the root of the
tree and contains information shared with all the nodes, such as cmfs
@@ -576,7 +577,7 @@ class Tree(Node):
try:
error, axis, partition = leaf.find_best_partition()
- except treeError as e:
+ except Otsu2018Error as e:
_print('Failed: %s' % e)
continue
@@ -606,7 +607,111 @@ class Tree(Node):
_print('Finished.')
- def write_python_dataset(self, path):
+ def _create_selector_array(self):
+ """
+ Create an array that describes how to select the appropriate cluster
+ for given *CIE xy* coordinates. See ``Otsu2018Dataset.select`` for
+ information about what the array looks like and how to use it.
+ """
+
+ rows = []
+ leaf_number = 0
+ symbol_table = {}
+
+ def add_rows(node):
+ nonlocal leaf_number
+
+ if node.leaf:
+ symbol_table[node] = leaf_number
+ leaf_number += 1
+ return
+
+ symbol_table[node] = -len(rows)
+ rows.append([node.partition_axis.direction,
+ node.partition_axis.origin,
+ node.children[0],
+ node.children[1]])
+
+ for child in node.children:
+ add_rows(child)
+
+ add_rows(self)
+
+ # Special case for trees with just the root
+ if len(rows) == 0:
+ return as_float_array([0., 0., 0., 0.])
+
+ for i, (_, _, symbol_1, symbol_2) in enumerate(rows):
+ rows[i][2] = symbol_table[symbol_1]
+ rows[i][3] = symbol_table[symbol_2]
+
+ return as_float_array(rows)
+
+ def to_dataset(self):
+ """
+ Create an ``Otsu2018Dataset`` based on information stored in this tree.
+ The object can then be saved to disk or used in reflectance recovery.
+
+ Returns
+ =======
+ Otsu2018Dataset
+ The dataset object.
+ """
+
+ basis_functions = [leaf.basis_functions for leaf in self.leaves]
+ means = [leaf.mean for leaf in self.leaves]
+ selector_array = self._create_selector_array()
+
+ return Otsu2018Dataset(self.shape,
+ basis_functions,
+ means,
+ selector_array)
+
+
+class Otsu2018Dataset:
+ """
+ Stores all the information needed for the *Otsu et al. (2018)* spectral
+ upsampling method. Datasets can be either generated and turned into
+ this form using ``Otsu2018Tree.to_dataset`` or loaded from disk.
+
+ Attributes
+ ==========
+ shape: SpectralShape
+ Shape of the spectral data.
+ basis_functions : ndarray(n, 3, m)
+ Three basis functions for every cluster.
+ means : ndarray(n, m)
+ Mean for every cluster.
+ selector_array : ndarray(k, 4)
+ Array describing how to select the appropriate cluster. See
+ ``Otsu2018Dataset.select`` for details.
+ """
+
+ def __init__(self,
+ shape=None,
+ basis_functions=None,
+ means=None,
+ selector_array=None):
+ self.shape = shape
+ self.basis_functions = basis_functions
+ self.means = means
+ self.selector_array = selector_array
+
+ def to_file(self, path):
+ """
+ Saves the dataset to an .npz file.
+ """
+
+ shape_array = as_float_array([self.shape.start, self.shape.end,
+ self.shape.interval])
+
+ np.savez(path,
+ shape=shape_array,
+ basis_functions=self.basis_functions,
+ means=self.means,
+ selector_array=self.selector_array)
+
+ def to_Python_file(self, path):
"""
Write the tree into a Python dataset compatible with Colour's
``colour.recovery.otsu2018`` code.
@@ -626,46 +731,96 @@ class Tree(Node):
fd.write('OTSU_2018_SPECTRAL_SHAPE = SpectralShape%s\n\n\n'
% self.shape)
- # Basis functions
-
- fd.write('OTSU_2018_BASIS_FUNCTIONS = [\n')
- for i, leaf in enumerate(self.leaves):
- for line in (repr(leaf.basis_functions) + ',').splitlines():
+ def write_array(name, array):
+ fd.write('%s = [\n' % name)
+ for line in (repr(array) + ',').splitlines():
fd.write(' %s\n' % line)
- fd.write(']\n\n\n')
+ fd.write(']\n\n\n')
- # Means
+ write_array('OTSU_2018_BASIS_FUNCTIONS', self.basis_functions)
+ write_array('OTSU_2018_MEANS', self.means)
+ write_array('OTSU_2018_SELECTOR_ARRAY', self.selector_array)
- fd.write('OTSU_2018_MEANS = [\n')
- for leaf in self.leaves:
- for line in (repr(leaf.mean) + ',').splitlines():
- fd.write(' %s\n' % line)
- fd.write(']\n\n\n')
+ def from_file(self, path):
+ """
+ Loads a dataset from an .npz file.
+
+ Parameters
+ ==========
+ path : unicode
+ Path to file.
+
+ Raises
+ ======
+ ValueError, KeyError
+ Raised when loading the file succeeded but it did not contain the
+ expected data.
+ """
+
+ npz = np.load(path, allow_pickle=False)
+ if not isinstance(npz, np.lib.npyio.NpzFile):
+ raise ValueError('the loaded file is not an .npz file')
+
+ start, end, interval = npz['shape']
+ self.shape = SpectralShape(start, end, interval)
+ self.basis_functions = npz['basis_functions']
+ self.means = npz['means']
+ self.selector_array = npz['selector_array']
+
+ n, three, m = self.basis_functions.shape
+ if (three != 3 or self.means.shape != (n, m)
+ or self.selector_array.shape[1] != 4):
+ raise ValueError('array shapes are not correct, the file could be '
+ 'corrupted or in a wrong format')
- # Cluster selection function
+ def select(self, xy):
+ """
+ Returns the cluster index appropriate for the given *CIE xy*
+ coordinates.
- fd.write('def select_cluster_Otsu2018(xy):\n')
- fd.write(' x, y = xy\n\n')
+ Parameters
+ ==========
+ ndarray : (2,)
+ *CIE xy* chromaticity coordinates.
- counter = 0
+ Returns
+ =======
+ int
+ Cluster index.
+ """
- def write_if(node, indent):
- nonlocal counter
+ i = 0
+ while True:
+ row = self.selector_array[i, :]
+ direction, origin, lesser_index, greater_index = row
- if node.leaf:
- fd.write(' ' * indent)
- fd.write('return %d # %s\n' % (counter, node))
- counter += 1
- return
+ if xy[int(direction)] <= origin:
+ index = int(lesser_index)
+ else:
+ index = int(greater_index)
+
+ if index < 0:
+ i = -index
+ else:
+ return index
- fd.write(' ' * indent)
- fd.write('if %s <= %s:\n'
- % ('xy'[node.partition_axis.direction],
- repr(node.partition_axis.origin)))
- write_if(node.children[0], indent + 1)
+ def cluster(self, xy):
+ """
+ Returns the basis functions and dataset mean for the given *CIE xy*
+ coordinates.
+
+ Parameters
+ ==========
+ ndarray : (2,)
+ *CIE xy* chromaticity coordinates.
- fd.write(' ' * indent)
- fd.write('else:\n')
- write_if(node.children[1], indent + 1)
+ Returns
+ =======
+ basis_functions : ndarray (3, n)
+ Three basis functions.
+ mean : ndarray (n,)
+ Dataset mean.
+ """
- write_if(self, 1)
+ index = self.select(xy)
+ return self.basis_functions[index, :, :], self.means[index, :]