summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPaweł Redman <pawel.redman@gmail.com>2020-07-25 17:22:36 +0200
committerPaweł Redman <pawel.redman@gmail.com>2020-07-25 17:22:36 +0200
commitf88a852f659ea3223ce69ac147912b1d346cf7be (patch)
treebfd5150356bc0b68b31eb23b8ad26202bf6f085f
parent26c2bcbb77f9d2914be35e74e14333bc3e729c3e (diff)
Factor out printing progress information.
This way using tqdm is optional and not a direct dependency of otsu2018.py.
-rw-r--r--clustering.py52
-rw-r--r--otsu2018.py73
2 files changed, 98 insertions, 27 deletions
diff --git a/clustering.py b/clustering.py
new file mode 100644
index 0000000..527830b
--- /dev/null
+++ b/clustering.py
@@ -0,0 +1,52 @@
+import os
+import matplotlib.pyplot as plt
+import tqdm
+
+from colour import SpectralShape, COLOURCHECKER_SDS, sd_to_XYZ
+
+from otsu2018 import load_Otsu2018_spectra, Tree
+
+
+if __name__ == '__main__':
+ print('Loading spectral data...')
+ sds = load_Otsu2018_spectra('CommonData/spectrum_m.csv', every_nth=50)
+ shape = SpectralShape(380, 730, 10)
+
+ print('Initializing the tree...')
+ tree = Tree(sds, shape)
+
+ print('Clustering...')
+ before = tree.total_reconstruction_error()
+ tree.optimise(progress_bar=tqdm.tqdm)
+ after = tree.total_reconstruction_error()
+
+ print('Error before: %g' % before)
+ print('Error after: %g' % after)
+
+ print('Saving the dataset...')
+ os.makedirs('datasets', exist_ok=True)
+ tree.write_python_dataset('datasets/otsu2018.py')
+
+ print('Plotting...')
+ tree.visualise()
+
+ plt.figure()
+
+ examples = COLOURCHECKER_SDS['ColorChecker N Ohta'].items()
+ for i, (name, sd) in enumerate(examples):
+ plt.subplot(2, 3, 1 + i)
+ plt.title(name)
+
+ plt.plot(sd.wavelengths, sd.values, label='Original')
+
+ XYZ = sd_to_XYZ(sd) / 100
+ recovered_sd = tree.reconstruct(XYZ)
+ plt.plot(recovered_sd.wavelengths, recovered_sd.values,
+ label='Recovered')
+
+ plt.legend()
+
+ if i + 1 == 6:
+ break
+
+ plt.show()
diff --git a/otsu2018.py b/otsu2018.py
index f573b17..6d15580 100644
--- a/otsu2018.py
+++ b/otsu2018.py
@@ -1,6 +1,5 @@
import numpy as np
import matplotlib.pyplot as plt
-import tqdm
import time
from colour import (SpectralDistribution, STANDARD_OBSERVER_CMFS,
@@ -363,17 +362,20 @@ class Node:
"""
if self.best_partition is not None:
- print('%s: already optimised' % self)
return self.best_partition
error = self.reconstruction_error()
-
best_error = None
- bar = tqdm.trange(2 * len(self.colours), leave=False)
+
+ bar = None
+ if self.tree._progress_bar:
+ bar = self.tree._progress_bar(total=2 * len(self.colours),
+ leave=False)
for direction in [0, 1]:
for i in range(len(self.colours)):
- bar.update()
+ if bar:
+ bar.update()
origin = self.colours.xy[i, direction]
axis = PartitionAxis(origin, direction)
@@ -389,7 +391,8 @@ class Node:
if best_error is None or new_error < best_error:
self.best_partition = (new_error, axis, partition)
- bar.close()
+ if bar:
+ bar.close()
if self.best_partition is None:
raise treeError('no partitions are possible')
@@ -510,7 +513,11 @@ class Tree(Node):
E = self.illuminant.values * R
return self.k * np.dot(E, self.cmfs.values) * self.dw
- def optimise(self, repeats=8, min_cluster_size=None):
+ def optimise(self,
+ repeats=8,
+ min_cluster_size=None,
+ print_callback=print,
+ progress_bar=None):
"""
Optimise the tree by repeatedly performing optimal partitions of the
nodes, creating a tree that minimizes the total reconstruction
@@ -527,8 +534,26 @@ class Tree(Node):
automatically, based on the size of the dataset and desired number
of clusters. Must be at least 3 or principal component analysis
will not be possible.
+ print_callback : function, optional
+ Function to use for printing progress and diagnostic information.
+ progress_bar : class, optional
+ Class for creating progress bar objects. Must be compatible with
+ tqdm.
"""
+ t0 = time.time()
+
+ def _print(text):
+ if print_callback is None:
+ return
+
+ delta = time.time() - t0
+ stamp = '%3d:%02d ' % (delta // 60, np.floor(delta % 60))
+ for line in text.splitlines():
+ print_callback(stamp, line)
+
+ self._progress_bar = progress_bar
+
if min_cluster_size is not None:
self.min_cluster_size = min_cluster_size
else:
@@ -537,28 +562,22 @@ class Tree(Node):
if self.min_cluster_size <= 3:
self.min_cluster_size = 3
- t0 = time.time()
-
- def elapsed():
- delta = time.time() - t0
- return '%dm%.3fs' % (delta // 60, delta % 60)
-
initial_error = self.total_reconstruction_error()
- print('Initial error is %g.' % initial_error)
+ _print('Initial error is %g.' % initial_error)
for repeat in range(repeats):
- print('\n=== Iteration %d of %d ===' % (repeat + 1, repeats))
+ _print('\n=== Iteration %d of %d ===' % (repeat + 1, repeats))
best_total_error = None
total_error = self.total_reconstruction_error()
for i, leaf in enumerate(self.leaves):
- print('(%s) Optimising %s...' % (elapsed(), leaf))
+ _print('Optimising %s...' % leaf)
try:
error, axis, partition = leaf.find_best_partition()
except treeError as e:
- print('Failed: %s' % e)
+ _print('Failed: %s' % e)
continue
new_total_error = (total_error - leaf.reconstruction_error()
@@ -571,21 +590,21 @@ class Tree(Node):
best_partition = partition
if best_total_error is None:
- print('\nNo further improvements are possible.\n'
- 'Terminating at iteration %d.\n' % repeat)
+ _print('\nNo further improvements are possible.\n'
+ 'Terminating at iteration %d.\n' % repeat)
break
- print('\nSplit %s into %s and %s along %s.'
- % (best_leaf, *best_partition, best_axis))
- print('Error is reduced by %g and is now %g, '
- '%.1f%% of the initial error.'
- % (leaf.reconstruction_error()
- - error, best_total_error, 100 * best_total_error
- / initial_error))
+ _print('\nSplit %s into %s and %s along %s.'
+ % (best_leaf, *best_partition, best_axis))
+ _print('Error is reduced by %g and is now %g, '
+ '%.1f%% of the initial error.'
+ % (leaf.reconstruction_error()
+ - error, best_total_error, 100 * best_total_error
+ / initial_error))
best_leaf.split(best_partition, best_axis)
- print('Finished in %s.' % elapsed())
+ _print('Finished.')
def write_python_dataset(self, path):
"""