diff --git a/.gitignore b/.gitignore index 896b38a12..0084a84f1 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,7 @@ save* .log *.pid *.ipynb* +model/ +output_* +HiFloat4/ +datasets/ diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 000000000..6babffdef --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,5 @@ +我需要给wan2.2(https://huggingface.co/Wan-AI/Wan2.2-Animate-14B-Diffusers)进行hifp4的模拟量化,使用的方法是AWQ +但是比赛给的要求允许2个transformer-block的权重不进行量化,我觉得transformer和transformer_2的第0个block的权重不进行量化是收益最大的,请在当前branch基础上新建一个branch进行相应修改,并在修改完上传github +请参考配置文件configs/quantization/video_gen/wan2_2_t2v/awq_w_a.yaml和和运行脚本scripts/run_llmc.sh,帮我进行修改 +请注意,我现在的电脑是本地主机而不是服务器,所以需要你从代码本身的逻辑去修改而不能真的运行 +你有权限修改本文件夹下所有文件 \ No newline at end of file diff --git a/configs/quantization/video_gen/wan2_2_t2v/awq_w_a.yaml b/configs/quantization/video_gen/wan2_2_t2v/awq_w_a.yaml new file mode 100644 index 000000000..07b05b239 --- /dev/null +++ b/configs/quantization/video_gen/wan2_2_t2v/awq_w_a.yaml @@ -0,0 +1,66 @@ +base: + seed: &seed 42 +model: + type: Wan2T2V + path: /mnt/lm_data_afs/wangzining/charles/lab/llmc/model/Wan2.2-T2V-A14B + # 若未 `pip install -e /path/to/Wan2.2`,可显式指定官方仓库代码路径: + # wan2_repo_path: /path/to/Wan2.2 + # 默认严格走官方 Wan2.2 原生后端;官方代码不可用时会直接报错,不再静默回退到 Diffusers。 + # 若确实需要回退可开启: + # allow_diffusers_fallback: True + torch_dtype: auto + # 显存不足时开启:校准阶段捕获的激活存到 CPU,量化时再按 block 搬到 GPU + use_cpu_to_save_cuda_mem_for_catcher: True +calib: + name: t2v + download: False + path: ./assets/wan_t2v/calib/ + sample_steps: 20 # OOM 时可减小,如 8 或 10 + bs: 1 + target_height: 480 # OOM 时可减小,如 320 + target_width: 832 # OOM 时可减小,如 576 + num_frames: 81 # OOM 时可减小,如 49 或 33 + # 对齐官方 Wan2.2 默认 sample_guide_scale=(3.0, 4.0) (low_noise, high_noise) + guidance_scale: 4.0 # high_noise + guidance_scale_2: 3.0 # low_noise + seed: *seed +eval: + eval_pos: [transformed, fake_quant] + type: video_gen + name: t2v + download: False + path: ./assets/wan_t2v/calib/ + bs: 1 + target_height: 480 + target_width: 832 + num_frames: 81 + # 对齐官方 Wan2.2 默认 sample_guide_scale=(3.0, 4.0) (low_noise, high_noise) + guidance_scale: 4.0 # high_noise + guidance_scale_2: 3.0 # low_noise + output_video_path: ./output_videos_awq/ +quant: + video_gen: + method: Awq + weight: + # quant_type: int-quant + quant_type: hif4 + bit: 4 + symmetric: True + granularity: per_channel + group_size: -1 + act: + # quant_type: int-quant + quant_type: hif4 + bit: 4 + symmetric: True + granularity: per_token + special: + trans: True + trans_version: v2 + weight_clip: True + clip_sym: True +save: + # save_lightx2v: True + # save_path: ./save_for_lightx2v/wan2_2_t2v/awq_w_a/original/ + save_fake: True + save_path: ./save_for_fake/wan2_2_t2v/awq_w_a/original/ diff --git a/configs/quantization/video_gen/wan2_2_t2v/awq_w_a_skip_first.yaml b/configs/quantization/video_gen/wan2_2_t2v/awq_w_a_skip_first.yaml new file mode 100644 index 000000000..ef8e671da --- /dev/null +++ b/configs/quantization/video_gen/wan2_2_t2v/awq_w_a_skip_first.yaml @@ -0,0 +1,73 @@ +base: + seed: &seed 42 +model: + type: Wan2T2V + path: /mnt/lm_data_afs/wangzining/charles/lab/llmc/model/Wan2.2-T2V-A14B + # 若未 `pip install -e /path/to/Wan2.2`,可显式指定官方仓库代码路径: + # wan2_repo_path: /path/to/Wan2.2 + # 默认严格走官方 Wan2.2 原生后端;官方代码不可用时会直接报错,不再静默回退到 Diffusers。 + # 若确实需要回退可开启: + # allow_diffusers_fallback: True + torch_dtype: auto + # 显存不足时开启:校准阶段捕获的激活存到 CPU,量化时再按 block 搬到 GPU + use_cpu_to_save_cuda_mem_for_catcher: True +calib: + name: t2v + download: False + path: ./assets/wan_t2v/calib/ + sample_steps: 20 # OOM 时可减小,如 8 或 10 + bs: 1 + target_height: 480 # OOM 时可减小,如 320 + target_width: 832 # OOM 时可减小,如 576 + num_frames: 81 # OOM 时可减小,如 49 或 33 + # 对齐官方 Wan2.2 默认 sample_guide_scale=(3.0, 4.0) (low_noise, high_noise) + guidance_scale: 4.0 # high_noise + guidance_scale_2: 3.0 # low_noise + seed: *seed +eval: + eval_pos: [transformed, fake_quant] + type: video_gen + name: t2v + download: False + path: ./assets/wan_t2v/calib/ + bs: 1 + target_height: 480 + target_width: 832 + num_frames: 81 + # 对齐官方 Wan2.2 默认 sample_guide_scale=(3.0, 4.0) (low_noise, high_noise) + guidance_scale: 4.0 # high_noise + guidance_scale_2: 3.0 # low_noise + output_video_path: ./output_videos_awq_skip_first/ +quant: + video_gen: + method: Awq + weight: + # quant_type: int-quant + quant_type: hif4 + bit: 4 + symmetric: True + granularity: per_channel + group_size: -1 + act: + # quant_type: int-quant + quant_type: hif4 + bit: 4 + symmetric: True + granularity: per_token + special: + trans: True + trans_version: v2 + weight_clip: True + clip_sym: True +# Skip AWQ transformation and fake-quant deployment for: +# block 0 → transformer expert (high-noise), first block +# block 40 → transformer_2 expert (low-noise), first block +# (transformer has 40 blocks, so transformer_2 starts at index 40) +# Leaving layer_names empty means ALL linear layers in those blocks are skipped. +ignored_layers: + block_ids: [0, 40] +save: + # save_lightx2v: True + # save_path: ./save_for_lightx2v/wan2_2_t2v/awq_w_a/skip_first/ + save_fake: True + save_path: ./save_for_fake/wan2_2_t2v/awq_w_a/skip_first/ diff --git a/configs/quantization/video_gen/wan_i2v/awq_w_a.yaml b/configs/quantization/video_gen/wan_i2v/awq_w_a.yaml index 680fab43b..1b1097ad7 100755 --- a/configs/quantization/video_gen/wan_i2v/awq_w_a.yaml +++ b/configs/quantization/video_gen/wan_i2v/awq_w_a.yaml @@ -2,7 +2,7 @@ base: seed: &seed 42 model: type: WanI2V - path: /path/to/model + path: /mnt/lm_data_afs/wangzining/charles/lab/llmc/models/Wan2.2-T2V-A14B/ torch_dtype: auto calib: name: i2v @@ -31,12 +31,12 @@ quant: video_gen: method: Awq weight: - bit: 8 + bit: 4 symmetric: True granularity: per_channel group_size: -1 act: - bit: 8 + bit: 4 symmetric: True granularity: per_token special: @@ -46,4 +46,4 @@ quant: clip_sym: True save: save_lightx2v: True - save_path: /path/to/x2v/ + save_path: ../lightx2v/wan_i2v_awq_w_a/x2v/ diff --git a/configs/quantization/video_gen/wan_i2v/smoothquant_w_a_fp8_example.yaml b/configs/quantization/video_gen/wan_i2v/smoothquant_w_a_fp8_example.yaml new file mode 100644 index 000000000..adba728d0 --- /dev/null +++ b/configs/quantization/video_gen/wan_i2v/smoothquant_w_a_fp8_example.yaml @@ -0,0 +1,57 @@ +# Wan2.1 I2V FP8 量化配置示例 +# 这是一个快速开始的配置文件,请根据实际情况修改路径 + +base: + seed: &seed 42 + +model: + type: WanI2V + path: /path/to/wan2.1-i2v-model # 修改为你的 Wan2.1 I2V 模型路径 + torch_dtype: auto + +calib: + name: i2v + download: False + path: /path/to/calibration/data # 修改为你的校准数据路径 + sample_steps: 40 + bs: 1 + target_height: 480 + target_width: 832 + num_frames: 81 + guidance_scale: 5.0 + seed: *seed + +eval: + eval_pos: [fake_quant] + type: video_gen + name: i2v + download: False + path: /path/to/eval/data # 修改为你的评估数据路径 + bs: 1 + target_height: 480 + target_width: 832 + num_frames: 81 + guidance_scale: 5.0 + output_video_path: ./output_videos_fp8/ + +quant: + video_gen: + method: SmoothQuant + weight: + quant_type: float-quant + bit: e4m3 # FP8 E4M3 格式 + symmetric: True + granularity: per_channel + use_qtorch: True + act: + quant_type: float-quant + bit: e4m3 # FP8 E4M3 格式 + symmetric: True + granularity: per_token + use_qtorch: True + special: + alpha: 0.75 # SmoothQuant 平衡参数,范围 0.5-1.0 + +save: + save_lightx2v: True # 保存为 lightx2v 兼容格式 + save_path: /path/to/save/quantized/model # 修改为你的保存路径 diff --git a/configs/quantization/video_gen/wan_t2v/awq_w_a.yaml b/configs/quantization/video_gen/wan_t2v/awq_w_a.yaml index 14d05479d..ec6d8714e 100755 --- a/configs/quantization/video_gen/wan_t2v/awq_w_a.yaml +++ b/configs/quantization/video_gen/wan_t2v/awq_w_a.yaml @@ -2,12 +2,12 @@ base: seed: &seed 42 model: type: WanT2V - path: /path/to/wan_t2v + path: /mnt/lm_data_afs/wangzining/charles/lab/llmc/models/Wan2.1-T2V-14B-Diffusers torch_dtype: auto calib: name: t2v download: False - path: ../assets/wan_t2v/calib/ + path: ./assets/wan_t2v/calib/ sample_steps: 20 bs: 1 target_height: 480 @@ -20,7 +20,7 @@ eval: type: video_gen name: t2v download: False - path: ../assets/wan_t2v/calib/ + path: ./assets/wan_t2v/calib/ bs: 1 target_height: 480 target_width: 832 @@ -31,12 +31,12 @@ quant: video_gen: method: Awq weight: - bit: 6 + bit: 4 symmetric: True granularity: per_channel group_size: -1 act: - bit: 6 + bit: 4 symmetric: True granularity: per_token special: @@ -46,4 +46,4 @@ quant: clip_sym: True save: save_lightx2v: True - save_path: /path/to/x2v/ + save_path: ../lightx2v/wan_t2v_awq_w_a/x2v/ diff --git a/configs/quantization/video_gen/wan_t2v/awq_w_a_s.yaml b/configs/quantization/video_gen/wan_t2v/awq_w_a_s.yaml new file mode 100755 index 000000000..f140839e3 --- /dev/null +++ b/configs/quantization/video_gen/wan_t2v/awq_w_a_s.yaml @@ -0,0 +1,49 @@ +base: + seed: &seed 42 +model: + type: WanT2V + path: /mnt/lm_data_afs/wangzining/charles/lab/llmc/models/Wan2.1-T2V-1.3B-Diffusers + torch_dtype: auto +calib: + name: t2v + download: False + path: ./assets/wan_t2v/calib/ + sample_steps: 20 + bs: 1 + target_height: 480 + target_width: 832 + num_frames: 81 + guidance_scale: 5.0 + seed: *seed +eval: + eval_pos: [transformed, fake_quant] + type: video_gen + name: t2v + download: False + path: ./assets/wan_t2v/calib/ + bs: 1 + target_height: 480 + target_width: 832 + num_frames: 81 + guidance_scale: 5.0 + output_video_path: ./output_videos_awq/ +quant: + video_gen: + method: Awq + weight: + bit: 4 + symmetric: True + granularity: per_channel + group_size: -1 + act: + bit: 4 + symmetric: True + granularity: per_token + special: + trans: True + trans_version: v2 + weight_clip: True + clip_sym: True +save: + save_lightx2v: True + save_path: ../lightx2v/wan_t2v_awq_w_a_s/x2v/ diff --git a/configs/quantization/video_gen/wan_t2v/rtn_w_a.yaml b/configs/quantization/video_gen/wan_t2v/rtn_w_a.yaml deleted file mode 100755 index b6a53b0e0..000000000 --- a/configs/quantization/video_gen/wan_t2v/rtn_w_a.yaml +++ /dev/null @@ -1,32 +0,0 @@ -base: - seed: &seed 42 -model: - type: WanT2V - path: /path/to/wan_t2v - torch_dtype: auto -eval: - eval_pos: [transformed, fake_quant] - type: video_gen - name: t2v - download: False - path: ../assets/wan_t2v/eval/ - bs: 1 - target_height: 480 - target_width: 832 - num_frames: 81 - guidance_scale: 5.0 - output_video_path: ./output_videos_rtn/ -quant: - video_gen: - method: RTN - weight: - bit: 6 - symmetric: True - granularity: per_channel - act: - bit: 6 - symmetric: True - granularity: per_token -save: - save_lightx2v: True - save_path: /path/to/x2v/ diff --git a/configs/quantization/video_gen/wan_t2v/smoothquant_w_a.yaml b/configs/quantization/video_gen/wan_t2v/smoothquant_w_a.yaml index 7d65f31fc..f76edd294 100755 --- a/configs/quantization/video_gen/wan_t2v/smoothquant_w_a.yaml +++ b/configs/quantization/video_gen/wan_t2v/smoothquant_w_a.yaml @@ -2,12 +2,12 @@ base: seed: &seed 42 model: type: WanT2V - path: /path/to/wan_t2v + path: /mnt/lm_data_afs/wangzining/charles/lab/llmc/models/Wan2.2-T2V-14B-Diffusers torch_dtype: auto calib: name: t2v download: False - path: ../assets/wan_t2v/calib/ + path: ./assets/wan_t2v/calib/ sample_steps: 20 bs: 1 target_height: 480 @@ -20,26 +20,30 @@ eval: type: video_gen name: t2v download: False - path: ../assets/wan_t2v/calib/ + path: ./assets/wan_t2v/calib/ bs: 1 target_height: 480 target_width: 832 num_frames: 81 guidance_scale: 5.0 - output_video_path: ./output_videos_sq/ + output_video_path: ./output_videos_awq/ quant: video_gen: - method: SmoothQuant + method: Awq weight: - bit: 6 + bit: 4 symmetric: True granularity: per_channel + group_size: -1 act: - bit: 6 + bit: 4 symmetric: True granularity: per_token special: - alpha: 0.7 + trans: True + trans_version: v2 + weight_clip: True + clip_sym: True save: save_lightx2v: True - save_path: /path/to/x2v/ + save_path: ../lightx2v/wan_t2v_awq_w_a/x2v/ diff --git a/docs/wan2.1_quantization_guide.md b/docs/wan2.1_quantization_guide.md new file mode 100644 index 000000000..eeef5ac63 --- /dev/null +++ b/docs/wan2.1_quantization_guide.md @@ -0,0 +1,288 @@ +# Wan2.1 视频生成模型量化指南 + +## 概述 + +llmc 框架现已全面支持 Wan2.1 系列视频生成模型的量化,并提供真正量化的 INT8/FP8 权重导出,与 lightx2v 推理框架兼容。 + +## 支持的模型类型 + +- **WanI2V**: Image-to-Video (图像到视频) +- **WanT2V**: Text-to-Video (文本到视频) + +## 支持的量化方法 + +### FP8 量化 (推荐) + +**配置文件**: `configs/quantization/video_gen/wan_i2v/smoothquant_w_a_fp8.yaml` + +**特点**: +- 使用 E4M3 FP8 格式 (8-bit 浮点数,4位指数,3位尾数) +- SmoothQuant 算法,平衡权重和激活的量化难度 +- 适合 GPU 推理,性能损失小 + +**量化配置**: +```yaml +quant: + video_gen: + method: SmoothQuant + weight: + quant_type: float-quant + bit: e4m3 # FP8 E4M3 格式 + symmetric: True + granularity: per_channel + use_qtorch: True + act: + quant_type: float-quant + bit: e4m3 # FP8 E4M3 格式 + symmetric: True + granularity: per_token + use_qtorch: True + special: + alpha: 0.75 # SmoothQuant 平衡参数 +``` + +### INT8 量化 + +#### 1. RTN (Round-to-Nearest) +**配置文件**: `configs/quantization/video_gen/wan_i2v/rtn_w_a.yaml` + +**特点**: +- 最简单的量化方法 +- 直接四舍五入到最近的量化级别 +- 速度快,精度略低 + +#### 2. AWQ (Activation-aware Weight Quantization) +**配置文件**: `configs/quantization/video_gen/wan_i2v/awq_w_a.yaml` + +**特点**: +- 基于激活分布优化权重量化 +- 保护重要通道,减少精度损失 +- 需要校准数据 + +#### 3. SmoothQuant +**配置文件**: `configs/quantization/video_gen/wan_i2v/smoothquant_w_a.yaml` + +**特点**: +- 平衡权重和激活的量化难度 +- 数学上等价于平滑激活异常值 +- 通常提供最佳精度 + +### LoRA 模型量化 + +支持对 LoRA 适配器模型的量化: +- `smoothquant_w_a_int8_lora.yaml` +- `rtn_w_a_lora.yaml` + +## 运行步骤 + +### 1. 准备环境 + +```bash +# 设置 llmc 路径 +export llmc=/path/to/llmc +export PYTHONPATH=$llmc:$PYTHONPATH + +# 设置 GPU +export CUDA_VISIBLE_DEVICES=0 +``` + +### 2. 准备校准数据 + +为 I2V 模型准备校准数据: +``` +assets/wan_i2v/calib/ +├── image_1.jpg +├── image_2.jpg +└── ... +``` + +为 T2V 模型准备校准数据: +``` +assets/wan_t2v/calib/ +├── prompt_1.txt +├── prompt_2.txt +└── ... +``` + +### 3. 修改配置文件 + +编辑对应的 YAML 配置文件,设置: +- `model.path`: Wan2.1 模型路径 +- `calib.path`: 校准数据路径 +- `save.save_path`: 量化模型保存路径 + +**示例 (FP8 量化)**: +```yaml +base: + seed: 42 +model: + type: WanI2V + path: /path/to/wan2.1-i2v-model # 修改为你的模型路径 + torch_dtype: auto +calib: + name: i2v + download: False + path: /path/to/calibration/data # 修改为校准数据路径 + sample_steps: 40 + bs: 1 + target_height: 480 + target_width: 832 + num_frames: 81 + guidance_scale: 5.0 +save: + save_lightx2v: True + save_path: /path/to/save/quantized/model # 修改为保存路径 +``` + +### 4. 运行量化 + +#### 使用脚本运行 (推荐) + +```bash +# 运行 FP8 量化 (I2V) +./run_llmc.sh wan_i2v_fp8 + +# 运行 INT8 RTN 量化 (I2V) +./run_llmc.sh wan_i2v_int8_rtn + +# 运行 INT8 AWQ 量化 (I2V) +./run_llmc.sh wan_i2v_int8_awq + +# 运行 INT8 SmoothQuant 量化 (I2V) +./run_llmc.sh wan_i2v_int8_smoothquant + +# 运行 T2V 模型量化 +./run_llmc.sh wan_t2v_int8_rtn +./run_llmc.sh wan_t2v_int8_awq +./run_llmc.sh wan_t2v_int8_smoothquant +``` + +#### 直接运行命令 + +```bash +torchrun \ +--nnodes 1 \ +--nproc_per_node 1 \ +--rdzv_id $RANDOM \ +--rdzv_backend c10d \ +--rdzv_endpoint 127.0.0.1:29500 \ +${llmc}/llmc/__main__.py \ +--config configs/quantization/video_gen/wan_i2v/smoothquant_w_a_fp8.yaml \ +--task_id my_quant_task +``` + +### 5. 监控进度 + +```bash +# 查看日志 +tail -f wan_i2v_fp8.log + +# 查看进程 +ps aux | grep __main__.py +``` + +### 6. 停止任务 + +```bash +# 使用保存的 PID 文件 +xargs kill -9 < wan_i2v_fp8.pid +``` + +## 配置参数说明 + +### 模型配置 +- `type`: 模型类型 (`WanI2V` 或 `WanT2V`) +- `path`: 模型权重路径 +- `torch_dtype`: 数据类型 (`auto`, `bfloat16`, `float32`) + +### 校准配置 +- `sample_steps`: 采样步数 (通常 20-40) +- `bs`: 批大小 (通常 1,视频生成显存占用大) +- `target_height`: 目标视频高度 (默认 480) +- `target_width`: 目标视频宽度 (默认 832) +- `num_frames`: 视频帧数 (默认 81) +- `guidance_scale`: CFG 引导强度 (默认 5.0) + +### 量化配置 +- `method`: 量化方法 (`RTN`, `Awq`, `SmoothQuant`) +- `weight.bit`: 权重位宽 (8, e4m3) +- `act.bit`: 激活位宽 (8, e4m3) +- `granularity`: 量化粒度 (`per_channel`, `per_token`) +- `special.alpha`: SmoothQuant 平衡参数 (0.5-1.0) + +## 在 lightx2v 中使用量化模型 + +### 1. 配置 lightx2v + +编辑 `lightx2v/configs/quantization/wan_i2v.json`: +```json +{ + "infer_steps": 40, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "dit_quantized_ckpt": "/path/to/quantized/model", + "dit_quantized": true, + "dit_quant_scheme": "int8-vllm" +} +``` + +对于 FP8 模型,设置 `"dit_quant_scheme": "fp8"`。 + +### 2. 运行推理 + +```bash +python -m lightx2v.infer \ +--model_cls wan2.1 \ +--task i2v \ +--model_path /path/to/original/model \ +--config_json configs/quantization/wan_i2v.json \ +--prompt "Your prompt here" \ +--image_path /path/to/input/image.jpg \ +--save_result_path output.mp4 +``` + +## 性能建议 + +1. **FP8 vs INT8**: + - FP8: 精度更高,适合对质量要求高的场景 + - INT8: 压缩率更高,适合对速度要求高的场景 + +2. **量化方法选择**: + - 快速原型: RTN + - 平衡精度和速度: SmoothQuant + - 最高精度: AWQ + +3. **校准数据**: + - 使用 10-50 个样本 + - 覆盖典型使用场景 + - I2V: 使用多样化图像 + - T2V: 使用多样化文本描述 + +4. **资源需求**: + - GPU: 建议 24GB+ 显存 + - 校准时间: 30分钟 - 2小时 (取决于数据量) + - 存储空间: 量化后模型约原模型 25-50% 大小 + +## 故障排除 + +### 显存不足 +- 减小 `bs` 到 1 +- 减小 `num_frames` +- 减小 `target_height` 和 `target_width` + +### 量化精度损失过大 +- 尝试 SmoothQuant 方法 +- 增加校准数据数量 +- 调整 `alpha` 参数 (0.5-1.0) + +### lightx2v 兼容性问题 +- 确保使用 `save_lightx2v: True` +- 检查 `dit_quant_scheme` 设置 +- 确认量化模型路径正确 + +## 参考 + +- lightx2v 文档: [lightx2v 项目地址] +- llmc 框架: [llmc 项目地址] +- Wan2.1 模型: [模型地址] diff --git a/llmc/__main__.py b/llmc/__main__.py index ec60c1492..abf4911cb 100755 --- a/llmc/__main__.py +++ b/llmc/__main__.py @@ -32,7 +32,7 @@ def main(config): logger.info(f'tokenizer: {model.get_tokenizer()}') eval_list = get_eval_list(model, config) - eval_model(model, None, eval_list, eval_pos='pretrain') + # eval_model(model, None, eval_list, eval_pos='pretrain') blockwise_opts = [] modalities, modality_configs = get_modality(config) @@ -70,7 +70,7 @@ def main(config): blockwise_opts.append(blockwise_opt) dist.barrier() - eval_model(model, blockwise_opts, eval_list, eval_pos='transformed') + # eval_model(model, blockwise_opts, eval_list, eval_pos='transformed') if int(os.environ['RANK']) == 0: if 'save' in config and config.save.get('save_trans', False): blockwise_opt.save_model(save_trans_path) @@ -85,8 +85,8 @@ def main(config): config.save.get('trtllm_cfg'), ) - eval_model(model, blockwise_opts, eval_list, eval_pos='fake_quant') - eval_model(model, blockwise_opts, eval_list, eval_pos='fake_quant_wo_kv') + # eval_model(model, blockwise_opts, eval_list, eval_pos='fake_quant') + # eval_model(model, blockwise_opts, eval_list, eval_pos='fake_quant_wo_kv') if 'save' in config and config.save.get('save_fake', False): deploy_all_modality(blockwise_opts, 'fake_quant') diff --git a/llmc/compression/blockwise_optimization.py b/llmc/compression/blockwise_optimization.py index 72823d1bd..380e8f42c 100644 --- a/llmc/compression/blockwise_optimization.py +++ b/llmc/compression/blockwise_optimization.py @@ -31,11 +31,15 @@ def __init__(self, model, compress_config, input, padding_mask, config): def run_block_loop(self): for i in range(len(self.blocks)): self.block_idx = i + if self.input and hasattr(self.model, 'get_blockwise_input'): + self.input = self.model.get_blockwise_input(self.block_idx, self.input) logger.info( f'\nblock index: {self.block_idx}/{len(self.blocks)} ' f'\nblock: {self.blocks[self.block_idx]}' ) self.block_opt(self.blocks[self.block_idx]) + if self.input and hasattr(self.model, 'set_blockwise_input'): + self.model.set_blockwise_input(self.block_idx, self.input) if hasattr(self, 'save_scale') and self.save_scale: os.makedirs(self.scale_path, exist_ok=True) diff --git a/llmc/compression/quantization/__init__.py b/llmc/compression/quantization/__init__.py index 2c08343e2..07b4f5967 100644 --- a/llmc/compression/quantization/__init__.py +++ b/llmc/compression/quantization/__init__.py @@ -10,7 +10,7 @@ from .ntweak import NormTweaking from .omniq import OmniQuant from .osplus import OsPlus -from .quant import FloatQuantizer, IntegerQuantizer +from .quant import FloatQuantizer, HiFloat4Quantizer, IntegerQuantizer from .quarot import Quarot from .quik import QUIK from .rtn import RTN diff --git a/llmc/compression/quantization/base_blockwise_quantization.py b/llmc/compression/quantization/base_blockwise_quantization.py index 5a2232699..645fd4e17 100755 --- a/llmc/compression/quantization/base_blockwise_quantization.py +++ b/llmc/compression/quantization/base_blockwise_quantization.py @@ -4,6 +4,7 @@ import json import os import re +import shutil from collections import defaultdict from functools import partial @@ -35,7 +36,12 @@ _TRANSFORMERS_LN_TYPES_, EffcientFakeQuantLinear, FakeQuantLinear, LlmcActFn, OriginFloatLinear, RotateLinear) -from .quant import FloatQuantizer, IntegerQuantizer, Weight48IntegerQuantizer +from .quant import ( + FloatQuantizer, + HiFloat4Quantizer, + IntegerQuantizer, + Weight48IntegerQuantizer, +) class BaseBlockwiseQuantization(BlockwiseOpt): @@ -157,6 +163,8 @@ def set_quant_config(self): self.weight_quant_module = IntegerQuantizer elif quant_type == 'float-quant': self.weight_quant_module = FloatQuantizer + elif quant_type == 'hif4': + self.weight_quant_module = HiFloat4Quantizer logger.info(f'The used Weight Quant Module is {self.weight_quant_module}') self.wquantizer = self.weight_quant_module(**self.quant_config['weight']) @@ -175,6 +183,13 @@ def set_quant_config(self): self.act_quant_module = IntegerQuantizer elif quant_type == 'float-quant': self.act_quant_module = FloatQuantizer + elif quant_type == 'hif4': + self.act_quant_module = HiFloat4Quantizer + else: + raise ValueError( + f"Unsupported act quant_type: {quant_type}. " + "Supported: int-quant, float-quant, hif4." + ) self.quant_config['act']['tp'] = self.tp self.aquantizer = self.act_quant_module(**self.quant_config['act']) self.act_static = self.quant_config['act'].get('static', False) @@ -444,9 +459,21 @@ def run(self, block, input_feat, handles): h.remove() torch.cuda.empty_cache() - self.block_transform(block, input_feat, self.input['kwargs']) + if not self._is_ignored_block(self.block_idx): + self.block_transform(block, input_feat, self.input['kwargs']) + else: + logger.info( + f'Block {self.block_idx} is in ignored_block_ids, ' + f'skipping block_transform.' + ) else: - self.block_transform(block) + if not self._is_ignored_block(self.block_idx): + self.block_transform(block) + else: + logger.info( + f'Block {self.block_idx} is in ignored_block_ids, ' + f'skipping block_transform.' + ) if not self.data_free and self.quant_out: self.model.replace_module_block( @@ -907,27 +934,45 @@ def set_non_linear_mode(self, quant_format, module, mode): if getattr(m, 'calib', None) is not None: m.calib = mode + def _get_ignored_block_ids_set(self): + if not hasattr(self, '_ignored_block_ids_set_cache'): + expanded = [] + for item in self.ignored_block_ids: + match = re.match(r'(\d+)-(\d+)', str(item)) + if match: + start, end = int(match.group(1)), int(match.group(2)) + expanded.extend(range(start, end + 1)) + else: + expanded.append(int(item)) + self._ignored_block_ids_set_cache = set(expanded) + return self._ignored_block_ids_set_cache + + def _is_ignored_block(self, block_idx): + if not self.mixed_precision or not self.ignored_block_ids: + return False + return block_idx in self._get_ignored_block_ids_set() + def set_no_quant_layer(self): if self.ignored_speical_names: assert hasattr(self.model, 'block_name_prefix'), \ 'block_name_prefix missing in model' - ignored_block_ids = [] - for item in self.ignored_block_ids: - match = re.match(r'(\d+)-(\d+)', str(item)) - if match: - start, end = int(match.group(1)), int(match.group(2)) - ignored_block_ids.extend(range(start, end + 1)) - else: - ignored_block_ids.append(int(item)) + ignored_block_ids = self._get_ignored_block_ids_set() + # If no layer_names specified, skip all linear layers in the ignored blocks + skip_all_linears = not self.ignored_layer_names for idx, block in enumerate(self.blocks): for n, m in block.named_modules(): - if idx in ignored_block_ids and n in self.ignored_layer_names: - m.register_buffer('no_quant', torch.tensor(True)) - else: - layer_name = f'{self.model.block_name_prefix}.{idx}.{n}' - if layer_name in self.ignored_speical_names: + if idx in ignored_block_ids: + if skip_all_linears: + if isinstance(m, tuple(_LLMC_LINEAR_TYPES_ + _TRANSFORMERS_LINEAR_TYPES_)): + m.register_buffer('no_quant', torch.tensor(True)) + elif n in self.ignored_layer_names: m.register_buffer('no_quant', torch.tensor(True)) + else: + if self.ignored_speical_names: + layer_name = f'{self.model.block_name_prefix}.{idx}.{n}' + if layer_name in self.ignored_speical_names: + m.register_buffer('no_quant', torch.tensor(True)) @torch.no_grad() def deploy(self, quant_format, keep_device=False): @@ -1003,6 +1048,70 @@ def contiguous_params(self): if not param.is_contiguous(): param.data = param.data.contiguous() + if ( + self.config.model.type in ['Wan2T2V'] + and hasattr(self.model.Pipeline, 'transformer_2') + and self.model.Pipeline.transformer_2 is not None + ): + for name, param in self.model.Pipeline.transformer_2.named_parameters(): + if not param.is_contiguous(): + param.data = param.data.contiguous() + for name, param in self.model.Pipeline.transformer_2.named_buffers(): + if not param.is_contiguous(): + param.data = param.data.contiguous() + + def _copy_wan22_native_checkpoint(self, src, dst): + if not isinstance(src, str) or not os.path.isdir(src): + raise RuntimeError( + 'Wan2.2 official save expects a local native checkpoint directory, ' + f'but got src={src!r}.' + ) + if os.path.abspath(src) == os.path.abspath(dst): + raise RuntimeError( + 'Wan2.2 official save path must differ from source checkpoint path ' + f'(src=dst={src}).' + ) + if os.path.exists(dst): + shutil.rmtree(dst) + shutil.copytree(src, dst) + logger.info(f'Copied original Wan2.2 native checkpoint from {src} to {dst}') + + def _validate_wan22_native_save_structure(self, save_path, source_path=None): + if not os.path.isdir(save_path): + raise RuntimeError(f'Wan2.2 saved path is not a directory: {save_path}') + + required_entries = ['configuration.json', 'high_noise_model', 'low_noise_model'] + missing_required = [ + name for name in required_entries + if not os.path.exists(os.path.join(save_path, name)) + ] + if missing_required: + raise RuntimeError( + 'Wan2.2 saved structure is incomplete. Missing required entries: ' + f'{missing_required}. save_path={save_path}' + ) + + if isinstance(source_path, str) and os.path.isdir(source_path): + source_entries = set(os.listdir(source_path)) + source_non_expert_entries = sorted( + name for name in source_entries + if name not in {'high_noise_model', 'low_noise_model'} + ) + missing_non_expert = [ + name for name in source_non_expert_entries + if not os.path.exists(os.path.join(save_path, name)) + ] + if missing_non_expert: + raise RuntimeError( + 'Wan2.2 saved structure lost original non-expert files/directories: ' + f'{missing_non_expert}. source_path={source_path}, save_path={save_path}' + ) + + logger.info( + f'Wan2.2 native save structure verified. ' + f'top-level entries={sorted(os.listdir(save_path))}' + ) + @torch.no_grad() def save_model(self, path): if int(os.environ['RANK']) != 0: @@ -1023,6 +1132,58 @@ def save_model(self, path): self.model.avlm_model.save_pretrained(path) logger.info('save model done --') self.copy_tokenizer(path) + elif self.config.model.type in ['Wan2T2V']: + if getattr(self.model.Pipeline, '_is_wan_official', False): + src = getattr(self.model, 'pipeline_model_path', self.model.model_path) + self._copy_wan22_native_checkpoint(src, path) + + self.model.Pipeline.transformer.save_pretrained( + os.path.join(path, 'high_noise_model') + ) + logger.info('save Wan2.2 high_noise_model done --') + if ( + hasattr(self.model.Pipeline, 'transformer_2') + and self.model.Pipeline.transformer_2 is not None + ): + self.model.Pipeline.transformer_2.save_pretrained( + os.path.join(path, 'low_noise_model') + ) + logger.info('save Wan2.2 low_noise_model done --') + self._validate_wan22_native_save_structure(path, source_path=src) + return + + # Copy the full original pipeline (VAE, text encoder, tokenizer, scheduler, etc.) + # so that non-quantized components are preserved. + src = getattr(self.model, 'pipeline_model_path', self.model.model_path) + copied_from_source = False + if isinstance(src, str) and os.path.isdir(src) and os.path.abspath(src) != os.path.abspath(path): + if os.path.exists(path): + shutil.rmtree(path) + shutil.copytree(src, path) + logger.info(f'Copied original pipeline from {src} to {path}') + copied_from_source = True + if not copied_from_source: + if os.path.exists(path): + shutil.rmtree(path) + # Fallback for remote repo-id sources: materialize all non-quantized components first. + self.model.Pipeline.save_pretrained(path, safe_serialization=True) + logger.info( + 'save Wan2.2 full pipeline done via Pipeline.save_pretrained ' + f'(source={src}) --' + ) + # Overwrite transformer subfolder with quantized weights. + self.model.Pipeline.transformer.save_pretrained( + os.path.join(path, 'transformer') + ) + logger.info('save Wan2.2 transformer done --') + if ( + hasattr(self.model.Pipeline, 'transformer_2') + and self.model.Pipeline.transformer_2 is not None + ): + self.model.Pipeline.transformer_2.save_pretrained( + os.path.join(path, 'transformer_2') + ) + logger.info('save Wan2.2 transformer_2 done --') else: self.model.get_model().save_pretrained(path) logger.info('save model done --') diff --git a/llmc/compression/quantization/quant.py b/llmc/compression/quantization/quant.py index 2c24c03a8..55cd791a1 100755 --- a/llmc/compression/quantization/quant.py +++ b/llmc/compression/quantization/quant.py @@ -1,4 +1,6 @@ import gc +import os +import sys import torch from loguru import logger @@ -1229,6 +1231,102 @@ def __repr__(self): ) +def _get_hif4_quant_cy(): + """Lazy import HiFloat4 quant_cy (QType, quant_dequant_float) from HiFloat4/hif4_gpu.""" + _repo_root = os.path.dirname( + os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + ) + ) + _hif4_gpu = os.path.join(_repo_root, 'HiFloat4', 'hif4_gpu') + if _hif4_gpu not in sys.path: + sys.path.insert(0, _hif4_gpu) + try: + from quant_cy import QType, quant_dequant_float + return QType, quant_dequant_float + except Exception as e: + raise ImportError( + 'HiFloat4 4-bit quantization requires the HiFloat4/hif4_gpu package. ' + 'Ensure HiFloat4 is available at repo_root/HiFloat4/hif4_gpu and built.' + ) from e + + +class HiFloat4Quantizer(BaseQuantizer): + """4-bit HiFloat (hif4) simulation quantizer using HiFloat4 quant_dequant_float. + + Uses the HiFloat4 library's quant_dequant_float for block-wise float 4-bit + quantization. No scales/zeros; quantization is done per block along the last dim. + Only supports fake (simulation) quantization; real weight packing is not implemented. + """ + + def __init__(self, bit=4, symmetric=None, granularity=None, **kwargs): + super().__init__(bit, symmetric, granularity, **kwargs) + self.quant_type = 'hif4' + self.q_dim = kwargs.get('hif4_qdim', -1) + self.force_py = kwargs.get('force_py', False) + self.force_fp32 = kwargs.get('force_fp32', True) + self._QType = None + self._quant_dequant_float = None + + def _ensure_hif4(self): + if self._quant_dequant_float is None: + self._QType, self._quant_dequant_float = _get_hif4_quant_cy() + + def fake_quant_act_static(self, act, args={}): + self._ensure_hif4() + org_dtype = act.dtype + qtype = self._QType('hifx4').dim(self.q_dim) + out = self._quant_dequant_float( + act, qtype, force_py=self.force_py, force_fp32=self.force_fp32 + ) + return out.to(org_dtype) + + def fake_quant_act_dynamic(self, act, args={}): + self._ensure_hif4() + org_dtype = act.dtype + qtype = self._QType('hifx4').dim(self.q_dim) + out = self._quant_dequant_float( + act, qtype, force_py=self.force_py, force_fp32=self.force_fp32 + ) + return out.to(org_dtype) + + def fake_quant_weight_static(self, weight, args): + self._ensure_hif4() + org_dtype = weight.dtype + qtype = self._QType('hifx4').dim(self.q_dim) + out = self._quant_dequant_float( + weight, qtype, force_py=self.force_py, force_fp32=self.force_fp32 + ) + return out.to(org_dtype) + + def fake_quant_weight_dynamic(self, weight, args={}): + self._ensure_hif4() + org_dtype = weight.dtype + qtype = self._QType('hifx4').dim(self.q_dim) + out = self._quant_dequant_float( + weight, qtype, force_py=self.force_py, force_fp32=self.force_fp32 + ) + return out.to(org_dtype) + + def real_quant_weight_static(self, weight, args): + raise NotImplementedError( + 'HiFloat4 quantizer is simulation-only (fake quant). ' + 'real_quant_weight is not supported for hif4.' + ) + + def real_quant_weight_dynamic(self, weight, args={}): + raise NotImplementedError( + 'HiFloat4 quantizer is simulation-only (fake quant). ' + 'real_quant_weight is not supported for hif4.' + ) + + def __repr__(self): + return ( + f'HiFloat4Quantizer(quant_type=hif4, q_dim={self.q_dim}, ' + f'force_py={self.force_py}, force_fp32={self.force_fp32})' + ) + + class Weight48IntegerQuantizer(BaseQuantizer): # flake8: noqa def __init__(self, bit, bit4, bit8, **kwargs): diff --git a/llmc/eval/eval_video_generate.py b/llmc/eval/eval_video_generate.py index 0f99ff6c9..726187c0b 100755 --- a/llmc/eval/eval_video_generate.py +++ b/llmc/eval/eval_video_generate.py @@ -23,6 +23,7 @@ def __init__(self, model, config): self.target_width = self.eval_cfg.get('target_width', 832) self.num_frames = self.eval_cfg.get('num_frames', 81) self.guidance_scale = self.eval_cfg.get('guidance_scale', 5.0) + self.guidance_scale_2 = self.eval_cfg.get('guidance_scale_2', None) self.fps = self.eval_cfg.get('fps', 15) @torch.no_grad() @@ -56,14 +57,17 @@ def t2v_eval(self, model, testenc, bs, eval_pos): assert bs == 1, 'Only support eval bs=1' for i, data in enumerate(testenc): - output = model.Pipeline( - prompt=data['prompt'], - negative_prompt=data['negative_prompt'], - height=self.target_height, - width=self.target_width, - num_frames=self.num_frames, - guidance_scale=self.guidance_scale, - ).frames[0] + pipe_kw = { + 'prompt': data['prompt'], + 'negative_prompt': data['negative_prompt'], + 'height': self.target_height, + 'width': self.target_width, + 'num_frames': self.num_frames, + 'guidance_scale': self.guidance_scale, + } + if self.guidance_scale_2 is not None: + pipe_kw['guidance_scale_2'] = self.guidance_scale_2 + output = model.Pipeline(**pipe_kw).frames[0] export_to_video( output, os.path.join(self.output_video_path, f'{eval_pos}_output_{i}.mp4'), @@ -77,15 +81,18 @@ def i2v_eval(self, model, testenc, bs, eval_pos): for i, data in enumerate(testenc): image, width, height = self.pre_process(model, data['image']) - output = model.Pipeline( - image=image, - prompt=data['prompt'], - negative_prompt=data['negative_prompt'], - height=height, - width=width, - num_frames=self.num_frames, - guidance_scale=self.guidance_scale, - ).frames[0] + pipe_kw = { + 'image': image, + 'prompt': data['prompt'], + 'negative_prompt': data['negative_prompt'], + 'height': height, + 'width': width, + 'num_frames': self.num_frames, + 'guidance_scale': self.guidance_scale, + } + if self.guidance_scale_2 is not None: + pipe_kw['guidance_scale_2'] = self.guidance_scale_2 + output = model.Pipeline(**pipe_kw).frames[0] export_to_video( output, @@ -98,9 +105,9 @@ def i2v_eval(self, model, testenc, bs, eval_pos): @torch.no_grad() def eval_func(self, model, testenc, bs, eval_pos): assert bs == 1, 'Evaluation only supports batch size = 1.' - assert self.model_type in ['WanT2V', 'WanI2V'], ( + assert self.model_type in ['WanT2V', 'WanI2V', 'Wan2T2V'], ( f"Unsupported model type '{self.model_type}'.\n" - 'Only Wan2.1 video generation models (WanT2V, WanI2V) are supported.' + 'Only Wan video generation models (WanT2V, WanI2V, Wan2T2V) are supported.' ) if self.eval_dataset_name == 't2v': return self.t2v_eval(model, testenc, bs, eval_pos) diff --git a/llmc/models/__init__.py b/llmc/models/__init__.py index 83d746254..7351995df 100755 --- a/llmc/models/__init__.py +++ b/llmc/models/__init__.py @@ -37,3 +37,4 @@ from .vit import Vit from .wan_i2v import WanI2V from .wan_t2v import WanT2V +from .wan2_2_t2v import Wan2T2V diff --git a/llmc/models/base_model.py b/llmc/models/base_model.py index 4d7dda2ae..25393a871 100755 --- a/llmc/models/base_model.py +++ b/llmc/models/base_model.py @@ -119,7 +119,7 @@ def has_bias(self): pass def build_tokenizer(self): - if self.model_type not in ['Vit', 'WanT2V', 'WanI2V']: + if self.model_type not in ['Vit', 'WanT2V', 'WanI2V', 'Wan2T2V']: assert self.tokenizer_mode in ['fast', 'slow'] self.tokenizer = AutoTokenizer.from_pretrained( self.model_path, use_fast=self.tokenizer_mode, trust_remote_code=True @@ -129,7 +129,7 @@ def build_tokenizer(self): if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token else: - self.tokenizer = None + self.tokenizer = None def get_tokenizer(self): return self.tokenizer diff --git a/llmc/models/wan2_2_t2v.py b/llmc/models/wan2_2_t2v.py new file mode 100755 index 000000000..e2302b71a --- /dev/null +++ b/llmc/models/wan2_2_t2v.py @@ -0,0 +1,624 @@ +import gc +import copy +import inspect +import os +import sys +from collections import defaultdict +from types import SimpleNamespace + +import torch +import torch.nn as nn +from diffusers import AutoencoderKLWan, WanPipeline +from loguru import logger + +from llmc.compression.quantization.module_utils import LlmcWanTransformerBlock +from llmc.utils.registry_factory import MODEL_REGISTRY + +from .base_model import BaseModel + + +class WanOfficialPipelineAdapter: + """Adapter that exposes Wan-Video/Wan2.2 official t2v runtime as a Pipeline-like interface.""" + + def __init__( + self, + runner, + sample_solver='unipc', + sampling_steps=40, + sample_shift=12.0, + offload_model=True, + ): + self.runner = runner + # Keep the same expert naming semantics as existing LLMC Wan2.2 flow: + # transformer -> high-noise expert, transformer_2 -> low-noise expert. + self.transformer = runner.high_noise_model + self.transformer_2 = runner.low_noise_model + self.sample_solver = sample_solver + self.sampling_steps = sampling_steps + self.sample_shift = sample_shift + self.offload_model = offload_model + self._is_wan_official = True + + @staticmethod + def _tensor_to_frames(video): + if video is None: + return [] + if not torch.is_tensor(video): + return video + + video = video.detach().cpu() + if video.dim() != 4: + raise ValueError(f'Unexpected official Wan video shape: {tuple(video.shape)}') + + # Accept [C, F, H, W] and convert to [F, C, H, W]. + if video.shape[0] in (1, 3): + video = video.permute(1, 0, 2, 3) + + if video.dtype.is_floating_point: + if video.min().item() < 0: + video = (video.clamp(-1, 1) + 1.0) / 2.0 + else: + video = video.clamp(0, 1) + video = (video * 255).round().to(torch.uint8) + elif video.dtype != torch.uint8: + video = video.to(torch.uint8) + + return [frame.permute(1, 2, 0).contiguous().numpy() for frame in video] + + def to(self, device): # noqa: ARG002 + # Keep the same API as diffusers pipeline; official runner manages model movement itself. + return self + + def __call__( + self, + prompt, + negative_prompt='', + height=480, + width=832, + num_frames=81, + guidance_scale=5.0, + guidance_scale_2=None, + **kwargs, + ): + if isinstance(prompt, (list, tuple)): + prompt = prompt[0] + if isinstance(negative_prompt, (list, tuple)): + negative_prompt = negative_prompt[0] + + # Official Wan2.2 guide_scale order: (low_noise, high_noise). + guide_scale_low = guidance_scale if guidance_scale_2 is None else guidance_scale_2 + guide_scale_high = guidance_scale + + sampling_steps = kwargs.get( + 'num_inference_steps', + kwargs.get('sampling_steps', self.sampling_steps) + ) + sample_shift = kwargs.get('sample_shift', self.sample_shift) + sample_solver = kwargs.get('sample_solver', self.sample_solver) + seed = kwargs.get('seed', -1) + offload_model = kwargs.get('offload_model', self.offload_model) + + video = self.runner.generate( + input_prompt=prompt, + size=(width, height), + frame_num=num_frames, + shift=sample_shift, + sample_solver=sample_solver, + sampling_steps=sampling_steps, + guide_scale=(guide_scale_low, guide_scale_high), + n_prompt=negative_prompt if negative_prompt is not None else '', + seed=seed, + offload_model=offload_model, + ) + return SimpleNamespace(frames=[self._tensor_to_frames(video)]) + + +@MODEL_REGISTRY +class Wan2T2V(BaseModel): + """Wan2.2-T2V with MoE: two experts (high-noise + low-noise), same block structure as Wan2.1.""" + + def __init__(self, config, device_map=None, use_cache=False): + super().__init__(config, device_map, use_cache) + if 'calib' in config: + self.calib_bs = config.calib.bs + self.sample_steps = config.calib.sample_steps + self.target_height = config.calib.get('target_height', 480) + self.target_width = config.calib.get('target_width', 832) + self.num_frames = config.calib.get('num_frames', 81) + self.guidance_scale = config.calib.get('guidance_scale', 5.0) + self.guidance_scale_2 = config.calib.get('guidance_scale_2', 3.0) + else: + self.sample_steps = None + + @staticmethod + def _normalize_hf_repo_path(model_path): + hf_prefix = 'https://huggingface.co/' + if not isinstance(model_path, str) or not model_path.startswith(hf_prefix): + return model_path + repo_path = model_path[len(hf_prefix):].strip('/') + for marker in ['/tree/', '/blob/', '/resolve/']: + if marker in repo_path: + repo_path = repo_path.split(marker, maxsplit=1)[0] + return repo_path + + @staticmethod + def _has_diffusers_layout(model_path): + if not isinstance(model_path, str): + return False + return ( + os.path.isdir(model_path) + and os.path.isfile(os.path.join(model_path, 'model_index.json')) + and os.path.isdir(os.path.join(model_path, 'transformer')) + and os.path.isdir(os.path.join(model_path, 'vae')) + ) + + @staticmethod + def _has_wan22_native_layout(model_path): + if not isinstance(model_path, str): + return False + return ( + os.path.isdir(model_path) + and os.path.isfile(os.path.join(model_path, 'configuration.json')) + and os.path.isdir(os.path.join(model_path, 'high_noise_model')) + and os.path.isdir(os.path.join(model_path, 'low_noise_model')) + ) + + @staticmethod + def _is_wan22_native_repo_id(model_path): + if not isinstance(model_path, str): + return False + return model_path.rstrip('/\\') == 'Wan-AI/Wan2.2-T2V-A14B' + + def _should_require_official_backend(self, normalized_model_path): + if self.config.model.get('force_diffusers', False): + return False + if self.config.model.get('diffusers_path', None): + return False + if self.config.model.get('allow_diffusers_fallback', False): + return False + return ( + self._has_wan22_native_layout(normalized_model_path) + or self._is_wan22_native_repo_id(normalized_model_path) + ) + + def _import_official_wan(self): + def _import_impl(): + from wan.configs import t2v_A14B + from wan.text2video import WanT2V as WanOfficialT2V + + return t2v_A14B, WanOfficialT2V + + try: + return _import_impl() + except Exception as e: + repo_path = self.config.model.get('wan2_repo_path', None) + if repo_path and os.path.isdir(repo_path): + if repo_path not in sys.path: + sys.path.insert(0, repo_path) + try: + return _import_impl() + except Exception as e2: + logger.warning( + f'Failed to import official Wan2.2 from wan2_repo_path={repo_path}: {e2}' + ) + logger.warning( + 'Failed to import official Wan2.2 runtime (wan package). ' + 'Diffusers fallback depends on model.allow_diffusers_fallback/model.force_diffusers. ' + f'import_error={e}' + ) + return None, None + + def _try_build_official_wan_pipeline(self): + normalized_model_path = self._normalize_hf_repo_path(self.model_path) + if not self._has_wan22_native_layout(normalized_model_path): + return False + if self.config.model.get('force_diffusers', False): + logger.info('force_diffusers=True, skip official Wan2.2 import backend.') + return False + + t2v_A14B, WanOfficialT2V = self._import_official_wan() + if t2v_A14B is None or WanOfficialT2V is None: + return False + + wan_config = copy.deepcopy(t2v_A14B) + # Keep official defaults unless explicitly overridden by llmc config. + if self.config.model.get('sample_steps', None) is not None: + wan_config.sample_steps = self.config.model.sample_steps + if self.config.model.get('sample_shift', None) is not None: + wan_config.sample_shift = self.config.model.sample_shift + if self.config.model.get('boundary', None) is not None: + wan_config.boundary = self.config.model.boundary + + runner = WanOfficialT2V( + config=wan_config, + checkpoint_dir=normalized_model_path, + device_id=int(os.environ.get('LOCAL_RANK', 0)), + rank=int(os.environ.get('RANK', 0)), + t5_fsdp=False, + dit_fsdp=False, + use_sp=False, + t5_cpu=self.config.model.get('t5_cpu', False), + init_on_cpu=self.config.model.get('init_on_cpu', True), + convert_model_dtype=self.config.model.get('convert_model_dtype', False), + ) + self.Pipeline = WanOfficialPipelineAdapter( + runner=runner, + sample_solver=self.config.model.get('sample_solver', 'unipc'), + sampling_steps=self.config.model.get( + 'sampling_steps', getattr(wan_config, 'sample_steps', 40) + ), + sample_shift=self.config.model.get( + 'sample_shift', getattr(wan_config, 'sample_shift', 12.0) + ), + offload_model=self.config.model.get('offload_model', True), + ) + self.pipeline_model_path = normalized_model_path + self.pipeline_source = 'wan_official' + self.use_official_wan = True + logger.info( + f'Loaded Wan2.2 via official Wan runtime from native checkpoint: {normalized_model_path}' + ) + return True + + def _resolve_pipeline_model_path(self): + explicit_diffusers_path = self.config.model.get('diffusers_path', None) + if explicit_diffusers_path is not None: + resolved_path = self._normalize_hf_repo_path(explicit_diffusers_path) + logger.info(f'Use explicit Wan2.2 diffusers_path: {resolved_path}') + return resolved_path + + raw_model_path = self.model_path + normalized_path = self._normalize_hf_repo_path(raw_model_path) + + if normalized_path != raw_model_path: + logger.info( + f'Normalize Wan2.2 model path from URL to repo id: {normalized_path}' + ) + + if self._has_diffusers_layout(normalized_path): + return normalized_path + + if self._has_wan22_native_layout(normalized_path): + local_diffusers_candidate = normalized_path.rstrip('/\\') + '-Diffusers' + if self._has_diffusers_layout(local_diffusers_candidate): + logger.info( + 'Detected native Wan2.2 checkpoint. ' + f'Use local diffusers directory: {local_diffusers_candidate}' + ) + return local_diffusers_candidate + logger.warning( + 'Detected native Wan2.2 checkpoint layout ' + f'({normalized_path}) but no local diffusers export found. ' + 'Fallback to official diffusers repo: Wan-AI/Wan2.2-T2V-A14B-Diffusers. ' + 'You can set model.diffusers_path to override this behavior.' + ) + return 'Wan-AI/Wan2.2-T2V-A14B-Diffusers' + + if normalized_path.rstrip('/\\').endswith('Wan2.2-T2V-A14B'): + mapped_path = normalized_path.rstrip('/\\') + '-Diffusers' + logger.info( + f'Map Wan2.2 native repo/path to diffusers pipeline source: {mapped_path}' + ) + return mapped_path + + return normalized_path + + def build_model(self): + self.use_official_wan = False + normalized_model_path = self._normalize_hf_repo_path(self.model_path) + require_official_backend = self._should_require_official_backend(normalized_model_path) + + if self._try_build_official_wan_pipeline(): + self.find_llmc_model() + self.find_blocks() + logger.info( + 'Wan2.2 MoE official backend loaded: blocks=%s(+%s)', + len(self.Pipeline.transformer.blocks), + ( + len(self.Pipeline.transformer_2.blocks) + if hasattr(self.Pipeline, 'transformer_2') + and self.Pipeline.transformer_2 is not None + else 0 + ), + ) + logger.info('Model: %s', self.model) + return + + if require_official_backend: + raise RuntimeError( + 'Detected Wan2.2 native source ' + f'({normalized_model_path}) but official Wan runtime is unavailable. ' + 'Please install/prepare official Wan2.2 code (pip install -e /path/to/Wan2.2 ' + 'or set model.wan2_repo_path). ' + 'If you intentionally want Diffusers fallback, set ' + 'model.allow_diffusers_fallback=True or model.force_diffusers=True.' + ) + + self.pipeline_model_path = self._resolve_pipeline_model_path() + vae = AutoencoderKLWan.from_pretrained( + self.pipeline_model_path, + subfolder='vae', + torch_dtype=torch.float32, + use_safetensors=True, + ) + # Wan2.2: one pipeline, two transformer experts (transformer + transformer_2). + # Pipeline switches by SNR; both use WanTransformer3DModel with same block layout as Wan2.1. + self.Pipeline = WanPipeline.from_pretrained( + self.pipeline_model_path, + vae=vae, + torch_dtype=torch.bfloat16, + use_safetensors=True, + ) + self.find_llmc_model() + # Wrap both experts with LlmcWanTransformerBlock (same as Wan2.1 per-block layout). + for block_idx, block in enumerate(self.Pipeline.transformer.blocks): + new_block = LlmcWanTransformerBlock.new(block) + self.Pipeline.transformer.blocks[block_idx] = new_block + if hasattr(self.Pipeline, 'transformer_2') and self.Pipeline.transformer_2 is not None: + for block_idx, block in enumerate(self.Pipeline.transformer_2.blocks): + new_block = LlmcWanTransformerBlock.new(block) + self.Pipeline.transformer_2.blocks[block_idx] = new_block + self.num_transformer_blocks = len(self.Pipeline.transformer.blocks) + self.blocks = list(self.Pipeline.transformer.blocks) + list(self.Pipeline.transformer_2.blocks) + logger.info( + 'Wan2.2 MoE: both experts wrapped (high-noise + low-noise, 80 blocks total).' + ) + else: + self.blocks = list(self.Pipeline.transformer.blocks) + self.num_transformer_blocks = len(self.blocks) + logger.info('Wan2.2: single transformer wrapped (40 blocks).') + logger.info('Model: %s', self.model) + + def find_llmc_model(self): + self.model = self.Pipeline.transformer + + def find_blocks(self): + self.blocks = list(self.Pipeline.transformer.blocks) + self.num_transformer_blocks = len(self.blocks) + if hasattr(self.Pipeline, 'transformer_2') and self.Pipeline.transformer_2 is not None: + self.blocks += list(self.Pipeline.transformer_2.blocks) + + def _expert_name_from_block_idx(self, block_idx): + if block_idx < self.num_transformer_blocks: + return 'transformer' + return 'transformer_2' + + def get_blockwise_input(self, block_idx, fallback_input): + if not hasattr(self, 'blockwise_inputs'): + return fallback_input + return self.blockwise_inputs[self._expert_name_from_block_idx(block_idx)] + + def set_blockwise_input(self, block_idx, block_input): + if not hasattr(self, 'blockwise_inputs'): + return + self.blockwise_inputs[self._expert_name_from_block_idx(block_idx)] = block_input + + def get_catcher(self, first_block_input): + sample_steps = self.sample_steps + + class Catcher(nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + self.signature = inspect.signature(module.forward) + self.step = 0 + + def forward(self, *args, **kwargs): + params = list(self.signature.parameters.keys()) + capture_kwargs = dict(kwargs) + for i, arg in enumerate(args): + if i > 0: + capture_kwargs[params[i]] = arg + first_block_input['data'].append(args[0]) + first_block_input['kwargs'].append(capture_kwargs) + self.step += 1 + if self.step == sample_steps: + raise ValueError + else: + return self.module(*args, **kwargs) + + return Catcher + + @torch.no_grad() + def collect_first_block_input(self, calib_data, padding_mask=None): + first_block_input = { + 'transformer': defaultdict(list), + 'transformer_2': defaultdict(list), + } + sample_steps = self.sample_steps + + class Catcher(nn.Module): + def __init__(self, module, expert_name): + super().__init__() + self.module = module + self.signature = inspect.signature(module.forward) + self.expert_name = expert_name + + def _to_cpu(self, x): + if torch.is_tensor(x): + return x.detach().cpu() + if isinstance(x, tuple): + return tuple(self._to_cpu(t) for t in x) + return x + + def forward(self, *args, **kwargs): + params = list(self.signature.parameters.keys()) + capture_kwargs = dict(kwargs) + for i, arg in enumerate(args): + if i > 0: + capture_kwargs[params[i]] = arg + cur_num = len(first_block_input[self.expert_name]['data']) + if cur_num < sample_steps: + first_block_input[self.expert_name]['data'].append( + args[0].detach().cpu() if torch.is_tensor(args[0]) else args[0] + ) + first_block_input[self.expert_name]['kwargs'].append( + {k: self._to_cpu(v) for k, v in capture_kwargs.items()} + ) + if all(len(first_block_input[name]['data']) >= sample_steps for name in first_block_input): + raise ValueError + return self.module(*args, **kwargs) + + first_block = self.Pipeline.transformer.blocks[0] + self.Pipeline.transformer.blocks[0] = Catcher(first_block, 'transformer') + first_block_2 = None + if hasattr(self.Pipeline, 'transformer_2') and self.Pipeline.transformer_2 is not None: + first_block_2 = self.Pipeline.transformer_2.blocks[0] + self.Pipeline.transformer_2.blocks[0] = Catcher(first_block_2, 'transformer_2') + + self.Pipeline.to('cuda') + for data in calib_data: + try: + pipe_kw = { + 'prompt': data['prompt'], + 'negative_prompt': data['negative_prompt'], + 'height': self.target_height, + 'width': self.target_width, + 'num_frames': self.num_frames, + 'guidance_scale': self.guidance_scale, + } + if hasattr(self, 'guidance_scale_2'): + pipe_kw['guidance_scale_2'] = self.guidance_scale_2 + self.Pipeline(**pipe_kw) + except ValueError: + pass + gc.collect() + torch.cuda.empty_cache() + + self.Pipeline.transformer.blocks[0] = self.Pipeline.transformer.blocks[0].module + if first_block_2 is not None: + self.Pipeline.transformer_2.blocks[0] = self.Pipeline.transformer_2.blocks[0].module + self.Pipeline.to('cpu') + + assert len(first_block_input['transformer']['data']) > 0, 'Catch transformer input data failed.' + if hasattr(self.Pipeline, 'transformer_2') and self.Pipeline.transformer_2 is not None: + assert len(first_block_input['transformer_2']['data']) > 0, \ + 'Catch transformer_2 input data failed.' + + self.blockwise_inputs = first_block_input + self.first_block_input = self.blockwise_inputs['transformer'] + self.n_samples = sum(len(v['data']) for v in self.blockwise_inputs.values()) + logger.info( + 'Retrieved Wan2.2 calibration samples: transformer=%s, transformer_2=%s.', + len(self.blockwise_inputs['transformer']['data']), + len(self.blockwise_inputs['transformer_2']['data']), + ) + + def get_padding_mask(self): + return None + + def has_bias(self): + return True + + def __str__(self): + return '\nWan2.2 MoE Model:\n%s\nTotal params: ~27B (14B active per step)' % ( + str(self.model), + ) + + def get_layernorms_in_block(self, block): + if hasattr(block, 'affine_norm1'): + return { + 'affine_norm1': block.affine_norm1, + 'norm2': block.norm2, + 'affine_norm3': block.affine_norm3, + } + return { + 'norm1': block.norm1, + 'norm3': block.norm3, + 'norm2': block.norm2, + } + + def get_subsets_in_block(self, block): + if not hasattr(block, 'attn1'): + # Official Wan2.2 native block layout: + # self_attn/qkv/o, cross_attn/qkv/o, ffn[0|2], modulation. + return [ + { + 'layers': { + 'self_attn.q': block.self_attn.q, + 'self_attn.k': block.self_attn.k, + 'self_attn.v': block.self_attn.v, + }, + # Official Wan2.2 uses non-affine norm1/norm2 by default. + # Skip trans-based scale folding to avoid invalid ln.weight operations. + 'prev_op': [None], + 'input': ['self_attn.q'], + 'inspect': block.self_attn, + 'has_kwargs': True, + 'do_trans': False, + 'sub_keys': { + 'seq_lens': 'seq_lens', + 'grid_sizes': 'grid_sizes', + 'freqs': 'freqs', + }, + }, + { + 'layers': { + 'cross_attn.q': block.cross_attn.q, + }, + 'prev_op': [None], + 'input': ['cross_attn.q'], + 'inspect': block.cross_attn, + 'has_kwargs': True, + 'do_trans': False, + 'sub_keys': { + 'context': 'context', + 'context_lens': 'context_lens', + }, + }, + { + 'layers': { + 'ffn.0': block.ffn[0], + }, + 'prev_op': [None], + 'input': ['ffn.0'], + 'inspect': block.ffn, + 'has_kwargs': False, + 'do_trans': False, + }, + ] + return [ + { + 'layers': { + 'attn1.to_q': block.attn1.to_q, + 'attn1.to_k': block.attn1.to_k, + 'attn1.to_v': block.attn1.to_v, + }, + 'prev_op': [block.affine_norm1], + 'input': ['attn1.to_q'], + 'inspect': block.attn1, + 'has_kwargs': True, + 'sub_keys': {'rotary_emb': 'rotary_emb'}, + }, + { + 'layers': { + 'attn2.to_q': block.attn2.to_q, + }, + 'prev_op': [block.norm2], + 'input': ['attn2.to_q'], + 'inspect': block.attn2, + 'has_kwargs': True, + 'sub_keys': {'encoder_hidden_states': 'encoder_hidden_states'}, + }, + { + 'layers': { + 'ffn.net.0.proj': block.ffn.net[0].proj, + }, + 'prev_op': [block.affine_norm3], + 'input': ['ffn.net.0.proj'], + 'inspect': block.ffn, + 'has_kwargs': True, + }, + ] + + def find_embed_layers(self): + pass + + def get_embed_layers(self): + pass + + def get_layers_except_blocks(self): + pass + + def skip_layer_name(self): + pass diff --git a/llmc/models/wan_t2v.py b/llmc/models/wan_t2v.py index 885bccda3..8cbe112fc 100755 --- a/llmc/models/wan_t2v.py +++ b/llmc/models/wan_t2v.py @@ -31,10 +31,13 @@ def __init__(self, config, device_map=None, use_cache=False): def build_model(self): vae = AutoencoderKLWan.from_pretrained( - self.model_path, subfolder='vae', torch_dtype=torch.float32 + self.model_path, subfolder='vae', torch_dtype=torch.float32, use_safetensors=True ) + # self.Pipeline = WanPipeline.from_pretrained( + # self.model_path, vae=vae, torch_dtype=torch.bfloat16 + # ) self.Pipeline = WanPipeline.from_pretrained( - self.model_path, vae=vae, torch_dtype=torch.bfloat16 + self.model_path, vae=vae, torch_dtype=torch.bfloat16, use_safetensors=True ) self.find_llmc_model() self.find_blocks() @@ -61,16 +64,17 @@ def __init__(self, module): def forward(self, *args, **kwargs): params = list(self.signature.parameters.keys()) + capture_kwargs = dict(kwargs) for i, arg in enumerate(args): if i > 0: - kwargs[params[i]] = arg + capture_kwargs[params[i]] = arg first_block_input['data'].append(args[0]) - first_block_input['kwargs'].append(kwargs) + first_block_input['kwargs'].append(capture_kwargs) self.step += 1 if self.step == sample_steps: raise ValueError else: - return self.module(*args) + return self.module(*args, **kwargs) return Catcher diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 5869fa8d0..8fd082be7 100755 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -6,6 +6,7 @@ loguru transformers>=4.45.2 lmms-eval==0.3.0 huggingface-hub +safetensors sentencepiece protobuf accelerate>=0.26.0 diff --git a/scripts/run_llmc.sh b/scripts/run_llmc.sh index d90877f69..24d92689e 100755 --- a/scripts/run_llmc.sh +++ b/scripts/run_llmc.sh @@ -1,17 +1,26 @@ -#!/bin/bash - -# export CUDA_VISIBLE_DEVICES=0,1 - -llmc=/path/to/llmc +export PATH=/mnt/lm_data_afs/wangzining/charles/miniconda3/envs/llmc/bin:$PATH +export PYTHON=/mnt/lm_data_afs/wangzining/charles/miniconda3/envs/llmc/bin/python +export PIP=/mnt/lm_data_afs/wangzining/charles/miniconda3/envs/llmc/bin/pip +export HF_ENDPOINT=https://hf-mirror.com +cd /mnt/lm_data_afs/wangzining/charles/lab/llmc +# hif4 kernel +cd HiFloat4/hif4_gpu/ +bash build.sh +cd - + +# model_name=wan_t2v +model_name=wan2_2_t2v +task_name=awq_w_a_skip_first +# task_name=awq_w_a_s +log_name=${model_name}_${task_name} +rm -rf ./save_for_fake/${model_name}/awq_w_a/skip_first + +llmc=. export PYTHONPATH=$llmc:$PYTHONPATH - -task_name=awq_w_only -config=${llmc}/configs/quantization/methods/Awq/awq_w_only.yml - +config=${llmc}/configs/quantization/video_gen/${model_name}/${task_name}.yaml nnodes=1 nproc_per_node=1 - find_unused_port() { while true; do port=$(shuf -i 10000-60000 -n 1) @@ -22,25 +31,14 @@ find_unused_port() { done } UNUSED_PORT=$(find_unused_port) - - MASTER_ADDR=127.0.0.1 MASTER_PORT=$UNUSED_PORT task_id=$UNUSED_PORT -nohup \ torchrun \ --nnodes $nnodes \ --nproc_per_node $nproc_per_node \ --rdzv_id $task_id \ --rdzv_backend c10d \ --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \ -${llmc}/llmc/__main__.py --config $config --task_id $task_id \ -> ${task_name}.log 2>&1 & - -sleep 2 -ps aux | grep '__main__.py' | grep $task_id | awk '{print $2}' > ${task_name}.pid - -# You can kill this program by -# xargs kill -9 < xxx.pid -# xxx.pid is ${task_name}.pid file \ No newline at end of file +${llmc}/llmc/__main__.py --config $config --task_id $task_id |tee ${log_name}.log \ No newline at end of file diff --git a/tools/print_state_dict_hf.py b/tools/print_state_dict_hf.py new file mode 100644 index 000000000..449aac32c --- /dev/null +++ b/tools/print_state_dict_hf.py @@ -0,0 +1,119 @@ +import argparse +import json +import os +from collections import defaultdict +from importlib.metadata import version + +from huggingface_hub import snapshot_download +from safetensors import safe_open + + +def _find_index_file(model_dir: str) -> str: + candidates = [ + "diffusion_pytorch_model.safetensors.index.json", + "model.safetensors.index.json", + "pytorch_model.bin.index.json", + ] + for name in candidates: + p = os.path.join(model_dir, name) + if os.path.isfile(p): + return p + raise FileNotFoundError( + f"Cannot find an index json in {model_dir}. Tried: {', '.join(candidates)}" + ) + + +def _iter_safetensors_index(index_path: str): + with open(index_path, "r", encoding="utf-8") as f: + index = json.load(f) + + if "weight_map" not in index: + raise ValueError(f"Index file missing 'weight_map': {index_path}") + + weight_map = index["weight_map"] + shard_to_keys = defaultdict(list) + for k, shard_rel in weight_map.items(): + shard_to_keys[shard_rel].append(k) + + for shard_rel, keys in shard_to_keys.items(): + yield shard_rel, keys + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--repo", + type=str, + default="charles2530/Wan2.2-T2V-A14B-Diffusion-AWQ-INT4", + help="Hugging Face repo id, e.g. charles2530/Wan2.2-T2V-A14B-Diffusion-AWQ-INT4", + ) + parser.add_argument( + "--local_dir", + type=str, + default=None, + help="If provided, read model files from this local directory instead of downloading.", + ) + parser.add_argument( + "--revision", + type=str, + default="main", + help="HF revision (branch/tag/commit). Default: main", + ) + parser.add_argument( + "--download", + action="store_true", + help="Force download snapshot (ignored if --local_dir is set).", + ) + parser.add_argument( + "--max_keys", + type=int, + default=200, + help="Max number of parameter keys to print (across all shards). Default: 200", + ) + parser.add_argument( + "--print_values", + action="store_true", + help="Also print tensor repr (VERY large output). Default: off", + ) + args = parser.parse_args() + + print(f"huggingface-hub : {version('huggingface-hub')}") + print(f"safetensors : {version('safetensors')}") + + if args.local_dir is not None: + model_dir = args.local_dir + else: + model_dir = snapshot_download( + repo_id=args.repo, + revision=args.revision, + local_files_only=not args.download, + ) + + index_path = _find_index_file(model_dir) + print(f"model_dir : {model_dir}") + print(f"index : {index_path}") + + printed = 0 + for shard_rel, keys in _iter_safetensors_index(index_path): + shard_path = os.path.join(model_dir, shard_rel) + if not os.path.isfile(shard_path): + raise FileNotFoundError( + f"Shard not found: {shard_path}\n" + "Tip: re-run with --download to fetch all shards." + ) + + with safe_open(shard_path, framework="pt", device="cpu") as f: + for k in keys: + t = f.get_tensor(k) + print(f"{k} shape={tuple(t.shape)} dtype={t.dtype}") + if args.print_values: + print(t) + printed += 1 + if args.max_keys is not None and printed >= args.max_keys: + print(f"Reached --max_keys={args.max_keys}, stopping.") + return + + +if __name__ == "__main__": + main() +