Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 211 additions & 18 deletions code-rs/core/src/codex/exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,43 @@ fn truncate_payload(text: &str, limit: usize) -> String {
}
}

fn trimmed_non_empty(text: &str) -> Option<String> {
let trimmed = text.trim();
if trimmed.is_empty() {
None
} else {
Some(trimmed.to_string())
}
}

fn join_text_chunks(chunks: Vec<String>) -> Option<String> {
if chunks.is_empty() {
None
} else {
Some(chunks.join("\n\n"))
}
}

#[derive(Debug, Clone)]
pub(super) struct ProjectHookCommandResult {
pub stdout: String,
pub stderr: String,
pub exit_code: Option<i32>,
}

#[derive(Debug, Default, Clone)]
pub(super) struct UserPromptSubmitHookOutcome {
pub blocked: bool,
pub block_reason: Option<String>,
pub additional_contexts: Vec<String>,
}

#[derive(Debug, Default, Clone)]
pub(super) struct StopHookOutcome {
pub blocked: bool,
pub continuation_prompt: Option<String>,
}

fn build_exec_hook_payload(
event: ProjectHookEvent,
ctx: &ExecCommandContext,
Expand Down Expand Up @@ -688,7 +725,16 @@ impl Session {
let payload = build_exec_hook_payload(event, exec_ctx, params, output);
for (idx, hook) in hooks.into_iter().enumerate() {
self
.run_hook_command(turn_diff_tracker, &hook, event, &payload, Some(exec_ctx), attempt_req, idx)
.run_hook_command(
turn_diff_tracker,
&hook,
event,
&payload,
Some(exec_ctx),
None,
attempt_req,
idx,
)
.await;
}
}
Expand All @@ -709,22 +755,159 @@ impl Session {
let attempt_req = self.current_request_ordinal();
for (idx, hook) in hooks.into_iter().enumerate() {
self
.run_hook_command(&mut tracker, &hook, event, &payload, None, attempt_req, idx)
.run_hook_command(&mut tracker, &hook, event, &payload, None, None, attempt_req, idx)
.await;
}
}

pub(super) async fn run_user_prompt_submit_hooks(
&self,
sub_id: &str,
items: &[InputItem],
_final_output_json_schema: Option<&Value>,
attempt_req: u64,
) -> UserPromptSubmitHookOutcome {
let transcript_path = self
.clone_rollout_recorder()
.map(|rec| rec.rollout_path.to_string_lossy().to_string());
let prompt = items
.iter()
.filter_map(|item| match item {
InputItem::Text { text } => Some(text.trim()),
_ => None,
})
.filter(|text| !text.is_empty())
.collect::<Vec<_>>()
.join("\n\n");
let payload = json!({
"event": ProjectHookEvent::UserPromptSubmit.as_str(),
"session_id": self.id,
"turn_id": sub_id,
"transcript_path": transcript_path,
"cwd": self.cwd.to_string_lossy(),
"model": self.client.get_model(),
"prompt": prompt,
});
let results = self
.run_project_hooks_for_payload(
ProjectHookEvent::UserPromptSubmit,
&payload,
sub_id,
attempt_req,
)
.await;

let additional_contexts = results
.iter()
.filter_map(|result| trimmed_non_empty(&result.stdout))
.collect::<Vec<_>>();
let block_reasons = results
.iter()
.filter(|result| result.exit_code == Some(2))
.filter_map(|result| trimmed_non_empty(&result.stderr))
.collect::<Vec<_>>();
let block_reason = join_text_chunks(block_reasons);

UserPromptSubmitHookOutcome {
blocked: block_reason.is_some(),
block_reason,
additional_contexts,
}
}

pub(super) async fn run_stop_hooks(
&self,
sub_id: &str,
last_assistant_message: Option<&str>,
stop_hook_active: bool,
attempt_req: u64,
) -> StopHookOutcome {
let transcript_path = self
.clone_rollout_recorder()
.map(|rec| rec.rollout_path.to_string_lossy().to_string());
let payload = json!({
"event": ProjectHookEvent::Stop.as_str(),
"session_id": self.id,
"turn_id": sub_id,
"transcript_path": transcript_path,
"cwd": self.cwd.to_string_lossy(),
"model": self.client.get_model(),
"stop_hook_active": stop_hook_active,
"last_assistant_message": last_assistant_message,
});
let results = self
.run_project_hooks_for_payload(ProjectHookEvent::Stop, &payload, sub_id, attempt_req)
.await;
let prompts = results
.into_iter()
.filter(|result| result.exit_code == Some(2))
.filter_map(|result| trimmed_non_empty(&result.stderr))
.collect::<Vec<_>>();
let continuation_prompt = join_text_chunks(prompts);

StopHookOutcome {
blocked: continuation_prompt.is_some(),
continuation_prompt,
}
}

async fn run_project_hooks_for_payload(
&self,
event: ProjectHookEvent,
payload: &Value,
sub_id: &str,
attempt_req: u64,
) -> Vec<ProjectHookCommandResult> {
if self.project_hooks.is_empty() {
return Vec::new();
}
let hooks: Vec<ProjectHook> = self.project_hooks.hooks_for(event).cloned().collect();
if hooks.is_empty() {
return Vec::new();
}
let Some(_guard) = HookGuard::try_acquire(&self.hook_guard) else {
return Vec::new();
};
let mut tracker = TurnDiffTracker::new();
let mut results = Vec::with_capacity(hooks.len());
for (idx, hook) in hooks.into_iter().enumerate() {
let result = self
.run_hook_command(
&mut tracker,
&hook,
event,
payload,
None,
Some(sub_id),
attempt_req,
idx,
)
.await;
results.push(result);
}
results
}

fn build_session_payload(&self, event: ProjectHookEvent) -> Value {
let transcript_path = self
.clone_rollout_recorder()
.map(|rec| rec.rollout_path.to_string_lossy().to_string());
match event {
ProjectHookEvent::SessionStart => json!({
"event": event.as_str(),
"session_id": self.id,
"transcript_path": transcript_path,
"cwd": self.cwd.to_string_lossy(),
"model": self.client.get_model(),
"sandbox_policy": format!("{}", self.sandbox_policy),
"approval_policy": format!("{}", self.approval_policy),
}),
ProjectHookEvent::SessionEnd => json!({
"event": event.as_str(),
"session_id": self.id,
"transcript_path": transcript_path,
"cwd": self.cwd.to_string_lossy(),
"model": self.client.get_model(),
"sandbox_policy": format!("{}", self.sandbox_policy),
"approval_policy": format!("{}", self.approval_policy),
}),
Expand All @@ -739,11 +922,13 @@ impl Session {
event: ProjectHookEvent,
payload: &Value,
base_ctx: Option<&ExecCommandContext>,
fallback_sub_id: Option<&str>,
attempt_req: u64,
index: usize,
) {
) -> ProjectHookCommandResult {
let sub_id = base_ctx
.map(|ctx| ctx.sub_id.clone())
.or_else(|| fallback_sub_id.map(ToOwned::to_owned))
.unwrap_or_else(|| INITIAL_SUBMIT_ID.to_string());
let base_slug = base_ctx
.map(|ctx| sanitize_identifier(&ctx.call_id))
Expand Down Expand Up @@ -798,7 +983,7 @@ impl Session {
stdout_stream: None,
};

if let Err(err) = Box::pin(self.run_exec_with_events_inner(
match Box::pin(self.run_exec_with_events_inner(
turn_diff_tracker,
exec_ctx,
exec_args,
Expand All @@ -807,20 +992,28 @@ impl Session {
attempt_req,
false,
))
.await
{
let hook_label = hook
.name
.as_deref()
.unwrap_or_else(|| hook.command.first().map(String::as_str).unwrap_or("hook"));
let order = self.next_background_order(&sub_id, attempt_req, None);
self
.notify_background_event_with_order(
&sub_id,
order,
format!("Hook `{}` failed: {}", hook_label, get_error_message_ui(&err)),
)
.await;
.await {
Ok(output) => ProjectHookCommandResult {
stdout: output.stdout.text,
stderr: output.stderr.text,
exit_code: Some(output.exit_code),
},
Err(err) => {
let hook_label = hook
.name
.as_deref()
.unwrap_or_else(|| hook.command.first().map(String::as_str).unwrap_or("hook"));
let order = self.next_background_order(&sub_id, attempt_req, None);
let message = format!("Hook `{}` failed: {}", hook_label, get_error_message_ui(&err));
self
.notify_background_event_with_order(&sub_id, order, message.clone())
.await;
ProjectHookCommandResult {
stdout: String::new(),
stderr: message,
exit_code: None,
}
}
}
}

Expand Down
Loading