diff options
author | Paweł Redman <pawel.redman@gmail.com> | 2020-07-25 15:53:29 +0200 |
---|---|---|
committer | Paweł Redman <pawel.redman@gmail.com> | 2020-07-25 16:42:01 +0200 |
commit | 26c2bcbb77f9d2914be35e74e14333bc3e729c3e (patch) | |
tree | 00b2bb3104bd8cf11dc185d9a17c26348610290f /demo.py | |
parent | 9c0dfa59598968ab5be656aac5c4f6f79c50dd61 (diff) |
Rename Clustering to Tree and make it a subclass of Node.
Diffstat (limited to 'demo.py')
-rw-r--r-- | demo.py | 25 |
1 files changed, 12 insertions, 13 deletions
@@ -1,34 +1,33 @@ import os import matplotlib.pyplot as plt -from colour import ( - SpectralShape, COLOURCHECKER_SDS, ILLUMINANT_SDS, sd_to_XYZ) +from colour import SpectralShape, COLOURCHECKER_SDS, sd_to_XYZ -from otsu2018 import load_Otsu2018_spectra, Clustering +from otsu2018 import load_Otsu2018_spectra, Tree if __name__ == '__main__': print('Loading spectral data...') - sds = load_Otsu2018_spectra('CommonData/spectrum_m.csv', every_nth=1) + sds = load_Otsu2018_spectra('CommonData/spectrum_m.csv', every_nth=50) shape = SpectralShape(380, 730, 10) - print('Initializing the clustering...') - clustering = Clustering(sds, shape) + print('Initializing the tree...') + tree = Tree(sds, shape) - print('Clustering...') - before = clustering.root.total_reconstruction_error() - clustering.optimise() - after = clustering.root.total_reconstruction_error() + print('Tree...') + before = tree.total_reconstruction_error() + tree.optimise() + after = tree.total_reconstruction_error() print('Error before: %g' % before) print('Error after: %g' % after) print('Saving the dataset...') os.makedirs('datasets', exist_ok=True) - clustering.write_python_dataset('datasets/otsu2018.py') + tree.write_python_dataset('datasets/otsu2018.py') print('Plotting...') - clustering.root.visualise() + tree.visualise() plt.figure() @@ -40,7 +39,7 @@ if __name__ == '__main__': plt.plot(sd.wavelengths, sd.values, label='Original') XYZ = sd_to_XYZ(sd) / 100 - recovered_sd = clustering.reconstruct(XYZ) + recovered_sd = tree.reconstruct(XYZ) plt.plot(recovered_sd.wavelengths, recovered_sd.values, label='Recovered') |