Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions app.log
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
2026-03-24 14:58:30,184 - my_logger - INFO - Application start. Server waiting for [1, 1] clients.
2026-03-24 15:01:29,011 - my_logger - INFO - Start training round 1
2026-03-24 22:44:03,514 - my_logger - INFO - Application start. Server waiting for [1, 1] clients.
2026-03-24 22:45:14,428 - my_logger - INFO - Start training round 1
2026-03-25 09:13:23,571 - my_logger - INFO - Round 1 complete. Waiting for READY.
2026-03-25 09:13:29,550 - my_logger - INFO - Round 1 fully complete.
2026-03-25 09:13:32,331 - my_logger - INFO - Start training round 2
2026-03-25 19:44:00,127 - my_logger - INFO - Application start. Server waiting for [1, 1] clients.
2026-03-25 19:45:03,466 - my_logger - INFO - Start training round 1
2026-03-25 19:47:07,712 - my_logger - INFO - Start training round 1
2026-03-25 19:47:19,011 - my_logger - INFO - Start training round 1
2026-03-25 19:49:41,646 - my_logger - INFO - Start training round 1
2026-03-25 19:52:49,376 - my_logger - INFO - Application start. Server waiting for [1, 1] clients.
2026-03-25 19:54:50,514 - my_logger - INFO - Start training round 1
2026-03-25 19:56:16,092 - my_logger - INFO - Start training round 1
2026-03-25 19:57:18,773 - my_logger - INFO - Application start. Server waiting for [1, 1] clients.
2026-03-25 19:57:49,168 - my_logger - INFO - Start training round 1
2026-03-25 20:05:50,761 - my_logger - INFO - Application start. Server waiting for [1, 1] clients.
2026-03-25 20:06:20,483 - my_logger - INFO - Start training round 1
2026-03-26 04:53:12,204 - my_logger - INFO - Round 1 complete. Waiting for READY.
2026-03-26 04:53:17,911 - my_logger - INFO - Round 1 fully complete.
2026-03-26 04:53:20,860 - my_logger - INFO - Start training round 2
50 changes: 50 additions & 0 deletions client - Copy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import pika
import uuid
import argparse
import yaml
import os

import torch

import src.Log
from src.RpcClient import RpcClient

parser = argparse.ArgumentParser(description="Split learning framework")
parser.add_argument('--layer_id', type=int, required=True, help='ID of layer, start from 1')
parser.add_argument('--device', type=str, required=False, help='Device of client')

args = parser.parse_args()

with open('config.yaml', 'r') as file:
config = yaml.safe_load(file)

client_id = uuid.uuid4()
address = config["rabbit"]["address"]
username = config["rabbit"]["username"]
password = config["rabbit"]["password"]
virtual_host = config["rabbit"]["virtual-host"]

device = None
if args.device is None:
if torch.cuda.is_available():
device = "cuda"
print(f"Using device: {torch.cuda.get_device_name(device)}")
else:
device = "cpu"
print(f"Using device: CPU")
else:
device = args.device
print(f"Using device: {device}")

credentials = pika.PlainCredentials(username, password)
connection = pika.BlockingConnection(pika.ConnectionParameters(address, 5672, f'{virtual_host}', credentials))
channel = connection.channel()

if __name__ == "__main__":
src.Log.print_with_color("[>>>] Client sending registration message to server...", "red")

data = {"action": "REGISTER", "client_id": client_id, "layer_id": args.layer_id,"message": "Hello from Client!"}
client = RpcClient(client_id, args.layer_id, channel, device)
client.send_to_server(data)
client.wait_response()

18 changes: 13 additions & 5 deletions client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@

args = parser.parse_args()

with open('config.yaml', 'r') as file:
with open("config.yaml", "r", encoding="utf-8") as file:
config = yaml.safe_load(file)

client_id = uuid.uuid4()
address = config["rabbit"]["address"]
username = config["rabbit"]["username"]
Expand All @@ -37,7 +36,17 @@
print(f"Using device: {device}")

