From f88a852f659ea3223ce69ac147912b1d346cf7be Mon Sep 17 00:00:00 2001 From: Paweł Redman Date: Sat, 25 Jul 2020 17:22:36 +0200 Subject: Factor out printing progress information. This way using tqdm is optional and not a direct dependency of otsu2018.py. --- otsu2018.py | 73 ++++++++++++++++++++++++++++++++++++++----------------------- 1 file changed, 46 insertions(+), 27 deletions(-) (limited to 'otsu2018.py') diff --git a/otsu2018.py b/otsu2018.py index f573b17..6d15580 100644 --- a/otsu2018.py +++ b/otsu2018.py @@ -1,6 +1,5 @@ import numpy as np import matplotlib.pyplot as plt -import tqdm import time from colour import (SpectralDistribution, STANDARD_OBSERVER_CMFS, @@ -363,17 +362,20 @@ class Node: """ if self.best_partition is not None: - print('%s: already optimised' % self) return self.best_partition error = self.reconstruction_error() - best_error = None - bar = tqdm.trange(2 * len(self.colours), leave=False) + + bar = None + if self.tree._progress_bar: + bar = self.tree._progress_bar(total=2 * len(self.colours), + leave=False) for direction in [0, 1]: for i in range(len(self.colours)): - bar.update() + if bar: + bar.update() origin = self.colours.xy[i, direction] axis = PartitionAxis(origin, direction) @@ -389,7 +391,8 @@ class Node: if best_error is None or new_error < best_error: self.best_partition = (new_error, axis, partition) - bar.close() + if bar: + bar.close() if self.best_partition is None: raise treeError('no partitions are possible') @@ -510,7 +513,11 @@ class Tree(Node): E = self.illuminant.values * R return self.k * np.dot(E, self.cmfs.values) * self.dw - def optimise(self, repeats=8, min_cluster_size=None): + def optimise(self, + repeats=8, + min_cluster_size=None, + print_callback=print, + progress_bar=None): """ Optimise the tree by repeatedly performing optimal partitions of the nodes, creating a tree that minimizes the total reconstruction @@ -527,8 +534,26 @@ class Tree(Node): automatically, based on the size of the dataset and desired number of clusters. Must be at least 3 or principal component analysis will not be possible. + print_callback : function, optional + Function to use for printing progress and diagnostic information. + progress_bar : class, optional + Class for creating progress bar objects. Must be compatible with + tqdm. """ + t0 = time.time() + + def _print(text): + if print_callback is None: + return + + delta = time.time() - t0 + stamp = '%3d:%02d ' % (delta // 60, np.floor(delta % 60)) + for line in text.splitlines(): + print_callback(stamp, line) + + self._progress_bar = progress_bar + if min_cluster_size is not None: self.min_cluster_size = min_cluster_size else: @@ -537,28 +562,22 @@ class Tree(Node): if self.min_cluster_size <= 3: self.min_cluster_size = 3 - t0 = time.time() - - def elapsed(): - delta = time.time() - t0 - return '%dm%.3fs' % (delta // 60, delta % 60) - initial_error = self.total_reconstruction_error() - print('Initial error is %g.' % initial_error) + _print('Initial error is %g.' % initial_error) for repeat in range(repeats): - print('\n=== Iteration %d of %d ===' % (repeat + 1, repeats)) + _print('\n=== Iteration %d of %d ===' % (repeat + 1, repeats)) best_total_error = None total_error = self.total_reconstruction_error() for i, leaf in enumerate(self.leaves): - print('(%s) Optimising %s...' % (elapsed(), leaf)) + _print('Optimising %s...' % leaf) try: error, axis, partition = leaf.find_best_partition() except treeError as e: - print('Failed: %s' % e) + _print('Failed: %s' % e) continue new_total_error = (total_error - leaf.reconstruction_error() @@ -571,21 +590,21 @@ class Tree(Node): best_partition = partition if best_total_error is None: - print('\nNo further improvements are possible.\n' - 'Terminating at iteration %d.\n' % repeat) + _print('\nNo further improvements are possible.\n' + 'Terminating at iteration %d.\n' % repeat) break - print('\nSplit %s into %s and %s along %s.' - % (best_leaf, *best_partition, best_axis)) - print('Error is reduced by %g and is now %g, ' - '%.1f%% of the initial error.' - % (leaf.reconstruction_error() - - error, best_total_error, 100 * best_total_error - / initial_error)) + _print('\nSplit %s into %s and %s along %s.' + % (best_leaf, *best_partition, best_axis)) + _print('Error is reduced by %g and is now %g, ' + '%.1f%% of the initial error.' + % (leaf.reconstruction_error() + - error, best_total_error, 100 * best_total_error + / initial_error)) best_leaf.split(best_partition, best_axis) - print('Finished in %s.' % elapsed()) + _print('Finished.') def write_python_dataset(self, path): """ -- cgit