-
Notifications
You must be signed in to change notification settings - Fork 20
Fix mm server #125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+230
−40
Merged
Fix mm server #125
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
913f0f3
wip
tastelikefeet 1f7a23b
wip
tastelikefeet 76c0770
fix
tastelikefeet 772659c
fix
tastelikefeet 86969f4
lint code
tastelikefeet e88e620
fix
tastelikefeet c31c651
fix
tastelikefeet 06e9b2e
fix
tastelikefeet a07fc14
fix
tastelikefeet 0d5da72
bump model to qwen3.5-4b
tastelikefeet File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,168 @@ | ||
| # Twinkle Client - Transformers LoRA Training Example | ||
| # | ||
| # This script demonstrates how to fine-tune a language model using LoRA | ||
| # (Low-Rank Adaptation) through the Twinkle client-server architecture. | ||
| # The server must be running first (see server.py and server_config.yaml). | ||
|
|
||
| # Step 1: Load environment variables from a .env file (e.g., API tokens) | ||
| import dotenv | ||
tastelikefeet marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| import os | ||
| from twinkle.data_format import Trajectory, Message | ||
| from twinkle.preprocessor import Preprocessor | ||
|
|
||
| dotenv.load_dotenv('.env') | ||
tastelikefeet marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| import numpy as np | ||
| import torch | ||
| from peft import LoraConfig | ||
|
|
||
| from twinkle import get_logger | ||
| from twinkle.dataset import DatasetMeta | ||
| from twinkle_client import init_twinkle_client | ||
| from twinkle.dataloader import DataLoader | ||
| from twinkle.dataset import LazyDataset | ||
| from twinkle_client.model import MultiLoraTransformersModel | ||
|
|
||
| logger = get_logger() | ||
|
|
||
| base_model = 'Qwen/Qwen3.5-4B' | ||
| base_url = 'http://www.modelscope.cn/twinkle' | ||
|
|
||
| # Step 2: Initialize the Twinkle client to communicate with the remote server. | ||
| # - base_url: the address of the running Twinkle server | ||
| # - api_key: authentication token (loaded from environment variable) | ||
| client = init_twinkle_client(base_url=base_url, api_key=os.environ.get('MODELSCOPE_TOKEN')) | ||
|
|
||
| # Step 3: Query the server for existing training runs and their checkpoints. | ||
| # This is useful for resuming a previous training session. | ||
| runs = client.list_training_runs() | ||
|
|
||
| resume_path = None | ||
| for run in runs: | ||
| logger.info(run.model_dump_json(indent=2)) | ||
| # List all saved checkpoints for this training run | ||
| checkpoints = client.list_checkpoints(run.training_run_id) | ||
|
|
||
| for checkpoint in checkpoints: | ||
| logger.info(checkpoint.model_dump_json(indent=2)) | ||
| # Uncomment the line below to resume from a specific checkpoint: | ||
| # resume_path = checkpoint.twinkle_path | ||
|
|
||
|
|
||
| class LatexOCRProcessor(Preprocessor): | ||
|
|
||
| def __call__(self, rows): | ||
| rows = self.map_col_to_row(rows) | ||
| rows = [self.preprocess(row) for row in rows] | ||
| rows = self.map_row_to_col(rows) | ||
| return rows | ||
|
|
||
| def preprocess(self, row) -> Trajectory: | ||
| return Trajectory( | ||
| messages=[ | ||
| Message(role='user', content='<image>Using LaTeX to perform OCR on the image.', images=[row['image']]), | ||
| Message(role='assistant', content=row['text']), | ||
| ] | ||
| ) | ||
|
|
||
|
|
||
| def train(): | ||
| # Step 4: Prepare the dataset | ||
|
|
||
| # Load the latex dataset from ModelScope | ||
| dataset = LazyDataset(dataset_meta=DatasetMeta('ms://AI-ModelScope/LaTeX_OCR', data_slice=range(500))) | ||
|
|
||
| # Apply a chat template so the data matches the model's expected input format | ||
| dataset.set_template('Qwen3_5Template', model_id=f'ms://{base_model}', max_length=512) | ||
|
|
||
| # Replace placeholder names in the dataset with custom model/author names | ||
| dataset.map(LatexOCRProcessor) | ||
|
|
||
| # Tokenize and encode the dataset into model-ready input features | ||
| dataset.encode(batched=True) | ||
|
|
||
| # Wrap the dataset into a DataLoader that yields batches of size 4 | ||
| dataloader = DataLoader(dataset=dataset, batch_size=4) | ||
|
|
||
| # Step 5: Configure the model | ||
|
|
||
| # Create a multi-LoRA Transformers model pointing to the base model on ModelScope | ||
| model = MultiLoraTransformersModel(model_id=f'ms://{base_model}') | ||
|
|
||
| # Define LoRA configuration: apply low-rank adapters to all linear layers | ||
| lora_config = LoraConfig(target_modules='all-linear') | ||
|
|
||
| # Attach the LoRA adapter named 'default' to the model. | ||
| # gradient_accumulation_steps=2 means gradients are accumulated over 2 micro-batches | ||
| # before an optimizer step, effectively doubling the batch size. | ||
| model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=2) | ||
|
|
||
| # Set the same chat template used during data preprocessing | ||
| model.set_template('Qwen3_5Template') | ||
|
|
||
| # Set the input processor (pads sequences on the right side) | ||
| model.set_processor('InputProcessor', padding_side='right') | ||
|
|
||
| # Use cross-entropy loss for language modeling | ||
| model.set_loss('CrossEntropyLoss') | ||
|
|
||
| # Use Adam optimizer with a learning rate of 1e-4 (Only support Adam optimizer if server use megatron) | ||
| model.set_optimizer('Adam', lr=1e-4) | ||
|
|
||
| # Use a linear learning rate scheduler (Do not support LR scheduler if server use megatron) | ||
| # model.set_lr_scheduler('LinearLR') | ||
|
|
||
| # Step 6: Optionally resume from a previous checkpoint | ||
| if resume_path: | ||
| logger.info(f'Resuming training from {resume_path}') | ||
| model.load(resume_path, load_optimizer=True) | ||
|
|
||
| # Step 7: Run the training loop | ||
| logger.info(model.get_train_configs().model_dump()) | ||
|
|
||
| for epoch in range(3): | ||
| logger.info(f'Starting epoch {epoch}') | ||
| for step, batch in enumerate(dataloader): | ||
| for sample in batch: | ||
| for key in sample: | ||
| if isinstance(sample[key], np.ndarray): | ||
| sample[key] = sample[key].tolist() | ||
| elif isinstance(sample[key], torch.Tensor): | ||
| sample[key] = sample[key].cpu().numpy().tolist() | ||
| # Forward pass + backward pass (computes gradients) | ||
| model.forward_backward(inputs=batch) | ||
|
|
||
| # Step | ||
| model.clip_grad_and_step() | ||
| # Equal to the following steps: | ||
| # # Clip gradients to prevent exploding gradients (max norm = 1.0) | ||
| # model.clip_grad_norm(1.0) | ||
| # # Perform one optimizer step (update model weights) | ||
| # model.step() | ||
| # # Reset gradients to zero for the next iteration | ||
| # model.zero_grad() | ||
| # # Advance the learning rate scheduler by one step | ||
| # model.lr_step() | ||
|
|
||
| # Log the loss every 2 steps (aligned with gradient accumulation) | ||
| if step % 2 == 0: | ||
| # Print metric | ||
| metric = model.calculate_metric(is_training=True) | ||
| logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric.result}') | ||
|
|
||
| # Step 8: Save the trained checkpoint | ||
| twinkle_path = model.save(name=f'twinkle-epoch-{epoch}', save_optimizer=True) | ||
| logger.info(f'Saved checkpoint: {twinkle_path}') | ||
|
|
||
| # Step 9: Upload the checkpoint to ModelScope Hub | ||
| # YOUR_USER_NAME = "your_username" | ||
| # hub_model_id = f'{YOUR_USER_NAME}/twinkle-multi-modal' | ||
| # model.upload_to_hub( | ||
| # checkpoint_dir=twinkle_path, | ||
| # hub_model_id=hub_model_id, | ||
| # async_upload=False | ||
| # ) | ||
| # logger.info(f"Uploaded checkpoint to hub: {hub_model_id}") | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| train() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.