From 878ee6ab7efccdcdf4d9c0cf7d848abc97993499 Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 5 May 2026 14:13:24 +0000 Subject: [PATCH 1/8] docs(notes): Add TPU deep dive skeleton Set up the chapter structure for notes/tpu-deep-dive.cn.md based on the agreed 22-topic outline, organized by abstraction layer: hardware (Part I), XLA compiler/runtime (Part II), inference adaptation (Part III), cluster orchestration (Part IV), and system-level GPU comparison (Part V), plus three appendices for trade-off lookup, parameter list, and GPU-equivalence glossary. Each chapter has a single-line topic statement and TODO bullets that map directly to points in the source conversation; no prose has been written yet, so this commit contains structure only. Co-authored-by: Claude --- notes/tpu-deep-dive.cn.md | 259 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 259 insertions(+) create mode 100644 notes/tpu-deep-dive.cn.md diff --git a/notes/tpu-deep-dive.cn.md b/notes/tpu-deep-dive.cn.md new file mode 100644 index 0000000..a12bff0 --- /dev/null +++ b/notes/tpu-deep-dive.cn.md @@ -0,0 +1,259 @@ +# TPU 深入笔记:从单芯片到生产集群 + +> **状态**:骨架阶段,章节结构已确认,正文内容待源对话可读后填入。 +> **目的**:给自己留一份从硬件原理一路到推理 / 集群层适配的端到端参考,重点在**为什么这么设计**而非术语解释。 +> **对照对象**:相同问题在 GPU 体系下是怎么做的——每章末尾会有 `↔ GPU` 小节。 + +--- + +## Part I — 硬件层:芯片、互联、封装 + +### 1. 单芯片:MXU、VPU、SPU 与脉动阵列 + +一句话:TPU 把矩阵乘法做成专用电路(MXU),用脉动阵列在芯片内部「传递数据」而不是「传递地址」,省掉了缓存层级。 + +- TODO:MXU / VPU / SPU 三者的角色分工 +- TODO:CISC 指令集的取舍 +- TODO:脉动阵列在矩阵乘法里逐 cycle 的数据流转图(用 ASCII 或 Mermaid) +- TODO:↔ GPU(SM / Tensor Core / 寄存器堆 的对比) + +### 2. 多芯片互联:ICI 与 3D Torus + +一句话:单芯片很强,但真正让 TPU 成为 TPU 的是片间网络——ICI 是物理层、3D Torus 是逻辑拓扑。 + +- TODO:ICI 的位置(链路层、带宽、延迟数量级) +- TODO:3D Torus 的几何(每芯片 6 个邻居、X/Y/Z 三轴) +- TODO:为什么选 Torus 而不是 fat-tree +- TODO:↔ GPU(NVLink / NVSwitch / IB 的位置) + +### 3. OCS:光路交换让拓扑可重配 + +一句话:TPU Pod 在物理层是 star(每机架到 OCS 的若干光纤),逻辑层却被 OCS 接成 torus;MEMS 小镜子在微秒级切换光路。 + +- TODO:MEMS 镜子的工作原理(机械翻转 vs 半导体) +- TODO:4×4×4 机架为什么有 96 根外出光纤(角×3 + 棱×2 + 面心×1 的几何约束) +- TODO:OCS 切分粒度(机架 / Block 而非单芯片)的工程原因 +- TODO:↔ GPU(packet-switched IB 的世界对应不到这个) + +### 4. 集合通信:Ring All-Reduce 在 3D Torus 上的降维 + +一句话:3D Torus 上做 All-Reduce 的诀窍是把它拆成 X/Y/Z 三个 1D 圈分别做,再合起来——这是几何换算法。 + +- TODO:1D 圈的 All-Reduce 步骤 +- TODO:3D 降维的合成顺序与时间复杂度 +- TODO:「圈是 1D 但连线是 3D,剩下的线不就闲置了?」这个问题的回答 +- TODO:NUCA:跨机架大圈中,铜缆 vs 光纤路径的延迟异构 +- TODO:XLA 怎么把 TP(密集通信)映射到短边铜缆、把 DP(稀疏同步)映射到长光环 +- TODO:↔ GPU(NCCL ring/tree、NVLink 全互联下的差异) + +### 5. Host ↔ TPU:PCIe、NUMA 与 multi-host slice + +一句话:TPU 不是独立机器,是挂在 CPU host 上的 PCIe 设备;slice 一旦跨 host,就是天然的分布式系统。 + +- TODO:1:4 / 1:8 的 CPU:TPU 比例从哪来 +- TODO:multi-host slice 的 N:N 映射、SPMD 进程是怎么起的 +- TODO:NUMA 1:4 PCIe 劈管、单 NUMA VM 的切分策略、XLA 自动绑核 +- TODO:LWS / JobSet 在 K8s 层怎么把这套 N:N 表达出来 +- TODO:↔ GPU(HGX 8 卡 + NUMA、GDS 直连存储) + +### 6. 先进封装:算力面积 vs 带宽周长 + +一句话:Die 上面积决定算力(FLOPs),周长决定带宽(HBM 接口);2.5D / 3D 封装是在调和两者的根本张力。 + +- TODO:硅中介层(CoWoS 类)的角色 +- TODO:TSV(Through-Silicon Via)解决了什么 +- TODO:HBM 带宽 vs 算力增长的不对称(这条会在 Ch 19 再呼应) +- TODO:↔ GPU(H100 / B100 的封装方案) + +--- + +## Part II — 编译与运行时:XLA + +### 7. XLA 编译模型:静态调度做减法 + +一句话:XLA 的核心打法是「把所有不确定性在编译期消掉」——算子融合、静态 padding、软件流水线、VLIW 指令包都是这个思路的不同切面。 + +- TODO:HLO → LLO 的层次 +- TODO:算子融合的边界(什么能融、什么不能融) +- TODO:静态 padding 的代价与收益 +- TODO:软件流水线在 systolic array 上的体现 +- TODO:VLIW 指令包的五槽:DMA / MXU / VPU / SPU / **ICI**——为什么 ICI 也是一槽 +- TODO:↔ GPU(动态调度器、SM warp scheduler、CUDA Graph 的对比位置) + +### 8. 编译时机:JIT、AOT、bucketing、persistent cache + +一句话:静态编译的代价是首跑慢,工业上靠 bucketing + AOT + cache 把这个代价摊薄。 + +- TODO:JIT 触发条件与冷启动延迟 +- TODO:bucketing 的 shape 桶设计 +- TODO:AOT 预编译的部署链路 +- TODO:persistent cache 的实际形态 +- TODO:↔ GPU(PyTorch 动态图 / TorchInductor / Triton AOT 的对照) + +### 9. XLA 拓扑感知映射 + +一句话:编译器知道 3D Torus + OCS 的物理拓扑,所以能把高密度通信映射到短边、低频同步映射到长环。 + +- TODO:拓扑信息怎么传给 XLA +- TODO:TP(短边铜缆)、DP(长环光纤)的自动决策 +- TODO:与 Ch 4 NUCA 部分的呼应 +- TODO:↔ GPU(手工 NCCL group、torch.distributed 的拓扑感知能力) + +--- + +## Part III — 推理层适配(目标 C) + +### 10. 软件栈分叉:vLLM、JetStream、Saxml、GKE + +一句话:TPU 上推理框架不止一个,三家定位不同;GKE 是把它们都装进集群的胶水。 + +- TODO:vLLM-TPU 的位置与裁剪 +- TODO:JetStream 的角色(Google 内外用法的差异) +- TODO:Saxml 的服务模型 +- TODO:三者在 GKE 上的部署形态对比 +- TODO:↔ GPU(vLLM / TGI / TRT-LLM 的对应) + +### 11. PagedAttention 与连续批处理在 TPU 上的适配 + +一句话:GPU 上的动态内存管理(PagedAttention、Continuous Batching、Radix Tree)天生不适合静态编译;TPU 上只能靠 Pallas 写自定义 kernel + 把动态切到张量层。 + +- TODO:PagedAttention 的核心难点:动态 indirection +- TODO:Pallas kernel 在哪一层做适配 +- TODO:连续批处理的实现取舍 +- TODO:Radix Tree(前缀缓存)的适配现状 +- TODO:「用极其廉价的 FLOPs 去消除极其昂贵的 Control Flow」这个原则的展开 +- TODO:↔ GPU(vLLM 原生的 paged 实现) + +### 12. Prefill / Decode 协同与 Chunked Prefill + +一句话:TPU 在 Prefill 强、Decode 弱(HBM 带宽瓶颈),混合执行 + chunked prefill 是用算法补硬件。 + +- TODO:Prefill / Decode 在硬件资源上的差异 +- TODO:Chunked Prefill 的切片策略 +- TODO:静态 1D 展平怎么把变长 batch 套进静态 shape +- TODO:↔ GPU(vLLM 的 continuous batching、SARATHI) + +### 13. KV Cache 与内存层次 + +一句话:GPU 体系里的 RDMA / GDS / KV offload 在 TPU 上有的天生支持、有的不支持、有的只能走 PCIe 后备。 + +- TODO:ICI 天生 bypass host(相当于免费 RDMA)的语义 +- TODO:GDS 类直连存储在 TPU 上为什么没对应物 +- TODO:KV cache offload 的实际路径(HBM → PCIe → host DRAM / SSD) +- TODO:↔ GPU(NCCL over IB、GDS、Mooncake 类 KV pool) + +### 14. Gemini 在 TPU 上的实战妥协 + +一句话:MoE 和投机解码这两个推理优化,在 TPU 上都得改算法去迁就硬件。 + +- TODO:MoE 的 All-to-All 与 3D Torus 的张力(这条会在 Ch 19 再呼应硬件原因) +- TODO:Capacity Factor 的作用(控制每专家上限以变成静态 shape) +- TODO:投机解码 Tree Attention 的 mask 设计 +- TODO:为什么这些妥协在 GPU 上不是必须 + +--- + +## Part IV — 集群层适配(目标 D) + +### 15. K8s 上的 TPU 抽象 + +一句话:K8s 看不见光,所以 OCS 切分必须由独立组件负责,TPU device plugin + 拓扑标签 + Kueue + TPU Provisioner 组成完整链路。 + +- TODO:device plugin 暴露给 kubelet 的资源粒度 +- TODO:拓扑标签(node label)携带的 3D 坐标信息 +- TODO:Kueue gang scheduling 为什么必要(slice 必须整组上) +- TODO:TPU Provisioner 调用 OCS API 的时机 +- TODO:↔ GPU(NVIDIA device plugin、Volcano gang、Topology Aware Scheduling) + +### 16. Multi-host slice 的编排 + +一句话:一个 slice 跨多 host 时,K8s 看到的是 N 个 pod 的协同启动,每个 pod 内的 TPU 又是 4 / 8 个芯片的本地组——这是两层 N:N。 + +- TODO:LWS(LeaderWorkerSet)的语义 +- TODO:JobSet 的角色与 LWS 的关系 +- TODO:调度耦合点:哪一层 fail 会拖整个 slice +- TODO:↔ GPU(MPI Operator、Training Operator、Ray on K8s 的对应) + +--- + +## Part V — 系统对比与权衡(目标 B 集中点) + +### 17. 编程模型链:从单卡到多机的指令链 + +一句话:GPU 是「单卡 CUDA → 多卡 NCCL → 多机 IB/RDMA」三段;TPU 是「SPMD → ICI(VLIW 第五槽)」一段,编译器统管。 + +- TODO:CUDA → NCCL → IB/RDMA 的语义跳转点 +- TODO:SPMD + ICI 的「无缝」是怎么做到的 +- TODO:两套模型对故障半径、调度灵活度的影响 + +### 18. 成本 / 能效 + +一句话:MFU 和 Tokens/$ 这两个指标是衡量真实账面差异的杠杆,不是芯片峰值算力。 + +- TODO:MFU(Model FLOPs Utilization)的定义与典型水位 +- TODO:Tokens per dollar 的计算口径 +- TODO:Midjourney 案例($2.1M → <$700K,对话里给出的数字) +- TODO:Character.AI 案例 +- TODO:「峰值 TFLOPs 不等于实际产出」这个 trade-off 集中在这 + +### 19. TPU 的硬件劣势与权衡 + +一句话:每个静态调度的优势都对应一个不擅长的工作负载——MXU 粒度大、SPU 弱、3D Torus All-to-All 拥塞、HBM 带宽 vs 算力失衡。 + +- TODO:MXU 128×128 粒度对小矩阵的浪费 +- TODO:SPU 弱在哪些场景下成为瓶颈 +- TODO:MoE All-to-All 与 3D Torus 的根本不匹配 +- TODO:HBM 带宽 vs 峰值算力增长的剪刀差(呼应 Ch 6、Ch 12) +- TODO:每条劣势的 trade-off:换来了什么 + +--- + +## 附录 + +### A. Trade-off 速查表 + +按设计决策维度横切,每条 trade-off 链接回原章节。计划维度: + +- 静态 vs 动态(编译 vs 调度) +- 密度 vs 灵活(MXU 大 vs 小、Torus vs fat-tree) +- 算力 vs 带宽(封装 / HBM) +- 集中 vs 分布(编译器 vs 运行时调度器) + +### B. 数字 / 参数清单 + +所有数字都注明「源自原对话」。计划条目: + +- TPU v4: 4096 芯片 +- TPU v5p: 8960 芯片 +- MXU: 128×128 +- HBM 带宽数量级:TODO +- Midjourney: $2.1M → <$700K +- Character.AI: TODO +- CPU:TPU 比例:1:4 或 1:8 +- 4×4×4 机架: 96 根光纤 + +### C. 术语 ↔ GPU 等价物对照 + +纯查询用。计划条目: + +- ICI ↔ NVLink + IB(合并对应) +- SPMD ↔ NCCL collective +- Pallas ↔ Triton +- OCS ↔ (无对应) +- XLA ↔ TorchInductor / TensorRT +- MXU ↔ Tensor Core +- VPU ↔ CUDA Core +- HLO ↔ FX Graph +- TPU Provisioner ↔ (无对应,最接近的是 Slurm topology + 手工 NCCL) + +--- + +## 写作日志(让作者验收用) + +### 主动取舍(待原文读取后填) + +- TODO + +### 外部补充(Claude 加,原文未提及) + +- TODO,每条会用 `> **[补充 — Claude 加]** ...` 在正文中标注 From 76e2773d1634eb09d39c37f3c15093f2a62e99bc Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 5 May 2026 14:20:32 +0000 Subject: [PATCH 2/8] docs(notes): Fill Part I (hardware layer) of TPU deep dive MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace skeleton TODOs with full prose for chapters 1-6: - Single chip: MXU/VPU/SPU systolic array, 2x2 walk-through, VLIW packs - Inter-chip: ICI and 3D Torus topology - OCS: MEMS mirror physics, 96-fiber geometry, rack-level slicing - Collective communication: Ring All-Reduce, dim-partitioned 3D, NUCA - Host integration: PCIe, multi-host slicing (LWS/JobSet), NUMA - Advanced packaging: 2.5D/3D, silicon interposer, TSV Each chapter ends with a GPU-comparison subsection. Three "[补充 — Claude 加]" callouts mark facts added beyond the source for user review. Parts II-V and appendices to follow in subsequent commits on this branch. Co-authored-by: Claude --- notes/tpu-deep-dive.cn.md | 504 +++++++++++++++++++++++++------------- 1 file changed, 339 insertions(+), 165 deletions(-) diff --git a/notes/tpu-deep-dive.cn.md b/notes/tpu-deep-dive.cn.md index a12bff0..4fb52c0 100644 --- a/notes/tpu-deep-dive.cn.md +++ b/notes/tpu-deep-dive.cn.md @@ -1,8 +1,8 @@ # TPU 深入笔记:从单芯片到生产集群 -> **状态**:骨架阶段,章节结构已确认,正文内容待源对话可读后填入。 -> **目的**:给自己留一份从硬件原理一路到推理 / 集群层适配的端到端参考,重点在**为什么这么设计**而非术语解释。 -> **对照对象**:相同问题在 GPU 体系下是怎么做的——每章末尾会有 `↔ GPU` 小节。 +> **目的**:给自己留一份从硬件原理一路到推理 / 集群层适配的端到端参考,重点在**为什么这么设计**而非术语解释。 +> **对照对象**:每章末尾会有一个 `↔ GPU` 小节做就近对比。 +> **关于"补充"**:正文中所有 `> **[补充 — Claude 加]** ...` 标记的内容都是原对话里没有、我补的,请你过一眼决定要不要保留。 --- @@ -10,250 +10,424 @@ ### 1. 单芯片:MXU、VPU、SPU 与脉动阵列 -一句话:TPU 把矩阵乘法做成专用电路(MXU),用脉动阵列在芯片内部「传递数据」而不是「传递地址」,省掉了缓存层级。 +**一句话**:TPU 把矩阵乘法做成专用电路(MXU),用脉动阵列在芯片内部「传递数据」而不是「传递地址」,省掉了缓存层级。 -- TODO:MXU / VPU / SPU 三者的角色分工 -- TODO:CISC 指令集的取舍 -- TODO:脉动阵列在矩阵乘法里逐 cycle 的数据流转图(用 ASCII 或 Mermaid) -- TODO:↔ GPU(SM / Tensor Core / 寄存器堆 的对比) +#### 1.1 脉动阵列:绕开冯·诺依曼瓶颈 -### 2. 多芯片互联:ICI 与 3D Torus +通用架构里每次 ALU 算加法都要去寄存器或 L1 取数、写回。LLM 推理动辄要几百 GB/s 的矩阵乘吞吐,这种「算一下查一次内存」的模式直接被带宽锁死。 -一句话:单芯片很强,但真正让 TPU 成为 TPU 的是片间网络——ICI 是物理层、3D Torus 是逻辑拓扑。 +TPU 的破局思路是把计算单元排成二维网格(v4 的 MXU 是 128×128 个 MAC),让数据像心脏泵血一样流过: -- TODO:ICI 的位置(链路层、带宽、延迟数量级) -- TODO:3D Torus 的几何(每芯片 6 个邻居、X/Y/Z 三轴) -- TODO:为什么选 Torus 而不是 fat-tree -- TODO:↔ GPU(NVLink / NVSwitch / IB 的位置) +- **权重驻留**:算之前先把权重锁进每个 MAC 单元里 +- **数据脉动**:激活值从左侧、上方流入,逐 cycle 向右、向下推进 +- **邻居硬连线**:每个单元算完乘加,结果**直接通过物理走线**传给下一个邻居,不写寄存器 -### 3. OCS:光路交换让拓扑可重配 +数据穿过整个阵列的过程中被复用了几百到上千次,期间一次都不碰 SRAM。这就是为什么同样的硅面积,TPU 能塞下比 GPU 多得多的纯算力单元。 -一句话:TPU Pod 在物理层是 star(每机架到 OCS 的若干光纤),逻辑层却被 OCS 接成 torus;MEMS 小镜子在微秒级切换光路。 +#### 1.2 一个 2×2 的具体例子 -- TODO:MEMS 镜子的工作原理(机械翻转 vs 半导体) -- TODO:4×4×4 机架为什么有 96 根外出光纤(角×3 + 棱×2 + 面心×1 的几何约束) -- TODO:OCS 切分粒度(机架 / Block 而非单芯片)的工程原因 -- TODO:↔ GPU(packet-switched IB 的世界对应不到这个) +设 $Y = X \times W$,激活和权重都是 2×2,MXU 是 2×2 网格,**权重驻留**模式: -### 4. 集合通信:Ring All-Reduce 在 3D Torus 上的降维 +``` +权重位置: PE(1,1)=W11 PE(1,2)=W12 + PE(2,1)=W21 PE(2,2)=W22 -一句话:3D Torus 上做 All-Reduce 的诀窍是把它拆成 X/Y/Z 三个 1D 圈分别做,再合起来——这是几何换算法。 +激活流入: X 的行从左侧进入,第二行比第一行晚一个 cycle("倾斜"进入) +部分和方向: 从上向下传递 +``` -- TODO:1D 圈的 All-Reduce 步骤 -- TODO:3D 降维的合成顺序与时间复杂度 -- TODO:「圈是 1D 但连线是 3D,剩下的线不就闲置了?」这个问题的回答 -- TODO:NUCA:跨机架大圈中,铜缆 vs 光纤路径的延迟异构 -- TODO:XLA 怎么把 TP(密集通信)映射到短边铜缆、把 DP(稀疏同步)映射到长光环 -- TODO:↔ GPU(NCCL ring/tree、NVLink 全互联下的差异) +逐 cycle 看: -### 5. Host ↔ TPU:PCIe、NUMA 与 multi-host slice +- **Cycle 1**:$X_{11}$ 进入 PE(1,1),算出 $X_{11} \times W_{11}$,向下传部分和,向右传 $X_{11}$。 +- **Cycle 2**:$X_{11}$ 流到 PE(1,2);$X_{21}$ 进入 PE(2,1),把上方传来的 $X_{11} \times W_{11}$ 加上自己的 $X_{21} \times W_{21}$ —— 这一刻,PE(2,1) 向下输出的就是 $Y_{11}$ 的最终值。 +- **Cycle 3**:右下角 PE(2,2) 收齐两路输入,输出 $Y_{12}$。 +- **Cycle 4+**:剩余结果从阵列底部依次「滴落」,整个过程没有任何寄存器查表。 -一句话:TPU 不是独立机器,是挂在 CPU host 上的 PCIe 设备;slice 一旦跨 host,就是天然的分布式系统。 +#### 1.3 宏观架构:MXU 之外还需要谁 -- TODO:1:4 / 1:8 的 CPU:TPU 比例从哪来 -- TODO:multi-host slice 的 N:N 映射、SPMD 进程是怎么起的 -- TODO:NUMA 1:4 PCIe 劈管、单 NUMA VM 的切分策略、XLA 自动绑核 -- TODO:LWS / JobSet 在 K8s 层怎么把这套 N:N 表达出来 -- TODO:↔ GPU(HGX 8 卡 + NUMA、GDS 直连存储) +只有 MXU 跑不了完整网络。一个 TPU 核心里: -### 6. 先进封装:算力面积 vs 带宽周长 +| 组件 | 职责 | +|---|---| +| **MXU** | 稠密矩阵乘加,算力主力 | +| **VPU** | 向量运算:激活函数(GeLU/ReLU/Softmax)、LayerNorm 等没法压成矩阵乘的部分 | +| **SPU** | 标量与控制流:循环、分支、地址计算。算力很弱,但负责"指挥" | +| **Unified Buffer** | 几十 MB 的片上 SRAM,用来暂存从 HBM 搬进来的激活值和中间结果 | + +指令集走 **CISC** 风格:一条 `MatrixMultiply` 就能触发 MXU 做几千次 MAC,把指令解码和调度的开销压到极低。 + +#### 1.4 SPU 与 VLIW:硬件不动,控制流谁管 + +MXU 自己不知道该算什么、数据从哪来。这套调度由 SPU + **VLIW(超长指令字)** 驱动: + +XLA 编译器把多个无冲突的操作打包成一条长指令。一条指令同时包含四件事: + +``` +[ DMA ] | [ MXU ] | [ VPU ] | [ SPU 控制流 ] +``` + +举例:`[DMA 把下一块激活搬进 UB] | [MXU 算当前块] | [VPU 对上一块结果做 ReLU] | [SPU 更新地址]` + +SPU 的工作模式: -一句话:Die 上面积决定算力(FLOPs),周长决定带宽(HBM 接口);2.5D / 3D 封装是在调和两者的根本张力。 +- **不等结果就发指令**:把控制信号塞进 MXU 的 FIFO 队列,只要队列没满就一直往前跑 +- **同步靠栅栏**:需要等 MXU 算完才能继续时,插一条 `WAIT_MXU_DONE`,硬件级阻塞 +- **不参与计算**:SPU 算力极弱,但因为只做循环、地址、栅栏,开销极低 -- TODO:硅中介层(CoWoS 类)的角色 -- TODO:TSV(Through-Silicon Via)解决了什么 -- TODO:HBM 带宽 vs 算力增长的不对称(这条会在 Ch 19 再呼应) -- TODO:↔ GPU(H100 / B100 的封装方案) +整个时序在程序运行前就 100% 确定。 + +#### 1.5 ↔ GPU + +| 维度 | TPU | GPU | +|---|---|---| +| 矩阵单元 | MXU 128×128,单个巨型脉动阵列 | Tensor Core 16×8×16,散布在每个 SM 里 | +| 标量/向量单元 | SPU 弱、VPU 中等 | CUDA Core 多、强分支预测 | +| 数据缓冲 | Unified Buffer(软件显式 DMA 管理) | L1/L2 Cache(硬件自动预取) | +| 指令模型 | VLIW,编译期排好 | SIMT,运行时 Warp Scheduler 切换 | +| 隐藏访存延迟 | 静态流水线 | 海量并发线程切换 | + +**Trade-off**:TPU 把灵活性交给编译器,硬件做傻瓜执行;GPU 把灵活性留在硬件,软件相对简单但硅面积花在调度上。 --- -## Part II — 编译与运行时:XLA +### 2. 多芯片互联:ICI 与 3D Torus + +**一句话**:单芯片很强,但真正让 TPU 成为 TPU 的是片间网络——ICI 是物理层,3D Torus 是逻辑拓扑。 + +#### 2.1 ICI:网络做进硅片边缘 -### 7. XLA 编译模型:静态调度做减法 +每颗 TPU 的硅片边缘集成了一组高速接口叫 **ICI(Inter-Core Interconnect)**,特性: -一句话:XLA 的核心打法是「把所有不确定性在编译期消掉」——算子融合、静态 padding、软件流水线、VLIW 指令包都是这个思路的不同切面。 +- **没有外部交换机**:Pod 内传数据不经过任何传统网络设备,从一颗芯片的光模块直接跳到另一颗的光模块 +- **延迟极低且确定**:避开了交换机缓冲、路由表、拥塞控制等软件开销,XLA 能精确算出数据从 Pod 一端到另一端的纳秒数 -- TODO:HLO → LLO 的层次 -- TODO:算子融合的边界(什么能融、什么不能融) -- TODO:静态 padding 的代价与收益 -- TODO:软件流水线在 systolic array 上的体现 -- TODO:VLIW 指令包的五槽:DMA / MXU / VPU / SPU / **ICI**——为什么 ICI 也是一槽 -- TODO:↔ GPU(动态调度器、SM warp scheduler、CUDA Graph 的对比位置) +> **[补充 — Claude 加]** 原对话没明说 ICI 链路的具体带宽数字,公开资料显示 v4 单链路是 4.5 TB/s 量级(每芯片所有 6 个方向加起来),不确定要不要写进笔记,请你定。 -### 8. 编译时机:JIT、AOT、bucketing、persistent cache +#### 2.2 3D Torus:六邻居 + 首尾相连 -一句话:静态编译的代价是首跑慢,工业上靠 bucketing + AOT + cache 把这个代价摊薄。 +每颗芯片在 (X, Y, Z) 三维空间里有一个坐标,通过 ICI 连到 ±X, ±Y, ±Z 六个方向的邻居。**关键在于环回**:走到某个维度的边缘时,连线会绕回该维度的起点形成一个环,所以 Pod 里没有"边缘节点"。 -- TODO:JIT 触发条件与冷启动延迟 -- TODO:bucketing 的 shape 桶设计 -- TODO:AOT 预编译的部署链路 -- TODO:persistent cache 的实际形态 -- TODO:↔ GPU(PyTorch 动态图 / TorchInductor / Triton AOT 的对照) +为什么选 Torus 而不是 fat-tree: -### 9. XLA 拓扑感知映射 +- **冗余路径多**:一条线断了可以绕另一维度过去 +- **集合通信友好**:All-Reduce / All-Gather 这种环形通信在物理上就是直跑环,不需要路由计算 -一句话:编译器知道 3D Torus + OCS 的物理拓扑,所以能把高密度通信映射到短边、低频同步映射到长环。 +代价: -- TODO:拓扑信息怎么传给 XLA -- TODO:TP(短边铜缆)、DP(长环光纤)的自动决策 -- TODO:与 Ch 4 NUCA 部分的呼应 -- TODO:↔ GPU(手工 NCCL group、torch.distributed 的拓扑感知能力) +- **网络直径线性**:随着规模 $N$ 增长,最长跳数是 $O(N^{1/3})$,而 fat-tree 是 $O(\log N)$ +- **不规则点对点会拥塞**:MoE 的 All-to-All 这种乱序通信在 Torus 上会出现链路热点,第 19 章会展开 + +#### 2.3 ↔ GPU + +| 维度 | TPU | GPU | +|---|---|---| +| 节点内 | ICI(直集成在芯片边) | NVLink / NVSwitch(独立交换芯片) | +| 节点间 | ICI + OCS(同一套,纯光) | InfiniBand + NIC + 交换机 | +| 拓扑 | 3D Torus | Fat-tree / 多层 Spine-Leaf | +| 延迟特性 | 静态可预测 | 动态有抖动 | + +**Trade-off**:Torus 在规整集合通信上是降维打击,但面对动态稀疏通信(MoE)就吃亏;fat-tree 反过来。 --- -## Part III — 推理层适配(目标 C) +### 3. OCS:光路交换让拓扑可重配 -### 10. 软件栈分叉:vLLM、JetStream、Saxml、GKE +**一句话**:TPU Pod 在物理层是 star(每机架到 OCS 的若干光纤),逻辑层却被 OCS 接成 torus;MEMS 小镜子在微秒级切换光路。 -一句话:TPU 上推理框架不止一个,三家定位不同;GKE 是把它们都装进集群的胶水。 +#### 3.1 MEMS 镜子:纯反射、不读包 -- TODO:vLLM-TPU 的位置与裁剪 -- TODO:JetStream 的角色(Google 内外用法的差异) -- TODO:Saxml 的服务模型 -- TODO:三者在 GKE 上的部署形态对比 -- TODO:↔ GPU(vLLM / TGI / TRT-LLM 的对应) +普通以太网/InfiniBand 交换机走的是 **O-E-O**(光-电-光):光纤进来 → 转电信号 → 芯片读包头 → 查路由表 → 缓存排队 → 转回光信号出去。这套流程带来延迟、抖动,且交换芯片自己耗电极大。 -### 11. PagedAttention 与连续批处理在 TPU 上的适配 +OCS(Google 内部代号 Palomar)从 v4 开始引入,走纯 **O-O-O**: -一句话:GPU 上的动态内存管理(PagedAttention、Continuous Batching、Radix Tree)天生不适合静态编译;TPU 上只能靠 Pallas 写自定义 kernel + 把动态切到张量层。 +- 内部是封闭充惰性气体的空腔 +- 输入光纤的激光打在第一面 MEMS 微镜上(头发丝大小,靠静电引力偏转) +- 反射到对面第二面 MEMS 镜 +- 第二面镜把光"矫正"成水平角度,射进通往目标 TPU 的光纤 +- 全程**没有数字芯片,没有缓存,不解析数据包,没有带宽上限** -- TODO:PagedAttention 的核心难点:动态 indirection -- TODO:Pallas kernel 在哪一层做适配 -- TODO:连续批处理的实现取舍 -- TODO:Radix Tree(前缀缓存)的适配现状 -- TODO:「用极其廉价的 FLOPs 去消除极其昂贵的 Control Flow」这个原则的展开 -- TODO:↔ GPU(vLLM 原生的 paged 实现) +镜子在传输时是**完全静止**的——只在切换"道岔"时短暂调一下角度。所以 OCS 对数据是**透明**的(Data Agnostic):不管激光以 100 Gbps 还是 800 Gbps 闪烁,它只管反射。光纤升级了,OCS 不用换。 -### 12. Prefill / Decode 协同与 Chunked Prefill +#### 3.2 切分粒度:机架级,不是芯片级 -一句话:TPU 在 Prefill 强、Decode 弱(HBM 带宽瓶颈),混合执行 + chunked prefill 是用算法补硬件。 +**OCS 不能任意挑单颗芯片建环**。一个机架内的 64 颗 TPU(4×4×4 基础块)之间是用粗壮的 DAC 铜缆焊死的,便宜且距离短。机架对外暴露的接口才插光模块、拉光纤进 OCS。 -- TODO:Prefill / Decode 在硬件资源上的差异 -- TODO:Chunked Prefill 的切片策略 -- TODO:静态 1D 展平怎么把变长 batch 套进静态 shape -- TODO:↔ GPU(vLLM 的 continuous batching、SARATHI) +所以 OCS 的"乐高积木"最小单位是一个 4×4×4 机架。要 256 芯就把 4 个机架的光纤拼起来,要 1024 芯就拼 16 个,以此类推。 -### 13. KV Cache 与内存层次 +#### 3.3 96 根线:4×4×4 机架的几何 -一句话:GPU 体系里的 RDMA / GDS / KV offload 在 TPU 上有的天生支持、有的不支持、有的只能走 PCIe 后备。 +这是个挺优雅的几何题。表面 56 颗芯片向外暴露的接口数怎么算?按位置分: -- TODO:ICI 天生 bypass host(相当于免费 RDMA)的语义 -- TODO:GDS 类直连存储在 TPU 上为什么没对应物 -- TODO:KV cache offload 的实际路径(HBM → PCIe → host DRAM / SSD) -- TODO:↔ GPU(NCCL over IB、GDS、Mooncake 类 KV pool) +| 位置 | 数量 | 每颗出几根线 | 小计 | +|---|---|---|---| +| 8 个角 | 8 | 3(暴露 X、Y、Z 三方向) | 24 | +| 12 条棱(每条减去两端的角,剩 2 颗) | 24 | 2 | 48 | +| 6 个面(每面减去边角,剩 2×2=4 颗) | 24 | 1 | 24 | -### 14. Gemini 在 TPU 上的实战妥协 +合计 **96 根光纤**。也可以按面算验证:每个面是 4×4=16 个对外接口,6 个面 ×16 = **96**,吻合。 -一句话:MoE 和投机解码这两个推理优化,在 TPU 上都得改算法去迁就硬件。 +机房里这 96 根线不是一根根独立拉,而是用 MPO/MTP 高密度并行光缆(一根粗线含 16 或 32 芯光纤),从机架顶部"瀑布"式汇聚到中央的 OCS 网络机架。 -- TODO:MoE 的 All-to-All 与 3D Torus 的张力(这条会在 Ch 19 再呼应硬件原因) -- TODO:Capacity Factor 的作用(控制每专家上限以变成静态 shape) -- TODO:投机解码 Tree Attention 的 mask 设计 -- TODO:为什么这些妥协在 GPU 上不是必须 +#### 3.4 OCS 与 3D Torus 的关系 ---- +**3D Torus 是逻辑拓扑形态,OCS 是变形金刚的关节**: -## Part IV — 集群层适配(目标 D) +- v2/v3 时代:拓扑是物理硬连线,机架 A→B→C→A 焊死,谁坏一颗整个区域瘫痪 +- v4/v5 时代:物理走线变成 **star**(所有机架光纤都汇聚到 OCS),由 OCS 内部的镜子动态"折叠"出一个 3D Torus 闭环 -### 15. K8s 上的 TPU 抽象 +切片场景: -一句话:K8s 看不见光,所以 OCS 切分必须由独立组件负责,TPU device plugin + 拓扑标签 + Kueue + TPU Provisioner 组成完整链路。 +| 你要 | OCS 怎么做 | +|---|---| +| 单机架 64 芯闭环 | 把这 96 根线在自己内部互折(左 16↔右 16,前↔后,上↔下) | +| 4×4×8(128 芯,跨 2 个机架) | X、Y 轴各自机架内闭环;Z 轴上把机架 A 顶面(16 根)接给机架 B 底面,反之亦然 | +| 绕过坏芯片 | 借隔壁机架一根线过来,强行维持 3D 环的逻辑完整性 | -- TODO:device plugin 暴露给 kubelet 的资源粒度 -- TODO:拓扑标签(node label)携带的 3D 坐标信息 -- TODO:Kueue gang scheduling 为什么必要(slice 必须整组上) -- TODO:TPU Provisioner 调用 OCS API 的时机 -- TODO:↔ GPU(NVIDIA device plugin、Volcano gang、Topology Aware Scheduling) +#### 3.5 物理边界:Pod 内 -### 16. Multi-host slice 的编排 +激光在光纤和自由空间反射会衰减,OCS 网络的物理覆盖上限就是一个 Pod。v5p Pod 是 8960 颗芯片。**Pod 是光互联的绝对边界**。 -一句话:一个 slice 跨多 host 时,K8s 看到的是 N 个 pod 的协同启动,每个 pod 内的 TPU 又是 4 / 8 个芯片的本地组——这是两层 N:N。 +跨 Pod 的"通信"必须走传统的数据中心以太网(DCN),延迟和抖动都没法跟 ICI 比,会破坏 XLA 的静态时钟。所以现实中: -- TODO:LWS(LeaderWorkerSet)的语义 -- TODO:JobSet 的角色与 LWS 的关系 -- TODO:调度耦合点:哪一层 fail 会拖整个 slice -- TODO:↔ GPU(MPI Operator、Training Operator、Ray on K8s 的对应) +- **同一次计算(Model Parallelism)**:绝不跨 Pod +- **服务调度(Load Balancing)**:跨 Pod 没问题——用户 A 完整路由给 Pod 1,用户 B 给 Pod 2,Pod 之间只通过 HTTP/gRPC 负载均衡器协调 + +#### 3.6 ↔ GPU + +GPU 集群里**没有这个东西**。NVIDIA 的 NVSwitch 和 InfiniBand 都是 packet-switched 电交换,OCS 是 circuit-switched 光交换,理念完全不同。这条算是 TPU 体系最独特的设计。 + +> **[补充 — Claude 加]** 微软 Azure 在部分 AI 集群里也开始试用 Lumen 提供的 OCS,但量级和 Google TPU 不在一个层次。这条不写进正文,仅供你参考。 --- -## Part V — 系统对比与权衡(目标 B 集中点) +### 4. 集合通信:Ring All-Reduce 在 3D Torus 上的降维 + +**一句话**:3D Torus 上做 All-Reduce 的诀窍是把它拆成 X/Y/Z 三个 1D 圈分别做,再合起来——这是几何换算法。 + +#### 4.1 1D Ring All-Reduce:Reduce-Scatter + All-Gather + +设 4 颗 TPU 首尾成环,每颗算出一个 400 MB 张量,要把 4 份逐元素相加然后让每颗都拿到完整结果。 + +直接发给 TPU 0 算会瞬间挤爆它的网络。XLA 的做法是**切碎 + 流水线接力**。把 400 MB 切成 4 个 100 MB(块 A、B、C、D)。 + +**阶段一:Reduce-Scatter(每颗最终持有一个块的完整总和)** + +| Cycle | 动作 | +|---|---| +| 1 | 每颗芯片把自己的某个块往右传:TPU0→A→TPU1,TPU1→B→TPU2,TPU2→C→TPU3,TPU3→D→TPU0。**4 条线(含环回)全部满载** | +| 2 | TPU 1 把收到的 A 与本地 A 用 VPU 加起来变成"部分和 A",继续往右传 | +| 3 | 再传一次。TPU 3 收到流转一圈的部分和 A,加上自己的 A,得到全网完整 A | + +阶段结束时,TPU 0 持有完整 B,TPU 1 持有完整 C,TPU 2 持有完整 D,TPU 3 持有完整 A。 + +**阶段二:All-Gather(把各自的 1/4 答案分享出去)** + +3 步接力广播,环上四个完整块并行飞奔。每颗芯片的 Unified Buffer 拼齐 A/B/C/D 后,DMA 才把最终 400 MB 刷回 HBM。 + +#### 4.2 关键硬件细节:网络是缓存的延伸 + +TPU 把网络变成了 SRAM 的直接延伸: + +- 数据从发送方 Unified Buffer 抽出 → 光信号穿过 ICI → 接收方 Unified Buffer +- VPU 直接从 Unified Buffer 取数做加法,**全程不碰 HBM** +- 发送时同时附带一个极小的 **Sync Token** + +GPU 的对照流程:HBM → PCIe → NIC → 交换机 → NIC → PCIe → HBM → 计算核心。每一跳都涉及内存读写,HBM 带宽被通信吃掉一大块。 + +#### 4.3 不靠全局时钟,靠硬件信号量 + +数据中心规模上维持几百颗芯片在纳秒级共享物理时钟,物理学上做不到(光速延迟、时钟漂移)。TPU 用的是 **XLA 静态时间表 + 硬件信号量异步握手**。 + +接收端的链式反应: + +1. ICI 物理收完数据后**自动**把对应的硬件信号量 +1(SPU 完全不参与) +2. SPU 读 XLA 的指令,看到 `WAIT Semaphore_X >= 1` 就阻塞 VPU +3. 信号量翻 1 的瞬间,WAIT 放行,VPU 像弹簧一样起跳算加法 +4. 算完触发下一条 DMA 把结果发出去,附新的 Sync Token,把自己的信号量清零 + +宏观看就像几千颗芯片长在同一个齿轮上,但实际上每颗芯片只盯着自己面前的几个硬件信号灯。XLA 提前把每一步算力匹配得严丝合缝,所以没有死锁、也没人空转太久。 -### 17. 编程模型链:从单卡到多机的指令链 +#### 4.4 3D 降维:拆成 X、Y、Z 三个 1D 圈 -一句话:GPU 是「单卡 CUDA → 多卡 NCCL → 多机 IB/RDMA」三段;TPU 是「SPMD → ICI(VLIW 第五槽)」一段,编译器统管。 +如果在 4×4×4 机架里把 64 颗串成一个长 64 跳的单圈,灾难: -- TODO:CUDA → NCCL → IB/RDMA 的语义跳转点 -- TODO:SPMD + ICI 的「无缝」是怎么做到的 -- TODO:两套模型对故障半径、调度灵活度的影响 +- 数据要跑 63 跳才转一圈 +- 每颗芯片有 6 根线,只用了 1 根接收 + 1 根发送,**剩下 4 根全闲** -### 18. 成本 / 能效 +正确做法是 **多维正交环形同步**:把 3D 任务拆成三次并行的 1D 任务。 -一句话:MFU 和 Tokens/$ 这两个指标是衡量真实账面差异的杠杆,不是芯片峰值算力。 +| 阶段 | 干什么 | 并行度 | +|---|---|---| +| X 轴同步 | 16 条平行 X 轴小环(长 4)同时跑 Reduce-Scatter + All-Gather | 16 | +| Y 轴同步 | 同步好的数据再切,16 条 Y 轴小环跑一次 | 16 | +| Z 轴同步 | 16 条 Z 轴小环跑一次 | 16 | -- TODO:MFU(Model FLOPs Utilization)的定义与典型水位 -- TODO:Tokens per dollar 的计算口径 -- TODO:Midjourney 案例($2.1M → <$700K,对话里给出的数字) -- TODO:Character.AI 案例 -- TODO:「峰值 TFLOPs 不等于实际产出」这个 trade-off 集中在这 +总跳数:4 + 4 + 4 = **12 跳**,远少于 64 跳。 -### 19. TPU 的硬件劣势与权衡 +光看每个阶段似乎只有 1/3 的线在跑,但 XLA 用 **流水线(Pipelining)**:当矩阵切片 1 在走 Y 轴时,切片 2 已经在走 X 轴。宏观上 6 根线全部满负荷发光。 -一句话:每个静态调度的优势都对应一个不擅长的工作负载——MXU 粒度大、SPU 弱、3D Torus All-to-All 拥塞、HBM 带宽 vs 算力失衡。 +#### 4.5 大圈是必然:跨机架的 Z 轴怎么缝合 -- TODO:MXU 128×128 粒度对小矩阵的浪费 -- TODO:SPU 弱在哪些场景下成为瓶颈 -- TODO:MoE All-to-All 与 3D Torus 的根本不匹配 -- TODO:HBM 带宽 vs 峰值算力增长的剪刀差(呼应 Ch 6、Ch 12) -- TODO:每条劣势的 trade-off:换来了什么 +4×4×8(128 芯,跨 2 机架)切片里,X/Y 还是长 4 小圈,Z 轴变成长 **8 的大圈**。物理形态: + +``` +[ 机架 A 的 Z 轴 ] [ 机架 B 的 Z 轴 ] +TPU(Z=0) — TPU(Z=1) — TPU(Z=2) — TPU(Z=3) TPU(Z=4) — TPU(Z=5) — TPU(Z=6) — TPU(Z=7) + ^ | | + | ← OCS 跨机架光路缝合 ← +— OCS — TPU(Z=4) | + +———————————————————————————————————————— OCS 环回 ———————————————————————————————+ +``` + +满规模 v5p Pod(8960 芯,比例可能是 16×16×35),最长边的 Z 轴会形成 **35 跳的大圈**。光物理传输的延迟叠加就足以让上层 MXU 等到饥饿。 + +#### 4.6 NUCA:铜缆和光路的延迟异构 + +8 跳 Z 轴大圈里,2 跳是光路(跨机架),6 跳是铜缆(机架内)。两种介质: + +- **带宽必须严格一致**:流水线吞吐取决于最细那截管子。所以 TPU 设计时把光模块的调制速率和铜线 SerDes 速率对齐了 +- **延迟必然不一致**:铜缆纳秒级(个位数到十位数),光路要走 E-O 转换 → 几十米光纤 → OCS 反射 → O-E 转换,到几百纳秒级别 + +带宽同构 + 延迟异构 → 圈在物理上不是完美对称的。这种现象叫 **NUCA(Non-Uniform Communication Architecture)**。 + +业务层有两种化解办法: + +1. **流水线稳态掩盖**:刚启动时光路那两跳的延迟会让流水线出现微小气泡,但稳态后吞吐由带宽决定,几百纳秒的启动延迟在持续高吞吐下被淹没 +2. **XLA 的拓扑感知映射**:见下节 9.1 + +#### 4.7 ↔ GPU + +| 维度 | TPU 3D Torus | GPU NCCL on NVLink+IB | +|---|---|---| +| 算法 | Multi-dim Ring All-Reduce | Ring 或 Tree(可选) | +| 网络直径 | $O(N^{1/3})$ | $O(\log N)$ | +| 同步机制 | 硬件信号量 + 静态调度 | 软件 spin-wait + flag in HBM | +| 数据路径 | UB → ICI → UB → VPU | HBM → PCIe → NIC → IB → NIC → PCIe → HBM → SM | +| 计算资源占用 | VPU 顺手做加法,MXU 不停 | NCCL kernel 抢 SM 资源 | + +**Trade-off**:环形 + 静态调度对规则集合通信极致优化,但牺牲了任意点对点的灵活性。 --- -## 附录 +### 5. Host ↔ TPU:PCIe、NUMA 与 multi-host slice + +**一句话**:TPU 不是独立机器,是挂在 CPU host 上的 PCIe 设备;slice 一旦跨 host,就是天然的分布式系统。 + +#### 5.1 物理形态:CPU 是包工头,TPU 是工人 + +每个机架上跑的是标准 x86 服务器主板(Intel/AMD),普通 DDR 内存。TPU 通过 PCIe 总线连接到 CPU 旁边。**典型配比是 1 个 CPU host 管 4 颗或 8 颗 TPU**,这是物理上焊死的比例。 + +分工: + +- **CPU 干**:Linux 系统、Kubelet、HTTP/gRPC 接客、Python/PyTorch、vLLM 调度器(Radix Tree、PagedAttention 页表)、XLA 编译(CPU 密集) +- **TPU 干**:纯执行 CPU 编译好的机器码做矩阵乘加 + +一次 Decode Step 的全过程: + +1. CPU 在系统内存里拼好 metadata(页表指针等) +2. PCIe DMA 把数据 + 新 token embedding 拷贝到 TPU HBM +3. CPU 给 TPU 发指令"按第 5 号编译图开干" +4. TPU 闭关算 +5. PCIe 把 logits 拉回 CPU 内存 +6. CPU 做采样(Argmax、Top-P 等) + +#### 5.2 双网隔离:DCN 与 ICI + +Pod 内部有两套**完全独立**的物理网络: + +| 网络 | 谁用 | 介质 | 用途 | +|---|---|---|---| +| **DCN(数据中心网络)** | Host CPUs | 普通以太网交换机 | CPU 之间聊天,控制面同步 | +| **ICI(芯片互联)** | TPUs | OCS + 3D Torus 专用光纤 | TPU 之间狂飙数据,数据面 | -### A. Trade-off 速查表 +控制面(K8s 协调、Pod 启停)走 DCN,CPU 间用 gRPC。数据面(All-Reduce、KV 同步)完全不经 CPU。 -按设计决策维度横切,每条 trade-off 链接回原章节。计划维度: +#### 5.3 Multi-host slice:N:N 映射 -- 静态 vs 动态(编译 vs 调度) -- 密度 vs 灵活(MXU 大 vs 小、Torus vs fat-tree) -- 算力 vs 带宽(封装 / HBM) -- 集中 vs 分布(编译器 vs 运行时调度器) +很多人以为申请 64 芯 v4-64 切片会拿到一台带超大 CPU 的巨无霸 VM。**不是这样**。 -### B. 数字 / 参数清单 +物理上是 **16 台 Host VM**(每台 4 颗 TPU),16×4=64: -所有数字都注明「源自原对话」。计划条目: +- 16 台 VM 通过 DCN 相连 +- 64 颗 TPU 通过 OCS 组成 3D Torus +- 每台 VM 上跑一个 Kubelet +- K8s 调度 16 个 Pod,分别落在 16 台 VM 上 +- 推理代码(vLLM/JetStream)在 16 个 CPU 上同时启动,跑同一份 Python 代码(**SPMD**) +- 通常用 **LeaderWorkerSet (LWS)** 或 **JobSet** 表达:1 个 Leader Pod 对外暴露 API,15 个 Worker Pod 配合 +- Leader 收到请求后通过 DCN 广播给 Workers,每个 CPU 给自己脚下的 4 颗 TPU 下指令 +- 64 颗 TPU 算出结果后由 Leader 汇总返回 -- TPU v4: 4096 芯片 -- TPU v5p: 8960 芯片 -- MXU: 128×128 -- HBM 带宽数量级:TODO -- Midjourney: $2.1M → <$700K -- Character.AI: TODO -- CPU:TPU 比例:1:4 或 1:8 -- 4×4×4 机架: 96 根光纤 +没有"超级 CPU 管 64 颗 TPU"这回事。大切片本质上是多个"小 CPU + 小 TPU"通过 K8s gang scheduling 强行绑定的分布式联邦。 -### C. 术语 ↔ GPU 等价物对照 +#### 5.4 NUMA:双路 CPU 的 PCIe 劈管 -纯查询用。计划条目: +现代服务器主板通常是双路 CPU(CPU 0 和 CPU 1)。一半的 PCIe 通道连 CPU 0,另一半连 CPU 1。8 颗 TPU 主机上,TPU 0~3 挂 CPU 0,TPU 4~7 挂 CPU 1。 -- ICI ↔ NVLink + IB(合并对应) -- SPMD ↔ NCCL collective -- Pallas ↔ Triton -- OCS ↔ (无对应) -- XLA ↔ TorchInductor / TensorRT -- MXU ↔ Tensor Core -- VPU ↔ CUDA Core -- HLO ↔ FX Graph -- TPU Provisioner ↔ (无对应,最接近的是 Slurm topology + 手工 NCCL) +跨 NUMA 灾难:CPU 0 想写 TPU 4 的 HBM,数据必须先走 UPI 总线到 CPU 1,再走 PCIe 到 TPU 4。延迟剧增,可用带宽腰斩。 + +具体卡在哪些环节: + +- **输入流水线**:Tokenize 在 CPU 0 但任务下到 TPU 4,每次 in-feed 跨 NUMA +- **KV Cache offload**:TPU 0 的 KV swap 到了 CPU 1 管的 DDR 槽条 +- **权重加载**:百 GB 级的权重 DMA 跨 NUMA,冷启动慢 + +Google Cloud 的解法: + +- **单 NUMA VM 切分**:申请 4 颗 TPU 实例时,hypervisor 直接把物理机切两半。给你的 VM **只**包含 CPU 0 + CPU 0 的内存 + 挂在 CPU 0 下的 4 颗 TPU。看不见 CPU 1 那半边 +- **XLA 自动绑核**:8 颗 TPU 实例无法切分时,XLA Runtime(PJRT)读取 PCIe 树拓扑,自动把喂 TPU 0~3 的线程 pin 到 CPU 0 核心,喂 TPU 4~7 的 pin 到 CPU 1 核心 + +#### 5.5 ↔ GPU + +| 维度 | TPU | GPU | +|---|---|---| +| 物理挂载 | PCIe 挂在 host CPU 旁 | 同上 | +| 比例 | 1:4 或 1:8 焊死 | HGX 通常 1:8 | +| Multi-host 编排 | LWS / JobSet + K8s gang | MPI Operator / Training Operator | +| NUMA 处理 | 单 NUMA VM 切分 + XLA PJRT 自动绑核 | 多用 `numactl` 手动绑 + NCCL 拓扑感知 | +| 双网 | DCN + ICI 严格分开 | 控制面 + 数据面通常都走 IB | + +**Trade-off**:TPU 的 N:N 编排让单机故障域变小,但运维上看到的"一个推理服务"实际是多个 K8s 资源的协奏。 --- -## 写作日志(让作者验收用) +### 6. 先进封装:算力面积 vs 带宽周长 + +**一句话**:Die 上面积决定算力(FLOPs),周长决定带宽(HBM 接口);2.5D / 3D 封装是在调和两者的根本张力。 + +#### 6.1 内存墙的物理起源 + +- **算力 ∝ 面积**:MXU 是二维的,面积稍大乘加单元就 $O(N^2)$ 爆炸增长(64×64 → 128×128,面积翻倍,算力翻 4 倍) +- **传统带宽 ∝ 周长**:HBM 带宽取决于芯片边缘引脚数量,是 $O(N)$ 线性增长 -### 主动取舍(待原文读取后填) +面积平方级、周长线性级——**带宽永远跟不上算力**。这就是"内存墙"的物理根源。 -- TODO +#### 6.2 封装演进的三步 -### 外部补充(Claude 加,原文未提及) +| 代际 | 名字 | 解决了什么 | +|---|---|---| +| 第一代 | Wire Bonding(金线键合) | 芯片正面朝上,靠边缘金线接基板,受限于周长 | +| 第二代 | Flip-Chip(倒装芯片) | 芯片翻过来,整个面种满 C4 锡球,从"边长"扩到"面积"——但 PCB 主板走线精度跟不上(线宽几十微米) | +| 第三代 | **2.5D 硅中介层(Silicon Interposer / CoWoS)** | 在芯片和主板之间垫一块硅片,用 EUV 光刻在硅片上画**纳米级**走线 | +| 第四代 | **3D 封装 + TSV(Through-Silicon Via)** | 直接在硅片上垂直打几万个微米级孔灌铜,把多层芯片像盖楼一样叠起来 | -- TODO,每条会用 `> **[补充 — Claude 加]** ...` 在正文中标注 +#### 6.3 硅中介层为什么有用 + +普通 PCB 板上 1 mm 宽度大概挤 10 根线;硅中介层上 1 mm 能挤上千根。GPU/TPU 与 HBM 的连接通道从"双向 4 车道"变成"双向 4096 车道"。HBM3 的位宽就是这么来的。 + +HBM 自己内部也是 3D 结构:多层内存 die 用 TSV 垂直叠起来,所以单个 HBM stack 能给到 Tbps 级带宽。 + +#### 6.4 算力过剩问题与硬件妥协 + +v4 时代算力涨太快,HBM 带宽跟不上,跑 Decode 时 MFU 经常跌到个位数,MXU 在等米下锅。Google 的硬件级补救: + +- **v5e(推理芯片)故意缩小 MXU**,把 compute / bandwidth 比例调回更健康的区间 +- 牺牲峰值 FLOPs 换性价比,承认 Decode 是 memory-bound 不是 compute-bound + +> **[补充 — Claude 加]** 这一节里 v4 时代 MFU 跌到"个位数"是源对话原话,但没给出具体数据点和场景。我没有别的可信来源验证,建议你之后核实下要不要保留。 + +#### 6.5 ↔ GPU + +NVIDIA H100/B100 也用 CoWoS(同样台积电方案),技术路线一样。差异在 die 比例:H100 有 50 MB L2 Cache、5 TB/s HBM 带宽;TPU 不堆硬件 cache,把那块面积让给 MXU。 + +**Trade-off**:H100 用大 cache 容忍随机访问;TPU 用大 MXU 跑稠密。两条路对应不同的工作负载假设。 + +--- From ee7a6818813b0cff9dcd986386194f2fb3145de7 Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 6 May 2026 02:13:43 +0000 Subject: [PATCH 3/8] docs(notes): Fill Parts II and III of TPU deep dive MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Part II (XLA compiler & runtime, chapters 7-9): - Compilation model: operator fusion, static padding, software pipelining, VLIW 5-slot packs (single-chip and multi-chip variants) - Compilation timing: JIT timeline, bucketing + AOT + persistent cache pipeline, why no universal precompiled library - Topology-aware mapping: TP on short copper rings, DP on long optical rings Part III (inference adaptation, chapters 10-14): - Software stack split: vLLM (lift-and-shift), JetStream (TPU-native), Saxml (JAX legacy) - PagedAttention adaptation: control-plane/data-plane split with XLA pool + vLLM block tables + Pallas custom kernel - Prefill/Decode coordination: arithmetic intensity gap, static-bus continuous batching, chunked prefill, 1D static flatten for mixed steps - KV / memory hierarchy: ICI as native RDMA-bypass, no GDS equivalent, KV offload via PCIe to host DDR - Gemini practical compromises: MoE Capacity Factor for static routing, Tree Attention for tensorized speculative decoding One additional "[补充 — Claude 加]" callout in Chapter 13 about Mooncake-style separated KV pools (not in source). Part IV, V, appendices, and English mirror still pending. Co-authored-by: Claude --- notes/tpu-deep-dive.cn.md | 506 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 506 insertions(+) diff --git a/notes/tpu-deep-dive.cn.md b/notes/tpu-deep-dive.cn.md index 4fb52c0..d8a4a88 100644 --- a/notes/tpu-deep-dive.cn.md +++ b/notes/tpu-deep-dive.cn.md @@ -431,3 +431,509 @@ NVIDIA H100/B100 也用 CoWoS(同样台积电方案),技术路线一样。 **Trade-off**:H100 用大 cache 容忍随机访问;TPU 用大 MXU 跑稠密。两条路对应不同的工作负载假设。 --- + +## Part II — 编译与运行时:XLA + +### 7. XLA 编译模型:把不确定性消灭在编译期 + +**一句话**:XLA 的核心打法是「把所有不确定性在编译期消掉」——算子融合、静态 padding、软件流水线、VLIW 指令包都是这个思路的不同切面。 + +#### 7.1 算子融合 + +最经典的例子:一层 `MatMul + Bias + ReLU`。 + +朴素执行(GPU eager 模式)会: +1. 算 MatMul → 写回 HBM +2. 读出来加 Bias → 写回 HBM +3. 读出来做 ReLU → 写回 HBM + +XLA 把三步融合成一个计算块: + +``` +HBM → UB(MatMul 输入) → MXU 算 MatMul → 流出瞬间送进 VPU → VPU 做 Bias + ReLU → HBM +``` + +**节省了 2/3 的 HBM 读写**。这种融合在 LLM 里普遍存在,比如 Attention 后面紧跟的 dropout/scale/mask 通常都被融进同一个块。 + +#### 7.2 静态 Padding + +MXU 是 128×128 的硬连线阵列。如果你的矩阵是 100×100,XLA 不会让硬件去处理边界——硬件根本不支持。XLA 直接在编译期把矩阵 padding 到 128×128(多余位置填 0)。 + +代价:白算 28 行 × 28 列 ≈ 50% 的边界算力。 +收益:阵列保持满速流动,不用停下来做边界判断。 + +在脉动阵列里,**让硬件全速跑空格远比中途停下来快**。 + +#### 7.3 软件流水线 + +XLA 在编译期就算好每次 DMA 搬数需要多少 cycle,静态生成指令:MXU 算第 N 块时,DMA 已经在搬第 N+1 块的权重。计算和访存完美重叠。 + +#### 7.4 VLIW 五槽:单芯片 + 跨芯片同构 + +单芯片版指令包: + +``` +[ DMA ] | [ MXU ] | [ VPU ] | [ SPU 控制流 ] +``` + +多芯片协作版多了一槽——**ICI 网络槽**: + +``` +[ DMA ] | [ MXU ] | [ VPU ] | [ ICI 网络引擎 ] | [ SPU 控制流 ] +``` + +关键洞察:**对 TPU 来说,跨芯片传数据(ICI)和片内搬砖(DMA)是平级的**——都是 VLIW 指令字上的一个开关。XLA 可以做到 MXU 算乘法的同时,ICI 把上一层的梯度传给隔壁芯片,**计算和跨节点通信在时钟周期级别完美重叠**。 + +这是 GPU 体系做不到的。GPU 的跨节点通信要触发 CPU 中断、构建数据包、过 IB 交换机,是另一套子系统。 + +#### 7.5 一个具体的伪汇编例子 + +计算 $C = A \times B$,$A$ 是 256×128,$B$ 是 128×256。MXU 是 128×128。 + +XLA 在编译期把 $A$ 切上下两块 ($A_0, A_1$),$B$ 切左右两块 ($B_0, B_1$),分解成 4 个 128×128 子任务。所有 HBM 地址在编译期硬编码(无运行时指针计算)。 + +执行 $C_{00} = A_0 \times B_0$ 的 VLIW 流: + +``` +Instruction 1(预热加载权重): + [DMA] LOAD_HBM_TO_UB (Src: HBM_B0, Dst: UB_B) + [MXU] NOP + [VPU] NOP + [SPU] WAIT_DMA_DONE + +Instruction 2(权重驻留): + [DMA] NOP + [MXU] LOAD_UB_TO_WEIGHT_REG (Src: UB_B) + [VPU] NOP + [SPU] WAIT_MXU_DONE + +Instruction 3(核心计算 + 流水线预取): + [DMA] LOAD_HBM_TO_UB (Src: HBM_A0, Dst: UB_A) + [DMA_ASYNC] LOAD_HBM_TO_UB (Src: HBM_B1, Dst: UB_B_Next) ← 提前预取下一块 + [MXU] MATMUL_STREAM_ACT (Src: UB_A, Dest: Accumulator_C00) + [VPU] NOP + [SPU] WAIT_MXU_DONE + +Instruction 4(融合 ReLU): + [DMA] NOP + [MXU] NOP + [VPU] READ_ACCUM_AND_RELU_AND_STORE (Src: Accumulator_C00, Dst: UB_C00) + [SPU] WAIT_VPU_DONE + +Instruction 5(写回 HBM): + [DMA] STORE_UB_TO_HBM (Src: UB_C00, Dst: HBM_C00) + [MXU] NOP + [VPU] NOP + [SPU] JUMP_TO_NEXT_BLOCK ← 准备算 C01 +``` + +注意 Instruction 3:**DMA 在搬下一块的同时 MXU 在算**。如果 XLA 算错任何一个 cycle,要么 UB 溢出,要么 MXU 空转。这种精度只有静态编译才能给。 + +#### 7.6 ↔ GPU + +| 维度 | TPU XLA | GPU | +|---|---|---| +| 调度 | 编译期静态 | 运行期动态(Warp Scheduler) | +| 隐藏访存延迟 | 静态流水线 | 海量并发线程切换 | +| 指令格式 | VLIW 多槽并发 | SIMT 单指令多线程 | +| 算子融合 | XLA 自动 | TorchInductor / TVM / Triton 半自动 | + +**Trade-off**:XLA 的"全局上帝视角"在调度规整工作负载时无敌,但只要算子图里出现动态形状,就要重新编译。 + +--- + +### 8. 编译时机:JIT、AOT、bucketing、persistent cache + +**一句话**:静态编译的代价是首跑慢,工业上靠 bucketing + AOT + cache 把这个代价摊薄。 + +#### 8.1 JIT 在 PyTorch/XLA 上的真实时间线 + +vLLM 等推理框架在 TPU 上启动后的真实流程: + +| 阶段 | 动作 | XLA 状态 | +|---|---|---| +| 1. 初始化 | 加载权重到 HBM | **未编译**。只有 nn.Module 对象和权重张量 | +| 2. Tracing | 第一次请求触发 `model.forward()`,PyTorch/XLA 用 **Lazy Tensor** 不真正计算,只记录 DAG | 在画图,生成 HLO IR | +| 3. 触发编译 | 代码读 logits 做采样时遇到同步屏障(如 `xm.mark_step()`) | **立刻编译**:算子融合 + 静态地址 + 指令排布 → TPU Executable。耗时几秒到几十秒 | +| 4. 执行 + 缓存 | TPU 跑出结果(毫秒级),编译产物存内存中的 Compiler Cache | Cache key 包含计算图结构和所有输入 shape | +| 5. 后续请求 | 同 shape 命中 cache,跳过编译 | 直接复用 | + +**关键澄清**:权重数值不进编译产物。XLA 只关心权重的 shape 和 dtype,把权重视为"静态显存地址指针"。换微调过的 LoRA 权重、换同架构的另一个模型都不需要重编。 + +#### 8.2 工业级解法:bucketing + AOT + 持久化 cache + +生产环境绝不能让第一个用户等 30 秒编译。流程如下: + +**Step 1:限制并离散化 bucket** + +算法团队 profiling 后定一组离散桶: + +``` +BS_Buckets = [1, 2, 4, 8, 16, 32, 64] +SeqLen_Buckets = [128, 512, 1024, 2048, 4096, 8192] +``` + +运行时来一个 BS=5 的请求,padding 到 BS=8 进对应桶。 + +**Step 2:CI/CD 阶段 AOT 预热** + +构建 Docker 镜像或发布制品时加一个 warmup 环节:拉起含真实 TPU 拓扑的 CI 节点,遍历 `BS × SeqLen` 所有组合发假请求触发编译。 + +**Step 3:持久化 cache 打进镜像** + +用 `XLA_FLAGS="--xla_dump_to=/path/to/cache"` 把编译产物落盘。流水线最后把这几百 MB 到几 GB 的 cache 文件**直接打进 release 镜像**或挂在分布式存储里。 + +线上 vLLM/JetStream 实例启动时读这个 cache,命中桶就**毫秒级下发硬件**。 + +#### 8.3 为什么没有"天下大同"的预编译库 + +一个 XLA Executable 绑定的不只是模型 shape,还包括以下**致命变量**——任意一个变了缓存就失效: + +| 变量 | 影响 | +|---|---| +| **物理硬件拓扑** | v5e-8(一维环)的编译产物给不了 v5p-32(3D Torus),XLA 把走哪根光纤、延迟多少都算死了 | +| **并行切分策略** | TP 切 Attention 还是 FFN?PP 怎么切?这些 SPMD 注解必须编译前确定 | +| **编译器版本** | XLA / LLVM 后端更新频繁,旧 cache 大概率校验失败 | +| **模型结构微调** | 加一层 adapter、改 RoPE base 频率→常量折叠结果变→HLO hash 变→cache 全废 | + +所以每个团队都得自己维护一套**模型分发 + 缓存预热**流水线。基础设施大版本发版背后必然伴随大规模自动化重新编译。 + +#### 8.4 ↔ GPU + +GPU 也有这套问题(PyTorch 2.x 的 `torch.compile` / Inductor / TensorRT),但程度轻得多:因为 GPU 硬件能在运行时容忍 shape 变化(动态调度),编译失败时还能 fallback 到 eager。TPU 没有这个 fallback,编译失败 = 服务失败。 + +--- + +### 9. XLA 拓扑感知映射 + +**一句话**:编译器知道 3D Torus + OCS 的物理拓扑,所以能把高密度通信映射到短边、低频同步映射到长环。 + +第 4.6 节讲过 NUCA:跨机架的 8 跳大圈里,6 跳铜缆 + 2 跳光路,带宽同构延迟异构。XLA 在编译时把不同性质的并行策略塞进不同性质的拓扑。 + +#### 9.1 TP 走小圈,DP 走大圈 + +| 并行策略 | 通信特征 | 映射到 | +|---|---|---| +| **张量并行 (TP)** | 步步为营,每个线性层都要同步激活值。**对延迟极敏感** | X 轴或 Y 轴的 4 / 8 短铜环 | +| **数据并行 (DP)** | 秋后算账,每个 step(甚至累积几个 step)才同步一次梯度。梯度矩阵大,对带宽要求高,对单次延迟容忍 | Z 轴的 35 节点大光环 | +| **流水线并行 (PP)** | 阶段间传 activation,频次中等 | 通常分给中等长度的边 | +| **专家并行 (EP)** | 动态 All-to-All(MoE) | Torus 上吃亏,详见第 19 章 | + +DP 走大圈时,35 跳的物理延迟被流水线稳态掩盖,并且底层计算单元可以利用同步等待时间做下一个 step 的前向计算(Compute-Communication Overlap)。 + +#### 9.2 拓扑信息怎么传给 XLA + +K8s 给 Node 打的标签(如 `cloud.google.com/gke-tpu-topology: 4x4x4`)携带了切片的几何形状。XLA Runtime 启动时读这些 + PCIe sysfs 信息,构造拓扑图。然后根据用户的 SPMD 切分注解把通信组映射到具体的 ICI 链路。 + +**结论**:纯 K8s 调度只看节点存活,真正的高性能 AI 调度看的是微秒级的光电物理边界。 + +#### 9.3 ↔ GPU + +GPU 体系里类似的事情靠手工 NCCL group 配置 + `torch.distributed` 的拓扑感知 API。NCCL 知道 NVLink/IB 的层级,但 fat-tree 本身近似对称,拓扑感知的优化空间没 Torus 那么大。 + +--- + +## Part III — 推理层适配(目标 C) + +### 10. 软件栈分叉:vLLM、JetStream、Saxml、GKE + +**一句话**:TPU 上推理框架不止一个,三家定位不同;GKE 是把它们都装进集群的胶水。 + +#### 10.1 GKE 为什么死磕 vLLM on TPU + +vLLM 是开源推理事实上的"Linux"——绝大多数客户在 GPU 上用 PyTorch + vLLM 写好了业务代码(API 封装、调度、自定义 prompting)。GKE 想卖 TPU(v5e/v5p 极具性价比),最大阻力是迁移成本: + +> 如果客户得改代码才能上 TPU,他们就跑了。 + +所以 Google 的策略是 **Lift and Shift**:让客户原本的 `vllm serve` 命令换个基础镜像就在 TPU 上跑起来,PyTorch 调用被自动路由到 PyTorch/XLA。 + +技术底座:vLLM 官方代码库已包含 TPU backend,PagedAttention 在 TPU 静态图上的水土不服由 Google 工程师用 Pallas 写的自定义 kernel 补齐。 + +#### 10.2 三家的位置 + +| 框架 | 定位 | 目标客群 | +|---|---|---| +| **vLLM** | 生态兼容王,"代码不想改" | 创业公司、多云客户、GPU 迁移 | +| **JetStream** | TPU 性能榨汁机 | 大厂、高并发推理,愿意为性能改框架 | +| **Saxml** | JAX 生态历史重型武器 | 深度绑 JAX 的存量大客户、特殊大规模切分 | + +#### 10.3 JetStream 凭什么比 vLLM 快 20%-50% + +JetStream 是 Google Cloud + XLA 团队联合主导,专为 v5 系列定制。它**不去硬凑动态分页**,而是完美契合 TPU 的静态编排哲学: + +- 极深度的连续 Batching 优化 +- 大量 XLA 算子融合 +- 直接跟编译器协同设计,没有 PyTorch 这层间接 + +代价:API 不像 vLLM 那么开箱即用,对 PyTorch 生态的支持需要专门做。 + +#### 10.4 Saxml 为什么靠后 + +最早伴随 Pax / Seqio 一起诞生,深度绑 JAX。带浓厚的 Google 内部基础设施味道,外部上手门槛高,PyTorch 支持滞后。在公有云推广优先级靠后。 + +#### 10.5 ↔ GPU + +| TPU | GPU | +|---|---| +| vLLM-TPU(带 Pallas) | vLLM 原生 | +| JetStream | TensorRT-LLM(NVIDIA 自家 + 特化) | +| Saxml | (无完全对应;DeepSpeed Inference 算半个) | +| Pallas 写 kernel | Triton / CUDA C++ | + +--- + +### 11. PagedAttention 与连续批处理在 TPU 上的适配 + +**一句话**:GPU 上的动态内存管理(PagedAttention、Continuous Batching、Radix Tree)天生不适合静态编译;TPU 上靠 Pallas 写自定义 kernel + 把动态切到张量层来适配。 + +#### 11.1 没有 Pallas 之前:静态连续分配的低效 + +早期 TPU 推理(T5、早期 Pax)走的是"强迫症"路线: + +- XLA 在编译期按 `[Max_BS, Max_SeqLen, Hidden_Dim]` 一次性挖出连续 KV Cache 池 +- 假设 max_seq=4096,每个请求锁死 4096 个 Token 的 HBM 空间 +- 用户请求只有 100 个 Token?剩下 3996 个 slot 全部空跑(浪费 97% 显存) +- HBM 被无效 padding 占满 → batch size 上不去 → MXU 计算裕量充足但请求接不进 → **被内存墙卡死了算力** + +Google 早期靠"钞能力"扛——Pod 总 HBM 大、任务长度可控(翻译/搜索)、算法团队把桶切得极细——硬挺过去了。但长文本和多轮对话普及后这条路走不通。 + +#### 11.2 现代分治:XLA 建池子,vLLM 记账,Pallas 按图索骥 + +vLLM on TPU 现代架构是**控制面(CPU)+ 数据面(TPU)分离**: + +| 角色 | 在哪 | 干什么 | +|---|---|---| +| **XLA** | TPU | 在 HBM 里分配一个一维化的巨大块张量 `[Num_Total_Blocks, Block_Size, Head_Dim]`(比如 10 万个物理块,每块 16 个 Token)。**XLA 不知道里面装的是谁的数据** | +| **vLLM** | CPU | 维护 Radix Tree 和所有内存页表(Block Tables)。每个 step 把当前活跃请求的页表打包成整数 Tensor 喂给 XLA 图 | +| **Pallas Kernel** | TPU | XLA 图里的一个 Custom Call 节点,接收页表后执行底层间接寻址(Gather),把零散物理块拼进 Unified Buffer 做 Attention | + +**物理 HBM 还是 XLA 圈的全局张量,但 XLA 不再管理里面的内容。** vLLM 当调度员每个 cycle 把"寻址地图"发过去,TPU 上的 Pallas 算子照着地图取数据。 + +#### 11.3 三把刀的逐项落地 + +**PagedAttention(分页注意力)** + +- TPU 阵痛:XLA 极度讨厌动态指针寻址,每次 Attention 都查页表会让 DMA 剧本乱套 +- 解法:Pallas Kernel 在硬件寄存器层面手写页表查找 + 离散 Gather +- 结论:完全支持,HBM 碎片问题解决,batch size 提升 + +**Continuous Batching(连续批处理 / Inflight Batching)** + +- GPU 玩法:1D 展平,调度器随时踢掉完成的、塞进新的,绝对动态 +- TPU 玩法(**静态大巴模式**): + - XLA 预编译 `Batch_Size = 256` 的固定图,相当于 256 座大巴永远绕圈 + - 某请求生成到 EOS → vLLM 把座位标空 → 下个 step 把新请求的第一个 Decode Token 塞进刚空出来的索引 + - TPU 只看到完美 `[256, 1, D]` 张量,不知道索引 5 上一毫秒是 A、这一毫秒是 B + - 不够 256 真实请求时用 Dummy Token(全零)填满 + +**Radix Tree(前缀缓存)** + +- 在 TPU 上**完美适用**:本质是 CPU 端的调度算法 +- 命中前缀时 vLLM 只需修改下发的 Block Table 让逻辑块指针指向已有物理块 +- TPU 底层 Pallas 不知道是复用,按地图正常取数即可 + +#### 11.4 Google 自家也用同样的思想 + +JetStream / Saxml 也实现了等价机制(内部叫 **Blocked Attention** 或内置在底层的 **FlashAttention-TPU** 算子里),同样用 Pallas 写。所以 GKE 上跑 vLLM 还是 JetStream,**显存管理思想殊途同归**:HBM 维护离散 Block Pool + CPU 维护页表 + 计算时把页表传给底层 Kernel 做 Gather。 + +#### 11.5 一个核心哲学:FLOPs 换 Control Flow + +**用极其廉价的 FLOPs(算力),去消除极其昂贵的 Control Flow(控制流)。** + +这条贯穿 TPU 整个推理适配。Tree Attention(第 14 章)也是同一思想——把 if-else 编码成 Mask 矩阵,宁可多算多扔,也不让硬件停下来做分支判断。 + +#### 11.6 ↔ GPU + +| 维度 | TPU | GPU | +|---|---|---| +| KV 分页 | XLA 池子 + Pallas Custom Call | vLLM 原生 PagedAttention | +| 调度灵活度 | 固定 batch 桶 + Dummy padding | 动态 1D 展平 | +| Custom kernel 工具 | Pallas | Triton / CUDA | + +--- + +### 12. Prefill / Decode 协同与 Chunked Prefill + +**一句话**:TPU 在 Prefill 强、Decode 弱(HBM 带宽瓶颈),混合执行 + chunked prefill 是用算法补硬件。 + +#### 12.1 算术强度决定 TPU 体感 + +判断硬件适合不适合一个 workload,看 **算术强度 = FLOPs / Byte**(每读写一个字节内存能做多少次浮点运算)。 + +| 阶段 | 数学形式 | 算术强度 | 瓶颈 | TPU 体感 | +|---|---|---|---|---| +| **Prefill** | GEMM(矩阵 × 矩阵,权重被几千 token 复用) | 极高 | Compute-Bound | MXU 跑得极爽 | +| **Decode** | GEMV(矩阵 × 向量,权重读出来只算 1 个 token 就丢) | 极低 | Memory-Bound | MXU 大量饥饿停机 | + +所以 TPU 骨子里是 **Prefill 怪物**,Decode 阶段是被按在地上摩擦后强行优化出来的。 + +#### 12.2 Decode 阶段 TPU 怎么搞 Continuous Batching + +**Decode 阶段不需要给 token 设桶**——每个请求当前要算的新 token 长度恒为 1。设的是 **Batch Size 的桶**。 + +预编译 BS=256 的 Decode 图: + +- 静态输入矩阵 `[256, 1, Hidden_Dim]` +- 200 个真实请求 → 前 200 槽位放真 token,后 56 槽位塞 Dummy Token(全零) +- TPU MXU 算出 256 个结果 +- CPU 调度器只把前 200 个真实结果拿走发用户,后 56 个丢弃 + +**核心难点:256 个请求的历史长度都不一样怎么办?** + +CPU 还要传两个静态大小的整数数组: + +``` +context_lengths : 形状 [256],记录每个请求真实历史长度,如 [105, 3042, 12, ...] + Dummy 槽位的长度填 0 +block_tables : 形状 [256, Max_Blocks],每个请求的 KV 页表 +``` + +Pallas 算子按 `context_lengths` 做循环边界(或 mask),按 `block_tables` 去 HBM gather 对应历史 KV,做 Attention。 + +#### 12.3 Prefill 的难处:每个请求 Prompt 长度差异巨大 + +Decode 可以拼成 `[256, 1]` 的整齐方块,但 Prefill 不行:有人 prompt 100,有人 3000,怎么塞进静态图? + +两种方案: + +- **分桶**:长度 100 → padding 到 128 桶;3000 → 切到 4096 桶 +- **Chunked Prefill**:编译固定长度的 Prefill 图(如 chunk_size=512)。长 1000 的 prompt 切两块 512,分两次塞进同一个 `[1, 512]` 槽位计算 + +#### 12.4 Prefill / Decode 混合:静态 1D 展平 + +最前沿的做法:把 Prefill 长序列和 Decode 单 token 揉进**同一个 step**。 + +XLA 编译时设两个静态上限: + +``` +Max_Total_Tokens = 1024 # 一个 step 最多吞 1024 个 token +Max_Seqs = 256 # 最多 256 个并发序列 +``` + +输入张量从 3D 的 `[Batch, Seq, D]` 展平成 2D 的 `[1024, D]`。 + +CPU 端拼装: + +``` +请求 A (Prefill, 切下 chunk=512) → 数组前 512 位 +请求 B~Z (Decode, 200 个 token) → 紧接 200 位 + 共 712 位 +Dummy Token × 312 → 补齐到 1024 +``` + +**MXU 阶段**:对脉动阵列来说不存在身份差别。一个巨大 `[1024, D] × [D, 4D]` 矩阵乘法全速冲过去,1024 个 token 的 Q/K/V 一次性算出。 + +**Attention 阶段**:到这里就糊弄不过去了—— + +| Token 类型 | 该怎么 Attend | +|---|---| +| Prefill 的 512 个 | **互相**做 Attention,生成新 KV 写回页表 | +| Decode 的 200 个 | 各自用自己的 1 个 Token 去 attend 自己的历史 KV Cache | +| Dummy 的 312 个 | 跳过 | + +CPU 同时传 metadata:`seq_lens = [512, 1, 1, ..., 0, 0]`。Pallas 算子在底层解析 metadata,对 Prefill 块走 FlashAttention 逻辑(Q 向量在 UB 里互相点乘 + 写新 KV),对 Decode 块走 PagedAttention(gather 历史 KV),对 Dummy 块跳过。 + +#### 12.5 GPU vs TPU 混合的差异 + +| | GPU | TPU | +|---|---|---| +| 拼装方式 | 动态:712 → kernel 收 712;下次 850 → 收 850 | 静态:712 → 强行加 312 个 dummy → 1024 | +| 硬件成本 | 调度器开销 | Padding 的 MXU cycle | +| 软件成本 | kernel 灵活性高 | Pallas metadata 路由复杂 | + +**Trade-off**:Padding 浪费一小部分 MXU cycle 是可控的,但避免了重新编译 XLA 图的灾难,同时控制延迟抖动。两害相权取其轻。 + +#### 12.6 ↔ GPU + +GPU 玩"网约车":712 个乘客就发 712 座的车,绝不拉空座。 +TPU 玩"高铁直达专列":班次到点就发,不够人就放假人。 + +--- + +### 13. KV / 内存层次 + +**一句话**:GPU 体系里的 RDMA / GDS / KV offload 在 TPU 上有的天生支持、有的不支持、有的只能走 PCIe 后备。 + +#### 13.1 三件事的对照 + +| 优化技术 | GPU | TPU | +|---|---|---| +| **跨节点通信绕开 CPU**(GPU: RDMA over IB) | GPUDirect RDMA + IB 网卡 | **天生集成**:ICI 网络控制器直接做进 TPU 硅片,根本不需要外部网卡。CPU 完全不知情,零 CPU 周期 | +| **直读存储到加速卡**(GPU: GPUDirect Storage) | NVMe → PCIe → GPU VRAM,绕过 CPU 内存 | **不支持**。TPU 没有直接读硬盘 / 外部网络的接口,必须 CPU 中转:GCS / PD → Host VM DDR → PCIe DMA → TPU HBM | +| **KV Cache offload 到 host DRAM**(GPU: vLLM CPU swap) | HBM → PCIe → CPU DDR | **完全适用**。vLLM Block Manager 在 host CPU 上跑,HBM 满了触发 PCIe DMA 把 KV Block 搬到 host DDR | + +#### 13.2 ICI 比 RDMA 更彻底 + +GPU 的 RDMA:数据 → PCIe → NIC → IB 网络 → NIC → PCIe → VRAM。绕过 CPU 内存,但还是要离开 GPU 芯片走外部网卡。 + +TPU 的 ICI:数据 → 芯片边光模块 → 光纤 → 对端光模块 → 芯片。**根本不离开芯片硅片体系**。每秒几百 GB 的网络风暴里 host CPU 完全不知情。 + +#### 13.3 直读存储缺失为什么能忍 + +冷启动加载权重要把几百 GB 数据从 GCS 拉到 TPU HBM,必须 host CPU 当搬运工。但因为大模型部署是 multi-host 的(如 16 台 VM),16 台机器 CPU **并发**从 GCS 下载权重的不同 shard,总网络带宽极大,工程上可接受。 + +#### 13.4 KV Offload 性能特征 + +PCIe 带宽相对 ICI 来说**很窄**(v4 PCIe Gen4 x16 ≈ 64 GB/s 双向,而 ICI 单链路就能上几百 GB/s)。所以频繁 swap 会显著拖性能。这是防 OOM 的保底策略,不是首选。 + +#### 13.5 ↔ GPU + +GPU 的优势在 GDS(直读存储),TPU 的优势在 ICI(更彻底的网络旁路)。两边各占一半。 + +> **[补充 — Claude 加]** 业界开始有 **Mooncake 类的"分离式 KV pool"**(KV Cache 单独服务化、跨节点共享),目前主要在 GPU 体系。TPU 上对应的工作没看到公开方案。这条不写进正文,仅供你 cluster TL 视角参考。 + +--- + +### 14. Gemini 在 TPU 上的实战妥协 + +**一句话**:MoE 和投机解码这两个推理优化,在 TPU 上都得改算法去迁就硬件。 + +#### 14.1 静态化 MoE:Capacity Factor + +MoE 的天然问题是**动态路由**——你不知道下一个 token 会去找哪个 expert。XLA 不允许动态形状。 + +Gemini 团队的解法: + +- 给每个 expert 设一个严格的 **Capacity Factor**(容量因子 / 静态槽位大小),比如规定每个 expert 一个 step 最多接 64 个 token +- 路由过来的 token 不足 64 个 → 塞 Dummy Padding 凑齐 +- 路由过来的 token 超过 64 个 → 多出的**直接丢弃**(Token Dropping),或者强制走兜底通用网络 + +通过这种粗暴的截断 + 填充,MoE 的动态网络被强制拍平成 XLA 喜欢的静态计算图。 + +代价:偶尔丢 token,模型质量会受影响。Google 在 Gemini 训练时调 Capacity Factor 平衡丢弃率和算力开销。 + +#### 14.2 张量化投机解码:Tree Attention + +传统投机解码:小模型猜 K 个 token → if-else 判断大模型是否同意 → 拒绝则回退。这种 if-else 流让 TPU VLIW 流水线崩溃。 + +Gemini 的解法(**并行验证**): + +- 小模型生成 K 个猜测 token 后,大模型把这 K 个拼成一个 1D 向量 +- 设计一个特殊的 **Tree Attention Mask**(树状注意力掩码),让不相关的节点互相看不到(Mask 乘 0) +- 大模型在一次前向传播里**用一个矩阵乘法一口气把 K 个 token 的概率全部验证完** +- 把猜对的那条路径挑出来,其他废弃 + +这把"分支代码"变成了"小规模 Prefill 矩阵运算"。TPU 的 MXU 又狂喜了。 + +#### 14.3 为什么这些妥协在 GPU 上不是必须 + +- GPU 的 SIMT 调度器擅长 if-else(虽然分支发散会浪费 lane,但比 TPU 强得多) +- GPU 的动态显存可以容忍 expert 容量浮动 +- 所以 GPU 上跑原始版 MoE 路由 + 原始版投机解码也能工作 + +GPU 路线图里也在朝 Tree Attention 等张量化方向走,但不是被硬件逼的,是为了榨更多性能。 + +#### 14.4 一个有趣的趋势 + +正因为硬件极度讨厌分支(TPU 完全不能容忍,GPU 也不喜欢 host-device 高频同步),算法工程师在**绞尽脑汁把控制流(逻辑分支)改写成数据流(矩阵 mask)**。Tree Attention、Masked Attention、谓词执行 (Predication) 都是这个思路的不同形态。 + +核心哲学:**全算再扔比 if-else 便宜**。计算规模有限的浅层分支(投机解码 3-5 步、Causal Mask 半个矩阵)这个套路超划算;但深层嵌套(10+ 层条件树)就 $O(2^N)$ 爆炸。所以新一代芯片设计在追 **硬件原生稀疏支持**——掩码全 0 时电路在物理层面跳过乘加运算(时钟门控断电),既不写 if-else 也不耗能。 + +#### 14.5 ↔ GPU + +| 优化 | TPU 必须改 | GPU 可以不改 | +|---|---|---| +| MoE | Capacity Factor 静态化 | 动态路由可工作 | +| 投机解码 | Tree Attention 张量化 | if-else 也能跑(性能差点) | + +--- From a40bd9a7e53627b0492e52e7b81980b36c35c3df Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 6 May 2026 02:17:49 +0000 Subject: [PATCH 4/8] docs(notes): Fill Parts IV, V, and appendices of TPU deep dive MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Part IV (cluster orchestration, chapters 15-16): - K8s TPU abstraction: device plugin, topology labels, slice CRD, Kueue gang scheduling, TPU Provisioner calling OCS API, OCS-based failure self-healing - Multi-host slice orchestration: LWS/JobSet, SPMD startup mode, failure-radius analysis Part V (system comparison, chapters 17-19): - Programming model chains: GPU's three-segment (CUDA -> NCCL -> IB/RDMA) vs TPU's single-segment (SPMD + ICI as VLIW slot 5), with concrete pseudo-instruction streams and the intersection/highway/factory analogies - Cost and efficiency: NVIDIA Tax breakdown, MFU comparison, Performance/Watt, real-world TCO cases (Midjourney, Character.AI, Waymark) - TPU hardware weaknesses: MXU granularity penalty, weak SPU and speculative-decoding pain, MoE All-to-All congestion on Torus, Decode big-batch dead-ends (KV OOM + token dispersion), HBM bandwidth/compute imbalance, what each weakness buys Appendices: - A: Trade-off cheat sheet indexed by design dimension - B: Numbers and parameters list (all annotated as "from source") - C: TPU-to-GPU terminology mapping Writing log: - 10-item intentional-cuts list for user review - 5 external "[补充 — Claude 加]" callouts collected with chapter references Chinese version is now content-complete. English mirror to follow. Co-authored-by: Claude --- notes/tpu-deep-dive.cn.md | 456 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 456 insertions(+) diff --git a/notes/tpu-deep-dive.cn.md b/notes/tpu-deep-dive.cn.md index d8a4a88..cce2c92 100644 --- a/notes/tpu-deep-dive.cn.md +++ b/notes/tpu-deep-dive.cn.md @@ -937,3 +937,459 @@ GPU 路线图里也在朝 Tree Attention 等张量化方向走,但不是被硬 | 投机解码 | Tree Attention 张量化 | if-else 也能跑(性能差点) | --- + +## Part IV — 集群层适配(目标 D) + +### 15. K8s 上的 TPU 抽象 + +**一句话**:K8s 看不见光,所以 OCS 切分必须由独立组件负责,TPU device plugin + 拓扑标签 + Kueue + TPU Provisioner 组成完整链路。 + +#### 15.1 资源暴露:从芯片到 Node + +物理上 TPU 芯片不直接是 Node。每颗 TPU(或 4/8 颗一组的主板)挂在一台 Host VM 上,VM 上跑 Kubelet。 + +- **TPU Device Plugin**:Kubelet 加载,向 API Server 汇报扩展资源 `google.com/tpu: 4` +- **拓扑标签**:仅知道有几颗 TPU 不够,TPU Controller 给 Node 打详细 label: + +```yaml +cloud.google.com/gke-tpu-topology: 2x2x4 # 当前节点属于一个 16 芯切片 +cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice +``` + +K8s etcd 里这些带 label 的 Node 组成逻辑资源池。 + +#### 15.2 用户接口:Slice CRD + +不能写普通 Deployment,要用包装好的 Job 或专用 `TPUSlice` CRD: + +```yaml +nodeSelector: + cloud.google.com/gke-tpu-topology: 4x4x4 +resources: + requests: + google.com/tpu: 64 +``` + +意思:「我要 64 颗 TPU,物理连线必须是 4×4×4 闭环网络。」 + +#### 15.3 Gang Scheduling:Kueue + TPU Provisioner + +原生 `kube-scheduler` 做不到「要么同时拿到 64 个特定拓扑节点,要么一个都不要」。这是**帮派调度(Gang Scheduling)**的活,必须靠 Kueue 这类批处理调度器。 + +完整拉起一个 64 芯切片的链路: + +| Step | 谁干 | 干什么 | +|---|---|---| +| 1 | Kueue | 拦截 Job,发现要 4×4×4。资源池凑不齐 → Pending | +| 2 | Kueue / Cluster Autoscaler | 决定调度时通过 TPU Provisioner 触发 | +| 3 | **TPU Provisioner** | **绕过 K8s 控制面**,直接调底层数据中心的 OCS 硬件 API:「转镜子,给我拼 4×4×4 Torus」 | +| 4 | OCS | 几秒内调好镜片角度,光路锁定 | +| 5 | Kubelet | 探测到 ICI 链路接通,更新 Node Label | +| 6 | Kueue | 确认硬件就绪,把 64 个 Pod 一次性绑定到 64 个 Node | + +**关键**:TPU Provisioner 是 K8s 之外、连到数据中心物理层的独立组件。这是「K8s 看不见光」必然要做的妥协。 + +#### 15.4 故障自愈:OCS 绕过坏点 + +长达几个月的预训练或高可用推理服务里,TPU 硬件(HBM 降级、光模块烧毁)故障必然发生。 + +| 步骤 | 动作 | +|---|---| +| 1 | Kubelet 健康检查脚本发现硬件报错,向 API Server 报 Node 故障 | +| 2 | K8s 终止整个 Job 的 64 个 Pod | +| 3 | TPU Controller 标记坏芯片为 Unhealthy,从可用池剔除 | +| 4 | Controller 再呼叫 OCS:「绕过那个坏点,从池子里拉一颗新芯片,重新调几面镜子,再缝合一个 4×4×4 环」 | +| 5 | 光路缝合完毕,Job 重启,加载上一个 checkpoint | + +整个从硬件损坏到重新调度通常**几分钟内**全自动闭环。 + +#### 15.5 ↔ GPU + +| 维度 | TPU on GKE | GPU on K8s | +|---|---|---| +| Device Plugin | TPU Device Plugin | NVIDIA Device Plugin | +| 拓扑感知 | 标签 + TPU Provisioner 调 OCS | NVIDIA Topology Aware Scheduling、`gpu-feature-discovery` | +| Gang Scheduling | Kueue | Volcano、Kueue、KubeFlow | +| 物理拓扑重配 | OCS 动态拼接 Torus | (无对应;NVLink 是死的,IB fat-tree 不需要重配) | + +**Trade-off**:TPU 的拓扑硬约束让 K8s 抽象更复杂(必须拉个 Provisioner 出来),但换来动态切片能力。GPU 简单很多,因为没光要管。 + +--- + +### 16. Multi-host slice 的编排 + +**一句话**:一个 slice 跨多 host 时,K8s 看到的是 N 个 Pod 的协同启动,每个 Pod 内的 TPU 又是 4 / 8 个芯片的本地组——这是两层 N:N。 + +#### 16.1 LWS 与 JobSet + +Kubernetes 官方为了表达这种 multi-host 算力,提供两个 API: + +- **LeaderWorkerSet (LWS)**:一个 Leader Pod + 多个 Worker Pod,整组生命周期一致。Leader 通常对外暴露推理 API +- **JobSet**:多组 Job 的协同管理,更通用 + +64 芯 v4-64 切片对应: + +``` +LWS: + size: 16 # 16 台 Host VM + leader: + replicas: 1 + containers: # 跑 vLLM API + 调度器 + worker: + replicas: 15 + containers: # 跑同一份 Python 代码(SPMD) +``` + +#### 16.2 SPMD 启动模式 + +16 个 Pod 跑**同一份 Python 代码**(SPMD = Single Program Multiple Data)。代码内部读环境变量(`LWS_LEADER_ADDRESS`、`LWS_WORKER_INDEX`)确定身份,然后: + +- Leader:起 HTTP/gRPC 服务接客 +- 所有 Pod:通过 PyTorch/XLA 或 JAX 初始化分布式 group,自动建立 ICI 通信 +- 所有 Pod 上的 CPU 同时调度自己脚下 4 颗 TPU 做计算 +- 64 颗 TPU 通过 ICI 完成 All-Reduce 等集合通信 +- Leader 收集结果返回客户 + +#### 16.3 调度耦合点:哪一层 fail 会拖整个 slice + +| 故障层 | 后果 | +|---|---| +| 单颗 TPU 物理坏 | 整个 ICI 环网断 → 64 个 Pod 集合通信挂起 → Job 失败 → 走 15.4 自愈流程 | +| 单台 Host VM 网卡或 Kubelet 故障 | 该 Pod 上的 4 颗 TPU 失联 → 同上 | +| DCN(CPU 间以太网)抖动 | Leader↔Worker 控制面同步延迟 → 推理服务延迟抖动,但数据面(ICI)不受影响 | +| 单台 VM CPU 过载 | 该 Pod 喂数据慢 → 拖整个 step(最慢的那台决定整体速度) | + +故障半径基本上是 **slice 级别**——任何一个 host fail 都让 64 颗 TPU 无法继续。所以 LWS / JobSet 用 gang scheduling,要么全活要么全死。 + +#### 16.4 ↔ GPU + +| 维度 | TPU multi-host | GPU multi-host | +|---|---|---| +| 编排 API | LWS / JobSet | MPI Operator、Training Operator、Ray on K8s | +| 启动模式 | SPMD | MPI / NCCL group | +| 故障半径 | Slice 级(光环必须完整) | 通常 job 级(fat-tree 容忍单 NIC 故障) | +| Leader 角色 | 通常做 API gateway | MPI rank 0 或 PS-Worker 中的 PS | + +**Trade-off**:TPU 的 SPMD + LWS 模型简洁,但故障半径大;GPU 的 MPI 模型灵活但配置繁琐。 + +--- + +## Part V — 系统对比与权衡(目标 B 集中点) + +### 17. 编程模型链:从单卡到多机的指令链 + +**一句话**:GPU 是「单卡 CUDA → 多卡 NCCL → 多机 IB/RDMA」三段;TPU 是「SPMD → ICI(VLIW 第五槽)」一段,编译器统管。 + +#### 17.1 GPU 的三段式 + +**单卡(Tensor Core)** + +``` +Python (PyTorch) → ATen → cuBLAS / Triton → SASS / PTX + ├ LDG.E (HBM → 寄存器) + ├ STS / LDS (Shared Memory 中转) + └ HMMA.1688.F16 (Tensor Core 触发 16×8×16 半精度乘加) +``` + +数据流:**HBM → 寄存器 → Shared Memory → 寄存器 → Tensor Core**,反复横跳。 + +**节点内多卡(NVLink + Copy Engine)** + +``` +NCCL → CUDA Kernel + Copy Engine +DMA: GPU0_HBM → NVLink 总线 → GPU1_HBM +Reduction: GPU1 SM 做 LDG / FADD / STG(数据落地 HBM 后才加) +``` + +NVLink 提供统一虚拟寻址(UVA),但**数据必须先落地到接收方 HBM**,然后接收方 SM 从 HBM 读出来加,再写回 HBM。吃掉大量 HBM 带宽。 + +**节点间多机(IB + GPUDirect RDMA)** + +``` +MMIO 写 NIC Doorbell → NIC DMA 读 GPU HBM → IB 包封装 → 经 Spine-Leaf 交换机 → 对端 NIC 解包 → DMA 写对端 HBM +同步: 接收端 GPU CUDA Kernel 在 HBM 同步 flag 上 spin-wait(LDG.CG 绕过 cache) +``` + +控制面有 CPU 中断、协议封装、路由查询;数据面有交换机拥塞控制和排队。**异步事件驱动**。 + +#### 17.2 TPU 的一段式 + +``` +JAX / PyTorch (via PyTorch/XLA) → HLO → XLA → VLIW 五槽指令流 + ├ DMA (HBM → UB) + ├ MXU (脉动阵列乘加) + ├ VPU (向量运算) + ├ ICI (跨芯片通信,跟 DMA 平级) + └ SPU (控制流) + +跨芯片通信: + TX_UB_TO_NEIGHBOR (Src: UB_local, Dest_Node: Neighbor_ID) + WAIT_ICI_RX_SEMAPHORE + ADD_VECTOR (Src1: UB_local, Src2: UB_remote, Dest: UB_result) +``` + +跨芯片传数据和片内搬砖**没有逻辑区别**,都是 VLIW 指令字上的一个开关。XLA 可以让 MXU 算乘法的同时 ICI 传上一层梯度,时钟周期级 overlap。 + +#### 17.3 对比 + +| 维度 | GPU | TPU | +|---|---|---| +| 跨节点通信本质 | 异步 IO(CPU 中断、协议、交换机) | 同步指令(VLIW 一槽) | +| 接收端处理 | spin-wait HBM flag | 硬件信号量瞬间唤醒 | +| 数据落地点 | HBM(必经收费站) | UB(直接送 VPU) | +| 编译器视角 | 看不到跨节点 | 看到所有跳数和延迟 | +| 故障容忍 | 节点级隔离 | Slice 级紧耦合 | + +#### 17.4 三种比喻 + +- **单机 GPU 计算**:极度繁忙的立交桥路口(庞大缓存 + 调度器)。拥堵但靠精妙红绿灯(Warp Scheduler)保证吞吐 +- **跨机 GPU RDMA**:跨省高速公路物流。打包封装 → 上高速 → 下高速。路桥费(协议开销)+ 不可预知塞车(拥塞) +- **TPU Pod (VLIW + ICI + 3D Torus)**:极其庞大的全自动流水线工厂。所有传送带(光纤)和机械臂(MXU/VPU)都是硬连线。XLA 是排班表,确保每个零件在精准的周期到达指定工位 + +--- + +### 18. 成本 / 能效 + +**一句话**:MFU 和 Tokens/$ 这两个指标是衡量真实账面差异的杠杆,不是芯片峰值算力。 + +#### 18.1 算力单价:NVIDIA Tax + +| 维度 | NVIDIA H100 | Google TPU v5p | +|---|---|---| +| 等效算力硬件成本(业界拆解估算) | **~$21,000** 以上 | **~$6,900** | +| 中间 ~3× 差价 | 俗称 **NVIDIA Tax(英伟达税)** | | + +云端按需价: + +| 形态 | 单价 | +|---|---| +| 8 卡 H100 VM(Azure / GCP) | **$100 - $120 / 小时**(单芯 ~$12-$15) | +| TPU v5e(推理专用) | **~$1.20 / 小时**(单芯) | +| 8 芯 v5e 节点 | **$10 - $11 / 小时** | + +#### 18.2 MFU(Model FLOPs Utilization) + +实际跑出的 TFLOPs ÷ 硬件理论峰值: + +| 芯片 | LLM 训练典型 MFU | +|---|---| +| H100 | 50% - 52%(线程调度、缓存竞争、复杂控制流损耗) | +| TPU v5p | 58% - 60% 甚至更高(XLA 静态编排 + ICI 确定性延迟) | + +#### 18.3 能效比(Performance / Watt) + +- **H100 TDP 700W**:要给 L1/L2 cache 和乱序调度器供电 +- **TPU 砍掉了这些模块**,全靠极简 MXU 脉动阵列。统计显示对特定大规模矩阵 workload,TPU v5e/v5p 能耗比 GPU 集群低 **60% 到 65%**(某些场景能效优势达 H100 的 2-5 倍) + +省电费 + 降低数据中心散热 / 基建成本。 + +#### 18.4 真实业务的 Tokens / Dollar + +| 案例 | 改善 | +|---|---| +| **大模型预训练**(H100 → TPU v5p) | Tokens / $ 高 **15% - 25%** | +| **Midjourney**(图像生成,迁到 TPU v6e) | 推理账单从 **$2.1M / 月** 骤降至 **<$0.7M / 月**,降本 **3 倍**,吞吐不变 | +| **Character.AI**(高并发对话) | 转 TPU 后成本改善 **3.8 倍** | +| **Waymark**(视频扩散模型) | 成本比 H100 低 **4 倍** | + +#### 18.5 一句话总结行业格局 + +- 需要快速迭代、频繁改算子、依赖 PyTorch/CUDA 复杂生态的研究团队 → GPU +- 超大规模、结构稳定的预训练 / 亿级用户高并发 LLM 推理 → TPU 集群 + +> **[补充 — Claude 加]** 上述案例(Midjourney $2.1M→$700K、Character.AI 3.8×、Waymark 4×)的具体出处和年份在原对话里没给,仅说"公开报告显示"。我没有别的可信来源验证具体数字,请你核实。 + +--- + +### 19. TPU 的硬件劣势与权衡 + +**一句话**:每个静态调度的优势都对应一个不擅长的工作负载——MXU 粒度大、SPU 弱、3D Torus All-to-All 拥塞、HBM 带宽 vs 算力失衡。 + +#### 19.1 计算粒度:MXU 128×128 的"碎片化"惩罚 + +GPU Tensor Core 16×8×16,TPU MXU 128×128。当 batch size 是 5 这种不能对齐 128 的真实需求: + +- GPU:Warp Scheduler 把碎片紧凑塞进 SM,硬件利用率依然不低 +- TPU:硅片上 MXU 大量物理 ALU **真的在算 0×W=0**,白白消耗 cycle + +在高频、小并发的低延迟推理请求下,TPU 的物理算力被严重浪费。 + +#### 19.2 标量与分支控制羸弱:采样与投机解码的痛点 + +LLM 不只有矩阵乘——生成 token 的最后一步是**采样**(Top-K、Top-P、温度、惩罚因子),涉及大量排序、条件分支、标量运算。 + +- GPU:海量 CUDA Core + SIMT 分支预测,能并发万级线程做带逻辑判断的数组操作 +- TPU:SPU 算力极弱、VPU 只擅长规整向量。面对 if-else 采样效率极低 + +更致命的场景是**投机解码原始版**——硬件要在极短时间动态判断哪些 token 猜对、随时丢弃部分计算图。这种"走一步看一步"的动态计算图是 TPU VLIW 静态指令集的克星。所以 Gemini 用 Tree Attention 张量化(第 14.2 节)把这类计算硬掰成矩阵乘。 + +#### 19.3 动态网络路由:MoE All-to-All 的拥塞 + +MoE 的核心是**动态路由**——每个 token 在运行时被路由到不同 expert: + +| 集群类型 | All-to-All 表现 | +|---|---| +| GPU(NVSwitch + IB fat-tree) | 任意 N:N 通信都能提供无阻塞全交叉带宽,对 MoE 这种乱序、动态的数据包分发友好 | +| TPU(3D Torus 环网) | 静态 All-Reduce 天下无敌;但 All-to-All 时 token 要跨 X/Y/Z 多个中继芯片找 expert,**部分链路被挤爆,其他闲置**,端到端延迟拖慢 | + +#### 19.4 Decode 大 Batch 救不了 MoE:两个死局 + +直觉:增大 batch size 让 MoE 通信密度上升、计算可以掩盖通信。但有两个物理死局: + +**死局一:Token 被打散,矩阵依然小** + +并发 512 个请求 → MoE 层 → 路由给 8 个 expert → 平均每个 expert 只分到 64 个 token。`[64, D] × [D, 4D]` 对 MXU 来说是「塞牙缝」,远没到能掩盖通信延迟的体量。 + +**死局二:KV Cache 撑爆 HBM** + +不可能无限增大 batch。如果开到能让 MoE 算力饱和的程度(几千),HBM 早就 OOM 了。 + +**结论**:Decode 阶段 MoE 通信开销只能缓解,不能完全掩盖。 + +#### 19.5 Multi-hop 影响延迟还是带宽 + +在 3D Torus 上跨多机架找 expert: + +- **小 batch 时**:主要影响**延迟**。光信号中继 + 收发的物理延迟逃不掉 +- **大 batch 时**:致命的是**双向对分带宽(Bisection Bandwidth)**。token 散乱地往各方向挤,部分光纤瞬间过载,实际可用带宽暴跌 + +#### 19.6 计算掩盖通信的数学条件 + +为什么 Prefill 能掩盖通信、Decode 不能?看维度: + +| 阶段 | 计算量 | 数据传输量 | 比值 | +|---|---|---|---| +| **Prefill**(GEMM) | $O(N^3)$ | $O(N^2)$ | 计算时间 ≫ 网络时间,DMA 后台搬数 MXU 完全无感 | +| **Decode**(GEMV) | $O(N^2)$ | $O(N^2)$ | 比值约为 1,MXU 瞬间算完后只能停机干等 | + +#### 19.7 HBM 带宽 vs 算力失衡(呼应 6.4) + +第 6 章的物理定律:算力 $O(N^2)$ 涨,传统带宽 $O(N)$ 涨。v4 时代严重失衡,Decode MFU 跌到个位数。 + +Google 的硬件级补救:**v5e 故意缩小 MXU**,降低峰值 FLOPs,让 compute / bandwidth 比例回到健康区间。 + +算法级补救:**MQA / GQA**(多查询 / 分组查询注意力)——大幅缩小 KV Cache 体积,减少每次 Decode 从 HBM 捞数据的压力。这纯粹是为了迁就可怜的内存带宽对模型架构做的让步。 + +#### 19.8 每条劣势换来了什么 + +| 劣势 | 换来 | +|---|---| +| MXU 粒度大 | 同硅面积塞下更多算力单元,能效高 | +| SPU 弱 | 省下的晶体管堆给 MXU | +| 3D Torus 不擅长 All-to-All | 极致的规整集合通信效率,无外部交换机开销 | +| HBM 带宽跟不上 | 算力密度爆表,跑 Prefill / 训练时 MFU 高 | + +每条短板都对应一个 trade-off。理解了这些权衡才能判断什么 workload 该上 TPU、什么该留 GPU。 + +--- + +## 附录 + +### A. Trade-off 速查表 + +按设计决策维度横切,每条 trade-off 链接回原章节。 + +| 维度 | TPU 选择 | 代价 | 收益 | 章节 | +|---|---|---|---|---| +| **静态 vs 动态调度** | 静态 VLIW + XLA | 编译开销大、shape 变化要重编 | 硬件极简、能效高 | Ch 1, 7, 8 | +| **缓存 vs 直传** | 取消硬件 Cache,用 UB + DMA | 软件管理复杂 | 节省硅面积给 MXU | Ch 1, 6 | +| **粒度大 vs 小** | MXU 128×128 | 小矩阵浪费 | 大矩阵密度高 | Ch 1, 19 | +| **环 vs 树** | 3D Torus | All-to-All 拥塞、长边大圈延迟 | All-Reduce 极致、无外部交换机 | Ch 2, 4, 19 | +| **物理 vs 光路** | OCS 重配 | 切分粒度受机架限制 | 拓扑动态可重配、故障自愈 | Ch 3, 15 | +| **集中 vs 分布编排** | Multi-host SPMD | 故障半径 = slice | 编排简洁、SPMD 透明 | Ch 5, 16 | +| **算力 vs 带宽** | 算力堆面积,带宽吃封装 | 内存墙、Decode MFU 低 | Prefill / 训练 MFU 高 | Ch 6, 19 | +| **专用 vs 通用** | Pallas 写 PagedAttention | 工程门槛高,每个动态算子都要手写 | 突破 XLA 静态限制 | Ch 11, 14 | +| **Padding vs 重编译** | 保 batch 桶,不够用 dummy | 浪费小部分算力 | 避开重编译灾难 | Ch 8, 12 | +| **算法迁就硬件** | Capacity Factor、Tree Attention | 模型架构有 token dropping、增加 mask 复杂度 | TPU 上能跑通 MoE 和投机解码 | Ch 14 | + +### B. 数字 / 参数清单 + +所有数字都标"源自原对话"。 + +| 项 | 值 | 说明 | +|---|---|---| +| TPU v4 Pod 规模 | **4096 颗芯片** | 64 机架 × 64 芯 | +| TPU v5p Pod 规模 | **8960 颗芯片** | 满规模 | +| 单机架芯片数 | **64 颗**(v4 水冷) | 4 颗/板 × 16 板 | +| MXU 尺寸 | **128 × 128** MAC 单元 | v4 / v5p | +| 4×4×4 机架对外光纤数 | **96 根** | 8 角×3 + 24 棱×2 + 24 面心×1 | +| 表面 TPU 数 | **56 颗** | 64 - 内部 8 颗 | +| 等效 H100 算力硬件成本 | **~$21,000** | NVIDIA 售价 | +| 等效 TPU v5p 硬件成本 | **~$6,900** | Google 内部 | +| 8×H100 VM 按需价 | **$100-120 / 小时** | Azure / GCP | +| 8×TPU v5e 节点按需价 | **$10-11 / 小时** | Google Cloud | +| H100 TDP | **700W** | | +| H100 训练 MFU | **50% - 52%** | 大型 LLM 集群 | +| TPU v5p 训练 MFU | **58% - 60%** 或更高 | 同等任务 | +| TPU 能效优势 | 比 GPU 集群低 **60-65% 能耗** | 特定大规模矩阵 workload | +| Tokens / $ 优势(预训练) | TPU 高 **15% - 25%** | H100 → v5p | +| Midjourney 推理账单 | **$2.1M → <$0.7M / 月** | 迁 TPU v6e,3× | +| Character.AI 成本改善 | **3.8×** | 转 TPU | +| Waymark 视频生成 | **4×** | 比 H100 低 | +| CPU:TPU 比例 | **1:4 或 1:8** | 物理焊死 | +| ICI 单链路带宽量级 | (原对话未给数字) | 公开资料 v4 ≈ 4.5 TB/s 多向合计,待你确认 | +| H100 L2 Cache | **50 MB** | 上下文中提到 | +| MoE Capacity Factor 例 | 每 expert 64 个 token | 原对话举例 | +| v5p 长边可能尺寸 | **35** | 16×16×35 ≈ 8960 估算 | +| 跨机架 Z 轴 8 跳大圈 | **6 跳铜缆 + 2 跳光路** | 4×4×8 切片 | +| 光路 vs 铜缆延迟 | 光路 **几百 ns**,铜缆 **个位到十位 ns** | NUCA 异构 | + +### C. 术语 ↔ GPU 等价物对照 + +| TPU 术语 | GPU 等价物 | 备注 | +|---|---|---| +| **MXU** | Tensor Core | TPU 单个大阵列 vs GPU 多个小阵列 | +| **VPU** | CUDA Core(部分) | TPU 偏向规整向量 | +| **SPU** | 标量调度 + 寄存器 | TPU 控制流弱 | +| **Unified Buffer (UB)** | Shared Memory + L1/L2 | TPU 软件管理;GPU 硬件管理 | +| **HBM** | HBM | 一样 | +| **ICI** | NVLink + IB(合并对应) | TPU 节点内外同一套,GPU 分两层 | +| **OCS** | (无对应) | 唯一 | +| **3D Torus** | Fat-Tree(不同思想) | 拓扑哲学不同 | +| **SPMD** | NCCL collective + MPI rank | TPU 编译器统管 | +| **VLIW 五槽** | (无完全对应) | GPU 是 SIMT | +| **XLA** | TorchInductor / TVM / TensorRT | TPU 的核心;GPU 是可选 | +| **HLO / StableHLO** | FX Graph / TorchScript | XLA 的 IR | +| **Pallas** | Triton | 自定义 kernel 语言 | +| **JetStream** | TensorRT-LLM | 厂商特化推理引擎 | +| **Saxml** | DeepSpeed Inference(部分) | JAX 生态历史 | +| **TPU Provisioner** | (无对应) | 最接近的是 Slurm topology + 手工 NCCL | +| **TPU Device Plugin** | NVIDIA Device Plugin | K8s 概念一样 | +| **LWS / JobSet** | MPI Operator / Training Operator | Multi-host 编排 | +| **DCN(数据中心网络)** | 控制面 IB / Ethernet | TPU 用以太网,GPU 也常用 IB | +| **Capacity Factor**(MoE) | (无 TPU 特有的对应;GPU 上 Capacity Factor 概念存在但不强制) | 静态化 trick | +| **Tree Attention**(投机解码) | 同名 | 算法层;现在 GPU 也用 | +| **Bucketing + AOT** | torch.compile + persistent cache | TPU 必须;GPU 可选 | +| **NUCA 拓扑感知映射** | NCCL 拓扑发现 | TPU 由 XLA 自动;GPU 半手工 | + +--- + +## 写作日志(让作者验收用) + +### 主动取舍清单(原对话里有但没进笔记) + +下面这些细节我看到了但没收进正文,请你决定要不要加回。 + +1. **vLLM `model.generate()` 的 token 生成调用栈**(原对话第 850-870 行附近):详细描述了 `Lazy Tensor` 机制下 PyTorch eager 一行一行调 ATen 的过程。我觉得是对 Eager 模式的 GPU 端解释,不是 TPU 笔记重点,简化成了 Ch 8.1 表格里的"PyTorch Lazy Tensor"一行 +2. **GPU 单卡 SASS 指令的具体例子**(原对话第 700-720 行):`LDG.E`、`STS [Shared_Addr_A]`、`HMMA.1688.F16` 这些指令名我只在 Ch 17.1 用了一次。如果你想保留更详细的 GPU 指令对照可以扩 +3. **SPMD 启动方式的 Pod ID 读取细节**(原对话第 604 行):「代码内部读取硬件 Device_ID 决定加载哪块数据」。我在 Ch 16.2 简化成"读环境变量确定身份",没区分 TPU device ID 和 K8s pod 环境变量两个层级。要不要展开? +4. **TPU v6e(Trillium)**:原对话 1707 行提到这是「第六代」纯推理芯片,但没具体规格。Midjourney 案例里出现过它的名字。我没单开一节。如果你想要 v6e 专门的设计取舍可以扩 +5. **OCS 内部的 MEMS 阵列双层结构**(原对话 1521-1538 行):第一面镜子瞄准方向,第二面镜子做光束矫正、防止光纤入口角度偏差衰减。我在 Ch 3.1 合并成了「输入 → 反射 → 反射 → 输出」一段,没强调"双层 MEMS 必要性"。要不要细化? +6. **OCS 对外的物理布线(MPO/MTP 高密度光缆)**:原对话 2278 行讲了机房里 96 根光纤是用 MPO/MTP 一束 16/32 芯的高密度光缆汇聚。我提了一句但没展开物理形态。 +7. **Capacity Factor 的代价细节**:原对话 1628-1631 行提到 token dropping 或者「强制走兜底通用网络」。我只写了 dropping,没写兜底网络这条。 +8. **采样的具体细节**(Top-K / Top-P / 温度 / 惩罚因子):Ch 19.2 提了这是 SPU 弱点,但没展开各采样算法的硬件代价 +9. **Cerebras / Groq 的设计**(原对话 1500-1510 行):作为「别的厂在干什么」的旁支,提到了 Wafer-Scale 和 LPU。我没收进正文,因为不是 TPU 主题。要不要加个附录 D「行业横向对照」? +10. **"硬件级稀疏支持 / 时钟门控断电"(原对话 1810 行)**:作为下一代芯片的演进方向。我在 Ch 14.4 一笔带过了,没专门小节 + +### 外部补充清单(Claude 加,原文未提及) + +以下是我在正文里用 `> **[补充 — Claude 加]**` 显式标注的内容: + +- Ch 2.1:ICI 单链路 4.5 TB/s 量级(公开资料,原文未给数字) +- Ch 3.6:微软 Azure 部分集群试用 Lumen OCS(行业旁注,未进正文) +- Ch 6.4:v4 时代 MFU 跌到个位数无别的来源验证 +- Ch 13.5:Mooncake 类 KV pool 的提及(未进正文) +- Ch 18 末尾:Midjourney / Character.AI / Waymark 的具体数字源对话只说"公开报告显示",没给年份和出处,请核实 + +如有遗漏的取舍点你审 PR 时随时提出来,我在 PR 里直接补。 + From 73253dfdca3d2101f0d2dcd64302e1eb70aa7a66 Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 6 May 2026 03:22:47 +0000 Subject: [PATCH 5/8] docs(notes): Add English mirror Part I (hardware layer) 1:1 structural translation of notes/tpu-deep-dive.cn.md Part I per GEMINI.md section 1: idiomatic + concise English, every chapter, section, bullet, and bolded term mirrored from the Chinese source. Six chapters covering single chip (MXU/VPU/SPU systolic array, walk-through, VLIW packs), inter-chip ICI + 3D Torus, OCS optical switching with 96-fiber geometry, 3D-decomposed Ring All-Reduce with NUCA, host integration including PCIe/NUMA/multi-host slice, and advanced packaging from Wire Bonding to TSV. External-knowledge callouts ported verbatim from Chinese for review consistency. Parts II-V and appendices to follow. Co-authored-by: Claude --- notes/tpu-deep-dive.md | 434 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 434 insertions(+) create mode 100644 notes/tpu-deep-dive.md diff --git a/notes/tpu-deep-dive.md b/notes/tpu-deep-dive.md new file mode 100644 index 0000000..03831d9 --- /dev/null +++ b/notes/tpu-deep-dive.md @@ -0,0 +1,434 @@ +# TPU Deep Dive: From Single Chip to Production Cluster + +> **Purpose**: A reference for myself spanning hardware principles all the way to inference and cluster-layer adaptation, focused on **why each thing is designed the way it is** rather than terminology explanation. +> **Comparative target**: Each chapter ends with a `↔ GPU` subsection for nearby comparison. +> **About "additions"**: Every block tagged `> **[Note — added by Claude]** ...` is content not present in the source conversation. Please review and decide whether to keep. + +--- + +## Part I — Hardware: Chip, Interconnect, Packaging + +### 1. The Single Chip: MXU, VPU, SPU, and the Systolic Array + +**One sentence**: TPU bakes matrix multiplication into a dedicated circuit (MXU); the systolic array "passes data" inside the chip rather than "passing addresses", which removes the cache hierarchy. + +#### 1.1 The Systolic Array: Stepping Around the von Neumann Bottleneck + +In a general-purpose architecture, each ALU op fetches operands from a register or L1 and writes back. LLM inference demands hundreds of GB/s of matmul throughput, and this "compute once, hit memory once" pattern gets pinned by bandwidth. + +TPU's break: arrange compute units into a 2D grid (the v4 MXU is 128×128 MAC cells) and let data flow through like a pumping heart: + +- **Weight stationary**: Before compute starts, weights are locked into each MAC cell. +- **Data systole**: Activations stream in from the left and from above, advancing right and down each cycle. +- **Hard-wired neighbors**: After each cell does its multiply-accumulate, the result is **passed directly through physical wires** to the adjacent cell — no register write. + +Across the array, data is reused hundreds to thousands of times without ever touching SRAM. That's why, on the same silicon area, TPU packs far more pure compute units than a GPU. + +#### 1.2 A Concrete 2×2 Walk-Through + +Let $Y = X \times W$, with both activation and weight as 2×2, and the MXU as a 2×2 grid in **weight-stationary** mode: + +``` +Weight layout: PE(1,1)=W11 PE(1,2)=W12 + PE(2,1)=W21 PE(2,2)=W22 + +Activation in: Rows of X enter from the left; row 2 lags row 1 + by one cycle ("skewed" entry) +Partial sums: Flow downward +``` + +Cycle by cycle: + +- **Cycle 1**: $X_{11}$ enters PE(1,1), computes $X_{11} \times W_{11}$, passes the partial sum down and $X_{11}$ right. +- **Cycle 2**: $X_{11}$ flows to PE(1,2); $X_{21}$ enters PE(2,1), adds the incoming $X_{11} \times W_{11}$ to its own $X_{21} \times W_{21}$ — at this moment, what PE(2,1) outputs downward is the final $Y_{11}$. +- **Cycle 3**: Bottom-right PE(2,2) collects both inputs and outputs $Y_{12}$. +- **Cycle 4+**: Remaining results "drip" out from the array's bottom in sequence; no register lookup at any step. + +#### 1.3 Macro Architecture: Who Else Lives Beside the MXU + +The MXU alone can't run a full network. A TPU core also has: + +| Component | Role | +|---|---| +| **MXU** | Dense matrix multiply-accumulate; the algebraic muscle | +| **VPU** | Vector ops: activations (GeLU/ReLU/Softmax), LayerNorm — anything that doesn't compress to a matmul | +| **SPU** | Scalar and control flow: loops, branches, address computation. Weak in raw ALU, but the conductor | +| **Unified Buffer** | Tens of MB of on-chip SRAM, holds activations staged from HBM and intermediate results | + +Instructions follow a **CISC** style: a single `MatrixMultiply` triggers thousands of MACs in the MXU, keeping decode/dispatch overhead extremely low. + +#### 1.4 SPU and VLIW: With Hardware Static, Who Runs Control Flow? + +The MXU on its own doesn't know what to compute or where data comes from. That coordination falls to SPU + **VLIW (Very Long Instruction Word)**: + +XLA packs multiple non-conflicting operations into one long instruction. A single instruction encodes four things: + +``` +[ DMA ] | [ MXU ] | [ VPU ] | [ SPU control flow ] +``` + +Example: `[DMA stage next activation block] | [MXU compute current block] | [VPU apply ReLU on previous result] | [SPU update address]`. + +How the SPU works: + +- **Fires before results land**: Pushes control signals into the MXU's FIFO; as long as the queue isn't full, SPU runs ahead. +- **Synchronization via barriers**: When you must wait for the MXU before continuing, the compiler inserts `WAIT_MXU_DONE` — a hardware-level stall. +- **Doesn't compute**: SPU's compute power is intentionally tiny; with no animation duties beyond loops, addresses, and barriers, its overhead is negligible. + +The entire timing is determined 100% before the program runs. + +#### 1.5 ↔ GPU + +| Dimension | TPU | GPU | +|---|---|---| +| Matrix unit | One giant 128×128 MXU | Many small Tensor Cores (16×8×16) scattered across SMs | +| Scalar/vector unit | Weak SPU, mid-tier VPU | Many CUDA Cores, strong branch prediction | +| Data buffer | Unified Buffer (software-managed via DMA) | L1/L2 Cache (hardware prefetch) | +| Instruction model | VLIW, scheduled at compile time | SIMT, switched at runtime by Warp Scheduler | +| Hiding memory latency | Static pipelining | Massive thread-level concurrency | + +**Trade-off**: TPU hands flexibility to the compiler and lets hardware execute literally; GPU keeps flexibility in hardware and accepts that a chunk of silicon goes to scheduling. + +--- + +### 2. Inter-Chip Interconnect: ICI and the 3D Torus + +**One sentence**: A single chip is fast, but what makes a TPU a TPU is the inter-chip fabric — ICI is the physical layer, 3D Torus is the logical topology. + +#### 2.1 ICI: Networking Built Into the Silicon Edge + +Each TPU integrates a set of high-speed interfaces, **ICI (Inter-Core Interconnect)**, right at the silicon edge: + +- **No external switches**: Inside the Pod, data movement does not pass any traditional networking gear; it hops directly from one chip's optical module to the next. +- **Low and deterministic latency**: With switch buffering, route lookup, and congestion control all skipped, XLA can compute, in nanoseconds, how long data takes to traverse the Pod and use that for global scheduling. + +> **[Note — added by Claude]** The source doesn't give a concrete ICI link bandwidth. Public material puts v4 at the order of 4.5 TB/s aggregated across all 6 directions per chip. Decide whether to keep this number in the note. + +#### 2.2 3D Torus: Six Neighbors, Wrapped at the Edges + +Each chip has a coordinate $(X, Y, Z)$ in 3D space and connects via ICI to neighbors in $\pm X$, $\pm Y$, $\pm Z$. **The wrap-around is key**: at any dimension's boundary, the link loops back to the start to form a ring, so the Pod has no "edge node". + +Why a Torus over a fat-tree: + +- **Redundant paths**: A broken link can be detoured via another dimension. +- **Collective-communication friendly**: Operations like All-Reduce / All-Gather map directly onto rings — no route table computation. + +The cost: + +- **Network diameter is linear**: As $N$ scales, max hops grow as $O(N^{1/3})$, vs $O(\log N)$ for fat-tree. +- **Irregular point-to-point causes congestion**: MoE All-to-All on a Torus produces hot links — see Chapter 19. + +#### 2.3 ↔ GPU + +| Dimension | TPU | GPU | +|---|---|---| +| Within node | ICI (directly on chip edge) | NVLink / NVSwitch (separate switch chip) | +| Across node | ICI + OCS (one fabric, all optical) | InfiniBand + NIC + switches | +| Topology | 3D Torus | Fat-tree / multi-tier Spine-Leaf | +| Latency | Statically predictable | Dynamic with jitter | + +**Trade-off**: A Torus is dominant on regular collective comms but bleeds on dynamic sparse traffic (MoE); fat-tree is the inverse. + +--- + +### 3. OCS: Optical Switching Makes the Topology Reconfigurable + +**One sentence**: A TPU Pod is a star at the physical layer (each rack runs fibers into the OCS), but a Torus at the logical layer; MEMS micromirrors switch optical paths in microseconds. + +#### 3.1 MEMS Mirrors: Pure Reflection, No Packet Inspection + +A regular Ethernet/InfiniBand switch is **O-E-O** (optical → electrical → optical): fiber comes in → converted to electrical → ASIC reads packet headers, looks up routes, queues into buffers → converted back to optical → out. This adds latency, jitter, and the switch chip itself burns serious power. + +OCS (Google internal codename Palomar), introduced in v4, is purely **O-O-O**: + +- Inside is a sealed cavity filled with inert gas. +- The laser from the input fiber hits a MEMS mirror (the size of a hair, deflected by static electricity). +- It reflects to a second MEMS mirror on the opposite side. +- The second mirror "corrects" the beam to a flat angle and shoots it into the fiber leading to the destination TPU. +- End-to-end: **no digital chip, no buffer, no packet inspection, no bandwidth ceiling**. + +The mirrors stay **completely still during transmission** — they only nudge briefly when switching "tracks". OCS is **Data Agnostic**: it doesn't care whether the laser blinks at 100 Gbps or 800 Gbps, it only reflects. Upgrade the optics and the OCS doesn't change. + +#### 3.2 Slicing Granularity: Rack-Level, Not Chip-Level + +**OCS cannot pick individual chips at will.** Within a rack, the 64 TPUs (a 4×4×4 base block) are wired together with cheap, short DAC copper cables. Only the rack's outward-facing interfaces get optical modules and feed into the OCS. + +So OCS's "Lego brick" minimum unit is a 4×4×4 rack. Want 256 chips? Pin together 4 racks. Want 1024? 16 racks. And so on. + +#### 3.3 96 Fibers: The Geometry of a 4×4×4 Rack + +This is a clean little geometry exercise. How many outward-facing interfaces does the surface (56 chips) expose? Break down by position: + +| Position | Count | Fibers per chip | Subtotal | +|---|---|---|---| +| 8 corners | 8 | 3 (exposes X, Y, Z) | 24 | +| 12 edges (each minus its 2 endpoints, leaves 2) | 24 | 2 | 48 | +| 6 faces (each minus edges/corners, leaves 2×2=4) | 24 | 1 | 24 | + +Total: **96 fibers**. Cross-check via faces: each face has 4×4 = 16 outward interfaces, 6 faces × 16 = **96**. Matches. + +In the data center, those 96 fibers aren't pulled one by one; they go through high-density parallel optical cables (MPO/MTP, 16 or 32 fibers per cable), bundled "waterfall-style" from the top of the rack to a central network rack containing the OCS. + +#### 3.4 OCS and 3D Torus: How They Relate + +**3D Torus is the logical topology shape; OCS is the joint of the transformer.** + +- v2/v3 era: topology was hard-wired physical cabling, rack A → B → C → A welded in place. One bad chip and the whole region was offline. +- v4/v5 era: physical cabling becomes a **star** (every rack's fibers fan into the OCS); the OCS internally "folds" a 3D Torus loop using mirror angles. + +Slicing scenarios: + +| You ask for | What OCS does | +|---|---| +| Single rack, 64-chip closed loop | Reflects the 96 fibers among themselves (left 16 ↔ right 16, front ↔ back, up ↔ down) | +| 4×4×8 (128 chips, 2 racks) | X- and Y-axes wrap inside each rack; on the Z-axis, rack A's top face (16 fibers) cross-connects to rack B's bottom face, and vice versa | +| Bypass a bad chip | Borrows a fiber from a neighboring rack to maintain the 3D-torus logical integrity | + +#### 3.5 The Physical Boundary: One Pod + +Lasers attenuate over fibers and free space. The OCS network covers at most one Pod. A v5p Pod is 8960 chips. **The Pod is the absolute boundary of optical interconnect.** + +Cross-Pod "communication" must go over standard datacenter ethernet (DCN), with latency and jitter that no longer match ICI and that would shred XLA's static clock. So in practice: + +- **Inside one computation (Model Parallelism)**: never crosses Pods. +- **Service-level scheduling (Load Balancing)**: cross-Pod is fine — user A is fully routed to Pod 1, user B to Pod 2, with only an HTTP/gRPC load balancer between them. + +#### 3.6 ↔ GPU + +GPU clusters **don't have anything like this**. NVIDIA's NVSwitch and InfiniBand are both packet-switched electrical interconnects; OCS is a circuit-switched optical interconnect — a fundamentally different philosophy. This is one of the most distinct designs in the TPU lineage. + +> **[Note — added by Claude]** Microsoft Azure has begun trialing Lumen-supplied OCS in some AI clusters, but the volume and tier is nowhere near Google TPU's. Not in the main text — for context only. + +--- + +### 4. Collective Communication: Dimension-Partitioned Ring All-Reduce on a 3D Torus + +**One sentence**: The trick to All-Reduce on a 3D Torus is splitting it into independent X / Y / Z 1D rings and running them sequentially — geometry traded for algorithm. + +#### 4.1 1D Ring All-Reduce: Reduce-Scatter + All-Gather + +Suppose 4 TPUs form a ring head-to-tail, each producing a 400 MB tensor; we need to element-wise-add the 4 of them and let every TPU end up with the full sum. + +Naively forwarding everything to TPU 0 instantly chokes its network. XLA's approach: **chunk + pipeline relay**. Cut 400 MB into four 100 MB blocks (A, B, C, D). + +**Phase 1: Reduce-Scatter (each ends up holding the full sum of one block)** + +| Cycle | Action | +|---|---| +| 1 | Every chip ships one of its blocks rightward: TPU0→A→TPU1, TPU1→B→TPU2, TPU2→C→TPU3, TPU3→D→TPU0. **All 4 wires (including wrap-around) at full load** | +| 2 | TPU 1 adds incoming A to its own A via VPU → "partial sum A", forwards right | +| 3 | One more hop. TPU 3 receives the now-three-way partial-sum A, adds its own A → full-network sum of A | + +End of phase: TPU 0 holds full B, TPU 1 holds full C, TPU 2 holds full D, TPU 3 holds full A. + +**Phase 2: All-Gather (each broadcasts its own 1/4 of the answer)** + +3 hops of relay broadcast; the four full blocks fly around the ring in parallel. Once each chip's Unified Buffer holds A/B/C/D, DMA flushes the final 400 MB to HBM. + +#### 4.2 Hardware Detail: The Network Is an Extension of the Cache + +TPU turns the network into a direct extension of SRAM: + +- Data flows out of the sender's Unified Buffer → optical signal across ICI → the receiver's Unified Buffer. +- VPU pulls operands directly from the Unified Buffer to add — **HBM is never touched along the way**. +- A tiny **Sync Token** is appended to the data on send. + +GPU's contrasting flow: HBM → PCIe → NIC → switch → NIC → PCIe → HBM → compute core. Every hop touches memory, and HBM bandwidth is consumed by the comm path itself. + +#### 4.3 No Global Clock — Hardware Semaphores Instead + +At datacenter scale, keeping hundreds of chips synchronized to a single nanosecond-level physical clock is physically impossible (light-speed delay, clock drift). TPU uses **XLA static schedule + hardware-semaphore async handshake**. + +Receiver-side chain: + +1. After ICI hardware finishes physical receipt, it **automatically** increments the corresponding hardware semaphore (SPU isn't involved at all). +2. SPU, reading XLA's instructions, sees `WAIT Semaphore_X >= 1` and stalls the VPU. +3. The instant the semaphore flips, WAIT releases — VPU springs like a coiled spring and starts the addition. +4. After computing, it kicks DMA to ship results forward with a fresh Sync Token, and resets its own semaphore to zero. + +At the macro scale it looks like thousands of chips locked to the same gear, but actually each chip only watches a few hardware lights right in front of it. XLA pre-aligns the compute timing perfectly, so no deadlocks and minimal idle waiting. + +#### 4.4 3D Decomposition: Splitting Into X, Y, and Z 1D Rings + +If you string all 64 chips of a 4×4×4 rack into one giant 64-hop loop, disaster: + +- Data needs 63 hops to circumnavigate. +- Each chip has 6 wires; only 1 receives + 1 sends, **the other 4 idle**. + +The right play is **multi-dimensional orthogonal ring synchronization**: split a 3D task into three parallel 1D tasks. + +| Phase | Work | Parallelism | +|---|---|---| +| X-axis sync | 16 parallel X-axis rings (length 4) running Reduce-Scatter + All-Gather | 16 | +| Y-axis sync | The synced data is re-chunked, 16 Y-axis rings run | 16 | +| Z-axis sync | 16 Z-axis rings run | 16 | + +Total hops: 4 + 4 + 4 = **12**, vs 64. + +Each phase looks like only 1/3 of the wires are active, but XLA uses **pipelining**: while matrix shard 1 is on the Y-axis, shard 2 is already on the X-axis. Macro-view, all 6 wires light up at full load. + +#### 4.5 Big Rings Are Inevitable: How Cross-Rack Z-Axis Stitches + +In a 4×4×8 (128-chip, 2-rack) slice, X and Y stay as length-4 small rings, but Z becomes a **length-8 big ring**. Physical form: + +``` +[ Rack A's Z axis ] [ Rack B's Z axis ] +TPU(Z=0) — TPU(Z=1) — TPU(Z=2) — TPU(Z=3) TPU(Z=4) — TPU(Z=5) — TPU(Z=6) — TPU(Z=7) + ^ | | + | ← OCS cross-rack splice ← +— OCS — TPU(Z=4) | + +———————————————————————————————————————— OCS wrap ——————————————————————————————+ +``` + +At full v5p Pod scale (8960 chips, possibly arranged 16×16×35), the longest Z-axis edge becomes a **35-hop big ring**. Pure physical transmission delay alone is enough to starve the upstream MXU. + +#### 4.6 NUCA: Heterogeneous Latency Between Copper and Optical + +The 8-hop Z-axis big ring contains 2 optical hops (cross-rack) and 6 copper hops (intra-rack). Two media: + +- **Bandwidth must be strictly equal**: pipeline throughput is bottlenecked by the thinnest pipe segment. So TPU designs match the optical module modulation rate to the copper SerDes rate. +- **Latency is necessarily unequal**: copper is in the nanoseconds (single to low double digits); optical must do E-O conversion → tens of meters of fiber → OCS reflection → O-E conversion, into the hundreds of nanoseconds. + +Bandwidth homogeneous + latency heterogeneous → the ring isn't perfectly symmetric in physics. This phenomenon is **NUCA (Non-Uniform Communication Architecture)**. + +The business layer dissolves this with two tools: + +1. **Steady-state pipeline masking**: At ring start, those 2 optical hops cause a small pipeline bubble, but in steady-state throughput is bandwidth-limited, and the few-hundred-nanosecond startup penalty disappears in the high-throughput data flow. +2. **XLA topology-aware mapping**: see Section 9.1. + +#### 4.7 ↔ GPU + +| Dimension | TPU 3D Torus | GPU NCCL on NVLink+IB | +|---|---|---| +| Algorithm | Multi-dim Ring All-Reduce | Ring or Tree (selectable) | +| Network diameter | $O(N^{1/3})$ | $O(\log N)$ | +| Synchronization | Hardware semaphore + static schedule | Software spin-wait + flag in HBM | +| Data path | UB → ICI → UB → VPU | HBM → PCIe → NIC → IB → NIC → PCIe → HBM → SM | +| Compute resource cost | VPU does the addition for free, MXU never stalls | NCCL kernel competes for SM resources | + +**Trade-off**: Rings + static scheduling are unbeatable on regular collective comms but sacrifice arbitrary point-to-point flexibility. + +--- + +### 5. Host ↔ TPU: PCIe, NUMA, and Multi-Host Slice + +**One sentence**: TPU isn't a standalone machine — it's a PCIe device hanging next to a CPU host; once a slice spans hosts, you're running a distributed system. + +#### 5.1 Physical Form: CPU as Foreman, TPU as Worker + +Each rack runs standard x86 server boards (Intel/AMD) with ordinary DDR memory. TPUs attach via PCIe next to the CPU. **The typical ratio is 1 CPU host managing 4 or 8 TPUs**, hard-wired physically. + +Division of labor: + +- **CPU does**: Linux, Kubelet, accepting HTTP/gRPC, Python/PyTorch, vLLM scheduler (Radix Tree, PagedAttention page tables), XLA compilation (CPU-bound) +- **TPU does**: Pure execution of CPU-compiled machine code, doing matrix multiply-accumulate + +A single Decode step from end to end: + +1. CPU prepares metadata in system memory (page table pointer arrays, etc.). +2. PCIe DMA copies data + new token embedding into TPU HBM. +3. CPU sends the TPU an "execute compiled graph #5" instruction. +4. TPU goes into seclusion and computes. +5. PCIe pulls logits back to CPU memory. +6. CPU samples (Argmax, Top-P, etc.) on the result. + +#### 5.2 Two-Plane Isolation: DCN vs ICI + +Inside a Pod live two **completely independent** physical networks: + +| Network | Used by | Medium | Purpose | +|---|---|---|---| +| **DCN (Datacenter Network)** | Host CPUs | Standard ethernet switches | CPU-to-CPU coordination, the control plane | +| **ICI (Inter-Chip Interconnect)** | TPUs | OCS + 3D Torus dedicated fibers | TPU-to-TPU data flood, the data plane | + +Control plane (K8s coordination, Pod start/stop) goes over DCN with gRPC between CPUs. Data plane (All-Reduce, KV sync) bypasses CPU entirely. + +#### 5.3 Multi-Host Slice: N:N Mapping + +Many people assume a 64-chip v4-64 slice request lands on one giant VM with a super-CPU. **It doesn't.** + +Physically it's **16 Host VMs** (4 TPUs each), 16×4=64: + +- 16 VMs are connected by DCN. +- 64 TPUs are connected via OCS into a 3D Torus. +- Each VM runs one Kubelet. +- K8s schedules 16 Pods, one per VM. +- The inference code (vLLM/JetStream) starts in 16 CPUs simultaneously running the same Python code (**SPMD**). +- Typically expressed via **LeaderWorkerSet (LWS)** or **JobSet**: 1 Leader Pod exposes the API, 15 Worker Pods coordinate. +- The Leader broadcasts each request to Workers over DCN; each CPU drives the 4 TPUs underneath it. +- After the 64 TPUs compute the result, the Leader aggregates and returns it. + +There is no "super CPU managing 64 TPUs". A large slice is a federation of many "small CPU + small TPU" nodes, gang-scheduled together by K8s. + +#### 5.4 NUMA: PCIe Lanes Split Between Two Sockets + +Modern server motherboards typically have two CPU sockets (CPU 0 and CPU 1). Half the PCIe lanes go to CPU 0, half to CPU 1. On an 8-TPU host, TPU 0~3 hang off CPU 0; TPU 4~7 off CPU 1. + +The cross-NUMA disaster: CPU 0 wants to write into TPU 4's HBM — data must first cross the UPI bus to CPU 1, then PCIe to TPU 4. Latency spikes; usable bandwidth halves. + +Where it bites in practice: + +- **Input pipeline**: Tokenize on CPU 0, but task issued to TPU 4 — every in-feed crosses NUMA. +- **KV Cache offload**: TPU 0's KV swapped to a DDR slot owned by CPU 1. +- **Weight loading**: hundreds of GB of weights DMA-ing across NUMA — slow cold start. + +Google Cloud's mitigations: + +- **Single-NUMA VM partitioning**: When you request a 4-TPU instance, the hypervisor splits the physical machine in half. The VM you get **only contains** CPU 0 + CPU 0's memory + the 4 TPUs hanging off CPU 0. CPU 1's half is invisible. +- **XLA auto-pinning**: For 8-TPU instances that can't be split, XLA Runtime (PJRT) reads the PCIe-tree topology and automatically pins threads feeding TPU 0~3 to CPU 0 cores, threads feeding TPU 4~7 to CPU 1 cores. + +#### 5.5 ↔ GPU + +| Dimension | TPU | GPU | +|---|---|---| +| Physical attach | PCIe next to host CPU | Same | +| Ratio | 1:4 or 1:8, hard-wired | HGX usually 1:8 | +| Multi-host orchestration | LWS / JobSet + K8s gang | MPI Operator / Training Operator | +| NUMA handling | Single-NUMA VM split + XLA PJRT auto-pin | Often manual `numactl` + NCCL topology-aware | +| Two planes | DCN + ICI strictly separated | Usually both control and data planes share IB | + +**Trade-off**: TPU's N:N orchestration shrinks the per-host failure blast radius, but operationally the "single inference service" you see is actually a coordinated dance of multiple K8s resources. + +--- + +### 6. Advanced Packaging: Compute as Area, Bandwidth as Perimeter + +**One sentence**: Die area governs FLOPs; perimeter governs bandwidth (HBM interfaces); 2.5D / 3D packaging is reconciling that fundamental tension. + +#### 6.1 The Physical Origin of the Memory Wall + +- **Compute ∝ area**: An MXU is 2D — a slightly larger MXU has $O(N^2)$ MAC cells (64×64 → 128×128 doubles area, quadruples compute). +- **Traditional bandwidth ∝ perimeter**: HBM bandwidth depends on the count of edge pins; growth is $O(N)$ linear. + +Area is square-law, perimeter is linear — **bandwidth permanently lags compute**. That's the physical origin of the "memory wall". + +#### 6.2 Three Generations of Packaging Evolution + +| Generation | Name | What it solved | +|---|---|---| +| Gen 1 | Wire Bonding | Chip face-up, edge gold wires to the substrate; capped by perimeter | +| Gen 2 | Flip-Chip | Flip the chip, plant C4 micro-bumps across the whole face — moves from "edge" to "area"; but PCB trace precision is too coarse (line widths in tens of micrometers) | +| Gen 3 | **2.5D Silicon Interposer / CoWoS** | Lay a silicon slice between the chip and the substrate; use EUV lithography to draw **nanometer-scale** traces on it | +| Gen 4 | **3D Packaging + TSV (Through-Silicon Via)** | Drill tens of thousands of micrometer-scale holes vertically through the silicon, fill with copper; stack chip layers like floors of a building | + +#### 6.3 Why the Silicon Interposer Helps + +A regular PCB packs about 10 traces per millimeter; a silicon interposer packs over 1000 per millimeter. The connection between GPU/TPU and HBM expands from "bidirectional 4-lane" to "bidirectional 4096-lane". That's where HBM3's bus width comes from. + +HBM itself is internally 3D too — multiple layers of memory dies stacked with TSV vertical interconnect, so a single HBM stack can deliver Tbps-class bandwidth. + +#### 6.4 Compute-Surplus Problem and Hardware Compromise + +In the v4 era, compute scaled too quickly and HBM bandwidth couldn't keep up; Decode MFU dropped to single digits, MXUs starved waiting for data. Google's hardware mitigation: + +- **The v5e (inference chip) intentionally shrinks the MXU**, restoring a healthier compute/bandwidth ratio. +- Sacrificing peak FLOPs for cost-effectiveness — accepting that Decode is memory-bound, not compute-bound. + +> **[Note — added by Claude]** The "single-digit MFU in v4 era" claim is direct from the source but lacks a specific data point or context. I have no other reliable source. Decide whether to keep it. + +#### 6.5 ↔ GPU + +NVIDIA H100/B100 also use CoWoS (same TSMC line) — same tech path. Difference is in die budget: H100 has 50 MB L2 Cache, 5 TB/s HBM; TPU skips hardware cache and gives that area to the MXU. + +**Trade-off**: H100 uses big caches to tolerate random access; TPU uses big MXUs to crunch dense ops. Two paths matched to different workload assumptions. + +--- From b1bd658596e3c0d50dd87d7b41b34fac8e3e4506 Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 6 May 2026 19:30:53 +0000 Subject: [PATCH 6/8] docs(notes): Add English mirror Parts II and III 1:1 structural mirror of Chinese Parts II (XLA: chapters 7-9) and III (inference adaptation: chapters 10-14). Covers compilation model and VLIW packs, JIT/AOT/bucketing/persistent-cache pipeline, topology-aware mapping, vLLM/JetStream/Saxml stack split, PagedAttention adaptation via control-plane/data-plane split, Prefill/Decode coordination including chunked prefill and 1D static flatten, KV/memory hierarchy with ICI/GDS/KV-offload comparison, and Gemini's MoE Capacity Factor + Tree Attention. External-knowledge callouts mirror the Chinese version verbatim for review consistency. Parts IV, V, and appendices still to come. Co-authored-by: Claude --- notes/tpu-deep-dive.md | 508 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 508 insertions(+) diff --git a/notes/tpu-deep-dive.md b/notes/tpu-deep-dive.md index 03831d9..06aa848 100644 --- a/notes/tpu-deep-dive.md +++ b/notes/tpu-deep-dive.md @@ -432,3 +432,511 @@ NVIDIA H100/B100 also use CoWoS (same TSMC line) — same tech path. Difference **Trade-off**: H100 uses big caches to tolerate random access; TPU uses big MXUs to crunch dense ops. Two paths matched to different workload assumptions. --- + +## Part II — Compiler and Runtime: XLA + +### 7. The XLA Compilation Model: Erase Uncertainty at Compile Time + +**One sentence**: XLA's core play is to "erase all uncertainty at compile time" — operator fusion, static padding, software pipelining, and VLIW packs are all facets of that idea. + +#### 7.1 Operator Fusion + +Classic example: a layer of `MatMul + Bias + ReLU`. + +Naive execution (GPU eager mode) does: +1. Compute MatMul → write back to HBM. +2. Read it back, add Bias → write back to HBM. +3. Read it back, do ReLU → write back to HBM. + +XLA fuses the three steps into one compute block: + +``` +HBM → UB (MatMul input) → MXU does MatMul → result flows out → VPU + does Bias + ReLU → HBM +``` + +**Saves 2/3 of HBM read/write traffic.** This kind of fusion is everywhere in LLMs — for example dropout/scale/mask after Attention is typically fused into the same block. + +#### 7.2 Static Padding + +The MXU is a hard-wired 128×128 array. If your matmul is 100×100, XLA does not let hardware handle the boundary — the hardware doesn't support that. XLA pads to 128×128 at compile time, filling the rest with zeros. + +Cost: roughly 50% wasted compute on the boundary 28 rows × 28 columns. +Benefit: the array stays at full speed without stopping for boundary checks. + +In a systolic array, **letting hardware run full-speed through blank cells beats stopping mid-stream**. + +#### 7.3 Software Pipelining + +XLA pre-computes how many cycles each DMA takes and statically generates instructions: while the MXU is computing block N, DMA is already staging block N+1. Compute and memory transfer overlap perfectly. + +#### 7.4 VLIW 5-Slot: Same Shape, Single-Chip and Cross-Chip + +Single-chip variant: + +``` +[ DMA ] | [ MXU ] | [ VPU ] | [ SPU control flow ] +``` + +Multi-chip cooperation adds one slot — the **ICI network slot**: + +``` +[ DMA ] | [ MXU ] | [ VPU ] | [ ICI network ] | [ SPU control flow ] +``` + +Key insight: **From the TPU's perspective, sending data across chips (ICI) and moving it within a chip (DMA) sit at the same level** — both are switches in a VLIW instruction word. XLA can have the MXU multiplying while ICI ships the previous layer's gradient to a neighbor, achieving **compute-and-cross-node-communication overlap at the cycle level**. + +This is something GPU systems can't do. GPU cross-node comm fires CPU interrupts, builds packets, traverses IB switches — it's a separate subsystem. + +#### 7.5 A Concrete Pseudo-Assembly Example + +Compute $C = A \times B$, $A$ is 256×128, $B$ is 128×256. MXU is 128×128. + +XLA at compile time splits $A$ into upper/lower halves ($A_0, A_1$), $B$ into left/right halves ($B_0, B_1$), decomposing into 4 sub-tasks of 128×128. All HBM addresses are hardcoded at compile (no runtime pointer math). + +VLIW stream to compute $C_{00} = A_0 \times B_0$: + +``` +Instruction 1 (warm up: load weights): + [DMA] LOAD_HBM_TO_UB (Src: HBM_B0, Dst: UB_B) + [MXU] NOP + [VPU] NOP + [SPU] WAIT_DMA_DONE + +Instruction 2 (weight stationary): + [DMA] NOP + [MXU] LOAD_UB_TO_WEIGHT_REG (Src: UB_B) + [VPU] NOP + [SPU] WAIT_MXU_DONE + +Instruction 3 (core compute + pipeline prefetch): + [DMA] LOAD_HBM_TO_UB (Src: HBM_A0, Dst: UB_A) + [DMA_ASYNC] LOAD_HBM_TO_UB (Src: HBM_B1, Dst: UB_B_Next) ← prefetch next block + [MXU] MATMUL_STREAM_ACT (Src: UB_A, Dest: Accumulator_C00) + [VPU] NOP + [SPU] WAIT_MXU_DONE + +Instruction 4 (fused ReLU): + [DMA] NOP + [MXU] NOP + [VPU] READ_ACCUM_AND_RELU_AND_STORE (Src: Accumulator_C00, Dst: UB_C00) + [SPU] WAIT_VPU_DONE + +Instruction 5 (write back to HBM): + [DMA] STORE_UB_TO_HBM (Src: UB_C00, Dst: HBM_C00) + [MXU] NOP + [VPU] NOP + [SPU] JUMP_TO_NEXT_BLOCK ← prep for C01 +``` + +Note Instruction 3: **DMA stages the next block while the MXU is computing**. If XLA mis-times any cycle, either UB overflows or the MXU stalls. Only static compilation gives that precision. + +#### 7.6 ↔ GPU + +| Dimension | TPU XLA | GPU | +|---|---|---| +| Scheduling | Static at compile time | Dynamic at runtime (Warp Scheduler) | +| Memory-latency hiding | Static pipelining | Massive thread-level concurrency | +| Instruction format | VLIW with parallel slots | SIMT | +| Operator fusion | Automatic in XLA | Semi-automatic via TorchInductor / TVM / Triton | + +**Trade-off**: XLA's "global view" is unbeatable on regular workloads, but any dynamic shape forces a recompile. + +--- + +### 8. Compilation Timing: JIT, AOT, Bucketing, Persistent Cache + +**One sentence**: The cost of static compilation is a slow first run; in production, bucketing + AOT + cache amortize the cost away. + +#### 8.1 The Real JIT Timeline on PyTorch/XLA + +Real flow when vLLM and similar frameworks start on TPU: + +| Phase | Action | XLA state | +|---|---|---| +| 1. Init | Load weights to HBM | **Not compiled**. Just nn.Module objects and weight tensors | +| 2. Tracing | First request triggers `model.forward()`; PyTorch/XLA uses **Lazy Tensor** — no real compute, just records a DAG | Building the graph, generating HLO IR | +| 3. Trigger | Reading logits for sampling hits a sync barrier (e.g. `xm.mark_step()`) | **Compile now**: operator fusion + static address allocation + instruction layout → TPU Executable. Takes seconds to tens of seconds | +| 4. Execute + cache | TPU runs (in milliseconds), the compiled artifact is stored in an in-memory Compiler Cache | Cache key includes graph structure and all input shapes | +| 5. Subsequent requests | Same shape hits the cache, skips compile | Reuse | + +**Key clarification**: Weight values do NOT enter the compiled artifact. XLA only cares about the weight tensor's shape and dtype, treating weights as a "static device-memory pointer". Swapping in a fine-tuned LoRA, or another model of the same architecture, doesn't trigger recompile. + +#### 8.2 Production Approach: Bucketing + AOT + Persistent Cache + +In production you absolutely cannot let the first user wait 30 seconds for compilation. The flow: + +**Step 1: Restrict and discretize buckets** + +After profiling, the team picks a discrete set of buckets: + +``` +BS_Buckets = [1, 2, 4, 8, 16, 32, 64] +SeqLen_Buckets = [128, 512, 1024, 2048, 4096, 8192] +``` + +A request with BS=5 at runtime is padded to the BS=8 bucket. + +**Step 2: AOT warmup at CI/CD** + +Add a warmup stage when building the Docker image or release artifact: spin up a CI node with real TPU topology, iterate over every `BS × SeqLen` combination, and bombard the model with dummy requests to trigger compilation. + +**Step 3: Bake the persistent cache into the image** + +Use `XLA_FLAGS="--xla_dump_to=/path/to/cache"` to dump compiled artifacts. At the end of the pipeline, **bake those few hundred MB to a few GB of cache files directly into the release image** or attach them via shared storage. + +When production vLLM/JetStream instances boot, they read the cache; bucket hits become **millisecond hardware dispatches**. + +#### 8.3 Why There's No "One-Size-Fits-All" Pre-Compiled Library + +An XLA Executable is bound not just to model shapes but to these other **fatal variables** — change any one and the cache is invalid: + +| Variable | Impact | +|---|---| +| **Physical hardware topology** | A v5e-8 (1D ring) executable can't run on v5p-32 (3D Torus); XLA bakes in which fiber to use and the latency in nanoseconds | +| **Parallelism strategy** | Does TP cut Attention or FFN? How does PP cut? These SPMD annotations must be set before compilation | +| **Compiler version** | XLA / LLVM backend updates frequently; old caches usually fail validation | +| **Model structure tweaks** | Add a small adapter, tweak the RoPE base — constant-folding output changes → HLO hash changes → cache invalidated | + +Every team has to maintain its own **model distribution + cache pre-warming pipeline**. A major infrastructure release is always paired with large-scale automated recompilation. + +#### 8.4 ↔ GPU + +GPU has the same problem space (PyTorch 2.x's `torch.compile` / Inductor / TensorRT), but to a far smaller degree: GPU hardware tolerates runtime shape changes (dynamic dispatch), and compile failures fall back to eager. TPU has no such fallback — compile failure equals service failure. + +--- + +### 9. XLA Topology-Aware Mapping + +**One sentence**: The compiler knows the physical topology of the 3D Torus + OCS, so it can map high-density communication onto short edges and low-frequency synchronization onto long rings. + +Section 4.6 covered NUCA: the cross-rack 8-hop ring has 6 copper hops + 2 optical hops, bandwidth homogeneous, latency heterogeneous. XLA, at compile time, places different parallelism strategies onto different qualities of topology. + +#### 9.1 TP on Small Rings, DP on Big Rings + +| Parallelism | Comm characteristics | Mapped to | +|---|---|---| +| **Tensor Parallelism (TP)** | Step-by-step; every linear layer requires syncing activations. **Latency-sensitive** | The 4 / 8 short copper rings on X or Y | +| **Data Parallelism (DP)** | "After the dust settles"; each step (or after a few accumulated) syncs gradients. The matrix is large, demanding bandwidth, but tolerates one-shot latency | The 35-node large optical ring on Z | +| **Pipeline Parallelism (PP)** | Inter-stage activation transfer, medium frequency | Usually mid-length edges | +| **Expert Parallelism (EP)** | Dynamic All-to-All (MoE) | Suffers on Torus, see Chapter 19 | + +When DP runs on the big ring, the 35-hop physical latency is masked by steady-state pipelining, and underlying compute units use the sync wait window to do the next step's forward (Compute-Communication Overlap). + +#### 9.2 How Topology Information Reaches XLA + +K8s labels on each Node (e.g., `cloud.google.com/gke-tpu-topology: 4x4x4`) carry the slice's geometry. XLA Runtime at startup reads those plus PCIe sysfs information to construct a topology graph. Then, per the user's SPMD partitioning annotations, it maps comm groups to specific ICI links. + +**Conclusion**: Pure K8s scheduling only sees node liveness; high-performance AI scheduling sees microsecond-grained optical-electrical physical edges. + +#### 9.3 ↔ GPU + +GPU systems handle similar concerns with manual NCCL group configs + `torch.distributed`'s topology-aware APIs. NCCL knows NVLink/IB hierarchies, but fat-tree itself is roughly symmetric, so the optimization headroom isn't as large as on a Torus. + +--- + +## Part III — Inference-Layer Adaptation (Goal C) + +### 10. The Software Stack Forks: vLLM, JetStream, Saxml, GKE + +**One sentence**: TPU has more than one inference framework — three players with three positions; GKE is the glue that puts them all into a cluster. + +#### 10.1 Why GKE Is Hell-Bent on vLLM on TPU + +vLLM is the de-facto "Linux" of open-source inference — most customers writing on GPU + PyTorch + vLLM have already built their business code (API wrappers, schedulers, custom prompting). For GKE selling TPU (v5e/v5p, sharply cost-effective), the biggest blocker is migration cost: + +> If customers must rewrite code to use TPU, they leave. + +So Google's strategy is **Lift and Shift**: let customers keep their `vllm serve` invocation, swap the base image, and PyTorch calls are auto-routed to PyTorch/XLA. + +Foundation: vLLM's official repo already includes a TPU backend; the impedance-mismatch between PagedAttention and TPU's static graphs is patched with custom kernels written by Google engineers in Pallas. + +#### 10.2 Three Players' Positions + +| Framework | Position | Target users | +|---|---|---| +| **vLLM** | Ecosystem-compatibility king ("don't make me change code") | Startups, multi-cloud customers, GPU migrants | +| **JetStream** | TPU performance squeezer | Big tech, high-concurrency inference, willing to adapt framework for performance | +| **Saxml** | JAX legacy heavy artillery | Customers deeply tied to JAX, special large-scale partitioning | + +#### 10.3 Why JetStream Beats vLLM by 20%-50% + +JetStream is led jointly by Google Cloud + the XLA team, custom-built for the v5 series. It **doesn't try to bolt dynamic paging onto static graphs**; instead it embraces TPU's static-orchestration philosophy: + +- Deeply optimized continuous batching +- Lots of XLA operator fusion +- Co-designed with the compiler, no PyTorch indirection + +The cost: APIs aren't as plug-and-play as vLLM, and PyTorch ecosystem support requires dedicated work. + +#### 10.4 Why Saxml Has Slipped to the Background + +Born alongside Pax / Seqio, deeply tied to JAX. Carries the flavor of Google internal infrastructure, with a high external-developer learning curve and slower PyTorch support. In public-cloud advocacy it's been deprioritized. + +#### 10.5 ↔ GPU + +| TPU | GPU | +|---|---| +| vLLM-TPU (with Pallas) | Native vLLM | +| JetStream | TensorRT-LLM (NVIDIA's own + specialized) | +| Saxml | DeepSpeed Inference (partial) | +| Pallas for kernels | Triton / CUDA | + +--- + +### 11. Adapting PagedAttention and Continuous Batching to TPU + +**One sentence**: GPU-style dynamic memory management (PagedAttention, Continuous Batching, Radix Tree) is fundamentally hostile to static compilation; TPU adapts via Pallas custom kernels + pushing dynamism into the tensor layer. + +#### 11.1 Before Pallas: The Inefficiency of Static Contiguous Allocation + +Early TPU inference (T5, early Pax) ran on a "OCD" route: + +- XLA at compile time pre-allocated a contiguous KV Cache pool sized `[Max_BS, Max_SeqLen, Hidden_Dim]`. +- With max_seq=4096, every request locked 4096 tokens of HBM. +- Real request only 100 tokens? The remaining 3996 slots were wasted (97% memory wasted). +- HBM saturated with useless padding → batch size couldn't grow → MXU had compute headroom, but the pool was full → **compute pinned by the memory wall**. + +Google early on threw money at this — large total Pod HBM, controllable task lengths (translation/search), algorithm teams cutting buckets very finely — and survived. But long context and multi-turn dialogue popularized this approach to its limits. + +#### 11.2 Modern Split: XLA Builds the Pool, vLLM Keeps the Books, Pallas Reads the Map + +vLLM on TPU's modern architecture splits **control plane (CPU) and data plane (TPU)**: + +| Role | Where | What it does | +|---|---|---| +| **XLA** | TPU | Allocates one giant 1D-flattened block tensor in HBM, shaped `[Num_Total_Blocks, Block_Size, Head_Dim]` (e.g. 100K physical blocks, each holding 16 tokens). **XLA doesn't know whose data is in there** | +| **vLLM** | CPU | Maintains a Radix Tree and all per-request page tables (Block Tables). Each step it bundles the active requests' page tables into an integer Tensor and feeds it to the XLA graph | +| **Pallas Kernel** | TPU | A Custom Call node in the XLA graph. After receiving the page table, executes low-level indirect addressing (Gather), pulling fragments into the Unified Buffer for the Attention compute | + +**Physical HBM is still XLA's global tensor, but XLA no longer manages content.** vLLM acts as dispatcher each cycle, sending an "addressing map" over; the Pallas kernel on TPU follows the map to fetch data. + +#### 11.3 Each of the Three Knives in Detail + +**PagedAttention** + +- TPU pain: XLA hates dynamic pointer addressing; querying the page table on each Attention call wrecks the DMA script. +- Solution: a Pallas kernel hand-written at the register level for page-table lookup + scatter/gather Gather. +- Result: Fully supported; HBM fragmentation problem solved, batch size goes up. + +**Continuous Batching (Inflight Batching)** + +- GPU's play: 1D flatten; the scheduler can boot a finished request and admit a new one at any time, fully dynamic. +- TPU's play (**static-bus mode**): + - XLA pre-compiles a fixed-`Batch_Size = 256` graph — think of it as a 256-seat bus permanently looping the route. + - When some request hits EOS → vLLM marks the seat empty → next step it slots a new request's first Decode token into that index. + - TPU sees only a perfect `[256, 1, D]` tensor and doesn't know that index 5 was user A a millisecond ago and is user B now. + - When fewer than 256 real requests, fill empty slots with Dummy Tokens (zeros). + +**Radix Tree (Prefix Cache)** + +- **A perfect fit on TPU**: The mechanism is essentially a CPU-side scheduling algorithm. +- On prefix hit, vLLM only needs to point logical block pointers in the dispatched Block Table to existing physical blocks. +- The TPU-side Pallas kernel doesn't know it's reuse — it follows the map and fetches HBM data to UB normally. + +#### 11.4 Google's Own Frameworks Use the Same Approach + +JetStream / Saxml implement the equivalent (called **Blocked Attention** internally, or built into the underlying **FlashAttention-TPU** kernel), all written in Pallas. So whether you run vLLM or JetStream on GKE, **the memory-management thinking has converged**: maintain a discrete Block Pool in HBM + page table on CPU + pass the page table to the underlying kernel for Gather at compute time. + +#### 11.5 A Core Philosophy: FLOPs to Replace Control Flow + +**Use very cheap FLOPs (compute) to eliminate very expensive Control Flow.** + +This runs through all of TPU's inference adaptation. Tree Attention (Chapter 14) is the same idea — encode if-else into a Mask matrix, prefer to compute and discard rather than let hardware stop for branch decisions. + +#### 11.6 ↔ GPU + +| Dimension | TPU | GPU | +|---|---|---| +| KV paging | XLA pool + Pallas Custom Call | vLLM native PagedAttention | +| Scheduling flexibility | Fixed batch buckets + Dummy padding | Dynamic 1D flatten | +| Custom kernel tooling | Pallas | Triton / CUDA | + +--- + +### 12. Prefill/Decode Coordination and Chunked Prefill + +**One sentence**: TPU is strong on Prefill, weak on Decode (HBM bandwidth bottleneck); mixed execution + chunked prefill is algorithm patching for hardware. + +#### 12.1 Arithmetic Intensity Determines TPU Feel + +To judge whether hardware suits a workload, look at **Arithmetic Intensity = FLOPs / Byte** (how many floating-point ops you do per byte of memory access). + +| Phase | Math form | Arithmetic intensity | Bottleneck | TPU feel | +|---|---|---|---|---| +| **Prefill** | GEMM (matrix × matrix; weights reused across thousands of tokens) | Very high | Compute-Bound | MXU runs blissfully | +| **Decode** | GEMV (matrix × vector; weights pulled out for one token then discarded) | Very low | Memory-Bound | MXU starves frequently | + +So TPU is intrinsically a **Prefill beast**; Decode performance is bolted on after the fact. + +#### 12.2 Continuous Batching on TPU Decode + +**Decode does not bucket on token length** — every request's new-token length is fixed at 1. It buckets on **Batch Size**. + +Pre-compile a BS=256 Decode graph: + +- Static input matrix `[256, 1, Hidden_Dim]`. +- 200 real requests → first 200 slots hold real tokens, last 56 slots hold Dummy Tokens (zeros). +- The MXU computes 256 results. +- The CPU scheduler picks up only the first 200 real results to send to users; the 56 dummies are discarded. + +**The hard part: the 256 requests have wildly different histories.** + +The CPU also sends two static-sized integer arrays: + +``` +context_lengths : shape [256], real history length per request, e.g. [105, 3042, 12, ...] + (Dummy slots filled with 0) +block_tables : shape [256, Max_Blocks], each request's KV page table +``` + +The Pallas kernel uses `context_lengths` as loop bounds (or mask) and `block_tables` to gather historical KV from HBM, doing Attention. + +#### 12.3 The Prefill Difficulty: Wildly Different Prompt Lengths + +Decode neatly stacks into `[256, 1]`, but Prefill can't: one prompt is 100, another is 3000. How to fit a static graph? + +Two approaches: + +- **Bucketing**: 100 → padded to bucket 128; 3000 → cut to bucket 4096. +- **Chunked Prefill**: Compile a fixed-length Prefill graph (e.g. chunk_size=512). A 1000-length prompt is split into two chunks of 512, fed twice into the same `[1, 512]` slot. + +#### 12.4 Mixed Prefill/Decode: Static 1D Flatten + +The most up-to-date approach: combine Prefill long sequences and Decode single tokens **in the same step**. + +XLA at compile time sets two static upper bounds: + +``` +Max_Total_Tokens = 1024 # max tokens swallowed in one step +Max_Seqs = 256 # max concurrent sequences +``` + +The input tensor is flattened from 3D `[Batch, Seq, D]` to 2D `[1024, D]`. + +CPU-side composition: + +``` +Request A (Prefill, chunk=512) → first 512 of array +Requests B–Z (Decode, 200 tokens) → next 200 + 712 used +Dummy Tokens × 312 → pad to 1024 +``` + +**MXU phase**: To the systolic array, identities don't matter. One giant `[1024, D] × [D, 4D]` matmul flies through, computing Q/K/V for all 1024 tokens in one shot. + +**Attention phase**: Now identity matters: + +| Token type | What Attention they need | +|---|---| +| The 512 Prefill tokens | Attend **to each other**, generating new KV that gets written into the page table | +| The 200 Decode tokens | Each uses its 1 token's Q to attend to its own historical KV Cache | +| The 312 Dummy tokens | Skip | + +The CPU also passes metadata: `seq_lens = [512, 1, 1, ..., 0, 0]`. The Pallas kernel parses metadata at the register level — Prefill blocks go down a FlashAttention-style path (Q vectors mutually dot in UB + write new KV); Decode blocks go down PagedAttention (gather historical KV); Dummy blocks skip. + +#### 12.5 GPU vs TPU Mixed-Batching Differences + +| | GPU | TPU | +|---|---|---| +| Composition | Dynamic: 712 → kernel takes 712; next round 850 → 850 | Static: 712 → forced add 312 dummies → 1024 | +| Hardware cost | Scheduler overhead | Padded MXU cycles | +| Software cost | High kernel flexibility | Pallas metadata routing | + +**Trade-off**: Wasting a bit of MXU cycles on padding is controllable, but it dodges the recompilation disaster and keeps tail latency tight. The lesser of two evils. + +#### 12.6 ↔ GPU + +GPU is the "ride-share" model: 712 passengers → a 712-seater dispatches, no empty seats. +TPU is the "high-speed direct express": departs on schedule whether full or not — fill empty seats with dummies if needed. + +--- + +### 13. KV / Memory Hierarchy + +**One sentence**: GPU's RDMA / GDS / KV offload counterparts on TPU are: some are native, some are not supported, some only fall back to PCIe. + +#### 13.1 Three Things Compared + +| Optimization | GPU | TPU | +|---|---|---| +| **Cross-node comm bypassing CPU** (GPU: RDMA over IB) | GPUDirect RDMA + IB NIC | **Native**: ICI network controllers are integrated directly into TPU silicon — no external NIC needed. The CPU is not in the path; zero CPU cycles consumed | +| **Direct read of storage into accelerator** (GPU: GPUDirect Storage) | NVMe → PCIe → GPU VRAM, bypassing CPU memory | **Not supported**. TPU has no direct storage / external network interface; CPU must mediate: GCS / PD → Host VM DDR → PCIe DMA → TPU HBM | +| **KV Cache offload to host DRAM** (GPU: vLLM CPU swap) | HBM → PCIe → CPU DDR | **Fully applicable**. vLLM's Block Manager runs on host CPU; when HBM fills, PCIe DMA copies KV blocks to host DDR | + +#### 13.2 ICI Is More Thorough Than RDMA + +GPU RDMA: data → PCIe → NIC → IB network → NIC → PCIe → VRAM. CPU memory bypassed, but data still leaves the GPU chip via an external NIC. + +TPU ICI: data → chip-edge optical module → fiber → peer optical module → chip. **Never leaves the silicon-fabric world.** With hundreds of GB/s of network traffic flowing, the host CPU has zero awareness. + +#### 13.3 Why Missing Direct Storage Is Tolerable + +Cold-start weight loading needs hundreds of GB pulled from GCS to TPU HBM, requiring CPU as porter. But because deployment is multi-host (e.g. 16 VMs), 16 CPUs **download different weight shards in parallel** from GCS, total network bandwidth is large, and the workflow is acceptable in practice. + +#### 13.4 KV Offload Performance Profile + +PCIe bandwidth is **narrow** relative to ICI (v4 PCIe Gen4 x16 ≈ 64 GB/s bidirectional, while a single ICI link can do hundreds of GB/s). So frequent swap noticeably degrades performance. This is OOM insurance, not a primary mechanism. + +#### 13.5 ↔ GPU + +GPU advantage: GDS (direct storage). TPU advantage: ICI (more thorough comm bypass). They split the spoils. + +> **[Note — added by Claude]** Industry has begun exploring **Mooncake-style "disaggregated KV pools"** (KV Cache as a separate service shared across nodes) — currently mostly on GPU systems. No public TPU-side equivalent that I've found. Not in main text — for cluster-TL perspective only. + +--- + +### 14. Gemini's Practical Compromises on TPU + +**One sentence**: Both MoE and speculative decoding have to be re-written algorithmically to suit TPU hardware. + +#### 14.1 Static MoE: Capacity Factor + +MoE's natural problem is **dynamic routing** — you don't know which expert the next token will pick. XLA forbids dynamic shapes. + +The Gemini team's solution: + +- Each expert is given a strict **Capacity Factor** (a static slot size). Say each expert can take up to 64 tokens per step. +- Fewer than 64 routed tokens → pad with Dummy Padding to fill the slot. +- More than 64 routed tokens → the surplus is **dropped directly** (Token Dropping), or routed through a fallback general network. + +Through this brutal truncation + padding, the dynamic MoE network is forced into a static graph that XLA likes. + +Cost: occasional token drops, with possible model-quality impact. Google tunes the Capacity Factor during Gemini training to balance drop rate against compute cost. + +#### 14.2 Tensorized Speculative Decoding: Tree Attention + +Traditional speculative decoding: a small model proposes K tokens → if-else checks acceptance by the large model → on rejection, roll back. That if-else flow trashes TPU's VLIW pipeline. + +Gemini's solution (**parallel verification**): + +- After the small model produces K candidate tokens, the large model packs them into a 1D vector. +- A specially designed **Tree Attention Mask** ensures unrelated nodes don't see each other (multiplied by 0). +- The large model **verifies all K tokens' probabilities in one matmul during a single forward pass**. +- The accepted path is selected; the others are discarded. + +This converts "branch code" into "small Prefill matmul". The MXU rejoices. + +#### 14.3 Why These Compromises Aren't Required on GPU + +- GPU's SIMT scheduler handles if-else (branch divergence wastes lanes, but is much better than TPU). +- GPU's dynamic memory tolerates expert-capacity variability. +- So GPU runs vanilla MoE routing + vanilla speculative decoding fine. + +GPU roadmaps are also moving toward Tree Attention and similar tensorization, but it's about extra performance — not survival. + +#### 14.4 An Interesting Trend + +Because hardware hates branches (TPU outright; GPU dislikes host-device sync at high frequency), algorithm engineers are **rewriting control flow as data flow** wherever possible. Tree Attention, Masked Attention, predicated execution — all variations of the same idea. + +Core philosophy: **compute everything and discard is cheaper than if-else**. For shallow, bounded branches (speculative decoding 3-5 steps, Causal Mask half a matrix), this trick is wildly worth it. For deep nesting (10+ levels of conditional trees), it's $O(2^N)$ explosion. So next-gen chips are pursuing **native sparsity**: when a mask region is all zeros, the circuit physically skips MAC ops (clock-gated power off) — neither writes if-else nor consumes power. + +#### 14.5 ↔ GPU + +| Optimization | TPU must adapt | GPU need not | +|---|---|---| +| MoE | Static via Capacity Factor | Dynamic routing works | +| Speculative decoding | Tensorized via Tree Attention | if-else works (with perf hit) | + +--- + From a9473acff411003faa9d59e97eee28825bf20720 Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 6 May 2026 19:34:45 +0000 Subject: [PATCH 7/8] docs(notes): Complete English mirror with Parts IV, V, and appendices Final part of the 1:1 English mirror. Part IV covers K8s TPU abstractions (device plugin, slice CRD, Kueue gang scheduling, TPU Provisioner -> OCS, failure self-healing) and multi-host slice orchestration with LWS/JobSet and SPMD. Part V covers programming model chains (GPU three-segment vs TPU one-segment), cost and efficiency (NVIDIA Tax, MFU, Performance/Watt, real-world TCO), and TPU hardware weaknesses with paired trade-offs. Appendices A-C mirror the Chinese versions exactly: trade-off cheat sheet, numbers list, and TPU-to-GPU terminology mapping. The writing log is also mirrored, listing 10 intentional cuts and 5 external additions for author review. This completes the bilingual content. Next step is to update the mkdocs.yml navigation to surface the new note and prepare for review/merge. Co-authored-by: Claude --- notes/tpu-deep-dive.md | 456 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 456 insertions(+) diff --git a/notes/tpu-deep-dive.md b/notes/tpu-deep-dive.md index 06aa848..f9ed539 100644 --- a/notes/tpu-deep-dive.md +++ b/notes/tpu-deep-dive.md @@ -940,3 +940,459 @@ Core philosophy: **compute everything and discard is cheaper than if-else**. For --- +## Part IV — Cluster-Layer Adaptation (Goal D) + +### 15. K8s Abstractions for TPU + +**One sentence**: K8s can't see light, so OCS slicing must be handled by an independent component; the full chain is TPU device plugin + topology labels + Kueue + TPU Provisioner. + +#### 15.1 Resource Exposure: From Chip to Node + +Physically, TPU chips are not directly Nodes. Each chip (or 4/8-chip board) hangs off a Host VM running Kubelet. + +- **TPU Device Plugin**: Loaded by Kubelet, advertises an extended resource `google.com/tpu: 4` to the API Server. +- **Topology labels**: Knowing the count of TPUs is not enough. The TPU Controller stamps detailed labels: + +```yaml +cloud.google.com/gke-tpu-topology: 2x2x4 # this Node belongs to a 16-chip slice +cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice +``` + +In etcd, these labeled Nodes form logical resource pools. + +#### 15.2 User Interface: Slice CRD + +You don't write a plain Deployment — you submit a wrapped Job or a dedicated `TPUSlice` CRD: + +```yaml +nodeSelector: + cloud.google.com/gke-tpu-topology: 4x4x4 +resources: + requests: + google.com/tpu: 64 +``` + +Meaning: "I need 64 TPUs, physically wired into a 4×4×4 closed Torus." + +#### 15.3 Gang Scheduling: Kueue + TPU Provisioner + +Native `kube-scheduler` cannot do "all-or-nothing across 64 specific-topology nodes". This is **Gang Scheduling** territory; you need a batch scheduler like Kueue. + +End-to-end chain to bring up a 64-chip slice: + +| Step | Who | What | +|---|---|---| +| 1 | Kueue | Intercepts the Job, sees 4×4×4 needed. Resource pool can't satisfy → Pending | +| 2 | Kueue / Cluster Autoscaler | When ready to schedule, triggers TPU Provisioner | +| 3 | **TPU Provisioner** | **Bypasses K8s control plane**, calls the data center's OCS hardware API: "rotate mirrors, give me a 4×4×4 Torus" | +| 4 | OCS | In seconds, mirror angles are set, optical paths locked | +| 5 | Kubelet | Detects ICI link is up, updates Node labels | +| 6 | Kueue | Confirms hardware ready, binds 64 Pods to 64 Nodes in one shot | + +**Key**: TPU Provisioner is an independent component outside K8s, connected to the data center's physical layer. This is the necessary compromise of "K8s can't see light." + +#### 15.4 Failure Self-Healing: OCS Routes Around Bad Chips + +Across months of pre-training or high-availability inference, TPU hardware failures (HBM degradation, blown optical modules) inevitably happen. + +| Step | Action | +|---|---| +| 1 | Kubelet health checks detect hardware error, report Node failure to API Server | +| 2 | K8s tears down all 64 Pods of the Job | +| 3 | TPU Controller marks the bad chip as Unhealthy, removes it from the available pool | +| 4 | Controller calls OCS again: "route around that bad chip, pull a new chip from the pool, reconfigure mirrors, restitch a 4×4×4 ring" | +| 5 | Optical paths restitch; Job restarts, loads previous checkpoint | + +The whole hardware-failure-to-rescheduled cycle typically closes in **a few minutes**, fully automated. + +#### 15.5 ↔ GPU + +| Dimension | TPU on GKE | GPU on K8s | +|---|---|---| +| Device Plugin | TPU Device Plugin | NVIDIA Device Plugin | +| Topology awareness | Labels + TPU Provisioner calling OCS | NVIDIA Topology Aware Scheduling, `gpu-feature-discovery` | +| Gang Scheduling | Kueue | Volcano, Kueue, KubeFlow | +| Physical topology reconfig | OCS dynamic Torus assembly | (no equivalent; NVLink is fixed, IB fat-tree doesn't need reconfig) | + +**Trade-off**: TPU's hard topology constraint makes K8s abstractions more complex (you have to add a Provisioner) but yields dynamic slicing. GPU is simpler because there's no light to manage. + +--- + +### 16. Multi-Host Slice Orchestration + +**One sentence**: When a slice spans multiple hosts, K8s sees N Pods coordinating their startup, while each Pod's TPUs are a 4 / 8-chip local group — N:N at two layers. + +#### 16.1 LWS and JobSet + +To express this multi-host compute, Kubernetes provides two APIs: + +- **LeaderWorkerSet (LWS)**: One Leader Pod + several Worker Pods, all sharing the same lifecycle. The Leader typically exposes the inference API. +- **JobSet**: More general Job-group coordination. + +A 64-chip v4-64 slice corresponds to: + +``` +LWS: + size: 16 # 16 Host VMs + leader: + replicas: 1 + containers: # vLLM API + scheduler + worker: + replicas: 15 + containers: # same Python code (SPMD) +``` + +#### 16.2 SPMD Startup Mode + +The 16 Pods all run **the same Python code** (SPMD = Single Program Multiple Data). The code reads environment variables (`LWS_LEADER_ADDRESS`, `LWS_WORKER_INDEX`) to determine identity, then: + +- Leader: starts an HTTP/gRPC service for users. +- All Pods: via PyTorch/XLA or JAX, init the distributed group, ICI comm self-establishes. +- All Pods' CPUs simultaneously dispatch compute to the 4 TPUs underneath them. +- The 64 TPUs run All-Reduce and similar collectives via ICI. +- Leader collects results, returns to user. + +#### 16.3 Scheduling Coupling: Where a Single Failure Drags the Whole Slice + +| Failure layer | Outcome | +|---|---| +| Single TPU physical failure | Whole ICI ring breaks → 64 Pods' collectives hang → Job fails → 15.4 self-heal flow | +| Single Host VM NIC or Kubelet failure | The Pod's 4 TPUs lose connectivity → same as above | +| DCN (CPU-side ethernet) jitter | Leader↔Worker control-plane sync delay → service tail latency, but data plane (ICI) is unaffected | +| Single VM CPU overload | That Pod feeds data slowly → drags the whole step (slowest VM determines throughput) | + +The blast radius is essentially **slice-level** — any host failure prevents 64 TPUs from continuing. So LWS / JobSet uses gang scheduling: all alive or all dead. + +#### 16.4 ↔ GPU + +| Dimension | TPU multi-host | GPU multi-host | +|---|---|---| +| Orchestration API | LWS / JobSet | MPI Operator, Training Operator, Ray on K8s | +| Startup mode | SPMD | MPI / NCCL group | +| Blast radius | Slice level (optical ring must be complete) | Usually Job level (fat-tree tolerates single-NIC failure) | +| Leader role | Usually API gateway | MPI rank 0, or PS-Worker's PS | + +**Trade-off**: TPU's SPMD + LWS model is clean, but with a large blast radius; GPU's MPI model is flexible but configuration-heavy. + +--- + +## Part V — System Comparison and Trade-offs (Goal B Concentration) + +### 17. Programming Model Chains: From Single Card to Multi-Machine + +**One sentence**: GPU is "single-card CUDA → multi-card NCCL → multi-machine IB/RDMA" — three segments; TPU is "SPMD → ICI (VLIW slot 5)" — one segment, run by the compiler. + +#### 17.1 GPU's Three Segments + +**Single card (Tensor Core)** + +``` +Python (PyTorch) → ATen → cuBLAS / Triton → SASS / PTX + ├ LDG.E (HBM → register) + ├ STS / LDS (Shared Memory hop) + └ HMMA.1688.F16 (Tensor Core triggers 16×8×16 half-precision MAC) +``` + +Data path: **HBM → register → Shared Memory → register → Tensor Core**, ping-ponging. + +**Within node, multi-card (NVLink + Copy Engine)** + +``` +NCCL → CUDA Kernel + Copy Engine +DMA: GPU0_HBM → NVLink bus → GPU1_HBM +Reduction: GPU1's SM does LDG / FADD / STG (data lands in HBM before being added) +``` + +NVLink provides Unified Virtual Addressing (UVA), but **the data must first land in the receiving GPU's HBM**, then the receiver's SM reads from HBM, adds locally, writes back to HBM. Eats a lot of HBM bandwidth. + +**Cross-machine (IB + GPUDirect RDMA)** + +``` +MMIO writes NIC Doorbell → NIC DMAs from GPU HBM → IB packets encapsulate → through Spine-Leaf switches → peer NIC decapsulates → DMA into peer GPU HBM +Sync: Receiver GPU CUDA Kernel spin-waits on an HBM sync flag (LDG.CG bypasses cache) +``` + +Control plane has CPU interrupts, protocol encapsulation, route lookup; data plane has switch congestion control and queueing. **Async event-driven.** + +#### 17.2 TPU's Single Segment + +``` +JAX / PyTorch (via PyTorch/XLA) → HLO → XLA → VLIW 5-slot instruction stream + ├ DMA (HBM → UB) + ├ MXU (systolic array MAC) + ├ VPU (vector ops) + ├ ICI (cross-chip communication, peer to DMA) + └ SPU (control flow) + +Cross-chip communication: + TX_UB_TO_NEIGHBOR (Src: UB_local, Dest_Node: Neighbor_ID) + WAIT_ICI_RX_SEMAPHORE + ADD_VECTOR (Src1: UB_local, Src2: UB_remote, Dest: UB_result) +``` + +Cross-chip data transfer and intra-chip moves are **logically equivalent** — both are switches in the VLIW word. XLA can have the MXU multiplying while ICI transmits the previous layer's gradient — cycle-level overlap. + +#### 17.3 The Comparison + +| Dimension | GPU | TPU | +|---|---|---| +| Cross-node comm essence | Async IO (CPU interrupts, protocols, switches) | Synchronous instruction (one VLIW slot) | +| Receiver-side handling | spin-wait HBM flag | Hardware semaphore, instant wake | +| Data landing | HBM (mandatory toll) | UB (direct to VPU) | +| Compiler view | Can't see cross-node | Knows every hop and latency | +| Failure tolerance | Node-level isolation | Slice-level coupling | + +#### 17.4 Three Analogies + +- **Single-machine GPU compute**: A massively busy interchange (huge cache + scheduler). Congested, but precise traffic lights (Warp Scheduler) keep throughput high. +- **Cross-machine GPU RDMA**: Cross-province highway logistics. Pack and seal → onto highway → off highway. Toll fees (protocol overhead) + unpredictable jams (congestion). +- **TPU Pod (VLIW + ICI + 3D Torus)**: A massive fully-automated production line. Every conveyor (fiber) and arm (MXU/VPU) is hard-wired. XLA is the master schedule, ensuring every part arrives at the right station on a precise cycle. + +--- + +### 18. Cost / Efficiency + +**One sentence**: MFU and Tokens/$ are the levers that measure real ledger differences — not chip peak FLOPs. + +#### 18.1 Compute Unit Cost: NVIDIA Tax + +| Dimension | NVIDIA H100 | Google TPU v5p | +|---|---|---| +| Equivalent compute hardware cost (industry estimate) | **~$21,000+** | **~$6,900** | +| Roughly ~3× spread | Known as **NVIDIA Tax** | | + +Cloud on-demand prices: + +| Form | Price | +|---|---| +| 8×H100 VM (Azure / GCP) | **$100 - $120 / hour** (per chip ~$12-$15) | +| TPU v5e (inference) | **~$1.20 / hour** (single chip) | +| 8-chip v5e node | **$10 - $11 / hour** | + +#### 18.2 MFU (Model FLOPs Utilization) + +Actual TFLOPs ÷ hardware peak: + +| Chip | Typical LLM training MFU | +|---|---| +| H100 | 50% - 52% (thread scheduling, cache contention, complex control flow overhead) | +| TPU v5p | 58% - 60% or higher (XLA static orchestration + ICI deterministic latency) | + +#### 18.3 Performance / Watt + +- **H100 TDP 700W**: must power L1/L2 cache and out-of-order scheduler. +- **TPU drops those modules**, relying on a minimal MXU systolic array. Statistics show that on certain large matrix workloads, TPU v5e/v5p energy is **60-65% lower** than GPU clusters (in some scenarios efficiency is 2-5× H100). + +Saves on power + lowers data center cooling/infrastructure costs. + +#### 18.4 Real-World Tokens / Dollar + +| Case | Improvement | +|---|---| +| **Large model pre-training** (H100 → TPU v5p) | Tokens / $ higher by **15% - 25%** | +| **Midjourney** (image gen, migrated to TPU v6e) | Inference bill from **$2.1M / month** down to **<$0.7M / month**, a **3×** cost reduction with throughput maintained | +| **Character.AI** (high-concurrency dialog) | After migration to TPU, cost improvement of **3.8×** | +| **Waymark** (video diffusion) | Cost **4×** lower than H100 | + +#### 18.5 One-Sentence Industry Picture + +- For research teams that iterate fast, modify operators frequently, and depend on the deep PyTorch/CUDA ecosystem → GPU. +- For super-scale, structurally stable pre-training, or hundred-million-user high-concurrency LLM inference → TPU clusters. + +> **[Note — added by Claude]** The above cases (Midjourney $2.1M→$700K, Character.AI 3.8×, Waymark 4×) lack precise sources and dates in the original conversation; only "public reports show". I have no other reliable confirmation of the specific numbers — please verify. + +--- + +### 19. TPU's Hardware Disadvantages and Trade-offs + +**One sentence**: Each advantage of static scheduling has a workload it can't handle well — large MXU granularity, weak SPU, 3D Torus All-to-All congestion, and HBM bandwidth/compute mismatch. + +#### 19.1 Compute Granularity: The Fragmentation Penalty of MXU 128×128 + +GPU Tensor Core is 16×8×16; TPU MXU is 128×128. On a real inference need that can't align to multiples of 128 (e.g., batch size 5): + +- GPU: Warp Scheduler tightly packs the fragments into SMs, hardware utilization stays decent. +- TPU: A lot of physical ALUs in the MXU **literally compute 0×W=0**, wasting cycles. + +For high-frequency, low-concurrency, low-latency inference requests, TPU's physical compute is severely wasted. + +#### 19.2 Weak Scalar / Branch Control: The Pain of Sampling and Speculative Decoding + +LLMs aren't all matmul — the final step in token generation is **sampling** (Top-K, Top-P, temperature, penalty factors), involving lots of sorting, conditional branches, and scalar ops. + +- GPU: Massive CUDA Cores + SIMT branch prediction; can dispatch tens of thousands of threads concurrently to handle array ops with logic. +- TPU: SPU compute is extremely weak; VPU only excels at regular vectors. Facing if-else heavy sampling, efficiency tanks. + +Even worse is **vanilla speculative decoding** — hardware needs to rapidly determine which tokens are accepted and dynamically discard parts of the compute graph. This "step then reassess" dynamic graph is the antithesis of TPU VLIW static instructions. Hence Gemini's Tree Attention tensorization (Section 14.2) forces this kind of computation into matmul. + +#### 19.3 Dynamic Network Routing: MoE All-to-All Congestion + +MoE's core is **dynamic routing** — each token at runtime is sent to a different expert: + +| Cluster | All-to-All behavior | +|---|---| +| GPU (NVSwitch + IB fat-tree) | Any N:N comm gets non-blocking full-cross bandwidth — friendly to MoE's chaotic, dynamic packet dispatch | +| TPU (3D Torus ring) | Static All-Reduce is unmatched; but in All-to-All, tokens must traverse multiple intermediaries across X/Y/Z axes to find their expert. **Some links get jammed, others idle**, dragging end-to-end latency | + +#### 19.4 Big Decode Batches Don't Save MoE: Two Dead-Ends + +Intuition: large batch raises MoE comm density, so compute could mask comm. Two physical dead-ends: + +**Dead-end 1: Tokens disperse, matrices stay small** + +512 concurrent requests → MoE layer → routed to 8 experts → average 64 tokens per expert. `[64, D] × [D, 4D]` is "between the teeth" for the MXU — far short of saturating compute to mask comm latency. + +**Dead-end 2: KV Cache breaks HBM** + +You can't grow batch unboundedly. To saturate MoE compute (thousands), HBM is long since OOM. + +**Conclusion**: In Decode, MoE comm overhead can only be mitigated, not fully masked. + +#### 19.5 Multi-Hop Affects Latency or Bandwidth + +Across racks on a 3D Torus to find experts: + +- **Small batch**: Dominated by **latency**. Optical relay + transceiver physical delay can't be skipped. +- **Big batch**: The killer is **bisection bandwidth**. With tokens scattering everywhere, some fibers get instantly overloaded; effective bandwidth collapses. + +#### 19.6 The Math Behind Compute-Masks-Comm + +Why does Prefill mask comm but Decode can't? Look at the dimensions: + +| Phase | Compute | Data transfer | Ratio | +|---|---|---|---| +| **Prefill** (GEMM) | $O(N^3)$ | $O(N^2)$ | Compute time ≫ network time, DMA stages in the background, MXU completely unaware | +| **Decode** (GEMV) | $O(N^2)$ | $O(N^2)$ | Roughly equal, MXU finishes computing instantly then has to stop and wait | + +#### 19.7 HBM Bandwidth vs Compute Imbalance (Echoing 6.4) + +Chapter 6's physical law: compute scales $O(N^2)$, traditional bandwidth scales $O(N)$. The v4 era was severely imbalanced; Decode MFU dropped to single digits. + +Google's hardware-side mitigation: **the v5e shrinks the MXU on purpose**, lowering peak FLOPs to bring the compute/bandwidth ratio into a healthy zone. + +Algorithmic mitigation: **MQA / GQA** (Multi-Query / Grouped-Query Attention) — drastically shrinks KV Cache, reducing each Decode step's HBM pull. This is a model-architecture concession purely to placate poor memory bandwidth. + +#### 19.8 What Each Weakness Buys + +| Weakness | Buys | +|---|---| +| MXU large granularity | Higher compute density per silicon area, better energy efficiency | +| Weak SPU | Saved transistors fed back into the MXU | +| 3D Torus weak at All-to-All | Excellent regular collective comm, no external switch overhead | +| HBM bandwidth lags | Compute density off-the-charts; high MFU on Prefill / training | + +Each weakness corresponds to a trade-off. Understanding these is what lets you decide which workloads belong on TPU and which belong on GPU. + +--- + +## Appendix + +### A. Trade-off Cheat Sheet + +Cross-cut by design dimension; each trade-off links back to its chapter. + +| Dimension | TPU choice | Cost | Benefit | Chapters | +|---|---|---|---|---| +| **Static vs dynamic scheduling** | Static VLIW + XLA | High compile cost, shape changes recompile | Minimal hardware, high efficiency | Ch 1, 7, 8 | +| **Cache vs direct passthrough** | No hardware cache, only UB + DMA | Software complexity | Save silicon area for the MXU | Ch 1, 6 | +| **Granularity large vs small** | MXU 128×128 | Wastes on small matrices | High density on large matrices | Ch 1, 19 | +| **Ring vs tree** | 3D Torus | All-to-All congestion, long-edge ring latency | Excellent All-Reduce, no external switch | Ch 2, 4, 19 | +| **Physical vs optical** | OCS reconfigurable | Slicing limited by rack granularity | Dynamic topology + failure self-heal | Ch 3, 15 | +| **Centralized vs distributed orchestration** | Multi-host SPMD | Blast radius = slice | Clean orchestration, transparent SPMD | Ch 5, 16 | +| **Compute vs bandwidth** | Compute scales by area, bandwidth via packaging | Memory wall, low Decode MFU | High Prefill / training MFU | Ch 6, 19 | +| **Specialized vs general** | Pallas for PagedAttention | High engineering bar; every dynamic op needs a hand-write | Breaks XLA's static limit | Ch 11, 14 | +| **Padding vs recompile** | Keep batch buckets, fill with dummies | Wastes a small fraction of compute | Avoid recompile disaster | Ch 8, 12 | +| **Algorithm yields to hardware** | Capacity Factor, Tree Attention | Token dropping, increased mask complexity | TPU can run MoE and speculative decoding | Ch 14 | + +### B. Numbers / Parameters List + +All numbers labeled "from source conversation". + +| Item | Value | Note | +|---|---|---| +| TPU v4 Pod scale | **4096 chips** | 64 racks × 64 chips | +| TPU v5p Pod scale | **8960 chips** | Full scale | +| Single-rack chip count | **64** (v4 water-cooled) | 4 chips/board × 16 boards | +| MXU size | **128 × 128** MAC cells | v4 / v5p | +| 4×4×4 rack outgoing fibers | **96** | 8 corners×3 + 24 edges×2 + 24 face-centers×1 | +| Surface TPUs | **56** | 64 - inner 8 | +| Equivalent H100 hardware cost | **~$21,000** | NVIDIA selling price | +| Equivalent TPU v5p hardware cost | **~$6,900** | Google internal | +| 8×H100 VM on-demand | **$100-120 / hour** | Azure / GCP | +| 8×TPU v5e node on-demand | **$10-11 / hour** | Google Cloud | +| H100 TDP | **700W** | | +| H100 training MFU | **50% - 52%** | Large LLM clusters | +| TPU v5p training MFU | **58% - 60%** or higher | Equivalent task | +| TPU energy advantage | **60-65% lower** than GPU clusters | Specific large matrix workloads | +| Tokens / $ advantage (pre-training) | TPU higher by **15% - 25%** | H100 → v5p | +| Midjourney inference bill | **$2.1M → <$0.7M / month** | Migrated to TPU v6e, 3× | +| Character.AI cost improvement | **3.8×** | After TPU migration | +| Waymark video gen | **4×** | Lower than H100 | +| CPU:TPU ratio | **1:4 or 1:8** | Hard-wired physically | +| ICI single-link bandwidth order | (not in source) | Public ~4.5 TB/s aggregated multi-direction on v4, please verify | +| H100 L2 Cache | **50 MB** | Mentioned in source | +| MoE Capacity Factor example | 64 tokens per expert | Source example | +| v5p possible long-edge | **35** | 16×16×35 ≈ 8960 estimate | +| Cross-rack Z-axis 8-hop ring | **6 copper hops + 2 optical hops** | 4×4×8 slice | +| Optical vs copper latency | Optical **hundreds of ns**, copper **single to low-double-digit ns** | NUCA heterogeneity | + +### C. Terminology ↔ GPU Equivalent Mapping + +| TPU term | GPU equivalent | Note | +|---|---|---| +| **MXU** | Tensor Core | TPU one big array vs GPU many small arrays | +| **VPU** | CUDA Core (partial) | TPU biased toward regular vectors | +| **SPU** | Scalar dispatch + registers | TPU control flow is weak | +| **Unified Buffer (UB)** | Shared Memory + L1/L2 | TPU software-managed; GPU hardware-managed | +| **HBM** | HBM | Same | +| **ICI** | NVLink + IB (combined) | TPU one fabric inside and outside; GPU two layers | +| **OCS** | (no equivalent) | Unique | +| **3D Torus** | Fat-Tree (different idea) | Different topology philosophy | +| **SPMD** | NCCL collective + MPI rank | TPU compiler-managed | +| **VLIW 5-slot** | (no full equivalent) | GPU is SIMT | +| **XLA** | TorchInductor / TVM / TensorRT | TPU's core; GPU's optional | +| **HLO / StableHLO** | FX Graph / TorchScript | XLA's IR | +| **Pallas** | Triton | Custom kernel language | +| **JetStream** | TensorRT-LLM | Vendor-specialized inference engine | +| **Saxml** | DeepSpeed Inference (partial) | JAX legacy | +| **TPU Provisioner** | (no equivalent) | Closest is Slurm topology + manual NCCL | +| **TPU Device Plugin** | NVIDIA Device Plugin | Same K8s concept | +| **LWS / JobSet** | MPI Operator / Training Operator | Multi-host orchestration | +| **DCN (datacenter network)** | Control-plane IB / Ethernet | TPU uses ethernet; GPU often uses IB | +| **Capacity Factor** (MoE) | (no TPU-specific equivalent; concept exists on GPU but isn't enforced) | Static-ization trick | +| **Tree Attention** (speculative decoding) | Same name | At the algorithm level; now used on GPU too | +| **Bucketing + AOT** | torch.compile + persistent cache | Mandatory on TPU; optional on GPU | +| **NUCA topology-aware mapping** | NCCL topology discovery | Auto in XLA on TPU; semi-manual on GPU | + +--- + +## Writing Log (For Author Verification) + +### Intentional Cuts (Items in source not carried into the note) + +The following details were noticed but not carried forward — please decide whether to add them back. + +1. **vLLM `model.generate()` token-generation call stack** (around source lines 850-870): A detailed walk through PyTorch eager invoking ATen line-by-line under Lazy Tensor. I judged this to be a GPU-side eager-mode explanation, not the focus of TPU notes; simplified to a single "PyTorch Lazy Tensor" line in Ch 8.1's table. +2. **Specific GPU single-card SASS instruction example** (around source lines 700-720): `LDG.E`, `STS [Shared_Addr_A]`, `HMMA.1688.F16` instruction names are used only once in Ch 17.1. If you'd like more detailed GPU instruction comparison, this can be expanded. +3. **SPMD startup detail of Pod ID reading** (source line 604): "code internally reads the hardware Device_ID to decide which data block to load". I simplified to "reads environment variables to determine identity" in Ch 16.2, without distinguishing TPU device ID and K8s pod environment variables. Worth expanding? +4. **TPU v6e (Trillium)**: Source line 1707 mentions this is the "sixth generation" pure-inference chip, but no specific specs. The Midjourney case mentions its name. I didn't open a section. If you want a v6e-specific design trade-off section, this can be expanded. +5. **OCS internal MEMS array two-layer structure** (source lines 1521-1538): The first mirror aims direction, the second mirror corrects beam, preventing fiber-input angle deviation attenuation. I merged this into Ch 3.1 as "input → reflect → reflect → output", without emphasizing the necessity of two-layer MEMS. Refine? +6. **OCS external physical wiring (MPO/MTP high-density optical cables)**: Source line 2278 describes how 96 fibers in the data center are bundled via MPO/MTP at 16/32 fibers each. I mentioned this once but didn't elaborate on physical form. +7. **Capacity Factor cost detail**: Source lines 1628-1631 mention token dropping or "forced through a fallback general network". I only wrote dropping, didn't write the fallback network branch. +8. **Sampling specifics** (Top-K / Top-P / temperature / penalty factors): Ch 19.2 mentions this is an SPU weakness but doesn't expand on each sampling algorithm's hardware cost. +9. **Cerebras / Groq designs** (source lines 1500-1510): As a side note on "what other vendors are doing", mentions Wafer-Scale and LPU. I didn't include them in the main text since they aren't TPU-themed. Worth adding an Appendix D "industry cross-reference"? +10. **"Native sparsity / clock-gated power off"** (source line 1810): As a next-gen chip evolution direction. I touched it briefly in Ch 14.4 without a dedicated subsection. + +### External Additions (Claude added, not in source) + +The following are explicitly tagged with `> **[Note — added by Claude]**` in the body: + +- Ch 2.1: ICI single-link 4.5 TB/s order (public source, not in original) +- Ch 3.6: Microsoft Azure trialing Lumen OCS in some clusters (industry note, not in main text) +- Ch 6.4: "Single-digit MFU in v4 era" lacks other sources to verify +- Ch 13.5: Mooncake-style KV pool reference (not in main text) +- End of Ch 18: Midjourney / Character.AI / Waymark specific numbers — source only says "public reports show", no year or attribution. Please verify. + +If there are missed cuts, raise them on the PR — I'll address them inline. + + From a2ffce4a2bd9044623d1ede7d055e00b6d006ba1 Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 6 May 2026 19:35:11 +0000 Subject: [PATCH 8/8] nav: Surface TPU deep dive in mkdocs nav MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add the new bilingual TPU note to both the 中文版 -> 技术笔记 and English -> Tech Notes sections of mkdocs.yml so it appears in the published site sidebar after merge to main. Co-authored-by: Claude --- mkdocs.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mkdocs.yml b/mkdocs.yml index 796426d..33e13c1 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -64,6 +64,7 @@ nav: - "体系化专题": - "LLM 从原理到生产级推理": https://random-liu.github.io/llm-inference-principle-to-production/ - "技术笔记": + - "TPU 深入笔记:从单芯片到生产集群": notes/tpu-deep-dive.cn.md - "示例笔记": notes/example_note.cn.md - "敬请期待": notes/coming_soon.cn.md - "English": @@ -71,6 +72,7 @@ nav: - "Structured Projects": - "LLM Inference": https://random-liu.github.io/llm-inference-principle-to-production/ - "Tech Notes": + - "TPU Deep Dive: From Single Chip to Production Cluster": notes/tpu-deep-dive.md - "Example Note": notes/example_note.md - "Coming Soon": notes/coming_soon.md