summaryrefslogtreecommitdiff
path: root/gsoc_common.py
blob: c235375171288605a809e6f95b8fc230b3069b5f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import numpy as np, scipy.optimize as optimize
from colour import *
from colour.difference import delta_E_CIE1976
from colour.colorimetry import *
from colour.plotting import *
from matplotlib import pyplot as plt


D65_xy = ILLUMINANTS["CIE 1931 2 Degree Standard Observer"]["D65"]
D65 = SpectralDistribution(ILLUMINANT_SDS["D65"])

# The same wavelength grid is used throughout
wvl = SpectralShape(360, 830, 5)
wvlp = (wvl.range() - 360) / (830 - 360)


# This is the model of spectral reflectivity described in the article.
def model(wvlp, ccp):
    yy = ccp[0] * wvlp**2 + ccp[1] * wvlp + ccp[2]
    return 1 / 2 + yy / (2 * np.sqrt(1 + yy ** 2))


# Create a SpectralDistribution using given coefficients
def model_sd(ccp, primed=True):
    # FIXME: don't hardcode the wavelength grid; there should be a way
    #        of creating a SpectralDistribution from the function alone
    grid = wvlp if primed else wvl.range()
    return SpectralDistribution(model(grid, ccp), wvl.range(), name="Model")


# Makes a comparison plot with SDs and swatches
def plot_comparison(target, matched_sd, label, error, ill_sd, show=True):
    if type(target) is SpectralDistribution:
        target_XYZ = sd_to_XYZ(target, illuminant=ill_sd) / 100
    else:
        target_XYZ = target
    target_RGB = np.clip(XYZ_to_sRGB(target_XYZ), 0, 1)
    target_swatch = ColourSwatch(label, target_RGB)
    matched_XYZ = sd_to_XYZ(matched_sd, illuminant=ill_sd) / 100
    matched_RGB = np.clip(XYZ_to_sRGB(matched_XYZ), 0, 1)
    matched_swatch = ColourSwatch("Model", matched_RGB)

    axes = plt.subplot(2, 1, 1)
    plt.title(label)
    if type(target) is SpectralDistribution:
        plot_multi_sds([target, matched_sd], axes=axes, standalone=False)
    else:
        plot_single_sd(matched_sd, axes=axes, standalone=False)

    axes = plt.subplot(2, 1, 2)
    plt.title("ΔE = %g" % error)
    plot_multi_colour_swatches([target_swatch, matched_swatch],
                               standalone=show, axes=axes)