diff options
| -rw-r--r-- | clustering.py | 7 | ||||
| -rw-r--r-- | otsu2018.py | 34 | 
2 files changed, 27 insertions, 14 deletions
diff --git a/clustering.py b/clustering.py index 6df6643..d8b1ed9 100644 --- a/clustering.py +++ b/clustering.py @@ -9,7 +9,7 @@ from otsu2018 import load_Otsu2018_spectra, Otsu2018Tree  if __name__ == '__main__':      print('Loading spectral data...') -    sds = load_Otsu2018_spectra('CommonData/spectrum_m.csv', every_nth=7) +    sds = load_Otsu2018_spectra('CommonData/spectrum_m.csv', every_nth=100)      shape = SpectralShape(380, 730, 10)      print('Initializing the tree...') @@ -23,8 +23,9 @@ if __name__ == '__main__':      print('Error before: %g' % before)      print('Error after:  %g' % after) -    print('Saving the dataset...') -    os.makedirs('datasets', exist_ok=True) +    print('Saving the dataset...')     +    if not os.path.exists('datasets'): +        os.makedirs('datasets')      data = tree.to_dataset()      data.to_file('datasets/otsu2018.npz')      data.to_Python_file('datasets/otsu2018.py') diff --git a/otsu2018.py b/otsu2018.py index 884cbcc..6ce2343 100644 --- a/otsu2018.py +++ b/otsu2018.py @@ -1,3 +1,5 @@ +from __future__ import print_function +  import numpy as np  import matplotlib.pyplot as plt  import time @@ -60,7 +62,8 @@ class PartitionAxis:          return '%s=%s' % ('yx'[self.direction], repr(self.origin)) -class Colours: +# Python 3: drop the subclassing +class Colours(object):      """      Represents multiple colours: their reflectances, XYZ tristimulus values      and xy coordinates. The cmfs and the illuminant are taken from the parent @@ -137,7 +140,8 @@ class Otsu2018Error(Exception):      pass -class Node: +# Python 3: drop the subclassing +class Node(object):      """      Represents a node in the tree tree.      """ @@ -202,7 +206,9 @@ class Node:              yield self          else:              for child in self.children: -                yield from child.leaves +                # (Python 3) yield from child.leaves +                for leaf in child.leaves: +                    yield leaf      @property      def leaves(self): @@ -411,8 +417,9 @@ class Node:              return          symbols = ['+', '^', '*', '>', 'o', 'v', 'x', '<'] -        plt.plot(*self.colours.xy.T, -                 "k" + symbols[number[0] % len(symbols)], +        # Python 3: plt.plot(*self.colours.xy.T, ... +        plt.plot(self.colours.xy[:, 0], self.colours.xy[:, 1], +                 'k' + symbols[number[0] % len(symbols)],                   label=str(self))          number[0] += 1 @@ -490,7 +497,8 @@ class Otsu2018Tree(Node):          self.k = 1 / (np.sum(self.cmfs.values[:, 1]                        * self.illuminant.values) * self.dw) -        super().__init__(self, Colours(self, sds)) +        # Python 3: super().__init__(...) +        super(Otsu2018Tree, self).__init__(self, Colours(self, sds))      def fast_sd_to_XYZ(self, R):          """ @@ -596,7 +604,8 @@ class Otsu2018Tree(Node):                  break              _print('\nSplit %s into %s and %s along %s.' -                   % (best_leaf, *best_partition, best_axis)) +                   % (best_leaf, best_partition[0], best_partition[1], +                      best_axis))              _print('Error is reduced by %g and is now %g, '                     '%.1f%% of the initial error.'                     % (leaf.reconstruction_error() @@ -615,15 +624,18 @@ class Otsu2018Tree(Node):          """          rows = [] -        leaf_number = 0 +        # (Python 3) leaf_number = 0 +        leaf_number = [0]          symbol_table = {}          def add_rows(node): -            nonlocal leaf_number +            # nonlocal leaf_number              if node.leaf: -                symbol_table[node] = leaf_number -                leaf_number += 1 +                # symbol_table[node] = leaf_number +                # leaf_number += 1 +                symbol_table[node] = leaf_number[0] +                leaf_number[0] += 1                  return              symbol_table[node] = -len(rows)  | 
