diff --git a/src/twinkle/model/transformers/strategy/accelerate.py b/src/twinkle/model/transformers/strategy/accelerate.py index 18629b11..65079cc7 100644 --- a/src/twinkle/model/transformers/strategy/accelerate.py +++ b/src/twinkle/model/transformers/strategy/accelerate.py @@ -21,13 +21,14 @@ def __init__( mixed_precision: Literal['no', 'fp8', 'fp16', 'bf16'] = 'bf16', ddp_config: Dict[str, Any] = None, fsdp_config: Dict[str, Any] = None, + memory_efficient: bool = True, ): from accelerate import Accelerator self.device_mesh = device_mesh self.mixed_precision = mixed_precision parallelism_config = self._parallelism_config_from_device_mesh(device_mesh) - fsdp_plugin = self._fsdp_config_from_device_mesh(device_mesh, fsdp_config) + fsdp_plugin = self._fsdp_config_from_device_mesh(device_mesh, fsdp_config, memory_efficient) kwargs_handlers = [] if ddp_config is not None: @@ -69,7 +70,8 @@ def _parallelism_config_from_device_mesh(device_mesh: DeviceMesh): return parallelism_config - def _fsdp_config_from_device_mesh(self, device_mesh: DeviceMesh, fsdp_config: Dict[str, Any]): + def _fsdp_config_from_device_mesh(self, device_mesh: DeviceMesh, fsdp_config: Dict[str, Any], + memory_efficient: bool): from accelerate import FullyShardedDataParallelPlugin from torch.distributed.fsdp import BackwardPrefetch from torch.distributed.fsdp import ShardingStrategy as FSDPShardingStrategy @@ -107,11 +109,9 @@ def _fsdp_config_from_device_mesh(self, device_mesh: DeviceMesh, fsdp_config: Di activation_checkpointing=fsdp_config.pop('activation_checkpointing', False), auto_wrap_policy=fsdp_config.pop('auto_wrap_policy', 'transformer_based_wrap'), # noqa reshard_after_forward=fsdp_config.pop('reshard_after_forward', True), + cpu_ram_efficient_loading=fsdp_config.pop('cpu_ram_efficient_loading', memory_efficient), **fsdp_config, ) - # Enable memory efficient model loading in transformers(see `is_fsdp_enabled` in transformers) - # os.environ['ACCELERATE_USE_FSDP'] = '1' - # os.environ['FSDP_CPU_RAM_EFFICIENT_LOADING'] = '1' return fsdp_plugin def wrap_model(self, model, *args): diff --git a/src/twinkle/model/transformers/strategy/native_fsdp.py b/src/twinkle/model/transformers/strategy/native_fsdp.py index ce938eef..b35a7a94 100644 --- a/src/twinkle/model/transformers/strategy/native_fsdp.py +++ b/src/twinkle/model/transformers/strategy/native_fsdp.py @@ -42,12 +42,27 @@ def _build_ep_fsdp_device_mesh(self, ep_size: Optional[int] = None) -> Optional[ ) return ep_mesh.to_torch_device_mesh() - def wrap_model(self, model, optimizer=None): + def wrap_model(self, model, optimizer=None, memory_efficient=True): if self.device_mesh is None: return model, optimizer fsdp_mesh = _build_fsdp_mesh(self.device_mesh) if fsdp_mesh is not None: ep_enabled = (self.enable_ep and self.ep_fsdp_device_mesh is not None) + + # EP path requires experts on a real device, incompatible with meta-device flow. + use_meta = memory_efficient and not ep_enabled + + original_sd = None + saved_buffers = None + if use_meta: + original_sd = model.state_dict() + saved_buffers = _get_non_persistent_buffers(model) + if optimizer is not None: + _unbind_optimizer_params(optimizer) + model = model.to(torch.device('meta')) + if hasattr(model, 'tie_weights'): + model.tie_weights() + if ep_enabled: _ensure_moe_patched_if_needed(model, self.ep_fsdp_device_mesh) _place_ep_experts_on_local_device(model, self.ep_fsdp_device_mesh) @@ -57,11 +72,9 @@ def wrap_model(self, model, optimizer=None): if ep_enabled: _ensure_ep_fsdp_supported(model) - # Collect experts map and expert params experts_map = _collect_ep_experts_map(model) if ep_enabled else {} expert_params = _collect_expert_params(model) if self.enable_ep else None - # Build layer_pairs: [(layer_mod, experts_mod_or_None)] layers = _get_decoder_layers(model) layer_pairs = [] if layers is not None: @@ -69,7 +82,6 @@ def wrap_model(self, model, optimizer=None): experts_mod = _find_experts_in_layer(layer_mod, experts_map) layer_pairs.append((layer_mod, experts_mod)) - # FSDP2 wrapping per layer world_size = self.device_mesh.world_size ep_fsdp_mesh_1d = self.ep_fsdp_device_mesh['ep_fsdp'] if ep_enabled else None @@ -79,9 +91,6 @@ def wrap_model(self, model, optimizer=None): if experts_mod is not None and ep_fsdp_mesh_1d is not None: from torch.distributed.tensor import Shard - # PreMulSum (used by set_gradient_divide_factor) only supports - # float16/float32/float64; override reduce_dtype to float32 - # when the base policy uses bfloat16. ep_mp_policy = _build_ep_mp_policy(mp_policy) fully_shard( experts_mod, @@ -90,7 +99,6 @@ def wrap_model(self, model, optimizer=None): mp_policy=ep_mp_policy, shard_placement_fn=lambda param: Shard(1), ) - # gradient_divide_factor = world_size experts_mod.set_gradient_divide_factor(world_size) layer_mod._fsdp_modules.append(experts_mod) @@ -103,7 +111,6 @@ def wrap_model(self, model, optimizer=None): ) layer_mod._fsdp_modules.append(layer_mod) - # Root model fully_shard( model, mesh=fsdp_mesh, @@ -112,11 +119,22 @@ def wrap_model(self, model, optimizer=None): ignored_params=expert_params, ) - # Manual prefetch + if use_meta: + device_type = self.device_mesh.device_type or 'cuda' + is_rank0 = (dist.get_rank() == 0) + _broadcast_sharded_state_dict( + model, + original_sd if is_rank0 else {}, + device_type=device_type, + ) + target_device = torch.device(device_type) + _restore_non_persistent_buffers(model, saved_buffers, device=target_device) + if hasattr(model, 'tie_weights'): + model.tie_weights() + if ep_enabled and layer_pairs: _setup_manual_prefetch([lp[0] for lp in layer_pairs]) - # Rebuild groups after wrapping so grad clip sees the live Parameter objects. if ep_enabled: _rebuild_ep_param_groups(model) @@ -398,3 +416,74 @@ def _rebind_optimizer(optimizer: torch.optim.Optimizer, model: nn.Module) -> tor return optimizer optimizer.param_groups[0]['params'] = list(model.parameters()) return optimizer + + +def _broadcast_sharded_state_dict( + model: nn.Module, + full_sd: dict, + device_type: str = 'cuda', +) -> None: + """Broadcast full state dict from rank 0 and materialise local shards via distribute_tensor.""" + from torch.distributed.tensor import DTensor, distribute_tensor + + meta_sharded_sd = model.state_dict() + sharded_sd = {} + is_rank0 = (dist.get_rank() == 0) + + for param_name, sharded_param in meta_sharded_sd.items(): + shape = sharded_param.size() + dtype = sharded_param.dtype + + if is_rank0: + full_param = full_sd[param_name] + full_tensor = full_param.detach().to(device_type) + if isinstance(full_tensor, DTensor): + full_tensor = full_tensor.to_local() + else: + full_tensor = torch.empty(shape, device=device_type, dtype=dtype) + + dist.broadcast(full_tensor, src=0) + torch_util.synchronize() + + device_mesh = sharded_param.device_mesh + placements = sharded_param.placements + sharded_tensor = distribute_tensor(full_tensor, device_mesh, placements) + + sharded_sd[param_name] = sharded_tensor + + model.load_state_dict(sharded_sd, assign=True) + + +def _get_non_persistent_buffers(model: nn.Module) -> Dict[str, torch.Tensor]: + """Return {fqn: tensor} for non-persistent buffers (lost on to('meta')).""" + non_persistent_fqns: Set[str] = set() + for fqn, module in model.named_modules(): + for buf_name in getattr(module, '_non_persistent_buffers_set', set()): + full_fqn = f'{fqn}.{buf_name}' if fqn else buf_name + non_persistent_fqns.add(full_fqn) + + return {k: v.clone() for k, v in model.named_buffers() if k in non_persistent_fqns} + + +def _unbind_optimizer_params(optimizer: torch.optim.Optimizer) -> None: + """Drop optimizer param refs so model.to('meta') can free memory.""" + for group in optimizer.param_groups: + for i in range(len(group['params'])): + group['params'][i] = torch.empty(1) + + +def _restore_non_persistent_buffers( + model: nn.Module, + saved_buffers: Dict[str, torch.Tensor], + device: torch.device, +) -> None: + """Re-register non-persistent buffers saved before to('meta').""" + for fqn, buf_tensor in saved_buffers.items(): + buf_tensor = buf_tensor.to(device) + if '.' in fqn: + parent_fqn, local_name = fqn.rsplit('.', 1) + parent = model.get_submodule(parent_fqn) + else: + local_name = fqn + parent = model + parent.register_buffer(local_name, buf_tensor, persistent=False) diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index a35d0211..65c3c1ba 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -183,6 +183,7 @@ def __init__( ddp_config: Dict[str, Any] = None, fsdp_config: Dict[str, Any] = None, grad_scaler_config: Dict[str, Any] = None, + memory_efficient_init: bool = True, **kwargs): os.environ['TOKENIZERS_PARALLELISM'] = 'true' self._try_init_process_group() @@ -195,6 +196,7 @@ def __init__( self.mixed_precision = mixed_precision self._fsdp_config = dict(fsdp_config or {}) self._ddp_config = ddp_config or {} + self._memory_efficient_init = memory_efficient_init self._decide_strategy(strategy) self.grad_scaler_config = grad_scaler_config if isinstance(model_cls, str): @@ -203,8 +205,23 @@ def __init__( self.model = model_cls.from_config(config, **kwargs) else: model_id = HubOperation.download_model(model_id) - self.model = model_cls.from_pretrained(model_id, config=config, **kwargs) - # Construct sequence-parallel strategy lazily during wrapping to reduce init-time side effects. + # Trigger transformers' FSDP-aware loading: meta-device init + rank-0-only weight load. + use_efficient_loading = (memory_efficient_init and self.device_mesh is not None) + _saved_env = {} + if use_efficient_loading: + _saved_env['ACCELERATE_USE_FSDP'] = os.environ.get('ACCELERATE_USE_FSDP') + _saved_env['FSDP_CPU_RAM_EFFICIENT_LOADING'] = os.environ.get('FSDP_CPU_RAM_EFFICIENT_LOADING') + os.environ['ACCELERATE_USE_FSDP'] = 'true' + os.environ['FSDP_CPU_RAM_EFFICIENT_LOADING'] = 'true' + try: + self.model = model_cls.from_pretrained(model_id, config=config, **kwargs) + finally: + if use_efficient_loading: + for key, old_val in _saved_env.items(): + if old_val is None: + os.environ.pop(key, None) + else: + os.environ[key] = old_val self.model.gradient_checkpointing_enable() self.sp_strategy = None self._model_wrapped = False @@ -237,7 +254,8 @@ def _decide_strategy(self, strategy: Literal['accelerate', 'native_fsdp']): mixed_precision=self.mixed_precision, ddp_config=self._ddp_config, fsdp_config=self._fsdp_config, - device_mesh=self.device_mesh) + device_mesh=self.device_mesh, + memory_efficient=self._memory_efficient_init) # Sequence parallel ("ulysses") is derived from dp/fsdp ranks; it does not change world size. # We construct `sp_strategy` after the underlying HF model is initialized (see __init__). @@ -284,16 +302,25 @@ def _lazy_wrap_model(self): self._ensure_sp_strategy() if self.sp_strategy is not None: self.sp_strategy.initialize() + + extra_kwargs = {} + if isinstance(self.strategy, NativeFSDPStrategy): + extra_kwargs['memory_efficient'] = getattr(self, '_memory_efficient_init', True) + if len(optimizer_groups) == 1: optimizer_group = optimizer_groups[0] optimizer = optimizer_group.optimizer assert optimizer is not None - self.model, optimizer = self.strategy.wrap_model(self.model, optimizer) + self.model, optimizer = self.strategy.wrap_model(self.model, optimizer, **extra_kwargs) optimizer_group.optimizer = optimizer self.register_mm_forward_hook(optimizer_group) else: # maybe forward_only, no optimizer_group available - self.model = self.strategy.wrap_model(self.model) + result = self.strategy.wrap_model(self.model, **extra_kwargs) + if isinstance(result, tuple): + self.model = result[0] + else: + self.model = result self._model_wrapped = True def register_mm_forward_hook(self, optimizer_group: OptimizerGroup): diff --git a/tests/strategy/__init__.py b/tests/strategy/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/strategy/test_fsdp2_memory_efficient_init.py b/tests/strategy/test_fsdp2_memory_efficient_init.py new file mode 100644 index 00000000..6fe0ca51 --- /dev/null +++ b/tests/strategy/test_fsdp2_memory_efficient_init.py @@ -0,0 +1,599 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import numpy as np +import os +import socket +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn as nn +import unittest +from torch.distributed.fsdp import fully_shard + +from twinkle.utils import Platform + +_PLATFORM = Platform.get_platform() +_DEVICE_TYPE = _PLATFORM.device_prefix() # 'cuda' or 'npu' +_DIST_BACKEND = _PLATFORM.device_backend() # 'nccl' or 'hccl' + + +def _device_count() -> int: + if _DEVICE_TYPE == 'npu': + import torch_npu # noqa: F401 + return torch.npu.device_count() + return torch.cuda.device_count() + + +def _find_free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(('127.0.0.1', 0)) + return sock.getsockname()[1] + + +def _init_dist(rank, world_size, port): + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = str(port) + os.environ['RANK'] = str(rank) + os.environ['WORLD_SIZE'] = str(world_size) + if _DEVICE_TYPE == 'npu': + from twinkle.utils.platforms.npu import ensure_hccl_socket_env + ensure_hccl_socket_env(port) + dist.init_process_group(_DIST_BACKEND, rank=rank, world_size=world_size) + device = torch.device(_PLATFORM.get_local_device(rank)) + torch.device(device) # set current device + if _DEVICE_TYPE == 'npu': + import torch_npu # noqa: F401 + torch.npu.set_device(device) + else: + torch.cuda.set_device(rank) + + +class TinyModel(nn.Module): + """2-layer MLP for testing. Small enough to fit on any GPU.""" + + def __init__(self, dim=32): + super().__init__() + self.layer1 = nn.Linear(dim, dim, bias=False) + self.layer2 = nn.Linear(dim, dim, bias=False) + + def forward(self, x): + return self.layer2(self.layer1(x)) + + +class TinyDecoderLayer(nn.Module): + """Minimal decoder layer for per-layer sharding tests.""" + + def __init__(self, dim=32): + super().__init__() + self.fc1 = nn.Linear(dim, dim, bias=False) + self.fc2 = nn.Linear(dim, dim, bias=False) + + def forward(self, x): + return self.fc2(self.fc1(x)) + + +class TinyTransformerModel(nn.Module): + """Model with model.model.layers structure matching _get_decoder_layers expectations.""" + + def __init__(self, dim=32, num_layers=2): + super().__init__() + self.model = nn.Module() + self.model.layers = nn.ModuleList([TinyDecoderLayer(dim) for _ in range(num_layers)]) + self.lm_head = nn.Linear(dim, dim, bias=False) + + def forward(self, x): + for layer in self.model.layers: + x = layer(x) + return self.lm_head(x) + + +def _worker_broadcast_sharded(rank, world_size, port, ref_sd): + """Worker function: shard on meta, broadcast, verify values.""" + _init_dist(rank, world_size, port) + from twinkle.model.transformers.strategy.native_fsdp import _broadcast_sharded_state_dict + + model = TinyModel(dim=32) + + # Only rank 0 has the real weights; others get empty shell + if rank == 0: + full_sd = ref_sd # passed via mp args (shared memory) + else: + full_sd = {} + + # Save original full state dict for later verification + original_sd = ref_sd + + # Move to meta and shard + model = model.to('meta') + fully_shard(model.layer1) + fully_shard(model.layer2) + fully_shard(model) + + # Broadcast + _broadcast_sharded_state_dict(model, full_sd, device_type=_DEVICE_TYPE) + + # Verify: gather full state dict back and compare to original + from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict + gathered = get_model_state_dict( + model, + options=StateDictOptions(full_state_dict=True, cpu_offload=True), + ) + if rank == 0: + for key in original_sd: + assert torch.allclose(gathered[key], original_sd[key], atol=1e-6), \ + f"Mismatch on {key} after broadcast" + + dist.destroy_process_group() + + +@unittest.skipIf(_device_count() < 2, f'Need >= 2 {_DEVICE_TYPE.upper()}s') +class TestBroadcastShardedStateDict(unittest.TestCase): + + def test_broadcast_restores_weights(self): + port = _find_free_port() + world_size = 2 + # Create reference weights on CPU + ref_model = TinyModel(dim=32) + ref_sd = {k: v.clone() for k, v in ref_model.state_dict().items()} + mp.spawn( + _worker_broadcast_sharded, + args=(world_size, port, ref_sd), + nprocs=world_size, + join=True, + ) + + +# --------------------------------------------------------------------------- +# Task 2: _get_non_persistent_buffers +# --------------------------------------------------------------------------- + + +class ModelWithNonPersistentBuffer(nn.Module): + + def __init__(self, dim=32): + super().__init__() + self.linear = nn.Linear(dim, dim) + # Non-persistent buffer — will NOT appear in state_dict() + self.register_buffer('mask', torch.ones(dim), persistent=False) + + +class TestGetNonPersistentBuffers(unittest.TestCase): + + def test_finds_non_persistent_buffers(self): + from twinkle.model.transformers.strategy.native_fsdp import _get_non_persistent_buffers + model = ModelWithNonPersistentBuffer() + result = _get_non_persistent_buffers(model) + assert 'mask' in result + assert torch.equal(result['mask'], torch.ones(32)) + # Persistent params/buffers should NOT be in the result + assert 'linear.weight' not in result + + def test_empty_when_no_non_persistent(self): + from twinkle.model.transformers.strategy.native_fsdp import _get_non_persistent_buffers + model = TinyModel() + result = _get_non_persistent_buffers(model) + assert len(result) == 0 + + +# --------------------------------------------------------------------------- +# Task 3: _restore_non_persistent_buffers +# --------------------------------------------------------------------------- + + +class TestRestoreNonPersistentBuffers(unittest.TestCase): + + def test_restores_buffers_after_meta(self): + from twinkle.model.transformers.strategy.native_fsdp import (_get_non_persistent_buffers, + _restore_non_persistent_buffers) + model = ModelWithNonPersistentBuffer() + saved = _get_non_persistent_buffers(model) + # Move to meta — buffer becomes meta tensor + model = model.to('meta') + assert model.mask.device.type == 'meta' + # Restore + _restore_non_persistent_buffers(model, saved, device=torch.device('cpu')) + assert model.mask.device.type == 'cpu' + assert torch.equal(model.mask, torch.ones(32)) + + +# --------------------------------------------------------------------------- +# Task 4: wrap_model with memory_efficient=True +# --------------------------------------------------------------------------- + + +def _worker_wrap_model_memory_efficient(rank, world_size, port, ref_sd): + """Test that wrap_model with memory_efficient=True produces correct sharded model.""" + _init_dist(rank, world_size, port) + from twinkle.model.transformers.strategy.native_fsdp import NativeFSDPStrategy + from twinkle.utils import DeviceMesh as TwinkleMesh + + mesh = TwinkleMesh( + mesh=np.arange(world_size), + mesh_dim_names=('fsdp', ), + device_type=_DEVICE_TYPE, + ) + strategy = NativeFSDPStrategy(device_mesh=mesh, mixed_precision='no') + + model = TinyModel(dim=32).to(_DEVICE_TYPE) + if rank == 0: + model.load_state_dict(ref_sd) + + model, _ = strategy.wrap_model(model, optimizer=None, memory_efficient=True) + + # Verify: model should be on device, not meta + for name, param in model.named_parameters(): + assert param.device.type == _DEVICE_TYPE, f"{name} still on {param.device}" + + # Verify: gathered full state dict matches original + from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict + gathered = get_model_state_dict( + model, + options=StateDictOptions(full_state_dict=True, cpu_offload=True), + ) + if rank == 0: + for key in ref_sd: + assert torch.allclose(gathered[key], ref_sd[key], atol=1e-6), \ + f"Mismatch on {key}" + + dist.destroy_process_group() + + +@unittest.skipIf(_device_count() < 2, f'Need >= 2 {_DEVICE_TYPE.upper()}s') +class TestWrapModelMemoryEfficient(unittest.TestCase): + + def test_wrap_model_memory_efficient(self): + port = _find_free_port() + world_size = 2 + ref_model = TinyModel(dim=32) + ref_sd = {k: v.clone() for k, v in ref_model.state_dict().items()} + mp.spawn( + _worker_wrap_model_memory_efficient, + args=(world_size, port, ref_sd), + nprocs=world_size, + join=True, + ) + + +# --------------------------------------------------------------------------- +# Task 5: wrap_model with memory_efficient=False (legacy path) +# --------------------------------------------------------------------------- + + +def _worker_wrap_model_legacy(rank, world_size, port, ref_sd): + """Test that wrap_model with memory_efficient=False still works (old path).""" + _init_dist(rank, world_size, port) + from twinkle.model.transformers.strategy.native_fsdp import NativeFSDPStrategy + from twinkle.utils import DeviceMesh as TwinkleMesh + + mesh = TwinkleMesh( + mesh=np.arange(world_size), + mesh_dim_names=('fsdp', ), + device_type=_DEVICE_TYPE, + ) + strategy = NativeFSDPStrategy(device_mesh=mesh, mixed_precision='no') + + model = TinyModel(dim=32).to(_DEVICE_TYPE) + model.load_state_dict(ref_sd) + + model, _ = strategy.wrap_model(model, optimizer=None, memory_efficient=False) + + for name, param in model.named_parameters(): + assert param.device.type == _DEVICE_TYPE, f"{name} still on {param.device}" + + from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict + gathered = get_model_state_dict( + model, + options=StateDictOptions(full_state_dict=True, cpu_offload=True), + ) + if rank == 0: + for key in ref_sd: + assert torch.allclose(gathered[key], ref_sd[key], atol=1e-6), \ + f"Mismatch on {key}" + + dist.destroy_process_group() + + +@unittest.skipIf(_device_count() < 2, f'Need >= 2 {_DEVICE_TYPE.upper()}s') +class TestWrapModelLegacy(unittest.TestCase): + + def test_wrap_model_legacy_path(self): + port = _find_free_port() + world_size = 2 + ref_model = TinyModel(dim=32) + ref_sd = {k: v.clone() for k, v in ref_model.state_dict().items()} + mp.spawn( + _worker_wrap_model_legacy, + args=(world_size, port, ref_sd), + nprocs=world_size, + join=True, + ) + + +# --------------------------------------------------------------------------- +# Task 5b: wrap_model memory_efficient with per-layer sharding +# --------------------------------------------------------------------------- + + +def _worker_wrap_model_per_layer(rank, world_size, port, ref_sd): + """Test that wrap_model with memory_efficient=True correctly shards per decoder layer.""" + _init_dist(rank, world_size, port) + from twinkle.model.transformers.strategy.native_fsdp import NativeFSDPStrategy, _get_decoder_layers + from twinkle.utils import DeviceMesh as TwinkleMesh + + mesh = TwinkleMesh( + mesh=np.arange(world_size), + mesh_dim_names=('fsdp', ), + device_type=_DEVICE_TYPE, + ) + strategy = NativeFSDPStrategy(device_mesh=mesh, mixed_precision='no') + + model = TinyTransformerModel(dim=32, num_layers=2).to(_DEVICE_TYPE) + if rank == 0: + model.load_state_dict(ref_sd) + + # Verify the model has the expected structure before wrapping + assert _get_decoder_layers(model) is not None, \ + 'TinyTransformerModel should have model.model.layers' + + model, _ = strategy.wrap_model(model, optimizer=None, memory_efficient=True) + + # Verify: all parameters on device, not meta + for name, param in model.named_parameters(): + assert param.device.type == _DEVICE_TYPE, f"{name} still on {param.device}" + + # Verify: gathered full state dict matches original + from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict + gathered = get_model_state_dict( + model, + options=StateDictOptions(full_state_dict=True, cpu_offload=True), + ) + if rank == 0: + for key in ref_sd: + assert torch.allclose(gathered[key], ref_sd[key], atol=1e-6), \ + f"Mismatch on {key}" + + dist.destroy_process_group() + + +@unittest.skipIf(_device_count() < 2, f'Need >= 2 {_DEVICE_TYPE.upper()}s') +class TestWrapModelPerLayerSharding(unittest.TestCase): + + def test_wrap_model_per_layer_memory_efficient(self): + """Verify per-layer fully_shard works with the meta→broadcast path.""" + port = _find_free_port() + world_size = 2 + ref_model = TinyTransformerModel(dim=32, num_layers=2) + ref_sd = {k: v.clone() for k, v in ref_model.state_dict().items()} + mp.spawn( + _worker_wrap_model_per_layer, + args=(world_size, port, ref_sd), + nprocs=world_size, + join=True, + ) + + +# --------------------------------------------------------------------------- +# Task 6: env var / memory_efficient_init parameter in TransformersModel +# --------------------------------------------------------------------------- +class TestEnvVarRamEfficientLoading(unittest.TestCase): + """Test that __init__ sets FSDP env vars during from_pretrained.""" + + def test_env_vars_set_during_from_pretrained(self): + """Verify env vars are set when memory_efficient_init=True and restored after.""" + import inspect + from unittest.mock import MagicMock, patch + + from twinkle.model.transformers.transformers import TransformersModel + + # Verify the parameter exists in __init__ signature + sig = inspect.signature(TransformersModel.__init__) + assert 'memory_efficient_init' in sig.parameters, \ + 'memory_efficient_init parameter should exist in __init__' + + # Verify env vars are actually set during from_pretrained call + captured_env = {} + + # Build a fake model_cls whose from_pretrained captures env vars. + # We pass it explicitly as model_cls to avoid default-argument binding issues. + fake_model_cls = MagicMock() + + def fake_from_pretrained(*args, **kwargs): + captured_env['ACCELERATE_USE_FSDP'] = os.environ.get('ACCELERATE_USE_FSDP') + captured_env['FSDP_CPU_RAM_EFFICIENT_LOADING'] = os.environ.get('FSDP_CPU_RAM_EFFICIENT_LOADING') + mock_model = MagicMock() + mock_model.gradient_checkpointing_enable = MagicMock() + mock_model.named_parameters = MagicMock(return_value=iter([])) + return mock_model + + fake_model_cls.from_pretrained = fake_from_pretrained + + from twinkle.utils import DeviceMesh as TwinkleMesh + + mesh = TwinkleMesh( + mesh=np.arange(1), + mesh_dim_names=('fsdp', ), + device_type='cpu', + ) + + # Clean env before test + saved = {} + for key in ('ACCELERATE_USE_FSDP', 'FSDP_CPU_RAM_EFFICIENT_LOADING'): + saved[key] = os.environ.pop(key, None) + + try: + with patch('twinkle.model.transformers.transformers.HubOperation') as mock_hub, \ + patch.object(TransformersModel, '_try_init_process_group'), \ + patch.object(TransformersModel, '_decide_strategy'), \ + patch.object(TransformersModel, '_construct_default_optimizer_group', return_value=MagicMock()): + mock_hub.download_model.return_value = '/fake/path' + + # memory_efficient_init=True with device_mesh → env vars should be set + try: + TransformersModel( + model_cls=fake_model_cls, + model_id='/fake/model', + device_mesh=mesh, + memory_efficient_init=True, + ) + except Exception: + pass # We only care about the env var capture + + assert captured_env.get('ACCELERATE_USE_FSDP') == 'true', \ + f"ACCELERATE_USE_FSDP should be 'true' during from_pretrained, got " \ + f"{captured_env.get('ACCELERATE_USE_FSDP')}" + assert captured_env.get('FSDP_CPU_RAM_EFFICIENT_LOADING') == 'true', \ + f"FSDP_CPU_RAM_EFFICIENT_LOADING should be 'true' during " \ + f"from_pretrained, got {captured_env.get('FSDP_CPU_RAM_EFFICIENT_LOADING')}" + + # Verify env vars are restored after __init__ + assert os.environ.get('ACCELERATE_USE_FSDP') is None, \ + 'ACCELERATE_USE_FSDP should be restored (removed) after __init__' + assert os.environ.get('FSDP_CPU_RAM_EFFICIENT_LOADING') is None, \ + 'FSDP_CPU_RAM_EFFICIENT_LOADING should be restored (removed) after __init__' + finally: + # Restore original env + for key, val in saved.items(): + if val is None: + os.environ.pop(key, None) + else: + os.environ[key] = val + + def test_env_vars_not_set_when_disabled(self): + """Verify env vars are NOT set when memory_efficient_init=False.""" + from unittest.mock import MagicMock, patch + + from twinkle.model.transformers.transformers import TransformersModel + from twinkle.utils import DeviceMesh as TwinkleMesh + + captured_env = {} + + fake_model_cls = MagicMock() + + def fake_from_pretrained(*args, **kwargs): + captured_env['ACCELERATE_USE_FSDP'] = os.environ.get('ACCELERATE_USE_FSDP') + captured_env['FSDP_CPU_RAM_EFFICIENT_LOADING'] = os.environ.get('FSDP_CPU_RAM_EFFICIENT_LOADING') + mock_model = MagicMock() + mock_model.gradient_checkpointing_enable = MagicMock() + mock_model.named_parameters = MagicMock(return_value=iter([])) + return mock_model + + fake_model_cls.from_pretrained = fake_from_pretrained + + mesh = TwinkleMesh( + mesh=np.arange(1), + mesh_dim_names=('fsdp', ), + device_type='cpu', + ) + + saved = {} + for key in ('ACCELERATE_USE_FSDP', 'FSDP_CPU_RAM_EFFICIENT_LOADING'): + saved[key] = os.environ.pop(key, None) + + try: + with patch('twinkle.model.transformers.transformers.HubOperation') as mock_hub, \ + patch.object(TransformersModel, '_try_init_process_group'), \ + patch.object(TransformersModel, '_decide_strategy'), \ + patch.object(TransformersModel, '_construct_default_optimizer_group', return_value=MagicMock()): + mock_hub.download_model.return_value = '/fake/path' + + try: + TransformersModel( + model_cls=fake_model_cls, + model_id='/fake/model', + device_mesh=mesh, + memory_efficient_init=False, + ) + except Exception: + pass + + assert captured_env.get('ACCELERATE_USE_FSDP') is None, \ + 'ACCELERATE_USE_FSDP should NOT be set when memory_efficient_init=False' + assert captured_env.get('FSDP_CPU_RAM_EFFICIENT_LOADING') is None, \ + 'FSDP_CPU_RAM_EFFICIENT_LOADING should NOT be set when memory_efficient_init=False' + finally: + for key, val in saved.items(): + if val is None: + os.environ.pop(key, None) + else: + os.environ[key] = val + + +# --------------------------------------------------------------------------- +# Task 8: End-to-end integration test +# --------------------------------------------------------------------------- + + +def _worker_e2e_memory_efficient(rank, world_size, port, model_path): + """End-to-end: init → set_optimizer → trigger _lazy_wrap_model with memory_efficient.""" + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = str(port) + os.environ['RANK'] = str(rank) + os.environ['LOCAL_RANK'] = str(rank) + os.environ['WORLD_SIZE'] = str(world_size) + + import twinkle + from twinkle.model import TransformersModel + from twinkle.utils import DeviceMesh as TwinkleMesh + + mesh = TwinkleMesh( + mesh=np.arange(world_size), + mesh_dim_names=('fsdp', ), + device_type=_DEVICE_TYPE, + ) + + # Must initialize twinkle before creating TransformersModel, otherwise + # the @remote_class decorator will strip the device_mesh parameter. + twinkle.initialize(mode='local', global_device_mesh=mesh) + + model = TransformersModel( + model_id=model_path, + device_mesh=mesh, + strategy='native_fsdp', + mixed_precision='bf16', + memory_efficient_init=True, + ) + + model.set_optimizer('AdamW', lr=1e-4) + + # Trigger _lazy_wrap_model by calling the internal method directly. + model._lazy_wrap_model() + + # Verify: model should be wrapped and parameters on device + assert model._model_wrapped, 'Model should be wrapped after _lazy_wrap_model' + for name, param in model.model.named_parameters(): + assert param.device.type == _DEVICE_TYPE, f"{name} on {param.device}, expected {_DEVICE_TYPE}" + + # Verify: gathered full state dict matches (weights were broadcast correctly) + from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict + gathered = get_model_state_dict( + model.model, + options=StateDictOptions(full_state_dict=True, cpu_offload=True), + ) + # full_state_dict=True gathers shards to rank 0 only; other ranks get {}. + if rank == 0: + assert len(gathered) > 0, 'Should have gathered state dict' + + if dist.is_initialized(): + dist.destroy_process_group() + + +@unittest.skipIf(_device_count() < 2, f'Need >= 2 {_DEVICE_TYPE.upper()}s') +class TestE2EMemoryEfficientInit(unittest.TestCase): + + def test_e2e_wrap_model(self): + """Full pipeline test: TransformersModel init → set_optimizer → _lazy_wrap_model.""" + model_id = os.environ.get('TEST_SMALL_MODEL_ID') + if not model_id: + self.skipTest('Set TEST_SMALL_MODEL_ID env var to a small HF model path') + + port = _find_free_port() + world_size = 2 + mp.spawn( + _worker_e2e_memory_efficient, + args=(world_size, port, model_id), + nprocs=world_size, + join=True, + ) + + +if __name__ == '__main__': + unittest.main()