Read OSS

推論エンジンの内側:モデルアーキテクチャとフォワードパス

上級

前提知識

  • 第1回:アーキテクチャ概観
  • 第3回:スケジューラ
  • Transformerアーキテクチャの基礎(アテンション、MLP、埋め込み)
  • Goのリフレクション基礎

推論エンジンの内側:モデルアーキテクチャとフォワードパス

第1〜3回では、CLIからHTTPサーバー、スケジューラへとリクエストが流れ、最終的にランナーサブプロセスが起動されるまでを追いました。今回はそのサブプロセスの中に入ります。--ollama-engine オプション付きでランナーが起動すると、OllamaのGoネイティブ推論エンジンが動き始めます。このエンジンでは、20以上のモデルアーキテクチャが init() 関数によって自己登録され、テンソルの重みはリフレクションを通じてGoの構造体フィールドにバインドされ、フォワードパスはハードウェアに依存しないテンソルバックエンドを使った純粋なGoのメソッド呼び出しとして表現されます。トークンがパイプライン全体をどう流れるか、順を追って見ていきましょう。

アーキテクチャレジストリパターン

OllamaはLlama、Gemma、Qwen、DeepSeekなど、異なるアーキテクチャを持つ多数のモデルをサポートしています。巨大なswitch文で分岐させる代わりに、Goの init() 機構をプラグインレジストリとして活用しています。

model パッケージでは、インターフェースと登録用のマップが定義されています。

model/model.go#L36-L41

type Model interface {
	Forward(ml.Context, input.Batch) (ml.Tensor, error)
	Backend() ml.Backend
	Config() config
}

model/model.go#L101-L110

var models = make(map[string]func(fs.Config) (Model, error))

func Register(name string, f func(fs.Config) (Model, error)) {
	models[name] = f
}

各アーキテクチャは、それぞれの init() 関数内で自身を登録します。たとえば model/models/llama/model.go#L203-L205 の末尾では次のように書かれています。

func init() {
	model.Register("llama", New)
}

これらをつなぐのが model/models/models.go です。このファイルはブランクインポートだけで構成されています。

import (
	_ "github.com/ollama/ollama/model/models/bert"
	_ "github.com/ollama/ollama/model/models/deepseek2"
	_ "github.com/ollama/ollama/model/models/gemma3"
	_ "github.com/ollama/ollama/model/models/llama"
	_ "github.com/ollama/ollama/model/models/qwen3"
	// ... 20+ more
)
classDiagram
    class Model {
        <<interface>>
        +Forward(Context, Batch) Tensor, error
        +Backend() Backend
        +Config() config
    }
    class Base {
        +Backend() Backend
        +Config() config
    }
    class LlamaModel {
        +TokenEmbedding
        +Layers[]
        +OutputNorm
        +Output
        +Forward()
    }
    class Gemma3Model {
        +Forward()
    }
    class Qwen3Model {
        +Forward()
    }
    Model <|.. LlamaModel
    Model <|.. Gemma3Model
    Model <|.. Qwen3Model
    Base --* LlamaModel
    Base --* Gemma3Model
    Base --* Qwen3Model

ollamarunnermodel/models をインポートすると、すべての init() 登録がトリガーされます。GGUFファイルが読み込まれると、そのメタデータに含まれるアーキテクチャ名(例:"llama")をレジストリで検索し、対応するコンストラクタが取得されます。

ヒント: 新しいモデルアーキテクチャを追加するには、model/models/ 以下に新しいパッケージを作成し、Model インターフェースを実装して、init() 内で model.Register() を呼び出し、models.go にブランクインポートを追加するだけです。他のコードを変更する必要はありません。

リフレクションを使ったモデルの初期化

コンストラクタでモデルの構造体が作成されたあと、GGUFファイルからGoの構造体フィールドへテンソルをバインドする必要があります。これは model/model.go#L175-L200populateFields() によって自動的に行われます。

model/models/llama/model.go#L27-L37 のLlamaモデルの構造体を見てみましょう。

type Model struct {
	model.Base
	tokenizer.Tokenizer

	TokenEmbedding *nn.Embedding `gguf:"token_embd"`
	Layers         []Layer       `gguf:"blk"`
	OutputNorm     *nn.RMSNorm   `gguf:"output_norm"`
	Output         *nn.Linear    `gguf:"output,alt:token_embd"`
}

gguf 構造体タグは、GGUFのテンソル名にマッピングされます。token_embd はGGUFファイル内の token_embd.weight テンソルに対応し、blkblk.0.*blk.1.* などに対応します(スライスのサイズはモデルの block_count から自動的に決まります)。Output フィールドの alt:token_embd は「output.weight が存在しない場合は token_embd.weight にフォールバックする」という意味で、入力と出力の埋め込みを共有するモデルでよく使われるパターンです。

model/model.go#L113-L135New() 関数がロード処理全体を取りまとめています。

func New(modelPath string, params ml.BackendParams) (Model, error) {
	b, err := ml.NewBackend(modelPath, params)
	m, err := modelForArch(b.Config())
	base := Base{b: b, config: m.Config()}
	v := reflect.ValueOf(m)
	v.Elem().Set(populateFields(base, v.Elem()))
	return m, nil
}

