From c0b52c1c841957b27e2dd82a11ecb3b4ff8db265 Mon Sep 17 00:00:00 2001 From: Paweł Redman Date: Tue, 23 Jun 2020 20:08:29 +0200 Subject: Update the tests again and remove some useless ones. --- test_diff.py | 59 ++++++++++++++++++++++++++++++----------------------------- 1 file changed, 30 insertions(+), 29 deletions(-) (limited to 'test_diff.py') diff --git a/test_diff.py b/test_diff.py index 0ee5f58..4429894 100644 --- a/test_diff.py +++ b/test_diff.py @@ -5,46 +5,47 @@ from colour.recovery import error_function_Jakob2019 from matplotlib import pyplot as plt from gsoc_common import plot_comparison + # This test checks if derivatives are calculated correctly by comparing them # to finite differences. if __name__ == "__main__": - shape = SpectralShape(360, 830, 1) - cmfs = STANDARD_OBSERVER_CMFS["CIE 1931 2 Degree Standard Observer"].align(shape) + shape = SpectralShape(360, 830, 1) + cmfs = STANDARD_OBSERVER_CMFS["CIE 1931 2 Degree Standard Observer"].align(shape) - illuminant = SpectralDistribution(ILLUMINANT_SDS["D65"]).align(shape) - illuminant_XYZ = sd_to_XYZ(illuminant) / 100 + illuminant = SpectralDistribution(ILLUMINANT_SDS["D65"]).align(shape) + illuminant_XYZ = sd_to_XYZ(illuminant) / 100 - target = np.array([50, -20, 30]) # Some arbitrary Lab colour - xs = np.linspace(-10, 10, 500) - h = xs[1] - xs[0] + target = np.array([50, -20, 30]) # Some arbitrary Lab colour + xs = np.linspace(-10, 10, 500) + h = xs[1] - xs[0] - # Vary one coefficient at a time - for c_index in range(3): - errors = np.empty(len(xs)) - derrors = np.empty(len(xs)) + # Vary one coefficient at a time + for c_index in range(3): + errors = np.empty(len(xs)) + derrors = np.empty(len(xs)) - for i, x in enumerate(xs): - c = np.array([1.0, 1, 1]) - c[c_index] = x + for i, x in enumerate(xs): + c = np.array([1.0, 1, 1]) + c[c_index] = x - error, derror_dc = error_function_Jakob2019( - c, target, shape, cmfs, illuminant, illuminant_XYZ - ) + error, derror_dc = error_function_Jakob2019( + c, target, shape, cmfs, illuminant, illuminant_XYZ + ) - errors[i] = error - derrors[i] = derror_dc[c_index] + errors[i] = error + derrors[i] = derror_dc[c_index] - plt.subplot(2, 3, 1 + c_index) - plt.xlabel("c%d" % c_index) - plt.ylabel("ΔE") - plt.plot(xs, errors) + plt.subplot(2, 3, 1 + c_index) + plt.xlabel("c%d" % c_index) + plt.ylabel("ΔE") + plt.plot(xs, errors) - plt.subplot(2, 3, 4 + c_index) - plt.xlabel("c%d" % c_index) - plt.ylabel("dΔE/dc%d" % c_index) + plt.subplot(2, 3, 4 + c_index) + plt.xlabel("c%d" % c_index) + plt.ylabel("dΔE/dc%d" % c_index) - plt.plot(xs, derrors, "k-") - plt.plot(xs[:-1] + h / 2, np.diff(errors) / h, "r:") + plt.plot(xs, derrors, "k-") + plt.plot(xs[:-1] + h / 2, np.diff(errors) / h, "r:") - plt.show() + plt.show() -- cgit