import numpy as np
import matplotlib.pyplot as plt
from skimage import io
from numpy import linalg
from skimage.color.colorconv import _prepare_colorarray
from sklearn.metrics import mean_squared_error
img = io.imread('E:\\test1.png')
img1 = _prepare_colorarray(img, force_copy=True)
np.maximum(img1, 1E-6, out=img1)
y = np.log(img1)
X1 = np.array([[0.571, 0.584, 0.577], [0.095, 0.258, 0.961], [0.767, 0.576, 0.284]]) # CD8,PanCK,Hema
X1_inv = linalg.inv(X1)
B1 = y @ X1_inv
y1 = B1 @ X1
X2 = np.array([[0.095, 0.258, 0.961], [0.105, 0.758, 0.644], [0.767, 0.576, 0.284]]) # PanCK,PD-L1,Hema
X2_inv = linalg.inv(X2)
B2 = y @ X2_inv
y2 = B2 @ X2
X3 = np.array([[0.571, 0.584, 0.577], [0.767, 0.576, 0.284], [-0.48, 0.808, -0.343]]) # CD8,Hema
X3_inv = linalg.inv(X3)
B3 = y @ X3_inv
y3 = B3 @ X3
X4 = np.array([[0.095, 0.258, 0.961], [0.767, 0.576, 0.284], [-0.553, 0.817, -0.165]]) # PanCK,Hema
X4_inv = linalg.inv(X4)
B4 = y @ X4_inv
y4 = B4 @ X4
X5 = np.array([[0.105, 0.758, 0.644], [0.767, 0.576, 0.284], [-0.218, 0.649, -0.729]]) # PDL1,Hema
X5_inv = linalg.inv(X5)
B5 = y @ X5_inv
y5 = B5 @ X5
a = 0
b = 0
rgb_CD8 = np.zeros_like(y) + 1
rgb_PanCK = np.zeros_like(y) + 1
rgb_Hema = np.zeros_like(y) + 1
rgb_PDL1 = np.zeros_like(y) + 1
for i in y:
for j in i:
e = y[a, b, :]
e1 = y1[a, b, :]
e2 = y2[a, b, :]
e3 = y3[a, b, :]
e4 = y4[a, b, :]
e5 = y5[a, b, :]
p1 = mean_squared_error(e, e1)
p2 = mean_squared_error(e, e2)
p3 = mean_squared_error(e, e3)
p4 = mean_squared_error(e, e4)
p5 = mean_squared_error(e, e5)
if p1 > p2 and p3 > p2 and p4 > p2 and p5 > p2:
null = np.zeros_like(B2[:, :, 0])
B2_A = np.stack((B2[:, :, 0], null, null), axis=-1)
B2_B = np.stack((null, B2[:, :, 1], null), axis=-1)
B2_C = np.stack((null, null, B2[:, :, 2]), axis=-1)
conv_matrix = X2
log_rgb21 = B2_A[a][b] @ conv_matrix
rgb_PanCK[a][b] = np.exp(log_rgb21)
log_rgb22 = B2_B[a][b] @ conv_matrix
rgb_PDL1[a][b] = np.exp(log_rgb22)
log_rgb23 = B2_C[a][b] @ conv_matrix
rgb_Hema[a][b] = np.exp(log_rgb23)
elif p2 > p1 and p3 > p1 and p4 > p1 and p5 > p1:
null = np.zeros_like(B1[:, :, 0])
B1_A = np.stack((B1[:, :, 0], null, null), axis=-1)
B1_B = np.stack((null, B1[:, :, 1], null), axis=-1)
B1_C = np.stack((null, null, B1[:, :, 2]), axis=-1)
conv_matrix = X1
log_rgb11 = B1_A[a][b] @ conv_matrix
rgb_CD8[a][b] = np.exp(log_rgb11)
log_rgb12 = B1_B[a][b] @ conv_matrix
rgb_PanCK[a][b] = np.exp(log_rgb12)
log_rgb13 = B1_C[a][b] @ conv_matrix
rgb_Hema[a][b] = np.exp(log_rgb13)
elif p1 > p3 and p2 > p3 and p4 > p3 and p5 > p3:
null = np.zeros_like(B3[:, :, 0])
B3_A = np.stack((B3[:, :, 0], null, null), axis=-1)
B3_B = np.stack((null, B3[:, :, 1], null), axis=-1)
conv_matrix = X3
log_rgb31 = B3_A[a][b] @ conv_matrix
rgb_CD8[a][b] = np.exp(log_rgb31)
log_rgb32 = B3_B[a][b] @ conv_matrix
rgb_Hema[a][b] = np.exp(log_rgb32)
elif p1 > p4 and p2 > p4 and p3 > p4 and p5 > p4:
null = np.zeros_like(B4[:, :, 0])
B4_A = np.stack((B4[:, :, 0], null, null), axis=-1)
B4_B = np.stack((null, B4[:, :, 1], null), axis=-1)
conv_matrix = X4
log_rgb41 = B4_A[a][b] @ conv_matrix
rgb_PanCK[a][b] = np.exp(log_rgb41)
log_rgb42 = B4_B[a][b] @ conv_matrix
rgb_Hema[a][b] = np.exp(log_rgb42)
else:
null = np.zeros_like(B5[:, :, 0])
B5_A = np.stack((B5[:, :, 0], null, null), axis=-1)
B5_B = np.stack((null, B5[:, :, 1], null), axis=-1)
conv_matrix = X5
log_rgb51 = B5_A[a][b] @ conv_matrix
rgb_PDL1[a][b] = np.exp(log_rgb51)
log_rgb42 = B5_B[a][b] @ conv_matrix
rgb_Hema[a][b] = np.exp(log_rgb42)
b = b + 1
b = 0
a = a + 1
fig, axes = plt.subplots(3, 2, figsize=(8, 7), sharex=True, sharey=True)
ax = axes.ravel()
ax[0].imshow(img)
ax[0].set_title("Original image")
rgb_Hema[rgb_Hema < 0] = 0
rgb_Hema[rgb_Hema > 1] = 1
ax[2].imshow(rgb_Hema)
ax[2].set_title("Hematoxylin")
rgb_PanCK[rgb_PanCK < 0] = 0
rgb_PanCK[rgb_PanCK > 1] = 1
ax[3].imshow(rgb_PanCK)
ax[3].set_title("PanCK")
rgb_PDL1[rgb_PDL1 < 0] = 0
rgb_PDL1[rgb_PDL1 > 1] = 1
ax[4].imshow(rgb_PDL1)
ax[4].set_title("PDL1")
rgb_CD8[rgb_CD8 < 0] = 0
rgb_CD8[rgb_CD8 > 1] = 1
ax[5].imshow(rgb_CD8)
ax[5].set_title("CD8")
for a in ax.ravel():
a.axis('off')
fig.tight_layout()
plt.show()