populateFields() はリフレクションを使って構造体を再帰的に走査します。gguf タグが付いているフィールドを見つけると、タグの階層からテンソルのフルネームを組み立て、b.Get(name) でバックエンドから取得し、フィールドにセットします。SelfAttention のようなネストした構造体では、親のタグプレフィックスが引き継がれます。たとえば blk.0attn_q が組み合わさって blk.0.attn_q.weight になります。

詳細解説:GoによるLlamaのフォワードパス

標準的なTransformerの実装は model/models/llama/model.go#L111-L201 にあります。フォワードパス全体を順に追ってみましょう。

sequenceDiagram
    participant F as Forward()
    participant E as TokenEmbedding
    participant L as Layer[i]
    participant SA as SelfAttention
    participant MLP as MLP
    participant N as OutputNorm
    participant O as Output

    F->>E: Embed tokens → hiddenState [embed_dim, batch]
    loop For each layer i
        F->>L: layer.Forward(hiddenState, positions, cache)
        L->>L: AttentionNorm (RMSNorm)
        L->>SA: SelfAttention.Forward()
        SA->>SA: Q = Query.Forward(hidden)
        SA->>SA: K = Key.Forward(hidden)
        SA->>SA: V = Value.Forward(hidden)
        SA->>SA: Q, K = RoPE(Q, positions), RoPE(K, positions)
        SA->>SA: attention = ScaledDotProduct(Q, K, V, cache)
        SA->>SA: Output.Forward(attention)
        L->>L: residual + attention
        L->>L: MLPNorm (RMSNorm)
        L->>MLP: MLP.Forward()
        MLP->>MLP: Gate.Forward(hidden).SILU(Up.Forward(hidden))
        MLP->>MLP: Down.Forward(...)
        L->>L: residual + mlp_output
    end
    F->>N: OutputNorm.Forward(hiddenState)
    F->>O: Output.Forward(normalized) → logits

model/models/llama/model.go#L183-L201Forward() メソッドです。

func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
	positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
	hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)

	for i, layer := range m.Layers {
		m.Cache.SetLayer(i)
		var outputs ml.Tensor
		if i == len(m.Layers)-1 {
			outputs = batch.Outputs
		}
		hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, &m.Options)
	}

	hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
	return m.Output.Forward(ctx, hiddenState), nil
}

最終レイヤーでの outputs 最適化に注目してください。batch.Outputs はどのポジションのlogitsが必要かを示すテンソルです。最終レイヤーでは hiddenState.Rows(ctx, outputs) によって、MLP計算の前に必要なポジションだけにhidden stateを絞り込みます。最後のトークンのlogitsだけが必要なケースでは、これによって計算量を大幅に削減できます。

model/models/llama/model.go#L119-L138 のSelfAttentionは、Grouped Query Attention(GQA)を伴う標準的なマルチヘッドアテンションのパターンに従っています。numKVHeadsnumHeads より少なくでき、nn.Attention() 関数がブロードキャストを処理します。

model/models/llama/model.go#L150-L153 のMLPは、SiLU活性化関数を使ったゲート付きフィードフォワードネットワークです。Goのコードはわずか1行で表現されます。

func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
	hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
	return mlp.Down.Forward(ctx, hiddenState)
}

ml.Backend 抽象化レイヤー

すべてのテンソル演算は、ml/backend.go#L16-L120 で定義された Backend インターフェースと Context インターフェースを通じて行われます。

type Backend interface {
	Close()
	Load(ctx context.Context, progress func(float32)) error
	Config() fs.Config
	Get(name string) Tensor
	NewContext() Context
	BackendDevices() []DeviceInfo
}

type Context interface {
	Empty(dtype DType, shape ...int) Tensor
	Zeros(dtype DType, shape ...int) Tensor
	FromFloats(s []float32, shape ...int) Tensor
	Forward(...Tensor) Context
	Compute(...Tensor)
	Reserve()
	Close()
}

バックエンドもモデルアーキテクチャと同じ登録パターンを採用しています。ml/backend.go#L78-L84RegisterBackend() がコンストラクタをマップに格納します。現時点では "ggml" バックエンドのみ実装されており、CGo経由でGGMLのCライブラリをラップしています。ただし、インターフェース自体は複数バックエンドへの拡張を想定した設計になっており、将来的にVulkanネイティブやMetalネイティブのバックエンドを追加しても、モデルのコードを一切変更する必要はありません。

ContextCompute() 呼び出しについても重要な点があります。これは演算をその場で即時実行するのではなく、演算をつなぎ合わせて計算グラフを構築し、Compute() が呼ばれた時点でグラフ全体を一括実行します。この仕組みにより、GGMLのグラフ最適化(演算のフュージョンやメモリプランニング)がフォワードパス全体にわたって有効に機能します。

KV Cache:因果的アテンションのメモリ

