Pytorchで学習済みモデルによる背景削除を実装する

Pytorchで学習済みモデルによる背景削除を実装します.

全国630店舗以上!もみほぐし・足つぼ・ハンドリフレ・クイックヘッドのリラクゼーション店【りらくる】

# 背景削除

Semantic SegmentationによるDeep Image Matting(物体の切り抜き)を実装します。

Semantic Segmentationのカラーマップは、物体には255、背景には0、物体か背景か不明な場所には127と記述します。

# ライブライのインポート

下記のコマンドでライブラリをインポートします。

import numpy as np
import cv2
import matplotlib.pyplot as plt

import torch
import torchvision
from torchvision import transforms

# 画像の読み込み

画像を読み込み、DeepLabv3の入力サイズに合わせてリサイズします。

# 画像の読み込み
image_path = 'man.png'
img = cv2.imread(image_path)
# BGR->RGBへ変換
img = img[...,::-1]
h,w,_ = img.shape
# 画像のリサイズ
img = cv2.resize(img,(320,320))

# 学習済みモデルの読み込み

pythorchで学習済みモデル(21クラスのセマンティックセグメンテーション)を読み込みます。

# GPU or CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 学習済みモデルの読み込み
model = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=True)
model = model.to(device)
model.eval();

# 前処理

前処理として、画像のnumpy配列をtensor型にし、正規化します.また,バッチの次元を追加します.

正規化を実施する場合,本来は各チャネル(BGR)でtensor型の平均・標準偏差を算出して入力する.

しかし,今回は学習済みモデルを読み込んでいるため,事前に算出されている平均・標準偏差を公式ページより引用します.

torchvision.models — PyTorch master documentation (opens new window)

前処理

  • 前処理:transforms.Compose()
  • tensor型へ変換(0~1にスケーリング):transforms.ToTensor()
  • 正規化:transforms.Normalize()

バッチ化

元のテンソルを書き換えずに次元を増やして,バッチ化を記述する.

  • unsqueeze(dim):元のテンソルを書き換えずに、次元を増やしたテンソルを返す
  • unsqueeze_(dim):元のテンソルを書き換えて、テンソルの次元を増やす
# 前処理:tensor型へ変換 + 正規化
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# 入力画像の情報
print("img-info")
print('channel:', img[0][0])
print('shape:', img.shape)
print("="*40)

# 画像の前処理
input_tensor = preprocess(img)
print("preprocess-info")
print('channel:', input_tensor[0][0][0], input_tensor[1][0][0], input_tensor[2][0][0])
print('shape:', input_tensor.shape)
print("="*40)

# バッチ化(ここではバッチサイズ=1)
input_batch = input_tensor.unsqueeze(0).to(device)
print("batch-info")
print('channel:', input_batch[0][0][0][0], input_batch[0][1][0][0], input_batch[0][2][0][0])
print('shape:', input_batch.shape)

# セグメンテーションの実行

下記の手法でセグメンテーションを実行します.

torch.no_gradはテンソルの勾配の計算を不可にし,メモリの消費を減らします.

そして,21クラス中で最大となるクラスを出力させて,背景とそれ以外を区別してマスキングします.

最大となるクラスの出力は,argmax(0)で出力します.

3次元のnumpy型でのargmax(0)では、2次元目・3次元目での各要素で最大値となる1次元の順序を出力します.

import numpy as np

# 3次元
array3 = np.array([[[8, 4, 1 ], 
                    [7, 2, 12], 
                    [3, 2, 9]],
                   [[6, 11,  3], 
                    [5,  9, 10], 
                    [11, 7, 4]]])

print(array3)
print(array3.shape)
print(array3.argmax(0))

# [[[ 8  4  1]
#   [ 7  2 12]
#   [ 3  2  9]]

#  [[ 6 11  3]
#   [ 5  9 10]
#   [11  7  4]]]
# (2, 3, 3)
# [[0 1 1]
#  [0 1 0]
#  [1 1 0]]

下記のコードで背景削除を実行します.

