-
Notifications
You must be signed in to change notification settings - Fork 309
[WIP] [tx] Add deterministic mode #1125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -125,6 +125,12 @@ def force_exit(): | |
| app = FastAPI(title="Tinker API Mock", version="0.0.1", lifespan=lifespan) | ||
|
|
||
|
|
||
| def _is_jax_deterministic_mode(request: Request) -> bool: | ||
| """Return true when deterministic mode is enabled for the JAX backend.""" | ||
| cfg = request.app.state.engine_config | ||
| return cfg.backend == "jax" and bool(cfg.backend_config.get("deterministic", False)) | ||
|
|
||
|
|
||
| async def get_session(request: Request) -> AsyncGenerator[AsyncSession, None]: | ||
| """Dependency to get a database session.""" | ||
| async with AsyncSession(request.app.state.db_engine) as session: | ||
|
|
@@ -365,14 +371,16 @@ class SamplingParams(BaseModel): | |
| top_k: int = -1 | ||
| top_p: float = 1 | ||
|
|
||
| def to_types(self) -> types.SamplingParams: | ||
| def to_types(self, deterministic: bool = False) -> types.SamplingParams: | ||
| if self.max_tokens is None: | ||
| raise HTTPException(status_code=400, detail="max_tokens is currently required") | ||
| if self.max_tokens <= 0: | ||
| raise HTTPException(status_code=400, detail="max_tokens must be a positive number") | ||
|
|
||
| # Generate a random seed if not provided | ||
| seed = self.seed if self.seed is not None else random.randint(0, 2**31 - 1) | ||
| # In deterministic mode, default to a fixed seed instead of random fallback. | ||
| seed = self.seed if self.seed is not None else ( | ||
| 0 if deterministic else random.randint(0, 2**31 - 1) | ||
| ) | ||
|
|
||
| # Determine if stop values are token IDs (int) or strings | ||
| stop_tokens = None | ||
|
|
@@ -600,7 +608,7 @@ async def create_sampling_session(request: CreateSamplingSessionRequest, session | |
|
|
||
|
|
||
| @app.post("/api/v1/create_model", response_model=CreateModelResponse) | ||
| async def create_model(request: CreateModelRequest, session: AsyncSession = Depends(get_session)): | ||
| async def create_model(request: CreateModelRequest, req: Request, session: AsyncSession = Depends(get_session)): | ||
| """Create a new model, optionally with a LoRA adapter.""" | ||
| # Validate session exists | ||
| session_db = await session.get(SessionDB, request.session_id) | ||
|
|
@@ -610,8 +618,13 @@ async def create_model(request: CreateModelRequest, session: AsyncSession = Depe | |
| model_id = f"model_{uuid4().hex[:8]}" | ||
|
|
||
| # alpha = 32 seems to be the tinker default (see https://thinkingmachines.ai/blog/lora/) | ||
| # Generate a random seed if not provided | ||
| seed = request.lora_config.seed if request.lora_config.seed is not None else random.randint(0, 2**31 - 1) | ||
| # In deterministic mode, default to a fixed seed instead of random fallback. | ||
| deterministic = _is_jax_deterministic_mode(req) | ||
| seed = ( | ||
| request.lora_config.seed | ||
| if request.lora_config.seed is not None | ||
| else (0 if deterministic else random.randint(0, 2**31 - 1)) | ||
| ) | ||
|
Comment on lines
+623
to
+627
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This logic for determining the seed based on the To avoid code duplication and improve maintainability, consider extracting this logic into a shared helper function, for example: def _get_seed(provided_seed: int | None, deterministic: bool) -> int:
"""Return a seed, falling back to a fixed or random seed based on deterministic mode."""
if provided_seed is not None:
return provided_seed
return 0 if deterministic else random.randint(0, 2**31 - 1)You could then call this helper function here and in seed = _get_seed(request.lora_config.seed, deterministic) |
||
| lora_config = types.LoraConfig(rank=request.lora_config.rank, alpha=32.0, seed=seed) | ||
| request_id = await create_future( | ||
| session=session, | ||
|
|
@@ -635,7 +648,7 @@ async def create_model(request: CreateModelRequest, session: AsyncSession = Depe | |
| return CreateModelResponse( | ||
| model_id=model_id, | ||
| base_model=request.base_model, | ||
| lora_config=request.lora_config, | ||
| lora_config=LoRAConfig(rank=request.lora_config.rank, seed=seed), | ||
| status="created", | ||
| request_id=str(request_id), | ||
| ) | ||
|
|
@@ -890,6 +903,8 @@ async def asample(request: SampleRequest, req: Request, session: AsyncSession = | |
| # Validate that the checkpoint exists and is ready | ||
| await validate_checkpoint(req, model_id, checkpoint_id, types.CheckpointType.SAMPLER, session) | ||
|
|
||
| deterministic = _is_jax_deterministic_mode(req) | ||
|
|
||
| request_id = await create_future( | ||
| session=session, | ||
| request_type=( | ||
|
|
@@ -899,7 +914,7 @@ async def asample(request: SampleRequest, req: Request, session: AsyncSession = | |
| request_data=types.SampleInput( | ||
| base_model=base_model, | ||
| prompt=request.prompt.to_types(), | ||
| sampling_params=request.sampling_params.to_types(), | ||
| sampling_params=request.sampling_params.to_types(deterministic=deterministic), | ||
| num_samples=request.num_samples, | ||
| checkpoint_id=checkpoint_id, | ||
| prompt_logprobs=request.prompt_logprobs if request.prompt_logprobs is not None else False, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's significant code duplication here for running a training step and fetching the resulting checkpoint. The logic in lines 190-211 is repeated with minor differences in lines 198-215.
To improve readability and maintainability, consider extracting this logic into a helper function. This function could take the
training_client,rest_client, and anamefor the checkpoint, and return the checkpoint bytes. This would make the test's intent clearer and reduce redundancy.