diff options
-rw-r--r-- | clustering.py | 11 | ||||
-rw-r--r-- | demo.py | 51 | ||||
-rw-r--r-- | otsu2018.py | 239 |
3 files changed, 204 insertions, 97 deletions
diff --git a/clustering.py b/clustering.py index 527830b..6df6643 100644 --- a/clustering.py +++ b/clustering.py @@ -4,16 +4,16 @@ import tqdm from colour import SpectralShape, COLOURCHECKER_SDS, sd_to_XYZ -from otsu2018 import load_Otsu2018_spectra, Tree +from otsu2018 import load_Otsu2018_spectra, Otsu2018Tree if __name__ == '__main__': print('Loading spectral data...') - sds = load_Otsu2018_spectra('CommonData/spectrum_m.csv', every_nth=50) + sds = load_Otsu2018_spectra('CommonData/spectrum_m.csv', every_nth=7) shape = SpectralShape(380, 730, 10) print('Initializing the tree...') - tree = Tree(sds, shape) + tree = Otsu2018Tree(sds, shape) print('Clustering...') before = tree.total_reconstruction_error() @@ -25,7 +25,9 @@ if __name__ == '__main__': print('Saving the dataset...') os.makedirs('datasets', exist_ok=True) - tree.write_python_dataset('datasets/otsu2018.py') + data = tree.to_dataset() + data.to_file('datasets/otsu2018.npz') + data.to_Python_file('datasets/otsu2018.py') print('Plotting...') tree.visualise() @@ -50,3 +52,4 @@ if __name__ == '__main__': break plt.show() + diff --git a/demo.py b/demo.py deleted file mode 100644 index 3c4b933..0000000 --- a/demo.py +++ /dev/null @@ -1,51 +0,0 @@ -import os -import matplotlib.pyplot as plt - -from colour import SpectralShape, COLOURCHECKER_SDS, sd_to_XYZ - -from otsu2018 import load_Otsu2018_spectra, Tree - - -if __name__ == '__main__': - print('Loading spectral data...') - sds = load_Otsu2018_spectra('CommonData/spectrum_m.csv', every_nth=50) - shape = SpectralShape(380, 730, 10) - - print('Initializing the tree...') - tree = Tree(sds, shape) - - print('Tree...') - before = tree.total_reconstruction_error() - tree.optimise() - after = tree.total_reconstruction_error() - - print('Error before: %g' % before) - print('Error after: %g' % after) - - print('Saving the dataset...') - os.makedirs('datasets', exist_ok=True) - tree.write_python_dataset('datasets/otsu2018.py') - - print('Plotting...') - tree.visualise() - - plt.figure() - - examples = COLOURCHECKER_SDS['ColorChecker N Ohta'].items() - for i, (name, sd) in enumerate(examples): - plt.subplot(2, 3, 1 + i) - plt.title(name) - - plt.plot(sd.wavelengths, sd.values, label='Original') - - XYZ = sd_to_XYZ(sd) / 100 - recovered_sd = tree.reconstruct(XYZ) - plt.plot(recovered_sd.wavelengths, recovered_sd.values, - label='Recovered') - - plt.legend() - - if i + 1 == 6: - break - - plt.show() diff --git a/otsu2018.py b/otsu2018.py index 6d15580..884cbcc 100644 --- a/otsu2018.py +++ b/otsu2018.py @@ -2,9 +2,10 @@ import numpy as np import matplotlib.pyplot as plt import time -from colour import (SpectralDistribution, STANDARD_OBSERVER_CMFS, - sd_ones, sd_to_XYZ, XYZ_to_xy) +from colour import (SpectralShape, SpectralDistribution, + STANDARD_OBSERVER_CMFS, sd_ones, sd_to_XYZ, XYZ_to_xy) from colour.plotting import plot_chromaticity_diagram_CIE1931 +from colour.utilities import as_float_array def load_Otsu2018_spectra(path, every_nth=1): @@ -129,7 +130,7 @@ class Colours: return lesser, greater -class treeError(Exception): +class Otsu2018Error(Exception): """ Exception used for various errors originating from code in this file. """ @@ -172,7 +173,7 @@ class Node: @property def leaf(self): """ - Is this node a leaf? Tree leaves don't have any children and store + Is this node a leaf? Otsu2018Tree leaves don't have any children and store instances of ``Colours``. """ @@ -343,7 +344,7 @@ class Node: if (len(partition[0]) < self.tree.min_cluster_size or len(partition[1]) < self.tree.min_cluster_size): - raise treeError( + raise Otsu2018Error( 'partition created parts smaller than min_cluster_size') lesser = Node(self.tree, partition[0]) @@ -382,7 +383,7 @@ class Node: try: new_error, partition = self.partition_error(axis) - except treeError: + except Otsu2018Error: continue if new_error >= error: @@ -395,7 +396,7 @@ class Node: bar.close() if self.best_partition is None: - raise treeError('no partitions are possible') + raise Otsu2018Error('no partitions are possible') return self.best_partition @@ -448,7 +449,7 @@ class Node: plt.plot(self.wl, recon.values, 'C%d--' % i) -class Tree(Node): +class Otsu2018Tree(Node): """ This is an extension of ``Node``. It's meant to represent the root of the tree and contains information shared with all the nodes, such as cmfs @@ -576,7 +577,7 @@ class Tree(Node): try: error, axis, partition = leaf.find_best_partition() - except treeError as e: + except Otsu2018Error as e: _print('Failed: %s' % e) continue @@ -606,7 +607,111 @@ class Tree(Node): _print('Finished.') - def write_python_dataset(self, path): + def _create_selector_array(self): + """ + Create an array that describes how to select the appropriate cluster + for given *CIE xy* coordinates. See ``Otsu2018Dataset.select`` for + information about what the array looks like and how to use it. + """ + + rows = [] + leaf_number = 0 + symbol_table = {} + + def add_rows(node): + nonlocal leaf_number + + if node.leaf: + symbol_table[node] = leaf_number + leaf_number += 1 + return + + symbol_table[node] = -len(rows) + rows.append([node.partition_axis.direction, + node.partition_axis.origin, + node.children[0], + node.children[1]]) + + for child in node.children: + add_rows(child) + + add_rows(self) + + # Special case for trees with just the root + if len(rows) == 0: + return as_float_array([0., 0., 0., 0.]) + + for i, (_, _, symbol_1, symbol_2) in enumerate(rows): + rows[i][2] = symbol_table[symbol_1] + rows[i][3] = symbol_table[symbol_2] + + return as_float_array(rows) + + def to_dataset(self): + """ + Create an ``Otsu2018Dataset`` based on information stored in this tree. + The object can then be saved to disk or used in reflectance recovery. + + Returns + ======= + Otsu2018Dataset + The dataset object. + """ + + basis_functions = [leaf.basis_functions for leaf in self.leaves] + means = [leaf.mean for leaf in self.leaves] + selector_array = self._create_selector_array() + + return Otsu2018Dataset(self.shape, + basis_functions, + means, + selector_array) + + +class Otsu2018Dataset: + """ + Stores all the information needed for the *Otsu et al. (2018)* spectral + upsampling method. Datasets can be either generated and turned into + this form using ``Otsu2018Tree.to_dataset`` or loaded from disk. + + Attributes + ========== + shape: SpectralShape + Shape of the spectral data. + basis_functions : ndarray(n, 3, m) + Three basis functions for every cluster. + means : ndarray(n, m) + Mean for every cluster. + selector_array : ndarray(k, 4) + Array describing how to select the appropriate cluster. See + ``Otsu2018Dataset.select`` for details. + """ + + def __init__(self, + shape=None, + basis_functions=None, + means=None, + selector_array=None): + self.shape = shape + self.basis_functions = basis_functions + self.means = means + self.selector_array = selector_array + + def to_file(self, path): + """ + Saves the dataset to an .npz file. + """ + + shape_array = as_float_array([self.shape.start, self.shape.end, + self.shape.interval]) + + np.savez(path, + shape=shape_array, + basis_functions=self.basis_functions, + means=self.means, + selector_array=self.selector_array) + + def to_Python_file(self, path): """ Write the tree into a Python dataset compatible with Colour's ``colour.recovery.otsu2018`` code. @@ -626,46 +731,96 @@ class Tree(Node): fd.write('OTSU_2018_SPECTRAL_SHAPE = SpectralShape%s\n\n\n' % self.shape) - # Basis functions - - fd.write('OTSU_2018_BASIS_FUNCTIONS = [\n') - for i, leaf in enumerate(self.leaves): - for line in (repr(leaf.basis_functions) + ',').splitlines(): + def write_array(name, array): + fd.write('%s = [\n' % name) + for line in (repr(array) + ',').splitlines(): fd.write(' %s\n' % line) - fd.write(']\n\n\n') + fd.write(']\n\n\n') - # Means + write_array('OTSU_2018_BASIS_FUNCTIONS', self.basis_functions) + write_array('OTSU_2018_MEANS', self.means) + write_array('OTSU_2018_SELECTOR_ARRAY', self.selector_array) - fd.write('OTSU_2018_MEANS = [\n') - for leaf in self.leaves: - for line in (repr(leaf.mean) + ',').splitlines(): - fd.write(' %s\n' % line) - fd.write(']\n\n\n') + def from_file(self, path): + """ + Loads a dataset from an .npz file. + + Parameters + ========== + path : unicode + Path to file. + + Raises + ====== + ValueError, KeyError + Raised when loading the file succeeded but it did not contain the + expected data. + """ + + npz = np.load(path, allow_pickle=False) + if not isinstance(npz, np.lib.npyio.NpzFile): + raise ValueError('the loaded file is not an .npz file') + + start, end, interval = npz['shape'] + self.shape = SpectralShape(start, end, interval) + self.basis_functions = npz['basis_functions'] + self.means = npz['means'] + self.selector_array = npz['selector_array'] + + n, three, m = self.basis_functions.shape + if (three != 3 or self.means.shape != (n, m) + or self.selector_array.shape[1] != 4): + raise ValueError('array shapes are not correct, the file could be ' + 'corrupted or in a wrong format') - # Cluster selection function + def select(self, xy): + """ + Returns the cluster index appropriate for the given *CIE xy* + coordinates. - fd.write('def select_cluster_Otsu2018(xy):\n') - fd.write(' x, y = xy\n\n') + Parameters + ========== + ndarray : (2,) + *CIE xy* chromaticity coordinates. - counter = 0 + Returns + ======= + int + Cluster index. + """ - def write_if(node, indent): - nonlocal counter + i = 0 + while True: + row = self.selector_array[i, :] + direction, origin, lesser_index, greater_index = row - if node.leaf: - fd.write(' ' * indent) - fd.write('return %d # %s\n' % (counter, node)) - counter += 1 - return + if xy[int(direction)] <= origin: + index = int(lesser_index) + else: + index = int(greater_index) + + if index < 0: + i = -index + else: + return index - fd.write(' ' * indent) - fd.write('if %s <= %s:\n' - % ('xy'[node.partition_axis.direction], - repr(node.partition_axis.origin))) - write_if(node.children[0], indent + 1) + def cluster(self, xy): + """ + Returns the basis functions and dataset mean for the given *CIE xy* + coordinates. + + Parameters + ========== + ndarray : (2,) + *CIE xy* chromaticity coordinates. - fd.write(' ' * indent) - fd.write('else:\n') - write_if(node.children[1], indent + 1) + Returns + ======= + basis_functions : ndarray (3, n) + Three basis functions. + mean : ndarray (n,) + Dataset mean. + """ - write_if(self, 1) + index = self.select(xy) + return self.basis_functions[index, :, :], self.means[index, :] |