Code
The Code
import os
os.environ['FOR_DISABLE_CONSOLE_CTRL_HANDLER'] = '1' # avoid weird errors
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
import time
from math import log
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.linalg as L
import torch.optim as Opt
import torch.nn as N
import colour
from colour import MSDS_CMFS
from typing import Generator
FTYPE = torch.float64
# OKLab conversion matrices according to https://bottosson.github.io/posts/oklab/
M1 = torch.tensor(
[
[0.8189330101, 0.0329845436, 0.0482003018],
[0.3618667424, 0.9293118715, 0.2643662691],
[-0.1288597137, 0.0361456387, 0.6338517070]
], dtype=FTYPE).T.cuda()
M2 = torch.tensor(
[
[0.2104542553, 1.9779984951, 0.0259040371],
[0.7936177850, -2.4285922050, 0.7827717662],
[-0.0040720468, 0.4505937099, -0.8086757660]
], dtype=FTYPE).T.cuda()
def cbrt(x):
"""
Computes the cube root of the input tensor element-wise.
:param x: Input tensor.
:return: Cube root of the input tensor.
"""
xabs = torch.abs(x)
xsgn = torch.sign(x)
return xabs.pow(1/3).mul(xsgn)
def xyz_to_oklab(xyz):
"""
Converts from XYZ color space to OKLab color space:
https://bottosson.github.io/posts/oklab/
:param xyz: Tensor representing color in XYZ color space.
:return: Tensor representing color in OKLab color space.
"""
lms = torch.einsum('...ab,...b->...a', M1, xyz)
lab = torch.einsum('...ab,...b->...a', M2, cbrt(lms))
return lab
def oklab_to_xyz(lab):
"""
Converts from OKLab color space to XYZ color space:
https://bottosson.github.io/posts/oklab/
:param lab: Tensor representing color in OKLab color space.
:return: Tensor representing color in XYZ color space.
"""
lmscbrt = L.solve(M2, lab)
lms = lmscbrt*lmscbrt*lmscbrt
xyz, _ = L.solve(M1, lms)
return xyz
def line_dist(p1, p2, p):
"""
given sets of two points p1 p2 on n lines and a set of m points p,
return the squared distance of each p from every line defined by each pair p1 p2.
"""
pp1 = p[:, None, ...] - p1
p2p1 = p2 - p1
squared_length = (p2p1 * p2p1).sum(dim=1)
det = pp1[:, :, 0] * p2p1[:, 1] - pp1[:, :, 1] * p2p1[:, 0]
return det * det / squared_length
def golden_exponential(a: float) -> Generator[float, None, None]:
while True:
a %= 1
yield -log(1 - a)
a += 0.6180339887498948482045868343656381177203091798057628621354486227
exps = golden_exponential(0)
def save_spectra_to_csv(spectra, illuminant_spectrum):
channels = ['red', 'green', 'blue']
for i, channel in enumerate(channels):
filename = f"{channel}_spectrum.csv"
wl = np.arange(360, 831)
data = np.column_stack((wl, spectra[:, i]))
np.savetxt(filename, data,
delimiter=",", fmt='%.20E', header=f"wavelength | {channel} spectrum", comments='')
data = np.column_stack((wl, spectra[:, i] * illuminant_spectrum))
np.savetxt(f"{filename}_light", data,
delimiter=",", fmt='%.20E', header=f"wavelength | {channel} spectrum", comments='')
def main():
# get the color matching functions and the white point
cmfs = MSDS_CMFS['CIE 1931 2 Degree Standard Observer']
d65_spectral = colour.SDS_ILLUMINANTS['D65'].align(cmfs.shape)
d65_XYZ = colour.sd_to_XYZ(d65_spectral, cmfs=cmfs)
d65_spectral = colour.SDS_ILLUMINANTS['D65'].align(cmfs.shape) / d65_XYZ[1] # normalize to Y=1
d65_XYZ = colour.sd_to_XYZ(d65_spectral, cmfs=cmfs)
# switch to torch and cuda
m = torch.tensor(cmfs[:], dtype=FTYPE).cuda()
d65_spectral = torch.tensor(d65_spectral[:], dtype=FTYPE).cuda()
d65_XYZ = torch.tensor(d65_XYZ, dtype=FTYPE).cuda()
okw = xyz_to_oklab(d65_XYZ) # get the whitepoint in OKLab (for D65 that's basically 1 0 0)
RGB = torch.tensor( # whitepoint-agnostic Rec.709 (exact)
[
[
[1664/1245, -2368/3735, -256/1245],
[286/415, -407/1245, -44/415],
[26/415, -37/1245, -4/415]
],
[
[-863/2490, 5011/7470, 37/2490],
[-863/1245, 5011/3735, 37/1245],
[-863/7470, 5011/22410, 37/7470]
],
[
[5/498, -55/1494, 95/498],
[1/249, -11/747, 19/249],
[79/1494, -869/4482, 1501/1494]
]
], dtype=FTYPE).cuda()
rgb = torch.einsum("...a,a->...", RGB, d65_XYZ).cuda() # get RGB with the correct whitepoint and primaries
# define the initial parameters: Red /, Green /\, Blue \
params = torch.zeros_like(m, dtype=FTYPE).T.cuda()
params[0] = torch.linspace(-1, 1, steps=params[0].shape[0]).cuda()
params[1] = torch.linspace(-1, 1, steps=params[0].shape[0]).abs().mul(-1).add(0.5).cuda()
params[2] = torch.linspace(1, -1, steps=params[2].shape[0]).cuda()
params = params.requires_grad_(True)
# pick an optimizer, register parameters, pick hyperparameters
optimizer = Opt.AdamW(params=(params,), lr=0.001, weight_decay=1/(1 << 64), amsgrad=True)
SAMPLES: int = 64 # how many spectra to sample each iteration. Will effectively be squared so if 64 then 64² = 4096
# define the plot settings and update accordingly
fig, ax = plt.subplots()
lines = [ax.plot([], [], color=color, label=f'{color}')[0] for color in ['r', 'g', 'b']]
ax.set_xlim(0, 471)
ax.set_ylim(0, 1)
ax.legend()
plt.ion()
plt.show()
def update_plot(plot_spectra) -> None:
"""
Updates the plot with the latest spectra data.
"""
for i, line in enumerate(lines):
line.set_data(range(471), plot_spectra[:, i])
ax.relim()
ax.autoscale_view()
plt.draw()
plt.pause(0.001)
def train():
# sample different colors
RGB_spectra = params.softmax(dim=0) # convert the params into the base spectra that sum to 1
rand_rgb = torch.rand(SAMPLES, 3, dtype=FTYPE).cuda() # sample {SAMPLES} linear combinations of base spectra
sample_spectra = (rand_rgb @ RGB_spectra) # generate the colors
target_ok = xyz_to_oklab(rand_rgb @ rgb) # figure out OKLAB coordinates of those target colors
# generate the spectra corresponding to those colors, as well as higher powers of those spectra
samples = torch.tensor([sample for _, sample in zip(range(SAMPLES-1), exps)])
sample_spectra = torch.flatten(torch.stack((sample_spectra, *[sample_spectra**s for s in samples])), end_dim=1)
xyzs = torch.einsum('...ac,...cb->...ab', sample_spectra * d65_spectral, m) # multiply by illuminant d65
oks = xyz_to_oklab(xyzs) # figure out the OKLab coordinates of the specta
# calculate losses:
# square distance between target OKLab coordinates and those calculated for the spectra
point_loss = N.MSELoss()(oks[:SAMPLES], target_ok)
# distance from line through OKLab space such that
rep_rand_ok = target_ok.repeat(repeats=(SAMPLES, 1))
line_loss = line_dist(okw[1:], rep_rand_ok[:, 1:], oks[:, 1:]).mean()
tv_loss = (RGB_spectra[..., 1:]-RGB_spectra[..., :-1]).square().sum()
loss =\
(
100000000 * point_loss
+ line_loss
+ tv_loss
)
loss.backward()
optimizer.step()
optimizer.zero_grad()
# logging and saving outputs
latest_spectra = RGB_spectra.T.cpu().detach().numpy()
latest_loss = loss.item()
latest_point_loss = point_loss.item()
latest_line = line_loss.item()
latest_tv = tv_loss.item()
return latest_spectra, latest_loss, latest_point_loss, latest_line, latest_tv
# set up a timer to get regular updates
iteration = 0
starttime = time.monotonic()
curtime = starttime
while True: # main training and logging loop
try:
spectra, full_loss, ploss, lloss, tvloss = train()
if time.monotonic() - curtime > 0.25:
curtime = time.monotonic()
ips = iteration/(curtime - starttime)
print(f"i: {iteration}\ti/s: {ips:.3g}\t"
f"L: {full_loss:.8E}\t"
f"p: {ploss:.8E}\t"
f"l: {lloss:.8E}\t"
f"tv: {tvloss:.8E}\t"
)
update_plot(spectra)
iteration += 1
except KeyboardInterrupt: # it saves upon cancelling the script (Ctrl + C)
save_spectra_to_csv(spectra, (100 * d65_spectral).cpu().detach().numpy())
break
if __name__ == '__main__':
main()
I think I have been overcomplicating things. It turns out if you optimize in OKLab space (rather than XYZ space), you get smooth spectra from literally just directly optimizing the coordinate without any smoothness constraint, and the resulting colors much more strongly correspond to the desired target colors and they perform more strongly in tests overall:
Note the strong purple cast in the blue in the current version. Red also goes slightly purplish.
Very strong adherence to constant hue throughout, no purple cast. The bluish cast on the far right of each of these (most noticeable in the all-red version) is due to the D65 whitepoint of the light source while this version of the Spectral Branch actually aims for Illuminant E rendering. AFAIK, that conversion is done through an ad-hoc hack that could be taken out with this approach.