Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions scripts/checks/setup_implementation_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,10 @@ def check_working_tree_status(git_root: Path) -> dict:
# Fallback: just strip the status chars
dirty_files.append(line[3:].strip())
return {"clean": len(dirty_files) == 0, "dirty_files": dirty_files}
except Exception:
pass

return {"clean": True, "dirty_files": []}
else:
return {"clean": False, "dirty_files": [], "error": result.stderr.strip()}
except FileNotFoundError:
return {"clean": False, "dirty_files": [], "error": "git not found"}


def detect_commit_style(git_root: Path) -> str:
Expand Down
8 changes: 4 additions & 4 deletions scripts/hooks/capture-session-id.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,14 @@ def main() -> int:
pass

lines_to_write = []
if f"DEEP_SESSION_ID={session_id}" not in existing_content:
lines_to_write.append(f"export DEEP_SESSION_ID={session_id}\n")
if f'DEEP_SESSION_ID="{session_id}"' not in existing_content:
lines_to_write.append(f'export DEEP_SESSION_ID="{session_id}"\n')
if (
transcript_path
and f"CLAUDE_TRANSCRIPT_PATH={transcript_path}" not in existing_content
and f'CLAUDE_TRANSCRIPT_PATH="{transcript_path}"' not in existing_content
):
lines_to_write.append(
f"export CLAUDE_TRANSCRIPT_PATH={transcript_path}\n"
f'export CLAUDE_TRANSCRIPT_PATH="{transcript_path}"\n'
)