credentials = pika.PlainCredentials(username, password)
connection = pika.BlockingConnection(pika.ConnectionParameters(address, 5672, f'{virtual_host}', credentials))
# FIX: heartbeat=0 tắt timeout, tránh StreamLostError khi train lâu
connection = pika.BlockingConnection(
pika.ConnectionParameters(
host=address,
port=5672,
virtual_host=f'{virtual_host}',
credentials=credentials,
heartbeat=0,
blocked_connection_timeout=None,
)
)
channel = connection.channel()

if __name__ == "__main__":
Expand All @@ -46,5 +55,4 @@
data = {"action": "REGISTER", "client_id": client_id, "layer_id": args.layer_id,"message": "Hello from Client!"}
client = RpcClient(client_id, args.layer_id, channel, device)
client.send_to_server(data)
client.wait_response()

client.wait_response()
37 changes: 25 additions & 12 deletions config.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
name: SplitFedLLM

server:
global-round: 1
global-round: 10
clients:
- 1
- 1
cut-layers: 4
model-name: Bert # GPT2/Llama/Bert
data-name: EMOTION # EMOTION/GSM8K
model-name: GPT2 # GPT2 / Llama / Bert
data-name: E2E # E2E / EMOTION / AG_NEWS / GSM8K
pretrained_path: GPT2.pt
model:
GPT2:
n_block: 12
Expand All @@ -17,15 +19,15 @@ server:
parameters:
load: True
save: True
validation: True
validation: False
data-distribution:
non-iid: False
num-sample: 500
num-label: 10
num-sample: 10000
num-label: 1
dirichlet:
alpha: 1
refresh-each-round: True
random-seed: 1
random-seed: 42

rabbit:
address: 127.0.0.1
Expand All @@ -34,18 +36,29 @@ rabbit:
virtual-host: /

log_path: .
debug_mode: True
debug_mode: False

learning:
learning-rate: 0.00001
learning-rate: 0.00005
weight-decay: 0.01
batch-size: 2
control-count: 1
clip-grad-norm: 0.0
batch-size: 8
control-count: 2
clip-grad-norm: 1.0

fine-tune:
enable: True
name: LoRA
LoRA:
r: 8
alpha: 16
QLoRA:
r: 8
alpha: 16
bits: 4
double_quant: True

optimization:
flash_attention: False
precision: fp32
quantize_hidden: False
gradient_checkpointing: True
23 changes: 23 additions & 0 deletions convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import torch
from transformers import GPT2LMHeadModel

# load HF model
hf_model = GPT2LMHeadModel.from_pretrained("./gpt2_e2e_finetuned")

hf_sd = hf_model.state_dict()
new_sd = {}

for k, v in hf_sd.items():
new_k = k

# 🔥 remove prefix
if k.startswith("transformer."):
new_k = k.replace("transformer.", "")

# 🔥 lm_head giữ nguyên
new_sd[new_k] = v

# save
torch.save(new_sd, "GPT2.pt")

print("Converted → GPT2.pt")
15 changes: 13 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
torch
torch>=2.0.0
pika~=1.3.2
transformers>=4.36.2
datasets
peft
peft>=0.9.0
numpy
nltk
rouge_score

# ── Tối ưu hóa ────────────────────────────────────────────────
# QLoRA + INT8: quantize base model xuống 4-bit / 8-bit
bitsandbytes>=0.43.0

# Flash Attention 2: cần GPU Ampere+ (RTX 30xx, A100, H100,...)
# Cài thủ công nếu cần: pip install flash-attn --no-build-isolation
flash-attn>=2.5.0; sys_platform != "win32"

# Accelerate: dùng chung với bitsandbytes & gradient checkpointing
accelerate>=0.27.0
2 changes: 1 addition & 1 deletion server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

args = parser.parse_args()

with open('config.yaml') as file:
with open("config.yaml", "r", encoding="utf-8") as file:
config = yaml.safe_load(file)
address = config["rabbit"]["address"]
username = config["rabbit"]["username"]
Expand Down
Loading