summaryrefslogtreecommitdiff
path: root/demo.py
diff options
context:
space:
mode:
authorPaweł Redman <pawel.redman@gmail.com>2020-07-25 15:53:29 +0200
committerPaweł Redman <pawel.redman@gmail.com>2020-07-25 16:42:01 +0200
commit26c2bcbb77f9d2914be35e74e14333bc3e729c3e (patch)
tree00b2bb3104bd8cf11dc185d9a17c26348610290f /demo.py
parent9c0dfa59598968ab5be656aac5c4f6f79c50dd61 (diff)
Rename Clustering to Tree and make it a subclass of Node.
Diffstat (limited to 'demo.py')
-rw-r--r--demo.py25
1 files changed, 12 insertions, 13 deletions
diff --git a/demo.py b/demo.py
index eb02202..3c4b933 100644
--- a/demo.py
+++ b/demo.py
@@ -1,34 +1,33 @@
import os
import matplotlib.pyplot as plt
-from colour import (
- SpectralShape, COLOURCHECKER_SDS, ILLUMINANT_SDS, sd_to_XYZ)
+from colour import SpectralShape, COLOURCHECKER_SDS, sd_to_XYZ
-from otsu2018 import load_Otsu2018_spectra, Clustering
+from otsu2018 import load_Otsu2018_spectra, Tree
if __name__ == '__main__':
print('Loading spectral data...')
- sds = load_Otsu2018_spectra('CommonData/spectrum_m.csv', every_nth=1)
+ sds = load_Otsu2018_spectra('CommonData/spectrum_m.csv', every_nth=50)
shape = SpectralShape(380, 730, 10)
- print('Initializing the clustering...')
- clustering = Clustering(sds, shape)
+ print('Initializing the tree...')
+ tree = Tree(sds, shape)
- print('Clustering...')
- before = clustering.root.total_reconstruction_error()
- clustering.optimise()
- after = clustering.root.total_reconstruction_error()
+ print('Tree...')
+ before = tree.total_reconstruction_error()
+ tree.optimise()
+ after = tree.total_reconstruction_error()
print('Error before: %g' % before)
print('Error after: %g' % after)
print('Saving the dataset...')
os.makedirs('datasets', exist_ok=True)
- clustering.write_python_dataset('datasets/otsu2018.py')
+ tree.write_python_dataset('datasets/otsu2018.py')
print('Plotting...')
- clustering.root.visualise()
+ tree.visualise()
plt.figure()
@@ -40,7 +39,7 @@ if __name__ == '__main__':
plt.plot(sd.wavelengths, sd.values, label='Original')
XYZ = sd_to_XYZ(sd) / 100
- recovered_sd = clustering.reconstruct(XYZ)
+ recovered_sd = tree.reconstruct(XYZ)
plt.plot(recovered_sd.wavelengths, recovered_sd.values,
label='Recovered')