summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPaweł Redman <pawel.redman@gmail.com>2020-07-21 12:58:54 +0200
committerPaweł Redman <pawel.redman@gmail.com>2020-07-21 12:58:54 +0200
commitc5c6b4f2f2640319a51924349055678e2b924514 (patch)
treeda7c77caede944b8287646a138523010419e10dc
parent7630de3416f5e8a9b44def854600f61483388d1b (diff)
Major refactoring.
Better and more consistent naming. Some functions that were used only once were factored in.
-rw-r--r--demo.py4
-rw-r--r--otsu2018.py241
2 files changed, 122 insertions, 123 deletions
diff --git a/demo.py b/demo.py
index a15fa72..86b1c61 100644
--- a/demo.py
+++ b/demo.py
@@ -8,7 +8,7 @@ from otsu2018 import load_Otsu2018_spectra, Clustering
if __name__ == '__main__':
print('Loading spectral data...')
- sds = load_Otsu2018_spectra('CommonData/spectrum_m.csv', every_nth=10)
+ sds = load_Otsu2018_spectra('CommonData/spectrum_m.csv', every_nth=50)
shape = SpectralShape(380, 730, 10)
print('Initializing the clustering...')
@@ -16,7 +16,7 @@ if __name__ == '__main__':
print('Clustering...')
before = clustering.root.total_reconstruction_error()
- clustering.do_best_splits(8)
+ clustering.optimise(2)
after = clustering.root.total_reconstruction_error()
print('Error before: %g' % before)
diff --git a/otsu2018.py b/otsu2018.py
index 713da1d..7d3c1eb 100644
--- a/otsu2018.py
+++ b/otsu2018.py
@@ -36,6 +36,28 @@ def load_Otsu2018_spectra(path, every_nth=1):
return np.array(spectra)
+class PartitionAxis:
+ """
+ Represents a horizontal or vertical line, partitioning the 2D space in
+ two half-planes.
+
+ Attributes
+ ----------
+ origin : float
+ The x coordinate of a vertical line or the y coordinate of a horizontal
+ line.
+ direction : int
+ '0' if vertical, '1' if horizontal.
+ """
+
+ def __init__(self, origin, direction):
+ self.origin = origin
+ self.direction = direction
+
+ def __str__(self):
+ return '%s = %g' % ('yx'[self.direction], self.origin)
+
+
class Colours:
"""
Represents multiple colours: their reflectances, XYZ tristimulus values
@@ -75,14 +97,12 @@ class Colours:
"""
return self.reflectances.shape[0]
- def partition(self, x_or_y, axis):
+ def partition(self, axis):
"""
Parameters
==========
- x_or_y : int
- Whether to split according to X or Y coordinates.
- axis : float
- The coordinate that defines where the split happens.
+ axis : PartitionAxis
+ Defines the partition axis.
Returns
=======
@@ -91,7 +111,7 @@ class Colours:
greater : Colours
The right or upper part.
"""
- mask = self.xy[:, x_or_y] <= axis
+ mask = self.xy[:, axis.direction] <= axis.origin
lesser = object.__new__(Colours)
greater = object.__new__(Colours)
@@ -139,6 +159,7 @@ class Node:
self._cached_reconstruction_error = None
self.PCA_done = False
+ self.best_partition = None
# This is just for __str__ and plots
self.number = Node._counter
@@ -156,7 +177,7 @@ class Node:
return self.children is None
- def split(self, children, split_x_or_y, split_i):
+ def split(self, children, partition_axis):
"""
Turns a leaf into a node with the given children.
@@ -164,21 +185,15 @@ class Node:
==========
children : tuple
Two instances of ``Node`` in a tuple.
- split_x_or_y : int
- Split's ``x_or_y``.
- split_axis : float
- Split's ``axis``.
+ partition_axis : PartitionAxis
+ Defines the partition axis.
"""
- if not self.leaf:
- raise RuntimeError(
- 'Node.split called for a node that is not a leaf')
-
self.children = children
- self.split_x_or_y = split_x_or_y
- self.split_axis = self.colours.xy[split_i, split_x_or_y]
+ self.partition_axis = partition_axis
self.colours = None
self._cached_reconstruction_error = None
+ self.best_partition = None
def _leaves_generator(self):
if self.leaf:
@@ -212,7 +227,7 @@ class Node:
cov = np.cov(data.T) / data.shape[0]
v, w = np.linalg.eig(cov)
idx = v.argsort()[::-1]
- w = w[:,idx]
+ w = w[:, idx]
self.basis_functions = np.real(w[:, :3].T)
# TODO: better names
@@ -247,6 +262,10 @@ class Node:
reflectance = np.clip(reflectance, 0, 1)
return SpectralDistribution(reflectance, self.clustering.wl)
+ #
+ # Optimisation
+ #
+
def reconstruction_error(self):
"""
For every colour in this node, its spectrum is reconstructed (using
@@ -292,48 +311,15 @@ class Node:
return sum([child.total_reconstruction_error()
for child in self.children])
- def partition(self, x_or_y, i):
- """
- Splits this node into two and returns them. This operation does not
- affect the node it's used on. ``Node.split`` has to be called (with
- data returned from this method) to actually alter the tree.
-
- Parameters
- ==========
- x_or_y : int
- Whether to split according to X or Y coordinates.
- i : int
- The index of the colour whose coordinates determine where the
- split happens. Cannot be ``len(Node.colours)`` or greater.
-
- Returns
- =======
- lesser : Node
- The left or lower part.
- greater : Node
- The right or upper part.
- """
-
- axis = self.colours.xy[i, x_or_y]
- partition = self.colours.partition(x_or_y, axis)
-
- if len(partition[0]) < 3 or len(partition[1]) < 3:
- raise ClusteringError('partition created parts that are too small')
-
- lesser = Node(self.clustering, partition[0])
- greater = Node(self.clustering, partition[1])
- return lesser, greater
-
- def split_quality(self, x_or_y, i):
+ def partition_error(self, axis):
"""
+ Compute the sum of reconstruction errors of the two nodes created by
+ a given partition of this node.
Parameters
==========
- x_or_y : int
- Whether to split according to X or Y coordinates.
- i : int
- Index of the colour whose coordinates determine where the
- split happens. Cannot be ``len(Node.colours)`` or greater.
+ axis : PartitionAxis
+ Defines the partition axis.
Returns
=======
@@ -343,14 +329,53 @@ class Node:
lesser, greater : tuple
Subnodes created from splitting.
"""
- lesser, greater = self.partition(x_or_y, i)
+ partition = self.colours.partition(axis)
+
+ if len(partition[0]) < 3 or len(partition[1]) < 3:
+ raise ClusteringError(
+ 'partition created parts that are too small for PCA')
+ lesser = Node(self.clustering, partition[0])
lesser.PCA()
+
+ greater = Node(self.clustering, partition[1])
greater.PCA()
error = lesser.reconstruction_error() + greater.reconstruction_error()
return error, (lesser, greater)
+ def find_best_partition(self):
+ """
+ Finds the best partition of this node. See
+ ``Clustering.find_best_partition``.
+ """
+
+ if self.best_partition is not None:
+ return self.best_partition
+
+ best_error = None
+
+ for direction in [0, 1]:
+ for i in range(len(self.colours)):
+ origin = self.colours.xy[i, direction]
+ axis = PartitionAxis(origin, direction)
+
+ try:
+ error, partition = self.partition_error(axis)
+ except ClusteringError:
+ continue
+
+ if best_error is None or error < best_error:
+ self.best_partition = (error, axis, partition)
+
+ delta = error - self.reconstruction_error()
+ print('%10s %3d %10s %g' % (self, i, axis, delta))
+
+ if self.best_partition is None:
+ raise ClusteringError('no partitions are possible')
+
+ return self.best_partition
+
#
# Plotting
#
@@ -466,7 +491,7 @@ class Clustering:
if node.leaf:
return node
- if xy[node.split_x_or_y] <= node.split_a1xis:
+ if xy[node.partition_axis.direction] <= node.partition_axis.origin:
return search(node.children[0])
else:
return search(node.children[1])
@@ -474,75 +499,44 @@ class Clustering:
node = search(self.root)
return node.reconstruct(XYZ)
- def find_best_split(self):
+ def optimise(self, repeats):
"""
- Check every possible split in the entire tree to find the one that will
- reduce the error the most.
-
- Returns
- -------
- best_split : (Node, int, int)
- Tuple representing the best split found. It contains the ``Node``
- that should be split, the split direction (``x_or_y``) and
- the colour index (``i``).
- best_partition : (Node, Node)
- Subnodes to be used as children for the leaf.
-
- Raises
- ------
- ClusteringError
- If the tree has already been split too finely, further splits will
- not be possible and this exception will be raised.
- """
-
- best_new_error = None
- total_error = self.root.total_reconstruction_error()
-
- for leaf in self.root.leaves:
- total_error_minus_leaf = total_error - leaf.reconstruction_error()
-
- for x_or_y in [0, 1]:
- for i in range(len(leaf.colours)):
- try:
- split_error, partition = leaf.split_quality(x_or_y, i)
- except ClusteringError:
- continue
-
- new_error = total_error_minus_leaf + split_error
-
- if best_new_error is None or new_error < best_new_error:
- best_new_error = new_error
- best_split = (leaf, x_or_y, i)
- best_partition = partition
-
- print('%10s %s %4d %g'
- % (leaf, ['x', 'y'][x_or_y], i, new_error))
-
- if best_new_error is None:
- raise ClusteringError('no more splits were possible')
-
- return best_split, best_partition
-
- def do_best_splits(self, repeats):
- """
- Find the best split and perform it, and repeat the operation the
- specified amount of times.
+ Optimise the tree by repeatedly performing optimal partitions of the
+ nodes, creating a clustering that minimizes the total reconstruction
+ error.
Parameters
----------
repeats : int
- Number of splits.
+ Maximum number of splits. If the dataset is too small, this number
+ might not be reached.
"""
for repeat in range(repeats):
- try:
- (leaf, x_or_y, i), partition = self.find_best_split()
- except ClusteringError:
+ best_total_error = None
+ total_error = self.root.total_reconstruction_error()
+
+ for leaf in self.root.leaves:
+ try:
+ error, axis, partition = leaf.find_best_partition()
+ except ClusteringError:
+ print('%s has no partitions' % leaf)
+ continue
+
+ new_total_error = (total_error - leaf.reconstruction_error()
+ + error)
+ if (best_total_error is None
+ or new_total_error < best_total_error):
+ best_total_error = new_total_error
+ best_axis = axis
+ best_leaf = leaf
+ best_partition = partition
+
+ if best_total_error is None:
print('WARNING: only %d splits were possible' % repeat)
break
- print('==== Splitting %s, x_or_y=%d, i=%d ===='
- % (leaf, x_or_y, i))
- leaf.split(partition, x_or_y, i)
+ print('==== Splitting %s along %s ====' % (best_leaf, best_axis))
+ best_leaf.split(best_partition, best_axis)
def write_python_dataset(self, path):
"""
@@ -568,7 +562,6 @@ class Clustering:
fd.write('OTSU_2018_BASIS_FUNCTIONS = [\n')
for i, leaf in enumerate(self.root.leaves):
- leaf._i = i # For use when writing the selection function
for line in (repr(leaf.basis_functions) + ',').splitlines():
fd.write(' %s\n' % line)
fd.write(']\n\n\n')
@@ -586,15 +579,21 @@ class Clustering:
fd.write('def select_cluster_Otsu2018(xy):\n')
fd.write(' x, y = xy\n\n')
+ counter = 0
+
def write_if(node, indent):
+ nonlocal counter
+
if node.leaf:
fd.write(' ' * indent)
- fd.write('return %d # %s\n' % (node._i, node))
+ fd.write('return %d # %s\n' % (counter, node))
+ counter += 1
return
fd.write(' ' * indent)
- fd.write('if %s <= %s:\n' % (['x', 'y'][node.split_x_or_y],
- repr(node.split_axis)))
+ fd.write('if %s <= %s:\n'
+ % ('xy'[node.partition_axis.direction],
+ repr(node.partition_axis.origin)))
write_if(node.children[0], indent + 1)
fd.write(' ' * indent)