diff options
-rw-r--r-- | clustering.py | 7 | ||||
-rw-r--r-- | otsu2018.py | 34 |
2 files changed, 27 insertions, 14 deletions
diff --git a/clustering.py b/clustering.py index 6df6643..d8b1ed9 100644 --- a/clustering.py +++ b/clustering.py @@ -9,7 +9,7 @@ from otsu2018 import load_Otsu2018_spectra, Otsu2018Tree if __name__ == '__main__': print('Loading spectral data...') - sds = load_Otsu2018_spectra('CommonData/spectrum_m.csv', every_nth=7) + sds = load_Otsu2018_spectra('CommonData/spectrum_m.csv', every_nth=100) shape = SpectralShape(380, 730, 10) print('Initializing the tree...') @@ -23,8 +23,9 @@ if __name__ == '__main__': print('Error before: %g' % before) print('Error after: %g' % after) - print('Saving the dataset...') - os.makedirs('datasets', exist_ok=True) + print('Saving the dataset...') + if not os.path.exists('datasets'): + os.makedirs('datasets') data = tree.to_dataset() data.to_file('datasets/otsu2018.npz') data.to_Python_file('datasets/otsu2018.py') diff --git a/otsu2018.py b/otsu2018.py index 884cbcc..6ce2343 100644 --- a/otsu2018.py +++ b/otsu2018.py @@ -1,3 +1,5 @@ +from __future__ import print_function + import numpy as np import matplotlib.pyplot as plt import time @@ -60,7 +62,8 @@ class PartitionAxis: return '%s=%s' % ('yx'[self.direction], repr(self.origin)) -class Colours: +# Python 3: drop the subclassing +class Colours(object): """ Represents multiple colours: their reflectances, XYZ tristimulus values and xy coordinates. The cmfs and the illuminant are taken from the parent @@ -137,7 +140,8 @@ class Otsu2018Error(Exception): pass -class Node: +# Python 3: drop the subclassing +class Node(object): """ Represents a node in the tree tree. """ @@ -202,7 +206,9 @@ class Node: yield self else: for child in self.children: - yield from child.leaves + # (Python 3) yield from child.leaves + for leaf in child.leaves: + yield leaf @property def leaves(self): @@ -411,8 +417,9 @@ class Node: return symbols = ['+', '^', '*', '>', 'o', 'v', 'x', '<'] - plt.plot(*self.colours.xy.T, - "k" + symbols[number[0] % len(symbols)], + # Python 3: plt.plot(*self.colours.xy.T, ... + plt.plot(self.colours.xy[:, 0], self.colours.xy[:, 1], + 'k' + symbols[number[0] % len(symbols)], label=str(self)) number[0] += 1 @@ -490,7 +497,8 @@ class Otsu2018Tree(Node): self.k = 1 / (np.sum(self.cmfs.values[:, 1] * self.illuminant.values) * self.dw) - super().__init__(self, Colours(self, sds)) + # Python 3: super().__init__(...) + super(Otsu2018Tree, self).__init__(self, Colours(self, sds)) def fast_sd_to_XYZ(self, R): """ @@ -596,7 +604,8 @@ class Otsu2018Tree(Node): break _print('\nSplit %s into %s and %s along %s.' - % (best_leaf, *best_partition, best_axis)) + % (best_leaf, best_partition[0], best_partition[1], + best_axis)) _print('Error is reduced by %g and is now %g, ' '%.1f%% of the initial error.' % (leaf.reconstruction_error() @@ -615,15 +624,18 @@ class Otsu2018Tree(Node): """ rows = [] - leaf_number = 0 + # (Python 3) leaf_number = 0 + leaf_number = [0] symbol_table = {} def add_rows(node): - nonlocal leaf_number + # nonlocal leaf_number if node.leaf: - symbol_table[node] = leaf_number - leaf_number += 1 + # symbol_table[node] = leaf_number + # leaf_number += 1 + symbol_table[node] = leaf_number[0] + leaf_number[0] += 1 return symbol_table[node] = -len(rows) |