推論エンジンの内側:モデルアーキテクチャとフォワードパス
前提知識
- ›第1回:アーキテクチャ概観
- ›第3回:スケジューラ
- ›Transformerアーキテクチャの基礎(アテンション、MLP、埋め込み)
- ›Goのリフレクション基礎
推論エンジンの内側:モデルアーキテクチャとフォワードパス
第1〜3回では、CLIからHTTPサーバー、スケジューラへとリクエストが流れ、最終的にランナーサブプロセスが起動されるまでを追いました。今回はそのサブプロセスの中に入ります。--ollama-engine オプション付きでランナーが起動すると、OllamaのGoネイティブ推論エンジンが動き始めます。このエンジンでは、20以上のモデルアーキテクチャが init() 関数によって自己登録され、テンソルの重みはリフレクションを通じてGoの構造体フィールドにバインドされ、フォワードパスはハードウェアに依存しないテンソルバックエンドを使った純粋なGoのメソッド呼び出しとして表現されます。トークンがパイプライン全体をどう流れるか、順を追って見ていきましょう。
アーキテクチャレジストリパターン
OllamaはLlama、Gemma、Qwen、DeepSeekなど、異なるアーキテクチャを持つ多数のモデルをサポートしています。巨大なswitch文で分岐させる代わりに、Goの init() 機構をプラグインレジストリとして活用しています。
model パッケージでは、インターフェースと登録用のマップが定義されています。
type Model interface {
Forward(ml.Context, input.Batch) (ml.Tensor, error)
Backend() ml.Backend
Config() config
}
var models = make(map[string]func(fs.Config) (Model, error))
func Register(name string, f func(fs.Config) (Model, error)) {
models[name] = f
}
各アーキテクチャは、それぞれの init() 関数内で自身を登録します。たとえば model/models/llama/model.go#L203-L205 の末尾では次のように書かれています。
func init() {
model.Register("llama", New)
}
これらをつなぐのが model/models/models.go です。このファイルはブランクインポートだけで構成されています。
import (
_ "github.com/ollama/ollama/model/models/bert"
_ "github.com/ollama/ollama/model/models/deepseek2"
_ "github.com/ollama/ollama/model/models/gemma3"
_ "github.com/ollama/ollama/model/models/llama"
_ "github.com/ollama/ollama/model/models/qwen3"
// ... 20+ more
)
classDiagram
class Model {
<<interface>>
+Forward(Context, Batch) Tensor, error
+Backend() Backend
+Config() config
}
class Base {
+Backend() Backend
+Config() config
}
class LlamaModel {
+TokenEmbedding
+Layers[]
+OutputNorm
+Output
+Forward()
}
class Gemma3Model {
+Forward()
}
class Qwen3Model {
+Forward()
}
Model <|.. LlamaModel
Model <|.. Gemma3Model
Model <|.. Qwen3Model
Base --* LlamaModel
Base --* Gemma3Model
Base --* Qwen3Model
ollamarunner が model/models をインポートすると、すべての init() 登録がトリガーされます。GGUFファイルが読み込まれると、そのメタデータに含まれるアーキテクチャ名(例:"llama")をレジストリで検索し、対応するコンストラクタが取得されます。
ヒント: 新しいモデルアーキテクチャを追加するには、
model/models/以下に新しいパッケージを作成し、Modelインターフェースを実装して、init()内でmodel.Register()を呼び出し、models.goにブランクインポートを追加するだけです。他のコードを変更する必要はありません。
リフレクションを使ったモデルの初期化
コンストラクタでモデルの構造体が作成されたあと、GGUFファイルからGoの構造体フィールドへテンソルをバインドする必要があります。これは model/model.go#L175-L200 の populateFields() によって自動的に行われます。
model/models/llama/model.go#L27-L37 のLlamaモデルの構造体を見てみましょう。
type Model struct {
model.Base
tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
}
gguf 構造体タグは、GGUFのテンソル名にマッピングされます。token_embd はGGUFファイル内の token_embd.weight テンソルに対応し、blk は blk.0.*、blk.1.* などに対応します(スライスのサイズはモデルの block_count から自動的に決まります)。Output フィールドの alt:token_embd は「output.weight が存在しない場合は token_embd.weight にフォールバックする」という意味で、入力と出力の埋め込みを共有するモデルでよく使われるパターンです。
model/model.go#L113-L135 の New() 関数がロード処理全体を取りまとめています。
func New(modelPath string, params ml.BackendParams) (Model, error) {
b, err := ml.NewBackend(modelPath, params)
m, err := modelForArch(b.Config())
base := Base{b: b, config: m.Config()}
v := reflect.ValueOf(m)
v.Elem().Set(populateFields(base, v.Elem()))
return m, nil
}
populateFields() はリフレクションを使って構造体を再帰的に走査します。gguf タグが付いているフィールドを見つけると、タグの階層からテンソルのフルネームを組み立て、b.Get(name) でバックエンドから取得し、フィールドにセットします。SelfAttention のようなネストした構造体では、親のタグプレフィックスが引き継がれます。たとえば blk.0 と attn_q が組み合わさって blk.0.attn_q.weight になります。
詳細解説:GoによるLlamaのフォワードパス
標準的なTransformerの実装は model/models/llama/model.go#L111-L201 にあります。フォワードパス全体を順に追ってみましょう。
sequenceDiagram
participant F as Forward()
participant E as TokenEmbedding
participant L as Layer[i]
participant SA as SelfAttention
participant MLP as MLP
participant N as OutputNorm
participant O as Output
F->>E: Embed tokens → hiddenState [embed_dim, batch]
loop For each layer i
F->>L: layer.Forward(hiddenState, positions, cache)
L->>L: AttentionNorm (RMSNorm)
L->>SA: SelfAttention.Forward()
SA->>SA: Q = Query.Forward(hidden)
SA->>SA: K = Key.Forward(hidden)
SA->>SA: V = Value.Forward(hidden)
SA->>SA: Q, K = RoPE(Q, positions), RoPE(K, positions)
SA->>SA: attention = ScaledDotProduct(Q, K, V, cache)
SA->>SA: Output.Forward(attention)
L->>L: residual + attention
L->>L: MLPNorm (RMSNorm)
L->>MLP: MLP.Forward()
MLP->>MLP: Gate.Forward(hidden).SILU(Up.Forward(hidden))
MLP->>MLP: Down.Forward(...)
L->>L: residual + mlp_output
end
F->>N: OutputNorm.Forward(hiddenState)
F->>O: Output.Forward(normalized) → logits
model/models/llama/model.go#L183-L201 の Forward() メソッドです。
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
for i, layer := range m.Layers {
m.Cache.SetLayer(i)
var outputs ml.Tensor
if i == len(m.Layers)-1 {
outputs = batch.Outputs
}
hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, &m.Options)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
return m.Output.Forward(ctx, hiddenState), nil
}
最終レイヤーでの outputs 最適化に注目してください。batch.Outputs はどのポジションのlogitsが必要かを示すテンソルです。最終レイヤーでは hiddenState.Rows(ctx, outputs) によって、MLP計算の前に必要なポジションだけにhidden stateを絞り込みます。最後のトークンのlogitsだけが必要なケースでは、これによって計算量を大幅に削減できます。
model/models/llama/model.go#L119-L138 のSelfAttentionは、Grouped Query Attention(GQA)を伴う標準的なマルチヘッドアテンションのパターンに従っています。numKVHeads は numHeads より少なくでき、nn.Attention() 関数がブロードキャストを処理します。
model/models/llama/model.go#L150-L153 のMLPは、SiLU活性化関数を使ったゲート付きフィードフォワードネットワークです。Goのコードはわずか1行で表現されます。
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
ml.Backend 抽象化レイヤー
すべてのテンソル演算は、ml/backend.go#L16-L120 で定義された Backend インターフェースと Context インターフェースを通じて行われます。
type Backend interface {
Close()
Load(ctx context.Context, progress func(float32)) error
Config() fs.Config
Get(name string) Tensor
NewContext() Context
BackendDevices() []DeviceInfo
}
type Context interface {
Empty(dtype DType, shape ...int) Tensor
Zeros(dtype DType, shape ...int) Tensor
FromFloats(s []float32, shape ...int) Tensor
Forward(...Tensor) Context
Compute(...Tensor)
Reserve()
Close()
}
バックエンドもモデルアーキテクチャと同じ登録パターンを採用しています。ml/backend.go#L78-L84 の RegisterBackend() がコンストラクタをマップに格納します。現時点では "ggml" バックエンドのみ実装されており、CGo経由でGGMLのCライブラリをラップしています。ただし、インターフェース自体は複数バックエンドへの拡張を想定した設計になっており、将来的にVulkanネイティブやMetalネイティブのバックエンドを追加しても、モデルのコードを一切変更する必要はありません。
Context の Compute() 呼び出しについても重要な点があります。これは演算をその場で即時実行するのではなく、演算をつなぎ合わせて計算グラフを構築し、Compute() が呼ばれた時点でグラフ全体を一括実行します。この仕組みにより、GGMLのグラフ最適化(演算のフュージョンやメモリプランニング)がフォワードパス全体にわたって有効に機能します。
KV Cache:因果的アテンションのメモリ
kvcache/causal.go には、自己回帰的な生成を効率化するKV Cacheが実装されています。新しいトークンを生成するたびに過去のすべてのトークンに対してアテンションを再計算する代わりに、過去のポジションのキーとバリューのテンソルをキャッシュしておきます。
type Causal struct {
DType ml.DType
swaWindowSize int32 // sliding window attention limit
swaMemorySize int32 // memory retention for partial prefix caching
cells []cacheCell // position and sequence metadata per cache slot
cellRanges map[int]cellRange
keys, values map[int]ml.Tensor // per-layer storage
shiftFn shiftFn // for RoPE cache shifting
}
flowchart TD
A["New tokens arrive"] --> B["Find contiguous cache slots"]
B --> C["Store K,V for new positions"]
C --> D["Build attention mask"]
D --> E{"Sliding window?"}
E -->|"yes"| F["Mask limits to swaWindowSize"]
E -->|"no"| G["Full causal mask"]
F --> H["Return K_history, V_history, mask"]
G --> H
H --> I["nn.Attention uses cached K,V"]
Sliding Window Attention(SWA)のサポートも特筆すべき点です。Mistralなどのモデルは、各レイヤーが直近のNトークンにしかアテンションを当てないSWAを採用しています。キャッシュはプレフィックスキャッシングのために swaMemorySize 分のトークンをメモリに保持しつつ、アテンションマスクには swaWindowSize 分のトークンのみを含める形で、この仕組みに対応しています。
キャッシュが満杯になった場合、モデルが提供する Shift 関数がRoPEの位置エンコーディングを調整し、コンテキスト全体の表現を捨てることなく、ウィンドウを前方へ「スライド」させます。
トークンサンプリングパイプライン
フォワードパスでlogitsが出力されたあと、sample/samplers.go のサンプリングパイプラインがトークンを選択します。
flowchart TD
A["Raw logits from model"] --> B{"temperature == 0?"}
B -->|"yes"| C["Greedy: argmax"]
B -->|"no"| D["Top-K: keep K highest logits"]
D --> E["Temperature: scale logits"]
E --> F["Softmax: convert to probabilities"]
F --> G["Top-P: keep cumulative prob ≤ p"]
G --> H["Min-P: keep tokens above threshold"]
H --> I["Weighted random sample"]
I --> J{"Grammar sampler?"}
J -->|"yes"| K["Validate against grammar"]
J -->|"no"| L["Return token"]
K --> L
文法制約付きサンプリングには最適化が施されています。まず上位のトークンが文法的に有効かどうかを確認し、有効であればコストの高い全語彙への文法適用をスキップします。上位トークンが棄却された場合にのみ、全トークンに文法制約を適用して再サンプリングします。
ヒント: 決定論的な出力が必要な場合は
temperature: 0を設定しましょう。確率的なサンプリングをすべてバイパスし、純粋なgreedy decodingが使われます。再現性のある非greedyな出力が必要な場合は、オプションで固定のseedを指定してください。
computeBatch ループ:推論の制御塔
runner/ollamarunner/runner.go#L700-L855 の computeBatch() メソッドは、推論全体を束ねるコアループです。処理の流れは次のとおりです。
- バッチ入力の収集 — アクティブなすべてのシーケンスからトークンを集め、単一のバッチにまとめる
- フォワードパスの実行 —
activeBatch.ctx.ComputeWithNotify()で計算グラフを実行する - logitsの取り出し —
activeBatch.modelOutput.Floats()で生のlogitsを取得する - アクティブな各シーケンスに対して:
- バッチ出力からそのシーケンスのlogitsスライスを取り出す
seq.sampler.Sample(logits)を呼び出してトークンを選択する- EOSトークンを検出したらシーケンスを除去する
- トークンをテキストにデコードして保留中のレスポンスに追加する
- ストップシーケンスを検出したらテキストを切り詰めてシーケンスを除去する
- Unicodeが不完全な場合はバッファリングする
- 完成したテキストをレスポンスチャネルにフラッシュする
runner/ollamarunner/runner.go#L51-L116 の Sequence 構造体は、リクエストごとの状態を管理します。保留中の入力、生成されたレスポンス、サンプラー、ストップシーケンス、タイミングメトリクスなどがここに含まれます。複数のシーケンスが同一バッチ内でモデルのフォワードパス計算を共有しながら並列実行されます。これが OLLAMA_NUM_PARALLEL の動作原理です。
runner/ollamarunner/runner.go#L222-L298 のマルチモーダル入力処理では、プロンプトを [img-N] タグで分割し、テキスト部分をトークナイズします。その後、モデルの EncodeMultimodal() で画像をエンコードし、入力ストリームへ交互に組み込みます。
次回予告
今回は生の入力から出力トークンが生成されるまで、Goネイティブエンジンの全工程を追いました。第5回では、モデルの保存の仕組みに目を向けます。OCIにインスパイアされたblobシステム、Modelfileのパース、GGUFの変換、そしてモデルを配布するpull/pushレジストリプロトコルを取り上げます。