summaryrefslogtreecommitdiff
path: root/gsoc_common.py
blob: 5d4aea0e2cb437610fb49ad89450bea708328500 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import numpy as np, scipy.optimize as optimize
from colour import *
from colour.difference import delta_E_CIE1976
from colour.colorimetry import *
from colour.plotting import *
from matplotlib import pyplot as plt

D65_xy = ILLUMINANTS["CIE 1931 2 Degree Standard Observer"]["D65"]
D65 = SpectralDistribution(ILLUMINANT_SDS["D65"])

# The same wavelength grid is used throughout
wvl = SpectralShape(360, 830, 5)
wvlp = (wvl.range() - 360) / (830 - 360)

# This is the model of spectral reflectivity described in the article.
def model(wvlp, ccp):
	yy = ccp[0] * wvlp**2 + ccp[1] * wvlp + ccp[2]
	return 1 / 2 + yy / (2 * np.sqrt(1 + yy ** 2))

# Create a SpectralDistribution using given coefficients
def model_sd(ccp, primed=True):
	# FIXME: don't hardcode the wavelength grid; there should be a way
	#        of creating a SpectralDistribution from the function alone
	grid = wvlp if primed else wvl.range()
	return SpectralDistribution(model(grid, ccp), wvl.range(), name="Model")

# The goal is to minimize the color difference between a given distrbution
# and the one computed from the model above.
# This function also calculates the first derivatives with respect to c's.
def error_function(c, target, wvl, cmfs, illuminant, illuminant_XYZ):
	U = c[0] * wvl**2 + c[1] * wvl + c[2]
	t1 = np.sqrt(1 + U**2)
	R = 1 / 2 + U / (2 * t1)

	t2 = 1 / (2 * t1) - U**2 / (2 * t1**3)
	dR_dc = [wvl**2 * t2, wvl * t2, t2]

	E = illuminant.values * R / 100
	dE_dc = illuminant.values * dR_dc / 100

	XYZ = np.empty(3)
	dXYZ_dc = np.empty((3, 3))

	dlambda = cmfs.wavelengths[1] - cmfs.wavelengths[0]
	for i in range(3):
		XYZ[i] = np.dot(E, cmfs.values[:, i]) * dlambda
		for j in range(3):
			dXYZ_dc[i, j] = np.dot(dE_dc[j], cmfs.values[:, i]) * dlambda

	# FIXME: this isn't the full CIE 1976 lightness function
	f = (XYZ / illuminant_XYZ)**(1/3)

	# FIXME: this can be vectorized
	df_dc = np.empty((3, 3))
	for i in range(3):
		for j in range(3):
			df_dc[i, j] = 1 / (3 * illuminant_XYZ[i]**(1/3)
			                   * XYZ[i]**(2/3)) * dXYZ_dc[i, j]

	Lab = np.array([
		116 * f[1] - 16,
		500 * (f[0] - f[1]),
		200 * (f[1] - f[2])
	])

	dLab_dc = np.array([
		116 * df_dc[1],
		500 * (df_dc[0] - df_dc[1]),
		200 * (df_dc[1] - df_dc[2])
	])

	error = np.sqrt(np.sum((Lab - target)**2))

	derror_dc = np.zeros(3)
	for i in range(3):
		for j in range(3):
			derror_dc[i] += dLab_dc[j, i] * (Lab[j] - target[j])
		derror_dc[i] /= error

	#print("c=[%.3g, %.3g, %.3g], XYZ=[%.3g, %.3g, %.3g], Lab=[%.3g, %.3g, %.3g], error=%g" % (*c, *XYZ, *Lab, error))

	return error, derror_dc

# Finds the parameters for Jakob and Hanika's model
def jakob_hanika(target_XYZ, ill_sd, ill_xy, ccp0=(0, 0, 0), verbose=True, try_hard=True):
	ill_sd = ill_sd.align(wvl)
	ill_XYZ = sd_to_XYZ(ill_sd) / 100

	def do_optimize(XYZ, ccp0):
		Lab = XYZ_to_Lab(XYZ, ill_xy)

		def fun(*args):
			error, derror_dc = error_function(*args)
			return error

		def jac(*args):
			error, derror_dc = error_function(*args)
			return derror_dc

		cmfs = STANDARD_OBSERVER_CMFS["CIE 1931 2 Degree Standard Observer"].align(wvl)

		opt = optimize.minimize(
			fun, ccp0, (Lab, wvlp, cmfs, ill_sd, ill_XYZ), jac=jac,
			method="L-BFGS-B", options={"disp": verbose}
		)
		if verbose:
			print(opt)
		return opt

	# A special case that's hard to solve numerically
	if np.allclose(target_XYZ, [0, 0, 0]):
		return np.array([0, 0, -1e+9]), 0 # FIXME: dtype?

	if verbose:
		print("Trying the target directly, XYZ=%s" % target_XYZ)
	opt = do_optimize(target_XYZ, ccp0)
	if opt.fun < 0.1 or not try_hard:
		return opt.x, opt.fun

	good_XYZ = (1/3, 1/3, 1/3)
	good_ccp = (2.1276356, -1.07293026, -0.29583292) # FIXME: valid only for D65

	divisions = 3
	while divisions < 30:
		if verbose:
			print("Trying with %d divisions" % divisions)

		keep_divisions = False
		ref_XYZ = good_XYZ
		ref_ccp = good_ccp

		ccp0 = ref_ccp
		for i in range(1, divisions):
			intermediate_XYZ = ref_XYZ + (target_XYZ - ref_XYZ) * i / (divisions - 1)
			if verbose:
				print("Intermediate step %d/%d, XYZ=%s with ccp0=%s" %
				      (i + 1, divisions, intermediate_XYZ, ccp0))
			opt = do_optimize(intermediate_XYZ, ccp0)
			if opt.fun > 1e-3:
				if verbose:
					print("WARNING: intermediate optimization failed")
				break
			else:
				good_XYZ = intermediate_XYZ
				good_ccp = opt.x
				keep_divisions = True

			ccp0 = opt.x
		else:
			return opt.x, opt.fun

		if not keep_divisions:
			divisions += 2

	raise Exception("optimization failed for target_XYZ=%s, ccp0=%s" \
	                % (target_XYZ, ccp0))

# Makes a comparison plot with SDs and swatches
def plot_comparison(target, matched_sd, label, error, ill_sd, show=True):
	if type(target) is SpectralDistribution:
		target_XYZ = sd_to_XYZ(target, illuminant=ill_sd) / 100
	else:
		target_XYZ = target
	target_RGB = np.clip(XYZ_to_sRGB(target_XYZ), 0, 1)
	target_swatch = ColourSwatch(label, target_RGB)
	matched_XYZ = sd_to_XYZ(matched_sd, illuminant=ill_sd) / 100
	matched_RGB = np.clip(XYZ_to_sRGB(matched_XYZ), 0, 1)
	matched_swatch = ColourSwatch("Model", matched_RGB)
      
	axes = plt.subplot(2, 1, 1)
	plt.title(label)
	if type(target) is SpectralDistribution:
		plot_multi_sds([target, matched_sd], axes=axes, standalone=False)
	else:
		plot_single_sd(matched_sd, axes=axes, standalone=False)
      
	axes = plt.subplot(2, 1, 2)
	plt.title("ΔE = %g" % error)
	plot_multi_colour_swatches([target_swatch, matched_swatch],
	                           standalone=show, axes=axes)