summaryrefslogtreecommitdiff
path: root/otsu2018.py
diff options
context:
space:
mode:
Diffstat (limited to 'otsu2018.py')
-rw-r--r--otsu2018.py85
1 files changed, 68 insertions, 17 deletions
diff --git a/otsu2018.py b/otsu2018.py
index e188950..88ab261 100644
--- a/otsu2018.py
+++ b/otsu2018.py
@@ -168,7 +168,7 @@ class Node:
Node._counter += 1
def __str__(self):
- return 'Node #%d' % (self.number)
+ return 'Node #%d (%d)' % (self.number, len(self.colours))
@property
def leaf(self):
@@ -283,7 +283,7 @@ class Node:
if self._cached_reconstruction_error:
return self._cached_reconstruction_error
- if not self.PCA_done: # FIXME
+ if not self.PCA_done:
self.PCA()
error = 0
@@ -333,9 +333,10 @@ class Node:
"""
partition = self.colours.partition(axis)
- if len(partition[0]) < 3 or len(partition[1]) < 3:
+ if (len(partition[0]) < self.clustering.min_cluster_size
+ or len(partition[1]) < self.clustering.min_cluster_size):
raise ClusteringError(
- 'partition created parts that are too small for PCA')
+ 'partition created parts smaller than min_cluster_size')
lesser = Node(self.clustering, partition[0])
lesser.PCA()
@@ -356,24 +357,28 @@ class Node:
print('%s: already optimised' % self)
return self.best_partition
+ error = self.reconstruction_error()
+
best_error = None
bar = tqdm.trange(2 * len(self.colours), leave=False)
for direction in [0, 1]:
for i in range(len(self.colours)):
+ bar.update()
+
origin = self.colours.xy[i, direction]
axis = PartitionAxis(origin, direction)
try:
- error, partition = self.partition_error(axis)
+ new_error, partition = self.partition_error(axis)
except ClusteringError:
continue
- if best_error is None or error < best_error:
- self.best_partition = (error, axis, partition)
+ if new_error >= error:
+ continue
- delta = error - self.reconstruction_error()
- bar.update()
+ if best_error is None or new_error < best_error:
+ self.best_partition = (new_error, axis, partition)
bar.close()
@@ -407,6 +412,29 @@ class Node:
self._plot_colours([0])
plt.legend()
+ def visualise_pca(self):
+ """
+ Makes a plot showing the principal components of this node and how
+ well they reconstruct the source data.
+ """
+
+ plt.subplot(2, 1, 1)
+ plt.title(str(self) + ': principal components')
+ for i in range(3):
+ plt.plot(self.wl, self.basis_functions[i, :], label='PC%d' % i)
+ plt.legend()
+
+ plt.subplot(2, 1, 2)
+ plt.title(str(self) + ': data')
+ for i in range(3):
+ plt.plot(self.wl, self.colours.reflectances[i, :], 'C%d:' % i)
+
+ XYZ = self.colours.XYZ[i, :]
+ recon = self.reconstruct(XYZ)
+ plt.plot(self.wl, recon.values, 'C%d-' % i)
+ recon = self.reconstruct(XYZ)
+ plt.plot(self.wl, recon.values, 'C%d--' % i)
+
class Clustering:
"""
@@ -505,7 +533,7 @@ class Clustering:
node = search(self.root)
return node.reconstruct(XYZ)
- def optimise(self, repeats):
+ def optimise(self, repeats=8, min_cluster_size=None):
"""
Optimise the tree by repeatedly performing optimal partitions of the
nodes, creating a clustering that minimizes the total reconstruction
@@ -513,18 +541,35 @@ class Clustering:
Parameters
----------
- repeats : int
+ repeats : int, optional
Maximum number of splits. If the dataset is too small, this number
- might not be reached.
+ might not be reached. The default is to create 8 clusters, like in
+ the original paper.
+ min_cluster_size : int, optional
+ Smallest acceptable cluster size. By default it's chosen
+ automatically, based on the size of the dataset and desired number
+ of clusters. Must be at least 3 or principal component analysis
+ will not be possible.
"""
+ if min_cluster_size is not None:
+ self.min_cluster_size = min_cluster_size
+ else:
+ self.min_cluster_size = len(self.root.colours) / repeats // 2
+
+ if self.min_cluster_size <= 3:
+ self.min_cluster_size = 3
+
t0 = time.time()
def elapsed():
delta = time.time() - t0
return '%dm%.3fs' % (delta // 60, delta % 60)
+ initial_error = self.root.total_reconstruction_error()
+ print('Initial error is %g.' % initial_error)
+
for repeat in range(repeats):
- print('=== Iteration %d of %d ===' % (repeat + 1, repeats))
+ print('\n=== Iteration %d of %d ===' % (repeat + 1, repeats))
best_total_error = None
total_error = self.root.total_reconstruction_error()
@@ -534,8 +579,8 @@ class Clustering:
try:
error, axis, partition = leaf.find_best_partition()
- except ClusteringError:
- print('...no partitions are possible.')
+ except ClusteringError as e:
+ print('Failed: %s' % e)
continue
new_total_error = (total_error - leaf.reconstruction_error()
@@ -548,11 +593,17 @@ class Clustering:
best_partition = partition
if best_total_error is None:
- print('WARNING: only %d splits were possible.' % repeat)
+ print('\nNo further improvements are possible.\n'
+ 'Terminating at iteration %d.\n' % repeat)
break
- print('\nSplit %s into %s and %s along %s.\n'
+ print('\nSplit %s into %s and %s along %s.'
% (best_leaf, *best_partition, best_axis))
+ print('Error is reduced by %g and is now %g, '
+ '%.1f%% of the initial error.' % (leaf.reconstruction_error()
+ - error, best_total_error, 100 * best_total_error
+ / initial_error))
+
best_leaf.split(best_partition, best_axis)
print('Finished in %s.' % elapsed())