summaryrefslogtreecommitdiff
path: root/otsu2018.py
diff options
context:
space:
mode:
Diffstat (limited to 'otsu2018.py')
-rw-r--r--otsu2018.py16
1 files changed, 10 insertions, 6 deletions
diff --git a/otsu2018.py b/otsu2018.py
index 6533822..713da1d 100644
--- a/otsu2018.py
+++ b/otsu2018.py
@@ -1,6 +1,5 @@
import numpy as np
import matplotlib.pyplot as plt
-import sklearn.decomposition
from colour import (SpectralDistribution, STANDARD_OBSERVER_CMFS,
ILLUMINANT_SDS, sd_to_XYZ, XYZ_to_xy)
@@ -207,10 +206,14 @@ class Node:
if not self.leaf:
raise RuntimeError('Node.PCA called for a node that is not a leaf')
- pca = sklearn.decomposition.PCA(3)
- pca.fit(self.colours.reflectances)
- self.basis_functions = pca.components_
- self.mean = pca.mean_
+ # https://dev.to/akaame/implementing-simple-pca-using-numpy-3k0a
+ self.mean = np.mean(self.colours.reflectances, axis=0)
+ data = self.colours.reflectances - self.mean
+ cov = np.cov(data.T) / data.shape[0]
+ v, w = np.linalg.eig(cov)
+ idx = v.argsort()[::-1]
+ w = w[:,idx]
+ self.basis_functions = np.real(w[:, :3].T)
# TODO: better names
M = np.empty((3, 3))
@@ -241,6 +244,7 @@ class Node:
weights = np.dot(self.M_inverse, XYZ - self.XYZ_mu)
reflectance = np.dot(weights, self.basis_functions) + self.mean
+ reflectance = np.clip(reflectance, 0, 1)
return SpectralDistribution(reflectance, self.clustering.wl)
def reconstruction_error(self):
@@ -313,7 +317,7 @@ class Node:
axis = self.colours.xy[i, x_or_y]
partition = self.colours.partition(x_or_y, axis)
- if len(partition[0]) <= 5 or len(partition[1]) <= 5:
+ if len(partition[0]) < 3 or len(partition[1]) < 3:
raise ClusteringError('partition created parts that are too small')
lesser = Node(self.clustering, partition[0])