From f88a852f659ea3223ce69ac147912b1d346cf7be Mon Sep 17 00:00:00 2001 From: Paweł Redman Date: Sat, 25 Jul 2020 17:22:36 +0200 Subject: Factor out printing progress information. This way using tqdm is optional and not a direct dependency of otsu2018.py. --- clustering.py | 52 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 clustering.py (limited to 'clustering.py') diff --git a/clustering.py b/clustering.py new file mode 100644 index 0000000..527830b --- /dev/null +++ b/clustering.py @@ -0,0 +1,52 @@ +import os +import matplotlib.pyplot as plt +import tqdm + +from colour import SpectralShape, COLOURCHECKER_SDS, sd_to_XYZ + +from otsu2018 import load_Otsu2018_spectra, Tree + + +if __name__ == '__main__': + print('Loading spectral data...') + sds = load_Otsu2018_spectra('CommonData/spectrum_m.csv', every_nth=50) + shape = SpectralShape(380, 730, 10) + + print('Initializing the tree...') + tree = Tree(sds, shape) + + print('Clustering...') + before = tree.total_reconstruction_error() + tree.optimise(progress_bar=tqdm.tqdm) + after = tree.total_reconstruction_error() + + print('Error before: %g' % before) + print('Error after: %g' % after) + + print('Saving the dataset...') + os.makedirs('datasets', exist_ok=True) + tree.write_python_dataset('datasets/otsu2018.py') + + print('Plotting...') + tree.visualise() + + plt.figure() + + examples = COLOURCHECKER_SDS['ColorChecker N Ohta'].items() + for i, (name, sd) in enumerate(examples): + plt.subplot(2, 3, 1 + i) + plt.title(name) + + plt.plot(sd.wavelengths, sd.values, label='Original') + + XYZ = sd_to_XYZ(sd) / 100 + recovered_sd = tree.reconstruct(XYZ) + plt.plot(recovered_sd.wavelengths, recovered_sd.values, + label='Recovered') + + plt.legend() + + if i + 1 == 6: + break + + plt.show() -- cgit