diff options
Diffstat (limited to 'demo.py')
-rw-r--r-- | demo.py | 25 |
1 files changed, 12 insertions, 13 deletions
@@ -1,34 +1,33 @@ import os import matplotlib.pyplot as plt -from colour import ( - SpectralShape, COLOURCHECKER_SDS, ILLUMINANT_SDS, sd_to_XYZ) +from colour import SpectralShape, COLOURCHECKER_SDS, sd_to_XYZ -from otsu2018 import load_Otsu2018_spectra, Clustering +from otsu2018 import load_Otsu2018_spectra, Tree if __name__ == '__main__': print('Loading spectral data...') - sds = load_Otsu2018_spectra('CommonData/spectrum_m.csv', every_nth=1) + sds = load_Otsu2018_spectra('CommonData/spectrum_m.csv', every_nth=50) shape = SpectralShape(380, 730, 10) - print('Initializing the clustering...') - clustering = Clustering(sds, shape) + print('Initializing the tree...') + tree = Tree(sds, shape) - print('Clustering...') - before = clustering.root.total_reconstruction_error() - clustering.optimise() - after = clustering.root.total_reconstruction_error() + print('Tree...') + before = tree.total_reconstruction_error() + tree.optimise() + after = tree.total_reconstruction_error() print('Error before: %g' % before) print('Error after: %g' % after) print('Saving the dataset...') os.makedirs('datasets', exist_ok=True) - clustering.write_python_dataset('datasets/otsu2018.py') + tree.write_python_dataset('datasets/otsu2018.py') print('Plotting...') - clustering.root.visualise() + tree.visualise() plt.figure() @@ -40,7 +39,7 @@ if __name__ == '__main__': plt.plot(sd.wavelengths, sd.values, label='Original') XYZ = sd_to_XYZ(sd) / 100 - recovered_sd = clustering.reconstruct(XYZ) + recovered_sd = tree.reconstruct(XYZ) plt.plot(recovered_sd.wavelengths, recovered_sd.values, label='Recovered') |