PyTorchを使った効率的な画像セグメンテーション:Part 2

Efficient Image Segmentation with PyTorch Part 2

CNNベースのモデル

この記事は、PyTorchを使用してディープラーニング技術を使って画像セグメンテーションをスクラッチからステップバイステップで実装する4部作のうちの2部作目です。この記事では、ベースラインの画像セグメンテーション畳み込みニューラルネットワーク(CNN)モデルの実装に焦点を当てます。

Naresh Singhとの共同執筆

図1:CNNを使用した画像セグメンテーションの実行結果。上から下に順に、入力画像、正解のセグメンテーションマスク、予測されたセグメンテーションマスク。出典:著者

記事のアウトライン

この記事では、Convolutional Neural Network(CNN)と呼ばれるCNNベースのアーキテクチャSegNetを実装し、入力画像の各ピクセルを猫や犬などの対応するペットに割り当てます。どのペットにも属さないピクセルは、背景ピクセルとして分類されます。Oxford Petsデータセットを使用して、PyTorchでこのモデルを構築およびトレーニングし、成功した画像セグメンテーションタスクを実行するために必要なことを学びます。モデル構築プロセスでは、モデルの各層の役割について詳しく説明します。この記事には、さらなる学習のための研究論文や記事への参照がたくさん含まれています。

この記事全体を通じて、このノートブックのコードと結果を参照します。結果を再現する場合は、ノートブックが適切な時間内に完了するように、GPUが必要になります。

このシリーズの記事

このシリーズは、ディープラーニングの実践とビジョンAIについて学び、堅実な理論と実践的な経験を得たいすべての経験レベルの読者向けです。以下の記事で構成される、4部作の予定です。

  1. 概念とアイデア
  2. CNNベースのモデル(この記事)
  3. Depthwise separable convolutions
  4. A Vision Transformer-based model

Convolution、batch-normalization、ReLUブロックは、ビジョンAIにおける聖三位一体です。CNNベースのビジョンAIモデルでよく使用されます。これらの用語のそれぞれは、PyTorchで実装された異なるレイヤーを表します。畳み込み層は、学習フィルターを入力テンソルに対してクロス相関操作を実行する責任を持ちます。バッチ正規化は、バッチ内の要素をゼロ平均と単位分散に合わせ、ReLUは、入力の正の値のみを保持する非線形活性化関数です。

典型的なCNNは、層が重ねられるにつれて、入力空間の寸法を徐々に減らします。空間寸法を減らす動機は、次のセクションで説明します。この減少は、周辺値を最大または平均などの単純な関数を使用してプールすることによって達成されます。Max-Poolingセクションでさらに詳しく説明します。分類問題では、Conv-BN-ReLU-Poolブロックのスタックの後に、入力が目標クラスの1つに属する確率を予測する分類ヘッドが続きます。Semantic Segmentationなどの一部の問題では、ピクセルごとの予測が必要です。そのような場合には、ダウンサンプリングブロックの後にアップサンプリングブロックのスタックが追加され、出力を必要な空間寸法に投影します。アップサンプリングブロックは、PoolingレイヤーをUn-poolingレイヤーで置き換えたConv-BN-ReLU-Unpoolブロックであり、Un-poolingについて詳しく説明します。

次に、畳み込み層の背後にある動機についてさらに詳しく説明します。

畳み込み

畳み込みは、ビジョンAIモデルの基本的な構成要素です。コンピュータビジョンで頻繁に使用され、エッジ検出、画像のぼかしやシャープ化、エンボス、強調などのビジョン変換を実装するために歴史的に使用されてきました。

  1. エッジ検出
  2. 画像のぼかしやシャープ化
  3. エンボス
  4. 強調

畳み込み演算は、2つの行列の要素ごとの乗算と集計です。図2には、畳み込み演算の例が示されています。

図2:畳み込み演算のイラスト。出典:著者

ディープラーニングの文脈では、n次元のパラメータ行列であるフィルタまたはカーネルを入力より大きいサイズの入力領域に対して畳み込みを行います。これは、フィルタを入力上でスライドさせ、対応する領域に畳み込みを適用することによって実現されます。スライドの範囲はストライドパラメータを使用して構成されます。ストライドが1の場合、カーネルは1つのステップでスライドして次の領域に対して操作します。固定されたフィルタを使用する従来のアプローチとは異なり、ディープラーニングではバックプロパゲーションを使用してデータからフィルタを学習します。

では、畳み込みはディープラーニングでどのように支援されていますか?

