diff options
author | Paweł Redman <pawel.redman@gmail.com> | 2020-06-20 22:48:47 +0200 |
---|---|---|
committer | Paweł Redman <pawel.redman@gmail.com> | 2020-06-20 22:48:47 +0200 |
commit | 1934d1ce12200e26e9e76da9cd77b4593bc0a0c2 (patch) | |
tree | f397f8d0bd804f5df60424deb86374d27f7f059b /test_diff.py | |
parent | 77221e174cdab64de60450646de77f849b043062 (diff) |
Rewrite the error function w/ analytic derivatives
This greatly speeds up the code and improves convergence. Something's slightly off with the code, though, and the output colors are slightly skewed.
Diffstat (limited to 'test_diff.py')
-rw-r--r-- | test_diff.py | 51 |
1 files changed, 51 insertions, 0 deletions
diff --git a/test_diff.py b/test_diff.py new file mode 100644 index 0000000..b634ca6 --- /dev/null +++ b/test_diff.py @@ -0,0 +1,51 @@ +import numpy as np +from scipy.optimize import minimize +from colour import * +from colour.colorimetry import STANDARD_OBSERVER_CMFS, ILLUMINANT_SDS +from colour.models import eotf_inverse_sRGB, sRGB_to_XYZ +from matplotlib import pyplot as plt +from gsoc_common import plot_comparison, error_function, model_sd, D65_xy + +shape = SpectralShape(360, 830, 1) +cmfs = STANDARD_OBSERVER_CMFS["CIE 1931 2 Degree Standard Observer"].align(shape) + +illuminant = SpectralDistribution(ILLUMINANT_SDS["D65"]).align(shape) +illuminant_XYZ = sd_to_XYZ(illuminant) / 100 +wvl = np.linspace(0, 1, len(shape.range())) + +target = np.array([50, -20, 30]) # Some arbitrary Lab coordinates +xs = np.linspace(-10, 10, 500) +h = xs[1] - xs[0] + +# This test checks if derivatives are calculated correctly by comparing them +# to finite differences. +for c_index in range(3): + errors = np.empty(len(xs)) + derrors = np.empty(len(xs)) + + for i, x in enumerate(xs): + c = np.array([1.0, 1, 1]) + c[c_index] = x + + error, derror_dc = error_function( + c, target, wvl, cmfs, illuminant, illuminant_XYZ + ) + + errors[i] = error + derrors[i] = derror_dc[c_index] + + + plt.subplot(2, 3, 1 + c_index) + plt.xlabel("c%d" % c_index) + plt.ylabel("ΔE") + plt.plot(xs, errors) + + plt.subplot(2, 3, 4 + c_index) + plt.xlabel("c%d" % c_index) + plt.ylabel("dΔE/dc%d" % c_index) + + plt.plot(xs, derrors, "k-") + plt.plot(xs[:-1] + h / 2, np.diff(errors) / h, "r:") + + +plt.show() |