Input iThe token's embedding vector โ a dense numerical representation encoding its meaning and context. These same values are fed to both the gating network (to decide routing) and the active experts (to compute on).
Token embedding
def sort(arr):
768-dim vector (simplified)
W_g ยท x
Gating network iA single weight matrix W_g multiplied by the embedding produces one score (logit) per expert. TopK zeroes out all but the top k scores; softmax converts the survivors into probabilities that sum to 1. The entire gate adds negligible compute overhead.
Gate
Linear โ TopK โ Softmax
H = W_g ยท x + noise
G = softmax(topK(H, 2))
G = softmax(topK(H, 2))
weights
Experts iOnly the top-k experts receive the token and run their feed-forward computation. The rest stay completely dormant โ no matrix multiplications, no memory bandwidth used. Active experts run in parallel and their outputs are blended by the routing weights.
Routing probabilities โ raw logits vs. softmax weights iRaw logits are the gate scores before normalization โ higher = stronger routing preference. After TopK masking, softmax converts surviving logits to probabilities. Non-selected experts drop to exactly 0% because they're excluded before softmax, not merely down-weighted.
Expert
Weight (softmax)
Logit
Prob
Output formula iThe final token representation is a weighted sum: each active expert's output vector multiplied by its routing probability, then summed. The result is a single vector with the same shape as the input, passed to the next transformer layer.
Output = wโยทEโ(x) + wโยทEโ(x)
Active experts: