diff options
author | Paweł Redman <pawel.redman@gmail.com> | 2020-07-25 17:22:36 +0200 |
---|---|---|
committer | Paweł Redman <pawel.redman@gmail.com> | 2020-07-25 17:22:36 +0200 |
commit | f88a852f659ea3223ce69ac147912b1d346cf7be (patch) | |
tree | bfd5150356bc0b68b31eb23b8ad26202bf6f085f | |
parent | 26c2bcbb77f9d2914be35e74e14333bc3e729c3e (diff) |
Factor out printing progress information.
This way using tqdm is optional and not a direct dependency of otsu2018.py.
-rw-r--r-- | clustering.py | 52 | ||||
-rw-r--r-- | otsu2018.py | 73 |
2 files changed, 98 insertions, 27 deletions
diff --git a/clustering.py b/clustering.py new file mode 100644 index 0000000..527830b --- /dev/null +++ b/clustering.py @@ -0,0 +1,52 @@ +import os +import matplotlib.pyplot as plt +import tqdm + +from colour import SpectralShape, COLOURCHECKER_SDS, sd_to_XYZ + +from otsu2018 import load_Otsu2018_spectra, Tree + + +if __name__ == '__main__': + print('Loading spectral data...') + sds = load_Otsu2018_spectra('CommonData/spectrum_m.csv', every_nth=50) + shape = SpectralShape(380, 730, 10) + + print('Initializing the tree...') + tree = Tree(sds, shape) + + print('Clustering...') + before = tree.total_reconstruction_error() + tree.optimise(progress_bar=tqdm.tqdm) + after = tree.total_reconstruction_error() + + print('Error before: %g' % before) + print('Error after: %g' % after) + + print('Saving the dataset...') + os.makedirs('datasets', exist_ok=True) + tree.write_python_dataset('datasets/otsu2018.py') + + print('Plotting...') + tree.visualise() + + plt.figure() + + examples = COLOURCHECKER_SDS['ColorChecker N Ohta'].items() + for i, (name, sd) in enumerate(examples): + plt.subplot(2, 3, 1 + i) + plt.title(name) + + plt.plot(sd.wavelengths, sd.values, label='Original') + + XYZ = sd_to_XYZ(sd) / 100 + recovered_sd = tree.reconstruct(XYZ) + plt.plot(recovered_sd.wavelengths, recovered_sd.values, + label='Recovered') + + plt.legend() + + if i + 1 == 6: + break + + plt.show() diff --git a/otsu2018.py b/otsu2018.py index f573b17..6d15580 100644 --- a/otsu2018.py +++ b/otsu2018.py @@ -1,6 +1,5 @@ import numpy as np import matplotlib.pyplot as plt -import tqdm import time from colour import (SpectralDistribution, STANDARD_OBSERVER_CMFS, @@ -363,17 +362,20 @@ class Node: """ if self.best_partition is not None: - print('%s: already optimised' % self) return self.best_partition error = self.reconstruction_error() - best_error = None - bar = tqdm.trange(2 * len(self.colours), leave=False) + + bar = None + if self.tree._progress_bar: + bar = self.tree._progress_bar(total=2 * len(self.colours), + leave=False) for direction in [0, 1]: for i in range(len(self.colours)): - bar.update() + if bar: + bar.update() origin = self.colours.xy[i, direction] axis = PartitionAxis(origin, direction) @@ -389,7 +391,8 @@ class Node: if best_error is None or new_error < best_error: self.best_partition = (new_error, axis, partition) - bar.close() + if bar: + bar.close() if self.best_partition is None: raise treeError('no partitions are possible') @@ -510,7 +513,11 @@ class Tree(Node): E = self.illuminant.values * R return self.k * np.dot(E, self.cmfs.values) * self.dw - def optimise(self, repeats=8, min_cluster_size=None): + def optimise(self, + repeats=8, + min_cluster_size=None, + print_callback=print, + progress_bar=None): """ Optimise the tree by repeatedly performing optimal partitions of the nodes, creating a tree that minimizes the total reconstruction @@ -527,8 +534,26 @@ class Tree(Node): automatically, based on the size of the dataset and desired number of clusters. Must be at least 3 or principal component analysis will not be possible. + print_callback : function, optional + Function to use for printing progress and diagnostic information. + progress_bar : class, optional + Class for creating progress bar objects. Must be compatible with + tqdm. """ + t0 = time.time() + + def _print(text): + if print_callback is None: + return + + delta = time.time() - t0 + stamp = '%3d:%02d ' % (delta // 60, np.floor(delta % 60)) + for line in text.splitlines(): + print_callback(stamp, line) + + self._progress_bar = progress_bar + if min_cluster_size is not None: self.min_cluster_size = min_cluster_size else: @@ -537,28 +562,22 @@ class Tree(Node): if self.min_cluster_size <= 3: self.min_cluster_size = 3 - t0 = time.time() - - def elapsed(): - delta = time.time() - t0 - return '%dm%.3fs' % (delta // 60, delta % 60) - initial_error = self.total_reconstruction_error() - print('Initial error is %g.' % initial_error) + _print('Initial error is %g.' % initial_error) for repeat in range(repeats): - print('\n=== Iteration %d of %d ===' % (repeat + 1, repeats)) + _print('\n=== Iteration %d of %d ===' % (repeat + 1, repeats)) best_total_error = None total_error = self.total_reconstruction_error() for i, leaf in enumerate(self.leaves): - print('(%s) Optimising %s...' % (elapsed(), leaf)) + _print('Optimising %s...' % leaf) try: error, axis, partition = leaf.find_best_partition() except treeError as e: - print('Failed: %s' % e) + _print('Failed: %s' % e) continue new_total_error = (total_error - leaf.reconstruction_error() @@ -571,21 +590,21 @@ class Tree(Node): best_partition = partition if best_total_error is None: - print('\nNo further improvements are possible.\n' - 'Terminating at iteration %d.\n' % repeat) + _print('\nNo further improvements are possible.\n' + 'Terminating at iteration %d.\n' % repeat) break - print('\nSplit %s into %s and %s along %s.' - % (best_leaf, *best_partition, best_axis)) - print('Error is reduced by %g and is now %g, ' - '%.1f%% of the initial error.' - % (leaf.reconstruction_error() - - error, best_total_error, 100 * best_total_error - / initial_error)) + _print('\nSplit %s into %s and %s along %s.' + % (best_leaf, *best_partition, best_axis)) + _print('Error is reduced by %g and is now %g, ' + '%.1f%% of the initial error.' + % (leaf.reconstruction_error() + - error, best_total_error, 100 * best_total_error + / initial_error)) best_leaf.split(best_partition, best_axis) - print('Finished in %s.' % elapsed()) + _print('Finished.') def write_python_dataset(self, path): """ |