summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--clustering.py7
-rw-r--r--otsu2018.py34
2 files changed, 27 insertions, 14 deletions
diff --git a/clustering.py b/clustering.py
index 6df6643..d8b1ed9 100644
--- a/clustering.py
+++ b/clustering.py
@@ -9,7 +9,7 @@ from otsu2018 import load_Otsu2018_spectra, Otsu2018Tree
if __name__ == '__main__':
print('Loading spectral data...')
- sds = load_Otsu2018_spectra('CommonData/spectrum_m.csv', every_nth=7)
+ sds = load_Otsu2018_spectra('CommonData/spectrum_m.csv', every_nth=100)
shape = SpectralShape(380, 730, 10)
print('Initializing the tree...')
@@ -23,8 +23,9 @@ if __name__ == '__main__':
print('Error before: %g' % before)
print('Error after: %g' % after)
- print('Saving the dataset...')
- os.makedirs('datasets', exist_ok=True)
+ print('Saving the dataset...')
+ if not os.path.exists('datasets'):
+ os.makedirs('datasets')
data = tree.to_dataset()
data.to_file('datasets/otsu2018.npz')
data.to_Python_file('datasets/otsu2018.py')
diff --git a/otsu2018.py b/otsu2018.py
index 884cbcc..6ce2343 100644
--- a/otsu2018.py
+++ b/otsu2018.py
@@ -1,3 +1,5 @@
+from __future__ import print_function
+
import numpy as np
import matplotlib.pyplot as plt
import time
@@ -60,7 +62,8 @@ class PartitionAxis:
return '%s=%s' % ('yx'[self.direction], repr(self.origin))
-class Colours:
+# Python 3: drop the subclassing
+class Colours(object):
"""
Represents multiple colours: their reflectances, XYZ tristimulus values
and xy coordinates. The cmfs and the illuminant are taken from the parent
@@ -137,7 +140,8 @@ class Otsu2018Error(Exception):
pass
-class Node:
+# Python 3: drop the subclassing
+class Node(object):
"""
Represents a node in the tree tree.
"""
@@ -202,7 +206,9 @@ class Node:
yield self
else:
for child in self.children:
- yield from child.leaves
+ # (Python 3) yield from child.leaves
+ for leaf in child.leaves:
+ yield leaf
@property
def leaves(self):
@@ -411,8 +417,9 @@ class Node:
return
symbols = ['+', '^', '*', '>', 'o', 'v', 'x', '<']
- plt.plot(*self.colours.xy.T,
- "k" + symbols[number[0] % len(symbols)],
+ # Python 3: plt.plot(*self.colours.xy.T, ...
+ plt.plot(self.colours.xy[:, 0], self.colours.xy[:, 1],
+ 'k' + symbols[number[0] % len(symbols)],
label=str(self))
number[0] += 1
@@ -490,7 +497,8 @@ class Otsu2018Tree(Node):
self.k = 1 / (np.sum(self.cmfs.values[:, 1]
* self.illuminant.values) * self.dw)
- super().__init__(self, Colours(self, sds))
+ # Python 3: super().__init__(...)
+ super(Otsu2018Tree, self).__init__(self, Colours(self, sds))
def fast_sd_to_XYZ(self, R):
"""
@@ -596,7 +604,8 @@ class Otsu2018Tree(Node):
break
_print('\nSplit %s into %s and %s along %s.'
- % (best_leaf, *best_partition, best_axis))
+ % (best_leaf, best_partition[0], best_partition[1],
+ best_axis))
_print('Error is reduced by %g and is now %g, '
'%.1f%% of the initial error.'
% (leaf.reconstruction_error()
@@ -615,15 +624,18 @@ class Otsu2018Tree(Node):
"""
rows = []
- leaf_number = 0
+ # (Python 3) leaf_number = 0
+ leaf_number = [0]
symbol_table = {}
def add_rows(node):
- nonlocal leaf_number
+ # nonlocal leaf_number
if node.leaf:
- symbol_table[node] = leaf_number
- leaf_number += 1
+ # symbol_table[node] = leaf_number
+ # leaf_number += 1
+ symbol_table[node] = leaf_number[0]
+ leaf_number[0] += 1
return
symbol_table[node] = -len(rows)