diff options
Diffstat (limited to 'otsu2018.py')
-rw-r--r-- | otsu2018.py | 600 |
1 files changed, 600 insertions, 0 deletions
diff --git a/otsu2018.py b/otsu2018.py new file mode 100644 index 0000000..4f84cd3 --- /dev/null +++ b/otsu2018.py @@ -0,0 +1,600 @@ +import numpy as np +import matplotlib.pyplot as plt +import sklearn.decomposition + +from colour import (SpectralDistribution, STANDARD_OBSERVER_CMFS, + ILLUMINANT_SDS, sd_to_XYZ, XYZ_to_xy) +from colour.plotting import plot_chromaticity_diagram_CIE1931 + + +def load_Otsu2018_spectra(path, every_nth=1): + """ + Loads a set of measured reflectances from Otsu et al.'s csv file. + + TODO: This function can't determine the spectral shape. + + Parameters + ---------- + path : str + File path. + every_nth : int + Load only every n-th spectrum. The original files are huge, so this can + be useful for testing. + """ + data = np.genfromtxt(path, delimiter=',', skip_header=1) + + # The first column is the id and is redundant + data = data[:, 1:] + + spectra = [] + for i in range(data.shape[0]): + if i % every_nth != 0: + continue + + values = data[i, :] + spectra.append(values) + + return np.array(spectra) + + +class Colours: + """ + Represents multiple colours: their reflectances, XYZ tristimulus values + and xy coordinates. The cmfs and the illuminant are taken from the parent + Clustering. + + This class also supports partitioning, or creating two smaller instances + of Colours, split along a horizontal or a vertical axis on the xy plane. + """ + + def __init__(self, clustering, reflectances): + """ + Parameters + ========== + clustering : Clustering + The parent clustering. This determines what cmfs and illuminant + are used in colourimetric calculations. + reflectances : ndarray (n,m) + Reflectances of the ``n`` colours to be stored in this class. + The shape must match ``Clustering.shape`` with ``m`` points for + each colour. + """ + + self.reflectances = reflectances + self.XYZ = np.empty((reflectances.shape[0], 3)) + self.xy = np.empty((reflectances.shape[0], 2)) + + for i in range(len(self)): + sd = SpectralDistribution(reflectances[i, :], clustering.wl) + XYZ = sd_to_XYZ(sd, illuminant=clustering.illuminant) / 100 + self.XYZ[i, :] = XYZ + self.xy[i, :] = XYZ_to_xy(XYZ) + + def __len__(self): + """ + Counts the number of colours in this object. + """ + return self.reflectances.shape[0] + + def partition(self, x_or_y, axis): + """ + Parameters + ========== + x_or_y : int + Whether to split according to X or Y coordinates. + axis : float + The coordinate that defines where the split happens. + + Returns + ======= + lesser : Colours + The left or lower part. + greater : Colours + The right or upper part. + """ + mask = self.xy[:, x_or_y] <= axis + + lesser = object.__new__(Colours) + greater = object.__new__(Colours) + + lesser.reflectances = self.reflectances[mask, :] + greater.reflectances = self.reflectances[np.logical_not(mask), :] + + lesser.XYZ = self.XYZ[mask, :] + greater.XYZ = self.XYZ[np.logical_not(mask), :] + + lesser.xy = self.xy[mask, :] + greater.xy = self.xy[np.logical_not(mask), :] + + return lesser, greater + + +class ClusteringError(Exception): + """ + Exception used for various errors originating from code in this file. + """ + pass + + +class Node: + """ + Represents a node in the clustering tree. + """ + + _counter = 1 + + def __init__(self, clustering, colours): + """ + Parameters + ========== + clustering : Clustering + The parent clustering. This determines what cmfs and illuminant + are used in colourimetric calculations. + colours : Colours + The colours that belong in this node. + """ + + self.clustering = clustering + self.colours = colours + self.children = None + + self._cached_reconstruction_error = None + self.PCA_done = False + + # This is just for __str__ and plots + self.number = Node._counter + Node._counter += 1 + + def __str__(self): + return 'Node #%d' % (self.number) + + @property + def leaf(self): + """ + Is this node a leaf? Tree leaves don't have any children and store + instances of ``Colours``. + """ + + return self.children is None + + def split(self, children, split_x_or_y, split_i): + """ + Turns a leaf into a node with the given children. + + Parameters + ========== + children : tuple + Two instances of ``Node`` in a tuple. + split_x_or_y : int + Split's ``x_or_y``. + split_axis : float + Split's ``axis``. + """ + + if not self.leaf: + raise RuntimeError( + 'Node.split called for a node that is not a leaf') + + self.children = children + self.split_x_or_y = split_x_or_y + self.split_axis = self.colours.xy[split_i, split_x_or_y] + self.colours = None + self._cached_reconstruction_error = None + + def _leaves_generator(self): + if self.leaf: + yield self + else: + for child in self.children: + yield from child.leaves + + @property + def leaves(self): + """ + Returns a generator of all leaves connected to this node. + """ + return self._leaves_generator() + + # + # PCA and reconstruction + # + + def PCA(self): + """ + Performs the principal component analysis on colours in this node. + """ + + if not self.leaf: + raise RuntimeError('Node.PCA called for a node that is not a leaf') + + pca = sklearn.decomposition.PCA(3) + pca.fit(self.colours.reflectances) + self.basis_functions = pca.components_ + self.mean = pca.mean_ + + # TODO: better names + M = np.empty((3, 3)) + for i in range(3): + R = self.basis_functions[i, :] + M[:, i] = self.clustering.fast_sd_to_XYZ(R) + + self.M_inverse = np.linalg.inv(M) + self.XYZ_mu = self.clustering.fast_sd_to_XYZ(self.mean) + + self.PCA_done = True + + def reconstruct(self, XYZ): + """ + Reconstructs a reflectance using data stored in this node. + + Parameters + ========== + XYZ : ndarray, (3,) + *CIE XYZ* tristimulus values to recover the spectral distribution + from. + + Returns + ------- + SpectralDistribution + Recovered spectral distribution. + """ + + weights = np.dot(self.M_inverse, XYZ - self.XYZ_mu) + reflectance = np.dot(weights, self.basis_functions) + self.mean + return SpectralDistribution(reflectance, self.clustering.wl) + + def reconstruction_error(self): + """ + For every colour in this node, its spectrum is reconstructed (using + PCA data in this node) and compared with its true, measured spectrum. + The errors are then summed up and returned. + + Returns + ======= + error : float + The sum reconstruction errors for this node. + """ + + if self._cached_reconstruction_error: + return self._cached_reconstruction_error + + if not self.PCA_done: # FIXME + self.PCA() + + error = 0 + for i in range(len(self.colours)): + sd = self.colours.reflectances[i, :] + XYZ = self.colours.XYZ[i, :] + recovered_sd = self.reconstruct(XYZ) + error += np.sum((sd - recovered_sd.values) ** 2) + + self._cached_reconstruction_error = error + return error + + def total_reconstruction_error(self): + """ + Computes the reconstruction error for an entire subtree, starting at + this node. + + Returns + ======= + error : float + The total reconstruction error of the subtree. + """ + + if self.leaf: + return self.reconstruction_error() + else: + return sum([child.total_reconstruction_error() + for child in self.children]) + + def partition(self, x_or_y, i): + """ + Splits this node into two and returns them. This operation does not + affect the node it's used on. ``Node.split`` has to be called (with + data returned from this method) to actually alter the tree. + + Parameters + ========== + x_or_y : int + Whether to split according to X or Y coordinates. + i : int + The index of the colour whose coordinates determine where the + split happens. Cannot be ``len(Node.colours)`` or greater. + + Returns + ======= + lesser : Node + The left or lower part. + greater : Node + The right or upper part. + """ + + axis = self.colours.xy[i, x_or_y] + partition = self.colours.partition(x_or_y, axis) + + if len(partition[0]) <= 5 or len(partition[1]) <= 5: + raise ClusteringError('partition created parts that are too small') + + lesser = Node(self.clustering, partition[0]) + greater = Node(self.clustering, partition[1]) + return lesser, greater + + def split_quality(self, x_or_y, i): + """ + + Parameters + ========== + x_or_y : int + Whether to split according to X or Y coordinates. + i : int + Index of the colour whose coordinates determine where the + split happens. Cannot be ``len(Node.colours)`` or greater. + + Returns + ======= + error : float + Sum of reconstruction errors of the two nodes created from + splitting. + lesser, greater : tuple + Subnodes created from splitting. + """ + lesser, greater = self.partition(x_or_y, i) + + lesser.PCA() + greater.PCA() + + error = lesser.reconstruction_error() + greater.reconstruction_error() + return error, (lesser, greater) + + # + # Plotting + # + + def _plot_colours(self, number): + if not self.leaf: + for child in self.children: + child._plot_colours(number) + return + + symbols = ['+', '^', '*', '>', 'o', 'v', 'x', '<'] + plt.plot(*self.colours.xy.T, + "k" + symbols[number[0] % len(symbols)], + label=str(self)) + number[0] += 1 + + def visualise(self): + """ + Plots the subtree on a xy chromaticity diagram. Does not call + ``plt.show``. + """ + plot_chromaticity_diagram_CIE1931(standalone=False) + self._plot_colours([0]) + plt.legend() + + +class Clustering: + """ + Represents the process of clustering and optimisation. Instances store + shared data such as cmfs and the illuminant. + + Operations involving the entire tree, such as optimisation and + reconstruction, are implemented here. + """ + + def __init__( + self, + sds, + shape, + cmfs=STANDARD_OBSERVER_CMFS['CIE 1931 2 Degree Standard Observer'], + illuminant=ILLUMINANT_SDS['D65']): + """ + Parameters + ---------- + sds : ndarray (n,m) + Reflectances of the ``n`` reference colours to be used for + optimisation. + shape : SpectralShape + Spectral shape of ``sds``. + cmfs : XYZ_ColourMatchingFunctions, optional + Standard observer colour matching functions. + illuminant : SpectralDistribution, optional + Illuminant spectral distribution. + """ + + self.sds = sds + self.shape = shape + self.wl = shape.range() + self.dw = self.wl[1] - self.wl[0] + self.cmfs = cmfs.copy().align(shape) + self.illuminant = illuminant.copy().align(shape) + self.xy_w = XYZ_to_xy(sd_to_XYZ(illuminant, cmfs=cmfs)) + + # The normalising constant used in sd_to_XYZ. + self.k = 1 / (np.sum(self.cmfs.values[:, 1] + * self.illuminant.values) * self.dw) + + colours = Colours(self, sds) + self.root = Node(self, colours) + + def fast_sd_to_XYZ(self, R): + """ + Compute the XYZ tristimulus values of a given reflectance. Faster for + humans, by using cmfs and the illuminant stored in the ''Clustering'', + thus avoiding unnecessary repetition. Faster for computers, by using + a very simple and direct method. + + Parameters + ---------- + R : ndarray + Reflectance with shape matching the one used to construct this + ``Clustering``. + + Returns + ------- + ndarray (3,) + XYZ tristimulus values, normalised to 1. + """ + + E = self.illuminant.values * R + return self.k * np.dot(E, self.cmfs.values) * self.dw + + def reconstruct(self, XYZ): + """ + Finds the appropriate node and reconstructs the reflectance for the + given XYZ tristimulus values. + + Parameters + ========== + XYZ : ndarray, (3,) + *CIE XYZ* tristimulus values to recover the spectral distribution + from. + + Returns + ------- + SpectralDistribution + Recovered spectral distribution. + """ + + xy = XYZ_to_xy(XYZ) + + def search(node): + if node.leaf: + return node + + if xy[node.split_x_or_y] >= node.split_axis: + return search(node.children[0]) + else: + return search(node.children[1]) + + node = search(self.root) + return node.reconstruct(XYZ) + + def find_best_split(self): + """ + Check every possible split in the entire tree to find the one that will + reduce the error the most. + + Returns + ------- + best_split : (Node, int, int) + Tuple representing the best split found. It contains the ``Node`` + that should be split, the split direction (``x_or_y``) and + the colour index (``i``). + best_partition : (Node, Node) + Subnodes to be used as children for the leaf. + + Raises + ------ + ClusteringError + If the tree has already been split too finely, further splits will + not be possible and this exception will be raised. + """ + + best_new_error = None + total_error = self.root.total_reconstruction_error() + + for leaf in self.root.leaves: + total_error_minus_leaf = total_error - leaf.reconstruction_error() + + for x_or_y in [0, 1]: + for i in range(len(leaf.colours)): + try: + split_error, partition = leaf.split_quality(x_or_y, i) + except ClusteringError: + continue + + new_error = total_error_minus_leaf + split_error + + if best_new_error is None or new_error < best_new_error: + best_new_error = new_error + best_split = (leaf, x_or_y, i) + best_partition = partition + + print('%10s %s %4d %g' + % (leaf, ['x', 'y'][x_or_y], i, new_error)) + + if best_new_error is None: + raise ClusteringError('no more splits were possible') + + return best_split, best_partition + + def do_best_splits(self, repeats): + """ + Find the best split and perform it, and repeat the operation the + specified amount of times. + + Parameters + ---------- + repeats : int + Number of splits. + """ + for repeat in range(repeats): + try: + (leaf, x_or_y, i), partition = self.find_best_split() + except ClusteringError: + print('WARNING: only %d splits were possible' % repeat) + break + + print('==== Splitting %s, x_or_y=%d, i=%d ====' + % (leaf, x_or_y, i)) + leaf.split(partition, x_or_y, i) + + def write_python_dataset(self, path): + """ + Write the clustering into a Python dataset compatible with Colour's + ``colour.recovery.otsu2018`` code. + + Parameters + ---------- + path : string + File path. + """ + + with open(path, 'w') as fd: + fd.write('# Autogenerated, do not modify\n\n') + + fd.write('from numpy import array\n') + fd.write('from colour import SpectralShape\n\n\n') + + fd.write('OTSU_2018_SPECTRAL_SHAPE = SpectralShape%s\n\n\n' + % self.shape) + + # Basis functions + + fd.write('OTSU_2018_BASIS_FUNCTIONS = [\n') + for i, leaf in enumerate(self.root.leaves): + leaf._i = i # For use when writing the selection function + for line in (repr(leaf.basis_functions) + ',').splitlines(): + fd.write(' %s\n' % line) + fd.write(']\n\n\n') + + # Means + + fd.write('OTSU_2018_MEANS = [\n') + for leaf in self.root.leaves: + for line in (repr(leaf.mean) + ',').splitlines(): + fd.write(' %s\n' % line) + fd.write(']\n\n\n') + + # Cluster selection function + + fd.write('def select_cluster_Otsu2018(xy):\n') + fd.write(' x, y = xy\n\n') + + def write_if(node, indent): + if node.leaf: + fd.write(' ' * indent) + fd.write('return %d # %s\n' % (node._i, node)) + return + + fd.write(' ' * indent) + fd.write('if %s <= %s:\n' % (['x', 'y'][node.split_x_or_y], + repr(node.split_axis))) + write_if(node.children[0], indent + 1) + + fd.write(' ' * indent) + fd.write('else:\n') + write_if(node.children[1], indent + 1) + + write_if(self.root, 1) |