ディープラーニングでは、畳み込み層を使用して視覚的な特徴を検出します。典型的なCNNモデルには、そのような層のスタックが含まれます。スタックの下部の層は、線やエッジなどの単純な特徴を検出します。スタックを上に移動するにつれて、層はますます複雑な特徴を検出します。スタックの中間層は、線とエッジの組み合わせを検出し、トップ層は車、顔、飛行機などの複雑な形状を検出します。図3は、トレーニング済みモデルのトップとボトムの層の出力を視覚的に示したものです。

図3:畳み込みフィルタが識別するもの。出典:Convolutional Deep Belief Networks for Scalable Unsupervised Learning of Hierarchical Representations

畳み込み層には、入力の小領域に作用する学習可能なフィルタのセットがあり、各領域に対して代表的な出力値を生成します。たとえば、3×3フィルタは、3×3サイズの領域に作用し、領域を代表する値を生成します。フィルタを入力領域に反復適用することにより、出力がスタック内の次の層の入力となります。直感的には、上位層は入力のより広い領域を「見る」ことができます。たとえば、ストライド=1で畳み込み操作を仮定すると、第二層の3×3フィルタは、各セルが入力の3×3サイズの領域に関する情報を含む第一層の出力に作用します。第二層のフィルタは、元の入力の5×5サイズの領域を「見る」ことができます。これを畳み込みの受容野と呼びます。畳み込み層の繰り返し適用により、入力画像の空間次元が徐々に縮小され、フィルタの視野が広がり、複雑な形状を「見る」ことができるようになります。図4は、畳み込みネットワークによる1次元入力の処理を示しています。出力層の要素は、比較的大きな入力チャンクを代表するものです。

図4:カーネルサイズ=3の1次元畳み込みの受容野。3回適用した場合。ストライド=1でパディングなしを仮定する。3回の畳み込みカーネルを連続して適用した後、1つのピクセルが元の入力画像の7ピクセルを見ることができます。出典:著者

畳み込み層がこれらのオブジェクトを検出し、それらの表現を生成できるようになったら、これらの表現を画像分類、画像セグメンテーション、オブジェクト検出および位置決めに使用できます。大まかに言えば、CNNは次の一般的な原則に従います:

  1. 畳み込み層は、出力チャネル数(C)を維持するか、倍増させます。
  2. ストライド=1を使用して空間次元を維持するか、ストライド=2を使用して半分に減らします。
  3. 畳み込みブロックの出力をプールすることで、画像の空間次元を変更することが一般的です。

畳み込み層は、各入力に対してカーネルを独立して適用します。これにより、異なる入力に対して出力が異なる可能性があります。この問題に対処するために、通常、畳み込み層の後にバッチ正規化層が続きます。次のセクションで、その役割を詳しく理解しましょう。

バッチ正規化

バッチ正規化層は、バッチ入力のチャネル値を平均値が0で分散が1になるように正規化します。この正規化は、各チャネルに対して独立に実行され、入力のチャネル値が同じ分布を持つことを保証します。バッチ正規化の利点は以下の通りです:

  1. 勾配が小さすぎるのを防ぎ、トレーニングプロセスを安定化させます。
  2. タスクの収束をより速く達成します。

畳み込み層のスタックだけがあれば、線形変換の連鎖効果のため、単一の畳み込み層ネットワークに等しくなります。言い換えると、線形変換のシーケンスは同じ効果を持つ単一の線形変換で置き換えることができます。直感的に言うと、定数k₁でベクトルを乗算した後に別の定数k₂で乗算すると、定数k₁k₂で乗算したのと同じです。したがって、ネットワークが現実的に深くなるには、彼らの崩壊を防ぐために非線形性が必要です。非線形性として頻繁に使用されるReLUについて次のセクションで説明します。

ReLU

ReLUは、最低入力値を0以上にクリップする簡単な非線形活性化関数です。また、出力を0以上に制限することで、勾配の消失問題にも役立ちます。ReLU層は、ダウンスケーリングサブネットワークで空間次元を縮小するためのプーリング層、または空間次元をアップスケーリングするためのアンプーリング層に続いて通常使用されます。詳細については、次のセクションで説明します。

プーリング