if lines_to_write:
Expand Down
9 changes: 1 addition & 8 deletions scripts/lib/impl_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,8 @@
"""

from dataclasses import dataclass
from enum import StrEnum


class TaskStatus(StrEnum):
"""Status values for tasks."""

PENDING = "pending"
IN_PROGRESS = "in_progress"
COMPLETED = "completed"
from scripts.lib.task_storage import TaskStatus


@dataclass(frozen=True, slots=True, kw_only=True)
Expand Down
28 changes: 23 additions & 5 deletions tests/hooks/test_capture_session_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def test_writes_to_claude_env_file_when_available(self, tmp_path):

# Should also write to env file
content = env_file.read_text()
assert "export DEEP_SESSION_ID=test-session-456" in content
assert 'export DEEP_SESSION_ID="test-session-456"' in content

def test_writes_transcript_path_to_env_file(self, tmp_path):
"""Should write transcript_path to CLAUDE_ENV_FILE."""
Expand All @@ -138,13 +138,13 @@ def test_writes_transcript_path_to_env_file(self, tmp_path):

assert returncode == 0
content = env_file.read_text()
assert "export DEEP_SESSION_ID=test-session-789" in content
assert "export CLAUDE_TRANSCRIPT_PATH=/path/to/transcript.jsonl" in content
assert 'export DEEP_SESSION_ID="test-session-789"' in content
assert 'export CLAUDE_TRANSCRIPT_PATH="/path/to/transcript.jsonl"' in content

def test_does_not_duplicate_existing_entries(self, tmp_path):
"""Should not duplicate entries already in CLAUDE_ENV_FILE."""
env_file = tmp_path / "env_file.sh"
env_file.write_text("export DEEP_SESSION_ID=test-session-111\n")
env_file.write_text('export DEEP_SESSION_ID="test-session-111"\n')

payload = {"session_id": "test-session-111"}
returncode, _, _ = run_hook(
Expand All @@ -154,7 +154,7 @@ def test_does_not_duplicate_existing_entries(self, tmp_path):
assert returncode == 0
content = env_file.read_text()
# Should only appear once
assert content.count("DEEP_SESSION_ID=test-session-111") == 1
assert content.count('DEEP_SESSION_ID="test-session-111"') == 1

def test_succeeds_when_claude_env_file_not_set(self):
"""Should succeed even when CLAUDE_ENV_FILE is not set."""
Expand Down Expand Up @@ -241,3 +241,21 @@ def test_outputs_when_deep_session_id_not_set(self):
assert returncode == 0
output = json.loads(stdout)
assert output["hookSpecificOutput"]["additionalContext"] == "DEEP_SESSION_ID=test-session-789"


class TestShellQuoting:
"""Tests for shell quoting: export values should be double-quoted."""

def test_values_with_spaces_are_correctly_quoted(self, tmp_path):
"""Values containing spaces should be correctly quoted."""
env_file = tmp_path / "env_file.sh"
env_file.touch()

payload = {
"session_id": "test-session-456",
"transcript_path": "/path/with spaces/transcript.jsonl",
}
run_hook(json.dumps(payload), env={"CLAUDE_ENV_FILE": str(env_file)})

content = env_file.read_text()
assert 'export CLAUDE_TRANSCRIPT_PATH="/path/with spaces/transcript.jsonl"' in content
46 changes: 46 additions & 0 deletions tests/test_setup_implementation_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,52 @@ def test_modified_tracked_file(self, mock_git_repo):
assert result["clean"] is False
assert "README.md" in result["dirty_files"]

def test_file_not_found_error_returns_error_dict(self, mock_git_repo, mocker):
"""FileNotFoundError (git not installed) returns error dict, not clean=True."""
mocker.patch(
"scripts.checks.setup_implementation_session.subprocess.run",
side_effect=FileNotFoundError("git not found"),
)
result = check_working_tree_status(mock_git_repo)

assert result["clean"] is False
assert "git not found" in result.get("error", "")
assert result["dirty_files"] == []

def test_nonzero_returncode_returns_error_dict(self, mock_git_repo, mocker):
"""Non-zero returncode returns error dict with stderr."""
mocker.patch(
"scripts.checks.setup_implementation_session.subprocess.run",
return_value=subprocess.CompletedProcess(
args=["git", "status"],
returncode=128,
stdout="",
stderr="fatal: not a git repository",
),
)
result = check_working_tree_status(mock_git_repo)

assert result["clean"] is False
assert "fatal: not a git repository" in result.get("error", "")

def test_permission_error_propagates(self, mock_git_repo, mocker):
"""PermissionError should NOT be caught (propagates to caller)."""
mocker.patch(
"scripts.checks.setup_implementation_session.subprocess.run",
side_effect=PermissionError("Permission denied"),
)
with pytest.raises(PermissionError):
check_working_tree_status(mock_git_repo)

def test_no_bare_except_in_source(self):
"""check_working_tree_status should not contain 'except Exception'."""
import inspect
source = inspect.getsource(check_working_tree_status)
assert "except Exception" not in source, (
"check_working_tree_status still has bare 'except Exception' — "
"only FileNotFoundError should be caught"
)


class TestDetectCommitStyle:
"""Tests for detect_commit_style function."""
Expand Down
37 changes: 37 additions & 0 deletions tests/test_task_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,3 +256,40 @@ def test_returns_correct_path(self, monkeypatch):
result = get_tasks_dir("my-session-id")

assert result == Path("/Users/test/.claude/tasks/my-session-id")


class TestTaskStatusConsolidation:
"""Tests for TaskStatus consolidation — should be defined only in task_storage."""

def test_task_storage_exports_task_status(self):
"""TaskStatus should be importable from task_storage."""
from scripts.lib.task_storage import TaskStatus
assert hasattr(TaskStatus, "PENDING")
assert hasattr(TaskStatus, "IN_PROGRESS")
assert hasattr(TaskStatus, "COMPLETED")

def test_impl_tasks_imports_task_status(self):
"""TaskStatus should be importable from impl_tasks (via re-export or import)."""
from scripts.lib.impl_tasks import TaskStatus
assert hasattr(TaskStatus, "PENDING")

def test_impl_tasks_no_local_task_status_class(self):
"""impl_tasks module should NOT define TaskStatus class body (only imports it)."""
source_path = Path(__file__).parent.parent / "scripts" / "lib" / "impl_tasks.py"
source = source_path.read_text()
assert "class TaskStatus" not in source, (
"impl_tasks.py still defines TaskStatus locally — "
"it should import from task_storage instead"
)

def test_no_circular_import(self):
"""task_storage should NOT import from impl_tasks (no circular dependency)."""
source_path = Path(__file__).parent.parent / "scripts" / "lib" / "task_storage.py"
source = source_path.read_text()
for line in source.splitlines():
stripped = line.strip()
if stripped.startswith("#"):
continue
assert "import" not in stripped or "impl_tasks" not in stripped, (
f"task_storage.py imports from impl_tasks: {stripped}"
)