diff options
-rw-r--r-- | demo.py | 4 | ||||
-rw-r--r-- | otsu2018.py | 241 |
2 files changed, 122 insertions, 123 deletions
@@ -8,7 +8,7 @@ from otsu2018 import load_Otsu2018_spectra, Clustering if __name__ == '__main__': print('Loading spectral data...') - sds = load_Otsu2018_spectra('CommonData/spectrum_m.csv', every_nth=10) + sds = load_Otsu2018_spectra('CommonData/spectrum_m.csv', every_nth=50) shape = SpectralShape(380, 730, 10) print('Initializing the clustering...') @@ -16,7 +16,7 @@ if __name__ == '__main__': print('Clustering...') before = clustering.root.total_reconstruction_error() - clustering.do_best_splits(8) + clustering.optimise(2) after = clustering.root.total_reconstruction_error() print('Error before: %g' % before) diff --git a/otsu2018.py b/otsu2018.py index 713da1d..7d3c1eb 100644 --- a/otsu2018.py +++ b/otsu2018.py @@ -36,6 +36,28 @@ def load_Otsu2018_spectra(path, every_nth=1): return np.array(spectra) +class PartitionAxis: + """ + Represents a horizontal or vertical line, partitioning the 2D space in + two half-planes. + + Attributes + ---------- + origin : float + The x coordinate of a vertical line or the y coordinate of a horizontal + line. + direction : int + '0' if vertical, '1' if horizontal. + """ + + def __init__(self, origin, direction): + self.origin = origin + self.direction = direction + + def __str__(self): + return '%s = %g' % ('yx'[self.direction], self.origin) + + class Colours: """ Represents multiple colours: their reflectances, XYZ tristimulus values @@ -75,14 +97,12 @@ class Colours: """ return self.reflectances.shape[0] - def partition(self, x_or_y, axis): + def partition(self, axis): """ Parameters ========== - x_or_y : int - Whether to split according to X or Y coordinates. - axis : float - The coordinate that defines where the split happens. + axis : PartitionAxis + Defines the partition axis. Returns ======= @@ -91,7 +111,7 @@ class Colours: greater : Colours The right or upper part. """ - mask = self.xy[:, x_or_y] <= axis + mask = self.xy[:, axis.direction] <= axis.origin lesser = object.__new__(Colours) greater = object.__new__(Colours) @@ -139,6 +159,7 @@ class Node: self._cached_reconstruction_error = None self.PCA_done = False + self.best_partition = None # This is just for __str__ and plots self.number = Node._counter @@ -156,7 +177,7 @@ class Node: return self.children is None - def split(self, children, split_x_or_y, split_i): + def split(self, children, partition_axis): """ Turns a leaf into a node with the given children. @@ -164,21 +185,15 @@ class Node: ========== children : tuple Two instances of ``Node`` in a tuple. - split_x_or_y : int - Split's ``x_or_y``. - split_axis : float - Split's ``axis``. + partition_axis : PartitionAxis + Defines the partition axis. """ - if not self.leaf: - raise RuntimeError( - 'Node.split called for a node that is not a leaf') - self.children = children - self.split_x_or_y = split_x_or_y - self.split_axis = self.colours.xy[split_i, split_x_or_y] + self.partition_axis = partition_axis self.colours = None self._cached_reconstruction_error = None + self.best_partition = None def _leaves_generator(self): if self.leaf: @@ -212,7 +227,7 @@ class Node: cov = np.cov(data.T) / data.shape[0] v, w = np.linalg.eig(cov) idx = v.argsort()[::-1] - w = w[:,idx] + w = w[:, idx] self.basis_functions = np.real(w[:, :3].T) # TODO: better names @@ -247,6 +262,10 @@ class Node: reflectance = np.clip(reflectance, 0, 1) return SpectralDistribution(reflectance, self.clustering.wl) + # + # Optimisation + # + def reconstruction_error(self): """ For every colour in this node, its spectrum is reconstructed (using @@ -292,48 +311,15 @@ class Node: return sum([child.total_reconstruction_error() for child in self.children]) - def partition(self, x_or_y, i): - """ - Splits this node into two and returns them. This operation does not - affect the node it's used on. ``Node.split`` has to be called (with - data returned from this method) to actually alter the tree. - - Parameters - ========== - x_or_y : int - Whether to split according to X or Y coordinates. - i : int - The index of the colour whose coordinates determine where the - split happens. Cannot be ``len(Node.colours)`` or greater. - - Returns - ======= - lesser : Node - The left or lower part. - greater : Node - The right or upper part. - """ - - axis = self.colours.xy[i, x_or_y] - partition = self.colours.partition(x_or_y, axis) - - if len(partition[0]) < 3 or len(partition[1]) < 3: - raise ClusteringError('partition created parts that are too small') - - lesser = Node(self.clustering, partition[0]) - greater = Node(self.clustering, partition[1]) - return lesser, greater - - def split_quality(self, x_or_y, i): + def partition_error(self, axis): """ + Compute the sum of reconstruction errors of the two nodes created by + a given partition of this node. Parameters ========== - x_or_y : int - Whether to split according to X or Y coordinates. - i : int - Index of the colour whose coordinates determine where the - split happens. Cannot be ``len(Node.colours)`` or greater. + axis : PartitionAxis + Defines the partition axis. Returns ======= @@ -343,14 +329,53 @@ class Node: lesser, greater : tuple Subnodes created from splitting. """ - lesser, greater = self.partition(x_or_y, i) + partition = self.colours.partition(axis) + + if len(partition[0]) < 3 or len(partition[1]) < 3: + raise ClusteringError( + 'partition created parts that are too small for PCA') + lesser = Node(self.clustering, partition[0]) lesser.PCA() + + greater = Node(self.clustering, partition[1]) greater.PCA() error = lesser.reconstruction_error() + greater.reconstruction_error() return error, (lesser, greater) + def find_best_partition(self): + """ + Finds the best partition of this node. See + ``Clustering.find_best_partition``. + """ + + if self.best_partition is not None: + return self.best_partition + + best_error = None + + for direction in [0, 1]: + for i in range(len(self.colours)): + origin = self.colours.xy[i, direction] + axis = PartitionAxis(origin, direction) + + try: + error, partition = self.partition_error(axis) + except ClusteringError: + continue + + if best_error is None or error < best_error: + self.best_partition = (error, axis, partition) + + delta = error - self.reconstruction_error() + print('%10s %3d %10s %g' % (self, i, axis, delta)) + + if self.best_partition is None: + raise ClusteringError('no partitions are possible') + + return self.best_partition + # # Plotting # @@ -466,7 +491,7 @@ class Clustering: if node.leaf: return node - if xy[node.split_x_or_y] <= node.split_a1xis: + if xy[node.partition_axis.direction] <= node.partition_axis.origin: return search(node.children[0]) else: return search(node.children[1]) @@ -474,75 +499,44 @@ class Clustering: node = search(self.root) return node.reconstruct(XYZ) - def find_best_split(self): + def optimise(self, repeats): """ - Check every possible split in the entire tree to find the one that will - reduce the error the most. - - Returns - ------- - best_split : (Node, int, int) - Tuple representing the best split found. It contains the ``Node`` - that should be split, the split direction (``x_or_y``) and - the colour index (``i``). - best_partition : (Node, Node) - Subnodes to be used as children for the leaf. - - Raises - ------ - ClusteringError - If the tree has already been split too finely, further splits will - not be possible and this exception will be raised. - """ - - best_new_error = None - total_error = self.root.total_reconstruction_error() - - for leaf in self.root.leaves: - total_error_minus_leaf = total_error - leaf.reconstruction_error() - - for x_or_y in [0, 1]: - for i in range(len(leaf.colours)): - try: - split_error, partition = leaf.split_quality(x_or_y, i) - except ClusteringError: - continue - - new_error = total_error_minus_leaf + split_error - - if best_new_error is None or new_error < best_new_error: - best_new_error = new_error - best_split = (leaf, x_or_y, i) - best_partition = partition - - print('%10s %s %4d %g' - % (leaf, ['x', 'y'][x_or_y], i, new_error)) - - if best_new_error is None: - raise ClusteringError('no more splits were possible') - - return best_split, best_partition - - def do_best_splits(self, repeats): - """ - Find the best split and perform it, and repeat the operation the - specified amount of times. + Optimise the tree by repeatedly performing optimal partitions of the + nodes, creating a clustering that minimizes the total reconstruction + error. Parameters ---------- repeats : int - Number of splits. + Maximum number of splits. If the dataset is too small, this number + might not be reached. """ for repeat in range(repeats): - try: - (leaf, x_or_y, i), partition = self.find_best_split() - except ClusteringError: + best_total_error = None + total_error = self.root.total_reconstruction_error() + + for leaf in self.root.leaves: + try: + error, axis, partition = leaf.find_best_partition() + except ClusteringError: + print('%s has no partitions' % leaf) + continue + + new_total_error = (total_error - leaf.reconstruction_error() + + error) + if (best_total_error is None + or new_total_error < best_total_error): + best_total_error = new_total_error + best_axis = axis + best_leaf = leaf + best_partition = partition + + if best_total_error is None: print('WARNING: only %d splits were possible' % repeat) break - print('==== Splitting %s, x_or_y=%d, i=%d ====' - % (leaf, x_or_y, i)) - leaf.split(partition, x_or_y, i) + print('==== Splitting %s along %s ====' % (best_leaf, best_axis)) + best_leaf.split(best_partition, best_axis) def write_python_dataset(self, path): """ @@ -568,7 +562,6 @@ class Clustering: fd.write('OTSU_2018_BASIS_FUNCTIONS = [\n') for i, leaf in enumerate(self.root.leaves): - leaf._i = i # For use when writing the selection function for line in (repr(leaf.basis_functions) + ',').splitlines(): fd.write(' %s\n' % line) fd.write(']\n\n\n') @@ -586,15 +579,21 @@ class Clustering: fd.write('def select_cluster_Otsu2018(xy):\n') fd.write(' x, y = xy\n\n') + counter = 0 + def write_if(node, indent): + nonlocal counter + if node.leaf: fd.write(' ' * indent) - fd.write('return %d # %s\n' % (node._i, node)) + fd.write('return %d # %s\n' % (counter, node)) + counter += 1 return fd.write(' ' * indent) - fd.write('if %s <= %s:\n' % (['x', 'y'][node.split_x_or_y], - repr(node.split_axis))) + fd.write('if %s <= %s:\n' + % ('xy'[node.partition_axis.direction], + repr(node.partition_axis.origin))) write_if(node.children[0], indent + 1) fd.write(' ' * indent) |