# 勾配計算を不可にし,メモリを節約
with torch.no_grad():
    output = model(input_batch)['out']
    print("output-info")
    print('shape', model(input_batch)['out'].shape)
    output = output[0]
    print('shape', output.shape) # 21クラスでの各要素のスコア
    
# 画像を出力
# 21クラス中で最大スコアのクラスを出力
output = output.argmax(0)
# マスキング画像の生成
mask = output.byte().cpu().numpy()
mask = cv2.resize(mask,(w,h))
mask[mask!=0] = 255
img = cv2.resize(img,(w,h))

# 画像を描画
plt.gray()
plt.figure(figsize=(20,20))
plt.subplot(1,2,1)
plt.imshow(img)
plt.subplot(1,2,2)
plt.imshow(mask);
cv2.imwrite('./mask.png', mask)

# trimapの作成

OpenCVで膨張収縮処理をしてtrimapを生成し、適当な場所に保存します。下の右のようなtrimapが得られます。

# trimapを作成
def gen_trimap(mask,k_size=(5,5),ite=1):
    # 膨張収縮処理
    kernel = np.ones(k_size,np.uint8)
    eroded = cv2.erode(mask,kernel,iterations = ite)
    dilated = cv2.dilate(mask,kernel,iterations = ite)
    # 膨張収縮の差はグレーに着色
    trimap = np.full(mask.shape,128)
    trimap[eroded == 255] = 255
    trimap[dilated == 0] = 0
    return trimap

trimap = gen_trimap(mask,k_size=(10,10),ite=5)
print(trimap[500][300:400])
cv2.imwrite('./trimaps.png', trimap)
plt.figure(figsize=(20,20))
plt.subplot(1,2,1)
plt.imshow(img)
plt.subplot(1,2,2)
plt.imshow(trimap)

# 背景削除の実行

下記の手法で背景削除を実行します.

import numpy as np
import cv2
import matplotlib.pyplot as plt

# 読み込み
img = cv2.imread('./man.png')
img = img[...,::-1]
matte = cv2.imread('./mask.png')
h,w,_ = img.shape
wback = np.full_like(img,255) #white background

# 小数点型に変換
img = img.astype(float)
wback = wback.astype(float)

# 背景削除
matte = matte.astype(float)/255
img = cv2.multiply(img, matte) # 画像の乗算
wback = cv2.multiply(wback, 1.0 - matte) # 画像の乗算
# 白背景とカラー画像を重ねる
outImage = cv2.add(img, wback)
# 画像描画
plt.imshow(outImage/255)

# 画像出力
outImage = outImage[...,::-1]
cv2.imwrite('./remove_bg.png', outImage)

# まとめ

Pytorchで学習済みモデルによる背景削除を実装しました.

次は,論文のコードを実装します.

# 参考サイト

Deep Learningで背景削除をしてみる (opens new window)

poppinace/indexnet_matting (opens new window)

Indices Matter: Learning to Index for Deep Image Matting (opens new window)

Deep Image Matting : 物体の切り抜きを高精度化する機械学習モデル (opens new window)

unsqueeze(dim)とunsqueeze_(dim)の違い【PyTorch】 (opens new window)

【Ptyorch】ToTenserした画像をNormalizationすることに意味はあるのでしょうか (opens new window)

torchvision.models — PyTorch master documentation (opens new window)

PyTorchのtorch.no_grad()とは何か(超個人的メモ) (opens new window)

pytorch_vision_deeplabv3_resnet101 (opens new window)

全国630店舗以上!もみほぐし・足つぼ・ハンドリフレ・クイックヘッドのリラクゼーション店【りらくる】

DockerでQGISがインストールされたコンテナ(Ubuntu20.04)を作成し起動する

DockerでQGISがインストールされたコンテナ(Ubuntu20.04)を作成し起動する

DockerでQGISがインストールされたコンテナ(Ubuntu20.04)を作成し起動します.

Image Mattingで学習済みモデルによる背景削除を実装する

Image Mattingで学習済みモデルによる背景削除を実装する

Image Mattingで学習済みモデルによる背景削除を実装します.