diff options
author | Paweł Redman <pawel.redman@gmail.com> | 2020-07-25 15:53:29 +0200 |
---|---|---|
committer | Paweł Redman <pawel.redman@gmail.com> | 2020-07-25 16:42:01 +0200 |
commit | 26c2bcbb77f9d2914be35e74e14333bc3e729c3e (patch) | |
tree | 00b2bb3104bd8cf11dc185d9a17c26348610290f | |
parent | 9c0dfa59598968ab5be656aac5c4f6f79c50dd61 (diff) |
Rename Clustering to Tree and make it a subclass of Node.
-rw-r--r-- | demo.py | 25 | ||||
-rw-r--r-- | otsu2018.py | 146 |
2 files changed, 76 insertions, 95 deletions
@@ -1,34 +1,33 @@ import os import matplotlib.pyplot as plt -from colour import ( - SpectralShape, COLOURCHECKER_SDS, ILLUMINANT_SDS, sd_to_XYZ) +from colour import SpectralShape, COLOURCHECKER_SDS, sd_to_XYZ -from otsu2018 import load_Otsu2018_spectra, Clustering +from otsu2018 import load_Otsu2018_spectra, Tree if __name__ == '__main__': print('Loading spectral data...') - sds = load_Otsu2018_spectra('CommonData/spectrum_m.csv', every_nth=1) + sds = load_Otsu2018_spectra('CommonData/spectrum_m.csv', every_nth=50) shape = SpectralShape(380, 730, 10) - print('Initializing the clustering...') - clustering = Clustering(sds, shape) + print('Initializing the tree...') + tree = Tree(sds, shape) - print('Clustering...') - before = clustering.root.total_reconstruction_error() - clustering.optimise() - after = clustering.root.total_reconstruction_error() + print('Tree...') + before = tree.total_reconstruction_error() + tree.optimise() + after = tree.total_reconstruction_error() print('Error before: %g' % before) print('Error after: %g' % after) print('Saving the dataset...') os.makedirs('datasets', exist_ok=True) - clustering.write_python_dataset('datasets/otsu2018.py') + tree.write_python_dataset('datasets/otsu2018.py') print('Plotting...') - clustering.root.visualise() + tree.visualise() plt.figure() @@ -40,7 +39,7 @@ if __name__ == '__main__': plt.plot(sd.wavelengths, sd.values, label='Original') XYZ = sd_to_XYZ(sd) / 100 - recovered_sd = clustering.reconstruct(XYZ) + recovered_sd = tree.reconstruct(XYZ) plt.plot(recovered_sd.wavelengths, recovered_sd.values, label='Recovered') diff --git a/otsu2018.py b/otsu2018.py index b17c518..f573b17 100644 --- a/otsu2018.py +++ b/otsu2018.py @@ -64,22 +64,22 @@ class Colours: """ Represents multiple colours: their reflectances, XYZ tristimulus values and xy coordinates. The cmfs and the illuminant are taken from the parent - Clustering. + tree. This class also supports partitioning, or creating two smaller instances of Colours, split along a horizontal or a vertical axis on the xy plane. """ - def __init__(self, clustering, reflectances): + def __init__(self, tree, reflectances): """ Parameters ========== - clustering : Clustering - The parent clustering. This determines what cmfs and illuminant + tree : tree + The parent tree. This determines what cmfs and illuminant are used in colourimetric calculations. reflectances : ndarray (n,m) Reflectances of the ``n`` colours to be stored in this class. - The shape must match ``Clustering.shape`` with ``m`` points for + The shape must match ``tree.shape`` with ``m`` points for each colour. """ @@ -88,8 +88,8 @@ class Colours: self.xy = np.empty((reflectances.shape[0], 2)) for i in range(len(self)): - sd = SpectralDistribution(reflectances[i, :], clustering.wl) - XYZ = sd_to_XYZ(sd, illuminant=clustering.illuminant) / 100 + sd = SpectralDistribution(reflectances[i, :], tree.wl) + XYZ = sd_to_XYZ(sd, illuminant=tree.illuminant) / 100 self.XYZ[i, :] = XYZ self.xy[i, :] = XYZ_to_xy(XYZ) @@ -130,7 +130,7 @@ class Colours: return lesser, greater -class ClusteringError(Exception): +class treeError(Exception): """ Exception used for various errors originating from code in this file. """ @@ -139,23 +139,23 @@ class ClusteringError(Exception): class Node: """ - Represents a node in the clustering tree. + Represents a node in the tree tree. """ _counter = 1 - def __init__(self, clustering, colours): + def __init__(self, tree, colours): """ Parameters ========== - clustering : Clustering - The parent clustering. This determines what cmfs and illuminant + tree : tree + The parent tree. This determines what cmfs and illuminant are used in colourimetric calculations. colours : Colours The colours that belong in this node. """ - self.clustering = clustering + self.tree = tree self.colours = colours self.children = None @@ -233,16 +233,30 @@ class Node: M = np.empty((3, 3)) for i in range(3): R = self.basis_functions[i, :] - M[:, i] = self.clustering.fast_sd_to_XYZ(R) + M[:, i] = self.tree.fast_sd_to_XYZ(R) self.M_inverse = np.linalg.inv(M) - self.XYZ_mu = self.clustering.fast_sd_to_XYZ(self.mean) + self.XYZ_mu = self.tree.fast_sd_to_XYZ(self.mean) self.PCA_done = True + def _reconstruct_xy(self, XYZ, xy): + if not self.leaf: + if xy[self.partition_axis.direction] <= self.partition_axis.origin: + return self.children[0]._reconstruct_xy(XYZ, xy) + else: + return self.children[1]._reconstruct_xy(XYZ, xy) + + weights = np.dot(self.M_inverse, XYZ - self.XYZ_mu) + reflectance = np.dot(weights, self.basis_functions) + self.mean + reflectance = np.clip(reflectance, 0, 1) + return SpectralDistribution(reflectance, self.tree.wl) + def reconstruct(self, XYZ): """ - Reconstructs a reflectance using data stored in this node. + Reconstructs the reflectance for the given *XYZ* tristimulus values. + If this is a leaf, data from this node will be used. Otherwise the + code will look for the appropriate subnode. Parameters ========== @@ -256,10 +270,8 @@ class Node: Recovered spectral distribution. """ - weights = np.dot(self.M_inverse, XYZ - self.XYZ_mu) - reflectance = np.dot(weights, self.basis_functions) + self.mean - reflectance = np.clip(reflectance, 0, 1) - return SpectralDistribution(reflectance, self.clustering.wl) + xy = XYZ_to_xy(XYZ) + return self._reconstruct_xy(XYZ, xy) # # Optimisation @@ -330,15 +342,15 @@ class Node: """ partition = self.colours.partition(axis) - if (len(partition[0]) < self.clustering.min_cluster_size - or len(partition[1]) < self.clustering.min_cluster_size): - raise ClusteringError( + if (len(partition[0]) < self.tree.min_cluster_size + or len(partition[1]) < self.tree.min_cluster_size): + raise treeError( 'partition created parts smaller than min_cluster_size') - lesser = Node(self.clustering, partition[0]) + lesser = Node(self.tree, partition[0]) lesser.PCA() - greater = Node(self.clustering, partition[1]) + greater = Node(self.tree, partition[1]) greater.PCA() error = lesser.reconstruction_error() + greater.reconstruction_error() @@ -347,7 +359,7 @@ class Node: def find_best_partition(self): """ Finds the best partition of this node. See - ``Clustering.find_best_partition``. + ``tree.find_best_partition``. """ if self.best_partition is not None: @@ -368,7 +380,7 @@ class Node: try: new_error, partition = self.partition_error(axis) - except ClusteringError: + except treeError: continue if new_error >= error: @@ -380,7 +392,7 @@ class Node: bar.close() if self.best_partition is None: - raise ClusteringError('no partitions are possible') + raise treeError('no partitions are possible') return self.best_partition @@ -414,7 +426,7 @@ class Node: Makes a plot showing the principal components of this node and how well they reconstruct the source data. """ - + plt.subplot(2, 1, 1) plt.title(str(self) + ': principal components') for i in range(3): @@ -433,13 +445,14 @@ class Node: plt.plot(self.wl, recon.values, 'C%d--' % i) -class Clustering: +class Tree(Node): """ - Represents the process of clustering and optimisation. Instances store - shared data such as cmfs and the illuminant. + This is an extension of ``Node``. It's meant to represent the root of the + tree and contains information shared with all the nodes, such as cmfs + and the illuminant (if any is used). Operations involving the entire tree, such as optimisation and - reconstruction, are implemented here. + reconstruction, are also implemented here. """ def __init__( @@ -462,7 +475,6 @@ class Clustering: Illuminant spectral distribution. """ - self.sds = sds self.shape = shape self.wl = shape.range() self.dw = self.wl[1] - self.wl[0] @@ -474,13 +486,12 @@ class Clustering: self.k = 1 / (np.sum(self.cmfs.values[:, 1] * self.illuminant.values) * self.dw) - colours = Colours(self, sds) - self.root = Node(self, colours) + super().__init__(self, Colours(self, sds)) def fast_sd_to_XYZ(self, R): """ Compute the XYZ tristimulus values of a given reflectance. Faster for - humans, by using cmfs and the illuminant stored in the ''Clustering'', + humans, by using cmfs and the illuminant stored in the ''tree'', thus avoiding unnecessary repetition. Faster for computers, by using a very simple and direct method. @@ -488,7 +499,7 @@ class Clustering: ---------- R : ndarray Reflectance with shape matching the one used to construct this - ``Clustering``. + ``tree``. Returns ------- @@ -499,41 +510,10 @@ class Clustering: E = self.illuminant.values * R return self.k * np.dot(E, self.cmfs.values) * self.dw - def reconstruct(self, XYZ): - """ - Finds the appropriate node and reconstructs the reflectance for the - given XYZ tristimulus values. - - Parameters - ========== - XYZ : ndarray, (3,) - *CIE XYZ* tristimulus values to recover the spectral distribution - from. - - Returns - ------- - SpectralDistribution - Recovered spectral distribution. - """ - - xy = XYZ_to_xy(XYZ) - - def search(node): - if node.leaf: - return node - - if xy[node.partition_axis.direction] <= node.partition_axis.origin: - return search(node.children[0]) - else: - return search(node.children[1]) - - node = search(self.root) - return node.reconstruct(XYZ) - def optimise(self, repeats=8, min_cluster_size=None): """ Optimise the tree by repeatedly performing optimal partitions of the - nodes, creating a clustering that minimizes the total reconstruction + nodes, creating a tree that minimizes the total reconstruction error. Parameters @@ -552,31 +532,32 @@ class Clustering: if min_cluster_size is not None: self.min_cluster_size = min_cluster_size else: - self.min_cluster_size = len(self.root.colours) / repeats // 2 + self.min_cluster_size = len(self.colours) / repeats // 2 if self.min_cluster_size <= 3: self.min_cluster_size = 3 t0 = time.time() + def elapsed(): delta = time.time() - t0 return '%dm%.3fs' % (delta // 60, delta % 60) - initial_error = self.root.total_reconstruction_error() + initial_error = self.total_reconstruction_error() print('Initial error is %g.' % initial_error) for repeat in range(repeats): print('\n=== Iteration %d of %d ===' % (repeat + 1, repeats)) best_total_error = None - total_error = self.root.total_reconstruction_error() + total_error = self.total_reconstruction_error() - for i, leaf in enumerate(self.root.leaves): + for i, leaf in enumerate(self.leaves): print('(%s) Optimising %s...' % (elapsed(), leaf)) try: error, axis, partition = leaf.find_best_partition() - except ClusteringError as e: + except treeError as e: print('Failed: %s' % e) continue @@ -597,9 +578,10 @@ class Clustering: print('\nSplit %s into %s and %s along %s.' % (best_leaf, *best_partition, best_axis)) print('Error is reduced by %g and is now %g, ' - '%.1f%% of the initial error.' % (leaf.reconstruction_error() - - error, best_total_error, 100 * best_total_error - / initial_error)) + '%.1f%% of the initial error.' + % (leaf.reconstruction_error() + - error, best_total_error, 100 * best_total_error + / initial_error)) best_leaf.split(best_partition, best_axis) @@ -607,7 +589,7 @@ class Clustering: def write_python_dataset(self, path): """ - Write the clustering into a Python dataset compatible with Colour's + Write the tree into a Python dataset compatible with Colour's ``colour.recovery.otsu2018`` code. Parameters @@ -628,7 +610,7 @@ class Clustering: # Basis functions fd.write('OTSU_2018_BASIS_FUNCTIONS = [\n') - for i, leaf in enumerate(self.root.leaves): + for i, leaf in enumerate(self.leaves): for line in (repr(leaf.basis_functions) + ',').splitlines(): fd.write(' %s\n' % line) fd.write(']\n\n\n') @@ -636,7 +618,7 @@ class Clustering: # Means fd.write('OTSU_2018_MEANS = [\n') - for leaf in self.root.leaves: + for leaf in self.leaves: for line in (repr(leaf.mean) + ',').splitlines(): fd.write(' %s\n' % line) fd.write(']\n\n\n') @@ -667,4 +649,4 @@ class Clustering: fd.write('else:\n') write_if(node.children[1], indent + 1) - write_if(self.root, 1) + write_if(self, 1) |