メインコンテンツへスキップ
  1. 記事一覧/

YouTubeのサムネからホロメンを検出するAIを実装して誰でも使えるようにした

·1839 文字·4 分· loading · loading ·
VTuber SSD PyTorch 機械学習
ゆんたん
著者
ゆんたん
Web3/ML/Data engineer
目次

デモ: https://v.untan.xyz/

YouTubeのサムネイル画像からVTuberグループホロライブのメンバー(ホロメン)を検出する機械学習モデルをSSD300(Single Shot Multibox Detector)で実装した.また,推論を実行して結果を表示するインタフェースをWebサイトとして公開した.

特に工夫した点
#

  • 訓練画像の枚数と各メンバーのラベル数
  • データオーグメンテーション
  • モデルにデータを食わせる順番

訓練画像の枚数と各メンバーのラベル回数
#

YouTubeのサムネイル画像を900枚程度用意した.各画像にはLive 2Dモデルのメンバーが一人以上写っている.各メンバーにつき30回以上ラベル付けされるように画像数を増やした.ラベル付け回数の確認は下図のような画面を実装して確認した.

自作アプリの統計パネル

データオーグメンテーション
#

画像を1/3から3倍のスケールでランダムに拡大・縮小し,拡大した画像に対してはランダムな位置でクロップ,縮小した画像に対してはランダムな位置に配置し黒でパディングする処理を実装した.処理の例を以下に示す.左の画像が元画像,中央と右の画像がそれぞれ拡大・縮小した画像.アスペクト比が変わらないように処理している.

訓練画像の前処理のコードは次のようになっている.

class Transform():
    def __init__(self):
        self.transform = Compose([
            RandomScaleCrop(0.33, 3),
            Resize(300),  # 300x300にリサイズ
            ToTensor(),
        ])

    def __call__(self, image, boxes, labels):
        return self.transform(image, boxes, labels)

モデルにデータを入力する順番
#

「Heの初期化」したモデルにいきなりデータオーグメンテーションした画像を入力すると,ロスが発散してしまい学習が進まない.なので,一旦元画像で100epoch学習させた後に,データオーグメンテーションした画像を学習させるという方法をとった.

訓練画像の収集とラベル付け
#

訓練用画像を910枚,検証用画像を68枚用意した.各画像にはLive 2Dモデルのメンバーが一人以上写っている.34ラベルでラベル付けを行った.ラベル付けアプリはelectronで自作した.

左ペインが訓練用画像,右ペインが学習用画像
自作アプリでラベル付けしている様子

SSDの実装
#

PyTorchで実装した.SSDと前処理クラス・訓練関数の実装は 「つくりながら学ぶ!PyTorchによる発展ディープラーニング」小川雄太郎著を参考に実装した.2.8章にPyTorch >= 1.5では動かないコードがあったが, ここを見て解決した.

学習
#

学習済みモデルを用いずに,「Heの初期値」で初期化した状態から学習した.元画像で100epoch+データオーグメンテーションして400epoch回した.後半の400epochの誤差曲線を図に示す.青線がtrain,黄線がeval.バッチサイズは32.Google Colabで19時間程度要した.学習済みモデルのファイルサイズは108MBになった.

データオーグメンテーションしたときの誤差曲線.400epochあたりで訓練データに対して収束してそう.

推論
#

推論結果の例を以下に示す.推論結果の確信度の閾値は0.6とし,それ以上のボックスを検出したものとした.

うまくいったもの
#

全員正解.
多少ノイズが入っていても正しく検出できている.
右上に小さく写っている白上フブキも検出できている.
髪の色が似ていても区別できている.

うまくいかなかったもの
#

兎田ぺこらが検出されていない.
兎田ぺこら・姫森ルーナ・さくらみこが検出されていない.帽子が検出を妨げているのかもしれない.なぜかしぐれういが桃鈴ねねと誤検出されている.
夜空メル・尾丸ポルカが検出されていない.白上フブキ・姫森ルーナの確信度が低いのは周りと重なって隠れている部分が多いからだと思われる.
犬山たまきを兎田ぺこらと誤検出している.確かに似ているかもしれない.

検出されないもしくは誤検出されるメンバーは,訓練データを増やすと解決すると思われる.領域が小さいために検出されない場合は,訓練データのラベル領域の中央だけ切り取って学習させるなどすると上手くいくと思っている.

Webアプリの実装
#

Google Cloud RunにPyTorchと訓練済みモデルを載せてクラウド関数を作った.Google Cloud Functionsに載せなかった理由は,モデルのファイルサイズがクラウド関数にアップロードできるサイズを超えていたから.

感想
#

  • SSDを用いて画像から2Dモデルを検出するタスクを実践できた.
  • 機械学習モデルをWebアプリとしてデプロイできた.

データオーグメンテーションが上手くいって精度が向上したときはテンションが上がった.