summaryrefslogtreecommitdiff
path: root/otsu2018.py
diff options
context:
space:
mode:
Diffstat (limited to 'otsu2018.py')
-rw-r--r--otsu2018.py73
1 files changed, 46 insertions, 27 deletions
diff --git a/otsu2018.py b/otsu2018.py
index f573b17..6d15580 100644
--- a/otsu2018.py
+++ b/otsu2018.py
@@ -1,6 +1,5 @@
import numpy as np
import matplotlib.pyplot as plt
-import tqdm
import time
from colour import (SpectralDistribution, STANDARD_OBSERVER_CMFS,
@@ -363,17 +362,20 @@ class Node:
"""
if self.best_partition is not None:
- print('%s: already optimised' % self)
return self.best_partition
error = self.reconstruction_error()
-
best_error = None
- bar = tqdm.trange(2 * len(self.colours), leave=False)
+
+ bar = None
+ if self.tree._progress_bar:
+ bar = self.tree._progress_bar(total=2 * len(self.colours),
+ leave=False)
for direction in [0, 1]:
for i in range(len(self.colours)):
- bar.update()
+ if bar:
+ bar.update()
origin = self.colours.xy[i, direction]
axis = PartitionAxis(origin, direction)
@@ -389,7 +391,8 @@ class Node:
if best_error is None or new_error < best_error:
self.best_partition = (new_error, axis, partition)
- bar.close()
+ if bar:
+ bar.close()
if self.best_partition is None:
raise treeError('no partitions are possible')
@@ -510,7 +513,11 @@ class Tree(Node):
E = self.illuminant.values * R
return self.k * np.dot(E, self.cmfs.values) * self.dw
- def optimise(self, repeats=8, min_cluster_size=None):
+ def optimise(self,
+ repeats=8,
+ min_cluster_size=None,
+ print_callback=print,
+ progress_bar=None):
"""
Optimise the tree by repeatedly performing optimal partitions of the
nodes, creating a tree that minimizes the total reconstruction
@@ -527,8 +534,26 @@ class Tree(Node):
automatically, based on the size of the dataset and desired number
of clusters. Must be at least 3 or principal component analysis
will not be possible.
+ print_callback : function, optional
+ Function to use for printing progress and diagnostic information.
+ progress_bar : class, optional
+ Class for creating progress bar objects. Must be compatible with
+ tqdm.
"""
+ t0 = time.time()
+
+ def _print(text):
+ if print_callback is None:
+ return
+
+ delta = time.time() - t0
+ stamp = '%3d:%02d ' % (delta // 60, np.floor(delta % 60))
+ for line in text.splitlines():
+ print_callback(stamp, line)
+
+ self._progress_bar = progress_bar
+
if min_cluster_size is not None:
self.min_cluster_size = min_cluster_size
else:
@@ -537,28 +562,22 @@ class Tree(Node):
if self.min_cluster_size <= 3:
self.min_cluster_size = 3
- t0 = time.time()
-
- def elapsed():
- delta = time.time() - t0
- return '%dm%.3fs' % (delta // 60, delta % 60)
-
initial_error = self.total_reconstruction_error()
- print('Initial error is %g.' % initial_error)
+ _print('Initial error is %g.' % initial_error)
for repeat in range(repeats):
- print('\n=== Iteration %d of %d ===' % (repeat + 1, repeats))
+ _print('\n=== Iteration %d of %d ===' % (repeat + 1, repeats))
best_total_error = None
total_error = self.total_reconstruction_error()
for i, leaf in enumerate(self.leaves):
- print('(%s) Optimising %s...' % (elapsed(), leaf))
+ _print('Optimising %s...' % leaf)
try:
error, axis, partition = leaf.find_best_partition()
except treeError as e:
- print('Failed: %s' % e)
+ _print('Failed: %s' % e)
continue
new_total_error = (total_error - leaf.reconstruction_error()
@@ -571,21 +590,21 @@ class Tree(Node):
best_partition = partition
if best_total_error is None:
- print('\nNo further improvements are possible.\n'
- 'Terminating at iteration %d.\n' % repeat)
+ _print('\nNo further improvements are possible.\n'
+ 'Terminating at iteration %d.\n' % repeat)
break
- print('\nSplit %s into %s and %s along %s.'
- % (best_leaf, *best_partition, best_axis))
- print('Error is reduced by %g and is now %g, '
- '%.1f%% of the initial error.'
- % (leaf.reconstruction_error()
- - error, best_total_error, 100 * best_total_error
- / initial_error))
+ _print('\nSplit %s into %s and %s along %s.'
+ % (best_leaf, *best_partition, best_axis))
+ _print('Error is reduced by %g and is now %g, '
+ '%.1f%% of the initial error.'
+ % (leaf.reconstruction_error()
+ - error, best_total_error, 100 * best_total_error
+ / initial_error))
best_leaf.split(best_partition, best_axis)
- print('Finished in %s.' % elapsed())
+ _print('Finished.')
def write_python_dataset(self, path):
"""