diff options
Diffstat (limited to 'clustering.py')
-rw-r--r-- | clustering.py | 52 |
1 files changed, 52 insertions, 0 deletions
diff --git a/clustering.py b/clustering.py new file mode 100644 index 0000000..527830b --- /dev/null +++ b/clustering.py @@ -0,0 +1,52 @@ +import os +import matplotlib.pyplot as plt +import tqdm + +from colour import SpectralShape, COLOURCHECKER_SDS, sd_to_XYZ + +from otsu2018 import load_Otsu2018_spectra, Tree + + +if __name__ == '__main__': + print('Loading spectral data...') + sds = load_Otsu2018_spectra('CommonData/spectrum_m.csv', every_nth=50) + shape = SpectralShape(380, 730, 10) + + print('Initializing the tree...') + tree = Tree(sds, shape) + + print('Clustering...') + before = tree.total_reconstruction_error() + tree.optimise(progress_bar=tqdm.tqdm) + after = tree.total_reconstruction_error() + + print('Error before: %g' % before) + print('Error after: %g' % after) + + print('Saving the dataset...') + os.makedirs('datasets', exist_ok=True) + tree.write_python_dataset('datasets/otsu2018.py') + + print('Plotting...') + tree.visualise() + + plt.figure() + + examples = COLOURCHECKER_SDS['ColorChecker N Ohta'].items() + for i, (name, sd) in enumerate(examples): + plt.subplot(2, 3, 1 + i) + plt.title(name) + + plt.plot(sd.wavelengths, sd.values, label='Original') + + XYZ = sd_to_XYZ(sd) / 100 + recovered_sd = tree.reconstruct(XYZ) + plt.plot(recovered_sd.wavelengths, recovered_sd.values, + label='Recovered') + + plt.legend() + + if i + 1 == 6: + break + + plt.show() |