diff options
| -rw-r--r-- | demo.py | 25 | ||||
| -rw-r--r-- | otsu2018.py | 146 | 
2 files changed, 76 insertions, 95 deletions
@@ -1,34 +1,33 @@  import os  import matplotlib.pyplot as plt -from colour import ( -    SpectralShape, COLOURCHECKER_SDS, ILLUMINANT_SDS, sd_to_XYZ) +from colour import SpectralShape, COLOURCHECKER_SDS, sd_to_XYZ -from otsu2018 import load_Otsu2018_spectra, Clustering +from otsu2018 import load_Otsu2018_spectra, Tree  if __name__ == '__main__':      print('Loading spectral data...') -    sds = load_Otsu2018_spectra('CommonData/spectrum_m.csv', every_nth=1) +    sds = load_Otsu2018_spectra('CommonData/spectrum_m.csv', every_nth=50)      shape = SpectralShape(380, 730, 10) -    print('Initializing the clustering...') -    clustering = Clustering(sds, shape) +    print('Initializing the tree...') +    tree = Tree(sds, shape) -    print('Clustering...') -    before = clustering.root.total_reconstruction_error() -    clustering.optimise() -    after = clustering.root.total_reconstruction_error() +    print('Tree...') +    before = tree.total_reconstruction_error() +    tree.optimise() +    after = tree.total_reconstruction_error()      print('Error before: %g' % before)      print('Error after:  %g' % after)      print('Saving the dataset...')      os.makedirs('datasets', exist_ok=True) -    clustering.write_python_dataset('datasets/otsu2018.py') +    tree.write_python_dataset('datasets/otsu2018.py')      print('Plotting...') -    clustering.root.visualise() +    tree.visualise()      plt.figure() @@ -40,7 +39,7 @@ if __name__ == '__main__':          plt.plot(sd.wavelengths, sd.values, label='Original')          XYZ = sd_to_XYZ(sd) / 100 -        recovered_sd = clustering.reconstruct(XYZ) +        recovered_sd = tree.reconstruct(XYZ)          plt.plot(recovered_sd.wavelengths, recovered_sd.values,                   label='Recovered') diff --git a/otsu2018.py b/otsu2018.py index b17c518..f573b17 100644 --- a/otsu2018.py +++ b/otsu2018.py @@ -64,22 +64,22 @@ class Colours:      """      Represents multiple colours: their reflectances, XYZ tristimulus values      and xy coordinates. The cmfs and the illuminant are taken from the parent -    Clustering. +    tree.      This class also supports partitioning, or creating two smaller instances      of Colours, split along a horizontal or a vertical axis on the xy plane.      """ -    def __init__(self, clustering, reflectances): +    def __init__(self, tree, reflectances):          """          Parameters          ========== -        clustering : Clustering -            The parent clustering. This determines what cmfs and illuminant +        tree : tree +            The parent tree. This determines what cmfs and illuminant              are used in colourimetric calculations.          reflectances : ndarray (n,m)              Reflectances of the ``n`` colours to be stored in this class. -            The shape must match ``Clustering.shape`` with ``m`` points for +            The shape must match ``tree.shape`` with ``m`` points for              each colour.          """ @@ -88,8 +88,8 @@ class Colours:          self.xy = np.empty((reflectances.shape[0], 2))          for i in range(len(self)): -            sd = SpectralDistribution(reflectances[i, :], clustering.wl) -            XYZ = sd_to_XYZ(sd, illuminant=clustering.illuminant) / 100 +            sd = SpectralDistribution(reflectances[i, :], tree.wl) +            XYZ = sd_to_XYZ(sd, illuminant=tree.illuminant) / 100              self.XYZ[i, :] = XYZ              self.xy[i, :] = XYZ_to_xy(XYZ) @@ -130,7 +130,7 @@ class Colours:          return lesser, greater -class ClusteringError(Exception): +class treeError(Exception):      """      Exception used for various errors originating from code in this file.      """ @@ -139,23 +139,23 @@ class ClusteringError(Exception):  class Node:      """ -    Represents a node in the clustering tree. +    Represents a node in the tree tree.      """      _counter = 1 -    def __init__(self, clustering, colours): +    def __init__(self, tree, colours):          """          Parameters          ========== -        clustering : Clustering -            The parent clustering. This determines what cmfs and illuminant +        tree : tree +            The parent tree. This determines what cmfs and illuminant              are used in colourimetric calculations.          colours : Colours              The colours that belong in this node.          """ -        self.clustering = clustering +        self.tree = tree          self.colours = colours          self.children = None @@ -233,16 +233,30 @@ class Node:          M = np.empty((3, 3))          for i in range(3):              R = self.basis_functions[i, :] -            M[:, i] = self.clustering.fast_sd_to_XYZ(R) +            M[:, i] = self.tree.fast_sd_to_XYZ(R)          self.M_inverse = np.linalg.inv(M) -        self.XYZ_mu = self.clustering.fast_sd_to_XYZ(self.mean) +        self.XYZ_mu = self.tree.fast_sd_to_XYZ(self.mean)          self.PCA_done = True +    def _reconstruct_xy(self, XYZ, xy): +        if not self.leaf: +            if xy[self.partition_axis.direction] <= self.partition_axis.origin: +                return self.children[0]._reconstruct_xy(XYZ, xy) +            else: +                return self.children[1]._reconstruct_xy(XYZ, xy) + +        weights = np.dot(self.M_inverse, XYZ - self.XYZ_mu) +        reflectance = np.dot(weights, self.basis_functions) + self.mean +        reflectance = np.clip(reflectance, 0, 1) +        return SpectralDistribution(reflectance, self.tree.wl) +      def reconstruct(self, XYZ):          """ -        Reconstructs a reflectance using data stored in this node. +        Reconstructs the reflectance for the given *XYZ* tristimulus values. +        If this is a leaf, data from this node will be used. Otherwise the +        code will look for the appropriate subnode.          Parameters          ========== @@ -256,10 +270,8 @@ class Node:              Recovered spectral distribution.          """ -        weights = np.dot(self.M_inverse, XYZ - self.XYZ_mu) -        reflectance = np.dot(weights, self.basis_functions) + self.mean -        reflectance = np.clip(reflectance, 0, 1) -        return SpectralDistribution(reflectance, self.clustering.wl) +        xy = XYZ_to_xy(XYZ) +        return self._reconstruct_xy(XYZ, xy)      #      # Optimisation @@ -330,15 +342,15 @@ class Node:          """          partition = self.colours.partition(axis) -        if (len(partition[0]) < self.clustering.min_cluster_size -                or len(partition[1]) < self.clustering.min_cluster_size): -            raise ClusteringError( +        if (len(partition[0]) < self.tree.min_cluster_size +                or len(partition[1]) < self.tree.min_cluster_size): +            raise treeError(                  'partition created parts smaller than min_cluster_size') -        lesser = Node(self.clustering, partition[0]) +        lesser = Node(self.tree, partition[0])          lesser.PCA() -        greater = Node(self.clustering, partition[1]) +        greater = Node(self.tree, partition[1])          greater.PCA()          error = lesser.reconstruction_error() + greater.reconstruction_error() @@ -347,7 +359,7 @@ class Node:      def find_best_partition(self):          """          Finds the best partition of this node. See -        ``Clustering.find_best_partition``. +        ``tree.find_best_partition``.          """          if self.best_partition is not None: @@ -368,7 +380,7 @@ class Node:                  try:                      new_error, partition = self.partition_error(axis) -                except ClusteringError: +                except treeError:                      continue                  if new_error >= error: @@ -380,7 +392,7 @@ class Node:          bar.close()          if self.best_partition is None: -            raise ClusteringError('no partitions are possible') +            raise treeError('no partitions are possible')          return self.best_partition @@ -414,7 +426,7 @@ class Node:          Makes a plot showing the principal components of this node and how          well they reconstruct the source data.          """ -         +          plt.subplot(2, 1, 1)          plt.title(str(self) + ': principal components')          for i in range(3): @@ -433,13 +445,14 @@ class Node:              plt.plot(self.wl, recon.values, 'C%d--' % i) -class Clustering: +class Tree(Node):      """ -    Represents the process of clustering and optimisation. Instances store -    shared data such as cmfs and the illuminant. +    This is an extension of ``Node``. It's meant to represent the root of the +    tree and contains information shared with all the nodes, such as cmfs +    and the illuminant (if any is used).      Operations involving the entire tree, such as optimisation and -    reconstruction, are implemented here. +    reconstruction, are also implemented here.      """      def __init__( @@ -462,7 +475,6 @@ class Clustering:              Illuminant spectral distribution.          """ -        self.sds = sds          self.shape = shape          self.wl = shape.range()          self.dw = self.wl[1] - self.wl[0] @@ -474,13 +486,12 @@ class Clustering:          self.k = 1 / (np.sum(self.cmfs.values[:, 1]                        * self.illuminant.values) * self.dw) -        colours = Colours(self, sds) -        self.root = Node(self, colours) +        super().__init__(self, Colours(self, sds))      def fast_sd_to_XYZ(self, R):          """          Compute the XYZ tristimulus values of a given reflectance. Faster for -        humans, by using cmfs and the illuminant stored in the ''Clustering'', +        humans, by using cmfs and the illuminant stored in the ''tree'',          thus avoiding unnecessary repetition. Faster for computers, by using          a very simple and direct method. @@ -488,7 +499,7 @@ class Clustering:          ----------          R : ndarray              Reflectance with shape matching the one used to construct this -            ``Clustering``. +            ``tree``.          Returns          ------- @@ -499,41 +510,10 @@ class Clustering:          E = self.illuminant.values * R          return self.k * np.dot(E, self.cmfs.values) * self.dw -    def reconstruct(self, XYZ): -        """ -        Finds the appropriate node and reconstructs the reflectance for the -        given XYZ tristimulus values. - -        Parameters -        ========== -        XYZ : ndarray, (3,) -            *CIE XYZ* tristimulus values to recover the spectral distribution -            from. - -        Returns -        ------- -        SpectralDistribution -            Recovered spectral distribution. -        """ - -        xy = XYZ_to_xy(XYZ) - -        def search(node): -            if node.leaf: -                return node - -            if xy[node.partition_axis.direction] <= node.partition_axis.origin: -                return search(node.children[0]) -            else: -                return search(node.children[1]) - -        node = search(self.root) -        return node.reconstruct(XYZ) -      def optimise(self, repeats=8, min_cluster_size=None):          """          Optimise the tree by repeatedly performing optimal partitions of the -        nodes, creating a clustering that minimizes the total reconstruction +        nodes, creating a tree that minimizes the total reconstruction          error.          Parameters @@ -552,31 +532,32 @@ class Clustering:          if min_cluster_size is not None:              self.min_cluster_size = min_cluster_size          else: -            self.min_cluster_size = len(self.root.colours) / repeats // 2 +            self.min_cluster_size = len(self.colours) / repeats // 2          if self.min_cluster_size <= 3:              self.min_cluster_size = 3          t0 = time.time() +          def elapsed():              delta = time.time() - t0              return '%dm%.3fs' % (delta // 60, delta % 60) -        initial_error = self.root.total_reconstruction_error() +        initial_error = self.total_reconstruction_error()          print('Initial error is %g.' % initial_error)          for repeat in range(repeats):              print('\n=== Iteration %d of %d ===' % (repeat + 1, repeats))              best_total_error = None -            total_error = self.root.total_reconstruction_error() +            total_error = self.total_reconstruction_error() -            for i, leaf in enumerate(self.root.leaves): +            for i, leaf in enumerate(self.leaves):                  print('(%s) Optimising %s...' % (elapsed(), leaf))                  try:                      error, axis, partition = leaf.find_best_partition() -                except ClusteringError as e: +                except treeError as e:                      print('Failed: %s' % e)                      continue @@ -597,9 +578,10 @@ class Clustering:              print('\nSplit %s into %s and %s along %s.'                    % (best_leaf, *best_partition, best_axis))              print('Error is reduced by %g and is now %g, ' -                  '%.1f%% of the initial error.' % (leaf.reconstruction_error() -                  - error, best_total_error, 100 * best_total_error -                  / initial_error)) +                  '%.1f%% of the initial error.' +                  % (leaf.reconstruction_error() +                     - error, best_total_error, 100 * best_total_error +                     / initial_error))              best_leaf.split(best_partition, best_axis) @@ -607,7 +589,7 @@ class Clustering:      def write_python_dataset(self, path):          """ -        Write the clustering into a Python dataset compatible with Colour's +        Write the tree into a Python dataset compatible with Colour's          ``colour.recovery.otsu2018`` code.          Parameters @@ -628,7 +610,7 @@ class Clustering:              # Basis functions              fd.write('OTSU_2018_BASIS_FUNCTIONS = [\n') -            for i, leaf in enumerate(self.root.leaves): +            for i, leaf in enumerate(self.leaves):                  for line in (repr(leaf.basis_functions) + ',').splitlines():                      fd.write('    %s\n' % line)              fd.write(']\n\n\n') @@ -636,7 +618,7 @@ class Clustering:              # Means              fd.write('OTSU_2018_MEANS = [\n') -            for leaf in self.root.leaves: +            for leaf in self.leaves:                  for line in (repr(leaf.mean) + ',').splitlines():                      fd.write('    %s\n' % line)              fd.write(']\n\n\n') @@ -667,4 +649,4 @@ class Clustering:                  fd.write('else:\n')                  write_if(node.children[1], indent + 1) -            write_if(self.root, 1) +            write_if(self, 1)  | 