プーリング層は、入力の空間次元を縮小するために使用されます。ストライド=2でプーリングすると、空間次元(H、W)の入力を(H/2、W/2)に変換します。最大プーリングは、ディープCNNで最も一般的に使用されるプーリング技術です。これは、(たとえば)2×2のグリッド内で最大値を出力に投影します。その後、畳み込みと同様にストライドに基づいて2×2プーリングウィンドウを次のセクションにスライドします。これをストライド=2で繰り返すと、入力の高さと幅の半分の出力が得られます。もう1つの一般的に使用されるプーリング層は、最大値ではなく平均値を計算する平均プーリング層です。

プーリング層の逆は、アンプーリング層と呼ばれます。これは、(H、W)次元の入力を受け取り、ストライド=2の場合、(2H、2W)次元の出力に変換します。この変換の必要な要素は、出力セクション内の入力値を投影する場所を選択することです。これを行うには、前の最大プーリング操作で生成されたmax-unpooling-index-mapが必要です。図5は、プーリングとアンプーリング操作の例を示しています。

Figure 5: Max pooling and un-pooling. Source: DeepPainter: Painter Classification Using Deep Convolutional Autoencoders

最大プーリングは、非線形活性化関数の1種と見なすことができます。ただし、ReLUなどの非線形性を置き換えるために使用すると、ネットワークのパフォーマンスに影響すると報告されています。対照的に、平均プーリングは、すべての入力を使用して、その入力の線形結合である出力を生成するため、非線形関数としては考えられません。

これで、深いCNNのすべての基本的な構成要素を網羅しました。次に、それらを組み合わせてモデルを作成しましょう。この演習のために選択したモデルはSegNetと呼ばれます。次に、これについて説明します。

SegNet:CNNベースのモデル

SegNetは、この記事で説明した基本ブロックに基づく深いCNNモデルです。2つの異なるセクションがあります。下部セクション、またはエンコーダーとも呼ばれるものは、入力をダウンサンプリングして入力を表す特徴を生成します。上部デコーダーセクションは、特徴をアップサンプリングしてピクセルごとの分類を作成します。各セクションは、Conv-BN-ReLUブロックのシーケンスで構成されています。これらのブロックは、ダウンサンプリングパスとアップサンプリングパスの両方にプーリングまたはアンプーリングレイヤーを組み込んでいます。図6は、レイヤーの配置を詳しく示しています。SegNetは、エンコーダーでの最大プーリング操作からプーリングインデックスを使用して、デコーダーでの最大アンプーリング操作でコピーする値を決定します。活性化テンソルの各要素は4バイト(32ビット)ですが、2×2の正方形内のオフセットは、実行中のモデルで格納する必要があるため、2ビットだけで格納できます。これは、これらの活性化(またはSegNetの場合はインデックス)が実行中のモデルで格納する必要があるため、使用するメモリの効率がよくなります。

図6:画像セグメンテーションのためのSegNetモデルアーキテクチャ。出典:SegNet:画像セグメンテーションのためのディープコンボリューショナルエンコーダーデコーダーアーキテクチャ

このノートブックには、このセクションのすべてのコードが含まれています。

このモデルには、15.27Mのトレーニング可能なパラメータがあります。

モデルのトレーニングおよび検証中には、次の構成が使用されました。

  1. トレーニングセットにランダムな水平反転とカラージッターデータ拡張を適用して、過学習を防止します
  2. 画像は、アスペクト比を保持しないリサイズ操作で128×128ピクセルにリサイズされます
  3. 画像に入力正規化は適用されず、代わりにバッチ正規化層がモデルの最初の層として使用されます
  4. モデルは、Adamオプティマイザを使用して20エポックトレーニングされ、LRは0.001で、StepLRスケジューラが7エポックごとに学習率を0.7で減衰させます
  5. クロスエントロピー損失関数を使用して、ピクセルをペット、背景、またはペットの境界に属するものとして分類します

モデルは、20回のトレーニングエポック後に検証精度が88.28%に達しました。

私たちは、21枚の検証セットの画像のセグメンテーションマスクを予測する方法を学習しているSegNetモデルを示すgifをプロットしました。

図6:SegNetモデルが検証セットの21枚の画像のセグメンテーションマスクを予測する方法を示すgif。出典:著者

すべての検証メトリックの定義については、このシリーズの第1部を参照してください。

Tensorflowを使用したペット画像のセグメンテーションのための完全畳み込みモデルを見たい場合は、Efficient Deep Learning BookのChapter-4:Efficient Architecturesを参照してください。

モデル学習からの観察

