HoloLab DNN Packagesを利用してUnityでクラス分類の推論を実装する

概要

HoloLab DNN PackagesというUnity Sentisをベースにしたディープラーニングによる画像認識を実装したパッケージを公開しました。 この記事ではHoloLab DNN Packagesを利用したクラス分類の推論をUnityアプリに組み込む方法を紹介します。

blog.hololab.co.jp

クラス分類

学習済みモデルの準備

Unity SentisをベースにしたHoloLab DNN PackagesではONNXフォーマットの学習済みモデルが利用できます。 また、HoloLab DNN Packagesのクラス分類パッケージでは以下の基準を満たす一般的なクラス分類モデルを利用することができます。

  • 入力テンソルの形状はNCHW
  • 出力テンソルの形状は1xClasses

基本的にPyTorchからエクスポートされたONNXであれば動作することが期待できます。 まずはPyTorch(torchvision.models)の学習済みモデルを利用されるとよいでしょう。

pytorch.org

パッケージのインポート

Unityのメニューから[Window]>[Package Manager]を開きます。 Package Managerの左上にある[+]>[Add package from git URL]から以下のURLを入力して[Add]ボタンを押します。 ここではHoloLab DNN Packagesのうち「基本パッケージ」と「クラス分類パッケージ」の2つをインポートしていきます。

  1. 基本パッケージ

    まずは基本パッケージをインポートします。

     https://github.com/HoloLabInc/HoloLabDnnPackages.git?path=packages/jp.co.hololab.dnn.base
    

    基本パッケージのインポート

  2. クラス分類パッケージ

    次にクラス分類パッケージをインポートします。

     https://github.com/HoloLabInc/HoloLabDnnPackages.git?path=packages/jp.co.hololab.dnn.classification
    

    クラス分類パッケージのインポート

これで準備は完了です。

クラス分類の推論を実装

それではHoloLab DNN Packagesを利用してクラス分類の推論を実装していきます。

  1. 名前空間を参照する

    Unity SentisとHoloLab DNN Packagesの名前空間を参照します。 ここではクラス分類を利用するのでHoloLab.DNN.Classificationを参照しています。

     using Unity.Sentis;
     using HoloLab.DNN.Classification;
    
  2. クラスを生成する

    クラス分類のClassificationModelクラスを生成します。 2つのコンストラクタが用意されています。 インスペクタから学習済みモデルをアタッチする場合は前者、ファイルパスから読み込む場合は後者を利用します。

     // for Asset
     [SerializeField] private ModelAsset model_asset = default;
     var model = new ClassificationModel(model_asset);
    
     // for File Path
     [SerializeField] private string model_path = "./model.onnx";
     var model = new ClassificationModel(model_path);
    

    必要であれば利用する学習済みモデルに合わせていくつか設定を行います。 学習済みモデルにSoftmaxが含まれていない場合、後処理としてSoftmaxを適用します。

     model.SetApplySoftmax(true);
    

    学習済みモデルがImageNetデータセット以外のデータセットで学習されている場合、前処理として適切なmean/stdを適用します。

     model.SetInputMean(new Vector3(0.485f, 0.456f, 0.406f));
     model.SetInputStd(new Vector3(0.229f, 0.224f, 0.225f));
    
  3. 画像を準備する

    Texture2Dクラスで入力画像を準備します。 WebCamTextureTexture2D.LoadImage()などお好みの方法で取得します。

     // e.g. Load Image
     var data = File.ReadAllBytes("./image.jpg");
     var texure = new Texture2D(1, 1);
     texture.LoadImage(data);
    
  4. 画像をクロップする

    着目する領域がわかっている場合など入力画像を適切にクロップします。 一例として中央領域を短辺x短辺のサイズでクロップするCrop.CenterCrop()が提供されています。

     // e.g. Crop Texture
     var croped_texture = Crop.CenterCrop(texture);
    
  5. クラスに分類する

    モデルと画像が用意できたらクラスに分類します。 クラスの番号は学習に利用したデータセットと照らし合わせてクラスの名前を取得してください。

    最も其れらしいクラス(Top-1)を取得するにはClassificationModel.Classify()に画像を与えます。 クラスの番号と信頼度を得ることができます。

     // Classify Top-1
     (var class_id, var score) = model.Classify(croped_texture);
    

    信頼度が上位K個のクラスを取得するにはClassificationModel.Classify()に画像とKを与えます。 上位K個のクラスの番号と信頼度のリストを得ることができます。

     // Classify Top-3
     var top_k = 3;
     var classes = model.Classify(croped_texture, top_k);
    

まとめ

Unityで簡単にディープラーニングによる画像認識を実装できるHoloLab DNN Packagesを利用してクラス分類の実装方法を紹介しました。 Unityアプリにディープラーニングによる画像認識を使った機能を組み込んで遊んでみてください。