【Python】【画像処理】k-means法で画像を減色

減色された画像の色(代表ベクトル)とグラフの点の色が同じになるように改良した.
matplotlib.pyplotのscatterの色の指定ではc="#000000"のように16進数でRGBの順に指定する必要がある.
OpenCVではGBRの順なので気をつける.

#coding:utf-8
"""
$jupyter notebook
$for python3.x
"""
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import numpy as np
from sklearn.cluster import KMeans
from PIL import Image
from PIL import ImageDraw
from random import randint
from os.path import exists
from os import mkdir
from math import fabs
from random import random
from copy import deepcopy
from math import log
from math import sqrt
import cv2
import numpy as np
from collections import namedtuple
import csv

def calc_centroid(vecs):
    g,b,r = 0,0,0
    for v in vecs:
        g+=v[0][0]
        b+=v[0][1]
        r+=v[0][2]
    g/=len(vecs)
    b/=len(vecs)
    r/=len(vecs)
    return (int(round(g)),int(round(b)),int(round(r)))

def rgb2hex(centroids):
    #gbr->rbgの16進数
    hex_color_code = []
    for v in centroids:
        tmp = []
        hex_code = ""
        for e in v:
            if e<16:
                h = "0"+str(hex(e))[2:]
            else:
                h = str(hex(e))[2:]
            tmp.append(h)
        hex_code = tmp[2]+tmp[1]+tmp[0]
        hex_color_code.append("#"+hex_code)
    return hex_color_code
    

def main():
    #特徴ベクトルの型の定義
    vec = namedtuple('vec_info',['GBR','coordinate'])

    #画像を読み込み
    img = cv2.imread("sample.png")
    
    #画像のサイズ
    height = img.shape[0]
    width  = img.shape[1]

    #特徴ベクトルの総数(画素数)
    num_vec = height*width

    #各画素の特徴ベクトルを取得する
    vectors = []
    for y in range(height):
        for x in range(width):
            BGR = img[y,x]
            v = vec(BGR,(y,x))
            vectors.append(v)
    vectors = np.array(vectors)
    feature_vectors = vectors[:,0]
    #feature_vectors = np.reshape(feature_vectors,(1,len(feature_vectors)))
    tmp = []
    for vec in feature_vectors:
        tmp.append(vec)
    feature_vectors = np.array(list(tmp))

    #特徴空間の点をプロット
    feature_vectors_G = feature_vectors[:,0]
    feature_vectors_B = feature_vectors[:,1]
    feature_vectors_R = feature_vectors[:,2]

    #初期プロットの表示
    fig = plt.figure()
    ax = Axes3D(fig)
    ax.scatter3D(feature_vectors_G,feature_vectors_B,feature_vectors_R)
    plt.show()

    #クラスタの個数
    num_cluster = 50
    #k-means法
    km = KMeans(n_clusters=num_cluster,
                init='random',
                n_init=10,
                max_iter=100,
                tol=1e-04,
                random_state=0
                )
    y_km = km.fit_predict(feature_vectors)#y_kmにクラスタの番号が保存される

    #クラス毎に分類する
    #vectorsとfeature_vectorsとy_kmは互いに添字が一致
    CLUSTER = [[] for _ in range(num_cluster)]
    for i,v in enumerate(vectors):
        for which_cluster in range(num_cluster):
            if y_km[i] == which_cluster:
                CLUSTER[which_cluster].append(v)
    
    #セントロイドを求める
    centroids = [[] for _ in range(len(CLUSTER))]
    for i,vecs in enumerate(CLUSTER):
        centroids[i] = calc_centroid(vecs)

    #セントロイドを16進数コードに変換する
    COLOR = rgb2hex(centroids)

    #クラス毎にプロットする
    #COLOR = ['b','g','r','c','m','y','k']
    CLUSTER_COLOR = [[[],[],[]] for _ in range(num_cluster)]
    fig = plt.figure()
    ax = Axes3D(fig)
    ax.set_xlabel('G')
    ax.set_ylabel('B')
    ax.set_zlabel('R')
    for i,vecs in enumerate(CLUSTER):
        for j,v in enumerate(vecs):
            CLUSTER_COLOR[i][0].append(v[0][0])
            CLUSTER_COLOR[i][1].append(v[0][1])
            CLUSTER_COLOR[i][2].append(v[0][2])
        for k in range(3):
            CLUSTER_COLOR[i][k] = np.array(CLUSTER_COLOR[i][k])
    for i in range(len(CLUSTER_COLOR)):
        ax.scatter(CLUSTER_COLOR[i][0],CLUSTER_COLOR[i][1],CLUSTER_COLOR[i][2],c=COLOR[i])
    plt.show()
    
    #画像を変換していく
    for i,vecs in enumerate(CLUSTER):
        for j,v in enumerate(vecs):
            img[v[1][0],v[1][1]] = centroids[i]
    
    #画像を表示
    cv2.imshow("image",img)
    cv2.waitKey(0)#キーを押すと終了
    cv2.destroyAllWindows()    

    #画像を出力
    file_name = "res_"+str(len(CLUSTER))
    cv2.imwrite(file_name+".jpg",img)

    

if __name__ == '__main__':
    main()

f:id:umashika5555:20170503005804p:plainf:id:umashika5555:20170503005805p:plainf:id:umashika5555:20170503005806p:plainf:id:umashika5555:20170503005807p:plain