トレーニング済みモデルが各エポックの後に行う予測の開発に基づいて、以下を観察できます。

  1. モデルは、トレーニングエポック1でも、出力が画像のペットの正しい範囲内に見える程度に学習できます
  2. ボーダーピクセルをセグメンテーションするのは難しいため、等重量の損失関数を使用しており、成功(または失敗)を均等に扱うため、ボーダーピクセルを間違えてもモデルにとって損失は少ないです。この問題を修正するために試すことができる戦略を調べてください。Focal Lossを使用してパフォーマンスを確認してください
  3. モデルは、20回のトレーニングエポック後でも学習し続けているようです。これは、モデルをより長くトレーニングすれば、検証精度を向上させることができる可能性があることを示唆しています
  4. いくつかのグラウンドトゥルーラベル自体を理解するのが難しい場合があります。たとえば、中央行の犬のマスクの最後の列には、犬の体が植物によって隠されている領域の多くの不明なピクセルがあります。これはモデルにとって非常に難しいため、このような例では常に精度の損失が予想されます。ただし、これはモデルがうまく機能していないことを意味するわけではありません。全体的な検証メトリックだけでなく、予測を確認してモデルの動作を把握する必要があります。
図7:多くの不明なピクセルを含むグラウンドトゥルーセグメンテーションマスクの例。どのMLモデルにとっても非常に難しい入力です。出典:著者

結論

このシリーズの第2部では、ビジョンAIのための深層CNNの基本的な構築ブロックについて学びました。PyTorchでScratchからSegNetモデルを実装し、モデルが21の検証画像でトレーニングされた後の成功を視覚化しました。これにより、モデルが出力を正しい範囲内に見えるように十分に早く学ぶことができる速度を理解できます。この場合、最初のトレーニングエポックでも、実際のセグメンテーションマスクに大まかに似たセグメンテーションマスクを見ることができます!

このシリーズの次の部分では、モデルをオンデバイス推論のために最適化し、トレーニング可能なパラメータ(つまりモデルサイズ)を減らしつつ、検証精度をほぼ同じに保ちます。

さらに読む

畳み込みについてはこちらを参照してください:

  1. Joseph Redmonによるワシントン大学での講義「コンピュータビジョンの秘密」は、畳み込みに関する優れたビデオセットを持っており(特に第4章、第5章、および第13章)、非常におすすめです。
  2. 深層学習における畳み込み算術のガイド(強くお勧めします)
  3. https://towardsdatascience.com/computer-vision-convolution-basics-2d0ae3b79346
  4. PyTorchのConv2dレイヤー(ドキュメント)
  5. 畳み込みは何を学ぶのか?
  6. 畳み込み可視化ツール

バッチ正規化についてはこちらを参照してください:

  1. バッチ正規化:Wikipedia
  2. バッチ正規化:Machine learning mastery
  3. PyTorchのBatchNorm2dレイヤーはこちらです。

活性化関数とReLUについてはこちらを参照してください:

  1. ReLU:Machine learning mastery
  2. ReLU:Wikipedia
  3. ReLU:Quora
  4. PyTorchのReLU API

We will continue to update VoAGI; if you have any questions or suggestions, please contact us!

Share:

Was this article helpful?

93 out of 132 found this helpful

Discover more

AI研究

ペンシルベニア大学の研究者たちは、腎臓のマッチングを改善し、移植片の失敗リスクを減らすための機械学習戦略の開発を行っています

AIは、遺伝子の特定の変異を分析することにより、腎移植のリスクを最小化することで、人々に希望の光をもたらしています。腎...

データサイエンス

人工知能は人間を置き換えるのか?

はじめに 皆さんはご存知のとおり、AIは飛躍的な進歩を遂げ、科学者や一般の人々の想像をとらえています。ニュースやソーシャ...

機械学習

ソフトウェアエンジニアリングの未来 生成AIによる変革

この記事では、Generative AI(およびLarge Language Models)の出現と、それがソフトウェアエンジニアリングの将来をどのよ...

データサイエンス

生成AIモデル:マーチャンダイジング分析のユーザーエクスペリエンス向上

私たちのデータプラットフォームで利用可能なデータについて、ビジネスユーザーが何でも尋ねることができるように、生成型AI...

機械学習

このAIニュースレターは、あなたが必要とするすべてです#62

今週は、METAのコーディングモデルの開発とOpenAIの新しいファインチューニング機能の進展を見てきましたMetaは、Code LLaMA...

機械学習

「人工知能(AI)におけるアナログコンピュータの使用」

アナログコンピュータは、電気の電圧、機械の動き、または流体の圧力などの物理的な量を、解決すべき問題に対応する量に類似...