summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPaweł Redman <pawel.redman@gmail.com>2020-07-25 15:53:29 +0200
committerPaweł Redman <pawel.redman@gmail.com>2020-07-25 16:42:01 +0200
commit26c2bcbb77f9d2914be35e74e14333bc3e729c3e (patch)
tree00b2bb3104bd8cf11dc185d9a17c26348610290f
parent9c0dfa59598968ab5be656aac5c4f6f79c50dd61 (diff)
Rename Clustering to Tree and make it a subclass of Node.
-rw-r--r--demo.py25
-rw-r--r--otsu2018.py146
2 files changed, 76 insertions, 95 deletions
diff --git a/demo.py b/demo.py
index eb02202..3c4b933 100644
--- a/demo.py
+++ b/demo.py
@@ -1,34 +1,33 @@
import os
import matplotlib.pyplot as plt
-from colour import (
- SpectralShape, COLOURCHECKER_SDS, ILLUMINANT_SDS, sd_to_XYZ)
+from colour import SpectralShape, COLOURCHECKER_SDS, sd_to_XYZ
-from otsu2018 import load_Otsu2018_spectra, Clustering
+from otsu2018 import load_Otsu2018_spectra, Tree
if __name__ == '__main__':
print('Loading spectral data...')
- sds = load_Otsu2018_spectra('CommonData/spectrum_m.csv', every_nth=1)
+ sds = load_Otsu2018_spectra('CommonData/spectrum_m.csv', every_nth=50)
shape = SpectralShape(380, 730, 10)
- print('Initializing the clustering...')
- clustering = Clustering(sds, shape)
+ print('Initializing the tree...')
+ tree = Tree(sds, shape)
- print('Clustering...')
- before = clustering.root.total_reconstruction_error()
- clustering.optimise()
- after = clustering.root.total_reconstruction_error()
+ print('Tree...')
+ before = tree.total_reconstruction_error()
+ tree.optimise()
+ after = tree.total_reconstruction_error()
print('Error before: %g' % before)
print('Error after: %g' % after)
print('Saving the dataset...')
os.makedirs('datasets', exist_ok=True)
- clustering.write_python_dataset('datasets/otsu2018.py')
+ tree.write_python_dataset('datasets/otsu2018.py')
print('Plotting...')
- clustering.root.visualise()
+ tree.visualise()
plt.figure()
@@ -40,7 +39,7 @@ if __name__ == '__main__':
plt.plot(sd.wavelengths, sd.values, label='Original')
XYZ = sd_to_XYZ(sd) / 100
- recovered_sd = clustering.reconstruct(XYZ)
+ recovered_sd = tree.reconstruct(XYZ)
plt.plot(recovered_sd.wavelengths, recovered_sd.values,
label='Recovered')
diff --git a/otsu2018.py b/otsu2018.py
index b17c518..f573b17 100644
--- a/otsu2018.py
+++ b/otsu2018.py
@@ -64,22 +64,22 @@ class Colours:
"""
Represents multiple colours: their reflectances, XYZ tristimulus values
and xy coordinates. The cmfs and the illuminant are taken from the parent
- Clustering.
+ tree.
This class also supports partitioning, or creating two smaller instances
of Colours, split along a horizontal or a vertical axis on the xy plane.
"""
- def __init__(self, clustering, reflectances):
+ def __init__(self, tree, reflectances):
"""
Parameters
==========
- clustering : Clustering
- The parent clustering. This determines what cmfs and illuminant
+ tree : tree
+ The parent tree. This determines what cmfs and illuminant
are used in colourimetric calculations.
reflectances : ndarray (n,m)
Reflectances of the ``n`` colours to be stored in this class.
- The shape must match ``Clustering.shape`` with ``m`` points for
+ The shape must match ``tree.shape`` with ``m`` points for
each colour.
"""
@@ -88,8 +88,8 @@ class Colours:
self.xy = np.empty((reflectances.shape[0], 2))
for i in range(len(self)):
- sd = SpectralDistribution(reflectances[i, :], clustering.wl)
- XYZ = sd_to_XYZ(sd, illuminant=clustering.illuminant) / 100
+ sd = SpectralDistribution(reflectances[i, :], tree.wl)
+ XYZ = sd_to_XYZ(sd, illuminant=tree.illuminant) / 100
self.XYZ[i, :] = XYZ
self.xy[i, :] = XYZ_to_xy(XYZ)
@@ -130,7 +130,7 @@ class Colours:
return lesser, greater
-class ClusteringError(Exception):
+class treeError(Exception):
"""
Exception used for various errors originating from code in this file.
"""
@@ -139,23 +139,23 @@ class ClusteringError(Exception):
class Node:
"""
- Represents a node in the clustering tree.
+ Represents a node in the tree tree.
"""
_counter = 1
- def __init__(self, clustering, colours):
+ def __init__(self, tree, colours):
"""
Parameters
==========
- clustering : Clustering
- The parent clustering. This determines what cmfs and illuminant
+ tree : tree
+ The parent tree. This determines what cmfs and illuminant
are used in colourimetric calculations.
colours : Colours
The colours that belong in this node.
"""
- self.clustering = clustering
+ self.tree = tree
self.colours = colours
self.children = None
@@ -233,16 +233,30 @@ class Node:
M = np.empty((3, 3))
for i in range(3):
R = self.basis_functions[i, :]
- M[:, i] = self.clustering.fast_sd_to_XYZ(R)
+ M[:, i] = self.tree.fast_sd_to_XYZ(R)
self.M_inverse = np.linalg.inv(M)
- self.XYZ_mu = self.clustering.fast_sd_to_XYZ(self.mean)
+ self.XYZ_mu = self.tree.fast_sd_to_XYZ(self.mean)
self.PCA_done = True
+ def _reconstruct_xy(self, XYZ, xy):
+ if not self.leaf:
+ if xy[self.partition_axis.direction] <= self.partition_axis.origin:
+ return self.children[0]._reconstruct_xy(XYZ, xy)
+ else:
+ return self.children[1]._reconstruct_xy(XYZ, xy)
+
+ weights = np.dot(self.M_inverse, XYZ - self.XYZ_mu)
+ reflectance = np.dot(weights, self.basis_functions) + self.mean
+ reflectance = np.clip(reflectance, 0, 1)
+ return SpectralDistribution(reflectance, self.tree.wl)
+
def reconstruct(self, XYZ):
"""
- Reconstructs a reflectance using data stored in this node.
+ Reconstructs the reflectance for the given *XYZ* tristimulus values.
+ If this is a leaf, data from this node will be used. Otherwise the
+ code will look for the appropriate subnode.
Parameters
==========
@@ -256,10 +270,8 @@ class Node:
Recovered spectral distribution.
"""
- weights = np.dot(self.M_inverse, XYZ - self.XYZ_mu)
- reflectance = np.dot(weights, self.basis_functions) + self.mean
- reflectance = np.clip(reflectance, 0, 1)
- return SpectralDistribution(reflectance, self.clustering.wl)
+ xy = XYZ_to_xy(XYZ)
+ return self._reconstruct_xy(XYZ, xy)
#
# Optimisation
@@ -330,15 +342,15 @@ class Node:
"""
partition = self.colours.partition(axis)
- if (len(partition[0]) < self.clustering.min_cluster_size
- or len(partition[1]) < self.clustering.min_cluster_size):
- raise ClusteringError(
+ if (len(partition[0]) < self.tree.min_cluster_size
+ or len(partition[1]) < self.tree.min_cluster_size):
+ raise treeError(
'partition created parts smaller than min_cluster_size')
- lesser = Node(self.clustering, partition[0])
+ lesser = Node(self.tree, partition[0])
lesser.PCA()
- greater = Node(self.clustering, partition[1])
+ greater = Node(self.tree, partition[1])
greater.PCA()
error = lesser.reconstruction_error() + greater.reconstruction_error()
@@ -347,7 +359,7 @@ class Node:
def find_best_partition(self):
"""
Finds the best partition of this node. See
- ``Clustering.find_best_partition``.
+ ``tree.find_best_partition``.
"""
if self.best_partition is not None:
@@ -368,7 +380,7 @@ class Node:
try:
new_error, partition = self.partition_error(axis)
- except ClusteringError:
+ except treeError:
continue
if new_error >= error:
@@ -380,7 +392,7 @@ class Node:
bar.close()
if self.best_partition is None:
- raise ClusteringError('no partitions are possible')
+ raise treeError('no partitions are possible')
return self.best_partition
@@ -414,7 +426,7 @@ class Node:
Makes a plot showing the principal components of this node and how
well they reconstruct the source data.
"""
-
+
plt.subplot(2, 1, 1)
plt.title(str(self) + ': principal components')
for i in range(3):
@@ -433,13 +445,14 @@ class Node:
plt.plot(self.wl, recon.values, 'C%d--' % i)
-class Clustering:
+class Tree(Node):
"""
- Represents the process of clustering and optimisation. Instances store
- shared data such as cmfs and the illuminant.
+ This is an extension of ``Node``. It's meant to represent the root of the
+ tree and contains information shared with all the nodes, such as cmfs
+ and the illuminant (if any is used).
Operations involving the entire tree, such as optimisation and
- reconstruction, are implemented here.
+ reconstruction, are also implemented here.
"""
def __init__(
@@ -462,7 +475,6 @@ class Clustering:
Illuminant spectral distribution.
"""
- self.sds = sds
self.shape = shape
self.wl = shape.range()
self.dw = self.wl[1] - self.wl[0]
@@ -474,13 +486,12 @@ class Clustering:
self.k = 1 / (np.sum(self.cmfs.values[:, 1]
* self.illuminant.values) * self.dw)
- colours = Colours(self, sds)
- self.root = Node(self, colours)
+ super().__init__(self, Colours(self, sds))
def fast_sd_to_XYZ(self, R):
"""
Compute the XYZ tristimulus values of a given reflectance. Faster for
- humans, by using cmfs and the illuminant stored in the ''Clustering'',
+ humans, by using cmfs and the illuminant stored in the ''tree'',
thus avoiding unnecessary repetition. Faster for computers, by using
a very simple and direct method.
@@ -488,7 +499,7 @@ class Clustering:
----------
R : ndarray
Reflectance with shape matching the one used to construct this
- ``Clustering``.
+ ``tree``.
Returns
-------
@@ -499,41 +510,10 @@ class Clustering:
E = self.illuminant.values * R
return self.k * np.dot(E, self.cmfs.values) * self.dw
- def reconstruct(self, XYZ):
- """
- Finds the appropriate node and reconstructs the reflectance for the
- given XYZ tristimulus values.
-
- Parameters
- ==========
- XYZ : ndarray, (3,)
- *CIE XYZ* tristimulus values to recover the spectral distribution
- from.
-
- Returns
- -------
- SpectralDistribution
- Recovered spectral distribution.
- """
-
- xy = XYZ_to_xy(XYZ)
-
- def search(node):
- if node.leaf:
- return node
-
- if xy[node.partition_axis.direction] <= node.partition_axis.origin:
- return search(node.children[0])
- else:
- return search(node.children[1])
-
- node = search(self.root)
- return node.reconstruct(XYZ)
-
def optimise(self, repeats=8, min_cluster_size=None):
"""
Optimise the tree by repeatedly performing optimal partitions of the
- nodes, creating a clustering that minimizes the total reconstruction
+ nodes, creating a tree that minimizes the total reconstruction
error.
Parameters
@@ -552,31 +532,32 @@ class Clustering:
if min_cluster_size is not None:
self.min_cluster_size = min_cluster_size
else:
- self.min_cluster_size = len(self.root.colours) / repeats // 2
+ self.min_cluster_size = len(self.colours) / repeats // 2
if self.min_cluster_size <= 3:
self.min_cluster_size = 3
t0 = time.time()
+
def elapsed():
delta = time.time() - t0
return '%dm%.3fs' % (delta // 60, delta % 60)
- initial_error = self.root.total_reconstruction_error()
+ initial_error = self.total_reconstruction_error()
print('Initial error is %g.' % initial_error)
for repeat in range(repeats):
print('\n=== Iteration %d of %d ===' % (repeat + 1, repeats))
best_total_error = None
- total_error = self.root.total_reconstruction_error()
+ total_error = self.total_reconstruction_error()
- for i, leaf in enumerate(self.root.leaves):
+ for i, leaf in enumerate(self.leaves):
print('(%s) Optimising %s...' % (elapsed(), leaf))
try:
error, axis, partition = leaf.find_best_partition()
- except ClusteringError as e:
+ except treeError as e:
print('Failed: %s' % e)
continue
@@ -597,9 +578,10 @@ class Clustering:
print('\nSplit %s into %s and %s along %s.'
% (best_leaf, *best_partition, best_axis))
print('Error is reduced by %g and is now %g, '
- '%.1f%% of the initial error.' % (leaf.reconstruction_error()
- - error, best_total_error, 100 * best_total_error
- / initial_error))
+ '%.1f%% of the initial error.'
+ % (leaf.reconstruction_error()
+ - error, best_total_error, 100 * best_total_error
+ / initial_error))
best_leaf.split(best_partition, best_axis)
@@ -607,7 +589,7 @@ class Clustering:
def write_python_dataset(self, path):
"""
- Write the clustering into a Python dataset compatible with Colour's
+ Write the tree into a Python dataset compatible with Colour's
``colour.recovery.otsu2018`` code.
Parameters
@@ -628,7 +610,7 @@ class Clustering:
# Basis functions
fd.write('OTSU_2018_BASIS_FUNCTIONS = [\n')
- for i, leaf in enumerate(self.root.leaves):
+ for i, leaf in enumerate(self.leaves):
for line in (repr(leaf.basis_functions) + ',').splitlines():
fd.write(' %s\n' % line)
fd.write(']\n\n\n')
@@ -636,7 +618,7 @@ class Clustering:
# Means
fd.write('OTSU_2018_MEANS = [\n')
- for leaf in self.root.leaves:
+ for leaf in self.leaves:
for line in (repr(leaf.mean) + ',').splitlines():
fd.write(' %s\n' % line)
fd.write(']\n\n\n')
@@ -667,4 +649,4 @@ class Clustering:
fd.write('else:\n')
write_if(node.children[1], indent + 1)
- write_if(self.root, 1)
+ write_if(self, 1)