diff options
Diffstat (limited to 'otsu2018.py')
-rw-r--r-- | otsu2018.py | 85 |
1 files changed, 68 insertions, 17 deletions
diff --git a/otsu2018.py b/otsu2018.py index e188950..88ab261 100644 --- a/otsu2018.py +++ b/otsu2018.py @@ -168,7 +168,7 @@ class Node: Node._counter += 1 def __str__(self): - return 'Node #%d' % (self.number) + return 'Node #%d (%d)' % (self.number, len(self.colours)) @property def leaf(self): @@ -283,7 +283,7 @@ class Node: if self._cached_reconstruction_error: return self._cached_reconstruction_error - if not self.PCA_done: # FIXME + if not self.PCA_done: self.PCA() error = 0 @@ -333,9 +333,10 @@ class Node: """ partition = self.colours.partition(axis) - if len(partition[0]) < 3 or len(partition[1]) < 3: + if (len(partition[0]) < self.clustering.min_cluster_size + or len(partition[1]) < self.clustering.min_cluster_size): raise ClusteringError( - 'partition created parts that are too small for PCA') + 'partition created parts smaller than min_cluster_size') lesser = Node(self.clustering, partition[0]) lesser.PCA() @@ -356,24 +357,28 @@ class Node: print('%s: already optimised' % self) return self.best_partition + error = self.reconstruction_error() + best_error = None bar = tqdm.trange(2 * len(self.colours), leave=False) for direction in [0, 1]: for i in range(len(self.colours)): + bar.update() + origin = self.colours.xy[i, direction] axis = PartitionAxis(origin, direction) try: - error, partition = self.partition_error(axis) + new_error, partition = self.partition_error(axis) except ClusteringError: continue - if best_error is None or error < best_error: - self.best_partition = (error, axis, partition) + if new_error >= error: + continue - delta = error - self.reconstruction_error() - bar.update() + if best_error is None or new_error < best_error: + self.best_partition = (new_error, axis, partition) bar.close() @@ -407,6 +412,29 @@ class Node: self._plot_colours([0]) plt.legend() + def visualise_pca(self): + """ + Makes a plot showing the principal components of this node and how + well they reconstruct the source data. + """ + + plt.subplot(2, 1, 1) + plt.title(str(self) + ': principal components') + for i in range(3): + plt.plot(self.wl, self.basis_functions[i, :], label='PC%d' % i) + plt.legend() + + plt.subplot(2, 1, 2) + plt.title(str(self) + ': data') + for i in range(3): + plt.plot(self.wl, self.colours.reflectances[i, :], 'C%d:' % i) + + XYZ = self.colours.XYZ[i, :] + recon = self.reconstruct(XYZ) + plt.plot(self.wl, recon.values, 'C%d-' % i) + recon = self.reconstruct(XYZ) + plt.plot(self.wl, recon.values, 'C%d--' % i) + class Clustering: """ @@ -505,7 +533,7 @@ class Clustering: node = search(self.root) return node.reconstruct(XYZ) - def optimise(self, repeats): + def optimise(self, repeats=8, min_cluster_size=None): """ Optimise the tree by repeatedly performing optimal partitions of the nodes, creating a clustering that minimizes the total reconstruction @@ -513,18 +541,35 @@ class Clustering: Parameters ---------- - repeats : int + repeats : int, optional Maximum number of splits. If the dataset is too small, this number - might not be reached. + might not be reached. The default is to create 8 clusters, like in + the original paper. + min_cluster_size : int, optional + Smallest acceptable cluster size. By default it's chosen + automatically, based on the size of the dataset and desired number + of clusters. Must be at least 3 or principal component analysis + will not be possible. """ + if min_cluster_size is not None: + self.min_cluster_size = min_cluster_size + else: + self.min_cluster_size = len(self.root.colours) / repeats // 2 + + if self.min_cluster_size <= 3: + self.min_cluster_size = 3 + t0 = time.time() def elapsed(): delta = time.time() - t0 return '%dm%.3fs' % (delta // 60, delta % 60) + initial_error = self.root.total_reconstruction_error() + print('Initial error is %g.' % initial_error) + for repeat in range(repeats): - print('=== Iteration %d of %d ===' % (repeat + 1, repeats)) + print('\n=== Iteration %d of %d ===' % (repeat + 1, repeats)) best_total_error = None total_error = self.root.total_reconstruction_error() @@ -534,8 +579,8 @@ class Clustering: try: error, axis, partition = leaf.find_best_partition() - except ClusteringError: - print('...no partitions are possible.') + except ClusteringError as e: + print('Failed: %s' % e) continue new_total_error = (total_error - leaf.reconstruction_error() @@ -548,11 +593,17 @@ class Clustering: best_partition = partition if best_total_error is None: - print('WARNING: only %d splits were possible.' % repeat) + print('\nNo further improvements are possible.\n' + 'Terminating at iteration %d.\n' % repeat) break - print('\nSplit %s into %s and %s along %s.\n' + print('\nSplit %s into %s and %s along %s.' % (best_leaf, *best_partition, best_axis)) + print('Error is reduced by %g and is now %g, ' + '%.1f%% of the initial error.' % (leaf.reconstruction_error() + - error, best_total_error, 100 * best_total_error + / initial_error)) + best_leaf.split(best_partition, best_axis) print('Finished in %s.' % elapsed()) |