kvcache/causal.go には、自己回帰的な生成を効率化するKV Cacheが実装されています。新しいトークンを生成するたびに過去のすべてのトークンに対してアテンションを再計算する代わりに、過去のポジションのキーとバリューのテンソルをキャッシュしておきます。

type Causal struct {
	DType         ml.DType
	swaWindowSize int32       // sliding window attention limit
	swaMemorySize int32       // memory retention for partial prefix caching
	cells         []cacheCell // position and sequence metadata per cache slot
	cellRanges    map[int]cellRange
	keys, values  map[int]ml.Tensor  // per-layer storage
	shiftFn       shiftFn            // for RoPE cache shifting
}
flowchart TD
    A["New tokens arrive"] --> B["Find contiguous cache slots"]
    B --> C["Store K,V for new positions"]
    C --> D["Build attention mask"]
    D --> E{"Sliding window?"}
    E -->|"yes"| F["Mask limits to swaWindowSize"]
    E -->|"no"| G["Full causal mask"]
    F --> H["Return K_history, V_history, mask"]
    G --> H
    H --> I["nn.Attention uses cached K,V"]

Sliding Window Attention(SWA)のサポートも特筆すべき点です。Mistralなどのモデルは、各レイヤーが直近のNトークンにしかアテンションを当てないSWAを採用しています。キャッシュはプレフィックスキャッシングのために swaMemorySize 分のトークンをメモリに保持しつつ、アテンションマスクには swaWindowSize 分のトークンのみを含める形で、この仕組みに対応しています。

キャッシュが満杯になった場合、モデルが提供する Shift 関数がRoPEの位置エンコーディングを調整し、コンテキスト全体の表現を捨てることなく、ウィンドウを前方へ「スライド」させます。

トークンサンプリングパイプライン

フォワードパスでlogitsが出力されたあと、sample/samplers.go のサンプリングパイプラインがトークンを選択します。

flowchart TD
    A["Raw logits from model"] --> B{"temperature == 0?"}
    B -->|"yes"| C["Greedy: argmax"]
    B -->|"no"| D["Top-K: keep K highest logits"]
    D --> E["Temperature: scale logits"]
    E --> F["Softmax: convert to probabilities"]
    F --> G["Top-P: keep cumulative prob ≤ p"]
    G --> H["Min-P: keep tokens above threshold"]
    H --> I["Weighted random sample"]
    I --> J{"Grammar sampler?"}
    J -->|"yes"| K["Validate against grammar"]
    J -->|"no"| L["Return token"]
    K --> L

文法制約付きサンプリングには最適化が施されています。まず上位のトークンが文法的に有効かどうかを確認し、有効であればコストの高い全語彙への文法適用をスキップします。上位トークンが棄却された場合にのみ、全トークンに文法制約を適用して再サンプリングします。

ヒント: 決定論的な出力が必要な場合は temperature: 0 を設定しましょう。確率的なサンプリングをすべてバイパスし、純粋なgreedy decodingが使われます。再現性のある非greedyな出力が必要な場合は、オプションで固定の seed を指定してください。

computeBatch ループ:推論の制御塔

runner/ollamarunner/runner.go#L700-L855computeBatch() メソッドは、推論全体を束ねるコアループです。処理の流れは次のとおりです。

  1. バッチ入力の収集 — アクティブなすべてのシーケンスからトークンを集め、単一のバッチにまとめる
  2. フォワードパスの実行activeBatch.ctx.ComputeWithNotify() で計算グラフを実行する
  3. logitsの取り出しactiveBatch.modelOutput.Floats() で生のlogitsを取得する
  4. アクティブな各シーケンスに対して:
    • バッチ出力からそのシーケンスのlogitsスライスを取り出す
    • seq.sampler.Sample(logits) を呼び出してトークンを選択する
    • EOSトークンを検出したらシーケンスを除去する
    • トークンをテキストにデコードして保留中のレスポンスに追加する
    • ストップシーケンスを検出したらテキストを切り詰めてシーケンスを除去する
    • Unicodeが不完全な場合はバッファリングする
    • 完成したテキストをレスポンスチャネルにフラッシュする

runner/ollamarunner/runner.go#L51-L116Sequence 構造体は、リクエストごとの状態を管理します。保留中の入力、生成されたレスポンス、サンプラー、ストップシーケンス、タイミングメトリクスなどがここに含まれます。複数のシーケンスが同一バッチ内でモデルのフォワードパス計算を共有しながら並列実行されます。これが OLLAMA_NUM_PARALLEL の動作原理です。

runner/ollamarunner/runner.go#L222-L298 のマルチモーダル入力処理では、プロンプトを [img-N] タグで分割し、テキスト部分をトークナイズします。その後、モデルの EncodeMultimodal() で画像をエンコードし、入力ストリームへ交互に組み込みます。

次回予告

今回は生の入力から出力トークンが生成されるまで、Goネイティブエンジンの全工程を追いました。第5回では、モデルの保存の仕組みに目を向けます。OCIにインスパイアされたblobシステム、Modelfileのパース、GGUFの変換、そしてモデルを配布するpull/pushレジストリプロトコルを取り上げます。