From 71f058214927a5784ffdca0d3c1320e904055fdd Mon Sep 17 00:00:00 2001 From: ebembi-crdb Date: Thu, 14 May 2026 19:15:03 +0530 Subject: [PATCH 1/2] Add automated SQL command testing for docs (EDUENG-131) Add infrastructure to extract SQL code blocks from markdown documentation and execute them against a CockroachDB cluster to verify correctness. Blocks are classified as executable, expected-error, fragment, or skipped, with skip annotations supported per-block and per-page. Co-Authored-By: Claude Opus 4.6 --- .github/scripts/sql_test/__init__.py | 0 .github/scripts/sql_test/executor.py | 198 ++++++++++++++ .github/scripts/sql_test/extractor.py | 256 ++++++++++++++++++ .github/scripts/sql_test/models.py | 44 ++++ .github/scripts/sql_test/reporter.py | 151 +++++++++++ .github/scripts/sql_test_runner.py | 129 +++++++++ .github/scripts/test_sql_extractor.py | 361 ++++++++++++++++++++++++++ .github/workflows/sql-test.yml | 136 ++++++++++ src/current/Makefile | 14 + 9 files changed, 1289 insertions(+) create mode 100644 .github/scripts/sql_test/__init__.py create mode 100644 .github/scripts/sql_test/executor.py create mode 100644 .github/scripts/sql_test/extractor.py create mode 100644 .github/scripts/sql_test/models.py create mode 100644 .github/scripts/sql_test/reporter.py create mode 100644 .github/scripts/sql_test_runner.py create mode 100644 .github/scripts/test_sql_extractor.py create mode 100644 .github/workflows/sql-test.yml diff --git a/.github/scripts/sql_test/__init__.py b/.github/scripts/sql_test/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/.github/scripts/sql_test/executor.py b/.github/scripts/sql_test/executor.py new file mode 100644 index 00000000000..5c1af9c685b --- /dev/null +++ b/.github/scripts/sql_test/executor.py @@ -0,0 +1,198 @@ +"""Executes SQL blocks against a CockroachDB cluster.""" + +import re +import subprocess +import time +from pathlib import Path +from typing import List + +from .models import BlockType, SqlBlock, TestResult, PageResult +from .extractor import MOVR_TABLES + +DEFAULT_CONNECTION_URL = "postgresql://root@localhost:26257?sslmode=disable" +STATEMENT_TIMEOUT_S = 30 + + +def _sanitize_db_name(file_path: str) -> str: + """Generate a safe database name from a file path.""" + name = Path(file_path).stem + # Replace non-alphanumeric characters with underscores + name = re.sub(r'[^a-zA-Z0-9]', '_', name) + return f"sqltest_{name}"[:63] # CockroachDB identifier limit + + +def _uses_movr(blocks: List[SqlBlock]) -> bool: + """Check if any block references MovR tables.""" + for block in blocks: + content_lower = block.raw_content.lower() + for table in MOVR_TABLES: + if re.search(r'\b' + table + r'\b', content_lower): + return True + return False + + +def _run_sql(connection_url: str, sql: str, timeout: int = STATEMENT_TIMEOUT_S) -> subprocess.CompletedProcess: + """Execute SQL via cockroach sql subprocess.""" + return subprocess.run( + ["cockroach", "sql", "--url", connection_url, "--format=table", "-e", sql], + capture_output=True, + text=True, + timeout=timeout, + ) + + +def execute_page(page_result: PageResult, connection_url: str = DEFAULT_CONNECTION_URL) -> PageResult: + """Execute all SQL blocks for a single page against CockroachDB. + + Creates an isolated database per page, runs all executable blocks in + document order within that database, then cleans up. + + Args: + page_result: PageResult with extracted blocks (no results yet). + connection_url: CockroachDB connection URL. + + Returns: + The same PageResult with results populated. + """ + db_name = _sanitize_db_name(page_result.file_path) + + # Create isolated database + try: + _run_sql(connection_url, f'CREATE DATABASE IF NOT EXISTS "{db_name}";') + except Exception as e: + # If we can't create the DB, fail all blocks + for block in page_result.blocks: + if block.block_type in (BlockType.EXECUTABLE, BlockType.EXPECTED_ERROR): + page_result.results.append(TestResult( + block=block, + success=False, + error_message=f"Failed to create test database: {e}", + )) + return page_result + + # Build connection URL with the test database + if '?' in connection_url: + db_url = connection_url.replace('?', f'/"{db_name}"?', 1) + else: + db_url = f'{connection_url}/"{db_name}"' + + # Initialize MovR data if needed + executable_blocks = [b for b in page_result.blocks if b.block_type in (BlockType.EXECUTABLE, BlockType.EXPECTED_ERROR)] + if _uses_movr(page_result.blocks): + try: + subprocess.run( + ["cockroach", "workload", "init", "movr", db_url], + capture_output=True, + text=True, + timeout=60, + ) + except Exception as e: + for block in executable_blocks: + page_result.results.append(TestResult( + block=block, + success=False, + error_message=f"Failed to initialize MovR: {e}", + )) + _cleanup_db(connection_url, db_name) + return page_result + + # Execute blocks in order + for block in page_result.blocks: + if block.block_type == BlockType.FRAGMENT or block.block_type == BlockType.SKIPPED: + continue + + result = _execute_block(block, db_url) + page_result.results.append(result) + + # Cleanup + _cleanup_db(connection_url, db_name) + + return page_result + + +def _execute_block(block: SqlBlock, db_url: str) -> TestResult: + """Execute a single SQL block and return the result.""" + start_time = time.time() + combined_output = [] + combined_errors = [] + + for stmt in block.cleaned_statements: + try: + proc = _run_sql(db_url, stmt) + if proc.stdout: + combined_output.append(proc.stdout) + if proc.stderr: + combined_errors.append(proc.stderr) + + if proc.returncode != 0: + elapsed = (time.time() - start_time) * 1000 + error_text = proc.stderr.strip() if proc.stderr else "Non-zero exit code" + + if block.block_type == BlockType.EXPECTED_ERROR: + # Expected error: passing because it did error + return TestResult( + block=block, + success=True, + actual_output=error_text, + execution_time_ms=elapsed, + ) + else: + return TestResult( + block=block, + success=False, + actual_output='\n'.join(combined_output), + error_message=error_text, + execution_time_ms=elapsed, + ) + + except subprocess.TimeoutExpired: + elapsed = (time.time() - start_time) * 1000 + return TestResult( + block=block, + success=False, + error_message=f"Statement timed out after {STATEMENT_TIMEOUT_S}s: {stmt[:100]}", + execution_time_ms=elapsed, + ) + except Exception as e: + elapsed = (time.time() - start_time) * 1000 + if block.block_type == BlockType.EXPECTED_ERROR: + return TestResult( + block=block, + success=True, + actual_output=str(e), + execution_time_ms=elapsed, + ) + return TestResult( + block=block, + success=False, + error_message=str(e), + execution_time_ms=elapsed, + ) + + elapsed = (time.time() - start_time) * 1000 + + # All statements succeeded + if block.block_type == BlockType.EXPECTED_ERROR: + # Expected an error but all statements succeeded — this is a failure + return TestResult( + block=block, + success=False, + actual_output='\n'.join(combined_output), + error_message="Expected an error but all statements succeeded", + execution_time_ms=elapsed, + ) + + return TestResult( + block=block, + success=True, + actual_output='\n'.join(combined_output), + execution_time_ms=elapsed, + ) + + +def _cleanup_db(connection_url: str, db_name: str) -> None: + """Drop the test database.""" + try: + _run_sql(connection_url, f'DROP DATABASE IF EXISTS "{db_name}" CASCADE;') + except Exception: + pass # Best-effort cleanup diff --git a/.github/scripts/sql_test/extractor.py b/.github/scripts/sql_test/extractor.py new file mode 100644 index 00000000000..2c9bbda4f01 --- /dev/null +++ b/.github/scripts/sql_test/extractor.py @@ -0,0 +1,256 @@ +"""Extracts and classifies SQL code blocks from CockroachDB documentation markdown files.""" + +import re +from pathlib import Path +from typing import List, Optional + +from .models import BlockType, SqlBlock, PageResult + + +# Tables that indicate MovR dataset usage +MOVR_TABLES = frozenset({ + "users", "vehicles", "rides", "promo_codes", + "vehicle_location_histories", "user_promo_codes", +}) + +# Patterns that indicate a block is a fragment (not executable as-is) +FRAGMENT_INDICATORS = [ + re.compile(r'\.\.\.'), # Ellipsis (truncated content) + re.compile(r'<[a-zA-Z_][a-zA-Z0-9_ -]*>'), # style + re.compile(r'\{[a-zA-Z_][a-zA-Z0-9_]*\}'), # {placeholder} style + re.compile(r'{% remote_include'), # Liquid remote include +] + +# Skip annotation pattern: +SKIP_COMMENT_RE = re.compile( + r'' +) + + +def _has_page_level_skip(content: str) -> bool: + """Check if frontmatter contains sql_test: skip.""" + frontmatter_match = re.match(r'^---\s*\n(.*?)\n---', content, re.DOTALL) + if not frontmatter_match: + return False + frontmatter = frontmatter_match.group(1) + return bool(re.search(r'^\s*sql_test:\s*skip\s*$', frontmatter, re.MULTILINE)) + + +def _clean_sql_lines(raw: str) -> List[str]: + """Clean raw SQL block content into executable statements. + + Strips the leading '> ' prompt prefix from each line, then splits + on semicolons to produce individual statements. + """ + lines = [] + for line in raw.split('\n'): + # Strip the leading '> ' prompt that CockroachDB docs use + stripped = line + if stripped.startswith('> '): + stripped = stripped[2:] + elif stripped == '>': + stripped = '' + lines.append(stripped) + + joined = '\n'.join(lines).strip() + if not joined: + return [] + + # Split on semicolons, keeping each as a complete statement + statements = [] + current = [] + for line in joined.split('\n'): + current.append(line) + if line.rstrip().endswith(';'): + stmt = '\n'.join(current).strip() + if stmt: + statements.append(stmt) + current = [] + + # If there's remaining content without a trailing semicolon, + # include it as a statement (some SQL commands like \dt don't use semicolons) + if current: + stmt = '\n'.join(current).strip() + if stmt: + statements.append(stmt) + + return statements + + +def _classify_block( + raw: str, + statements: List[str], + expected_output: Optional[str], + skip_reason: Optional[str], +) -> BlockType: + """Classify a SQL block based on its content and context.""" + if skip_reason is not None: + return BlockType.SKIPPED + + # Check for fragment indicators in the raw SQL content + for pattern in FRAGMENT_INDICATORS: + if pattern.search(raw): + return BlockType.FRAGMENT + + # Check if any statement starts with $ (shell command, not SQL) + for stmt in statements: + if stmt.lstrip().startswith('$'): + return BlockType.FRAGMENT + + # Check if expected output indicates an error + if expected_output: + output_stripped = expected_output.strip() + if output_stripped.startswith('ERROR:') or output_stripped.startswith('pq:'): + return BlockType.EXPECTED_ERROR + + return BlockType.EXECUTABLE + + +def _uses_movr(blocks: List[SqlBlock]) -> bool: + """Check if any block references MovR tables.""" + for block in blocks: + content_lower = block.raw_content.lower() + for table in MOVR_TABLES: + # Match table name as a word boundary to avoid false positives + if re.search(r'\b' + table + r'\b', content_lower): + return True + return False + + +def extract_blocks(file_path: str, content: Optional[str] = None) -> PageResult: + """Extract all SQL code blocks from a markdown file. + + Args: + file_path: Path to the markdown file. + content: Optional file content. If None, reads from file_path. + + Returns: + PageResult containing all extracted and classified SQL blocks. + """ + path = Path(file_path) + + if content is None: + if not path.exists(): + return PageResult(file_path=file_path) + content = path.read_text(encoding='utf-8') + + page_result = PageResult(file_path=file_path) + + # Check for page-level skip + page_skip = _has_page_level_skip(content) + + lines = content.split('\n') + i = 0 + block_index = 0 + + while i < len(lines): + line = lines[i] + + # Check for skip annotation comment + skip_match = SKIP_COMMENT_RE.search(line) + if skip_match: + skip_reason = skip_match.group(1) or "Marked with sql-test:skip" + # Look for the next SQL block immediately following + j = i + 1 + while j < len(lines) and lines[j].strip() == '': + j += 1 + + if j < len(lines) and lines[j].strip() == '~~~ sql': + # Found the SQL block after the skip comment + sql_start = j + 1 + sql_end = sql_start + while sql_end < len(lines) and lines[sql_end].strip() != '~~~': + sql_end += 1 + + raw = '\n'.join(lines[sql_start:sql_end]) + statements = _clean_sql_lines(raw) + + block = SqlBlock( + file_path=file_path, + line_number=j + 1, # 1-indexed + raw_content=raw, + cleaned_statements=statements, + block_type=BlockType.SKIPPED, + skip_reason=skip_reason, + block_index=block_index, + ) + page_result.blocks.append(block) + block_index += 1 + i = sql_end + 1 + continue + + i += 1 + continue + + # Detect ~~~ sql block + if line.strip() == '~~~ sql': + sql_line_number = i + 1 # 1-indexed + + # Collect SQL content + sql_start = i + 1 + sql_end = sql_start + while sql_end < len(lines) and lines[sql_end].strip() != '~~~': + sql_end += 1 + + raw = '\n'.join(lines[sql_start:sql_end]) + statements = _clean_sql_lines(raw) + + # Look ahead for expected output block (~~~ without a language tag) + expected_output = None + j = sql_end + 1 + # Skip blank lines and non-code-block lines between SQL and output + while j < len(lines) and lines[j].strip() == '': + j += 1 + + if j < len(lines) and lines[j].strip() == '~~~': + # This is a plain ~~~ block (output block) + out_start = j + 1 + out_end = out_start + while out_end < len(lines) and lines[out_end].strip() != '~~~': + out_end += 1 + expected_output = '\n'.join(lines[out_start:out_end]) + + # Determine skip reason + skip_reason = None + if page_skip: + skip_reason = "Page-level sql_test: skip in frontmatter" + + block_type = _classify_block(raw, statements, expected_output, skip_reason) + + block = SqlBlock( + file_path=file_path, + line_number=sql_line_number, + raw_content=raw, + cleaned_statements=statements, + block_type=block_type, + expected_output=expected_output, + skip_reason=skip_reason, + block_index=block_index, + ) + page_result.blocks.append(block) + block_index += 1 + + # Advance past the closing ~~~ + i = sql_end + 1 + continue + + i += 1 + + return page_result + + +def extract_from_files(file_paths: List[str]) -> List[PageResult]: + """Extract SQL blocks from multiple files. + + Args: + file_paths: List of markdown file paths to process. + + Returns: + List of PageResult, one per file (only files with blocks). + """ + results = [] + for fp in file_paths: + page = extract_blocks(fp) + if page.blocks: + results.append(page) + return results diff --git a/.github/scripts/sql_test/models.py b/.github/scripts/sql_test/models.py new file mode 100644 index 00000000000..c4d7a17f50b --- /dev/null +++ b/.github/scripts/sql_test/models.py @@ -0,0 +1,44 @@ +"""Data models for SQL testing infrastructure.""" + +from dataclasses import dataclass, field +from enum import Enum +from typing import List, Optional + + +class BlockType(Enum): + """Classification of a SQL code block.""" + EXECUTABLE = "executable" + EXPECTED_ERROR = "expected_error" + FRAGMENT = "fragment" + SKIPPED = "skipped" + + +@dataclass +class SqlBlock: + """A single SQL code block extracted from a markdown file.""" + file_path: str + line_number: int + raw_content: str + cleaned_statements: List[str] + block_type: BlockType + expected_output: Optional[str] = None + skip_reason: Optional[str] = None + block_index: int = 0 + + +@dataclass +class TestResult: + """Result of executing a single SQL block.""" + block: SqlBlock + success: bool + actual_output: str = "" + error_message: str = "" + execution_time_ms: float = 0.0 + + +@dataclass +class PageResult: + """Aggregated results for all SQL blocks in a single file.""" + file_path: str + blocks: List[SqlBlock] = field(default_factory=list) + results: List[TestResult] = field(default_factory=list) diff --git a/.github/scripts/sql_test/reporter.py b/.github/scripts/sql_test/reporter.py new file mode 100644 index 00000000000..edaf3cf9f7b --- /dev/null +++ b/.github/scripts/sql_test/reporter.py @@ -0,0 +1,151 @@ +"""Output formatting for SQL test results.""" + +import sys +from typing import List + +from .models import BlockType, PageResult, TestResult + + +def _count_blocks(pages: List[PageResult]): + """Count blocks by type across all pages.""" + total = 0 + executable = 0 + skipped = 0 + fragments = 0 + for page in pages: + for block in page.blocks: + total += 1 + if block.block_type == BlockType.SKIPPED: + skipped += 1 + elif block.block_type == BlockType.FRAGMENT: + fragments += 1 + elif block.block_type in (BlockType.EXECUTABLE, BlockType.EXPECTED_ERROR): + executable += 1 + return total, executable, skipped, fragments + + +def print_dry_run(pages: List[PageResult], verbose: bool = False) -> None: + """Print extraction/classification summary without execution results.""" + total, executable, skipped, fragments = _count_blocks(pages) + + print(f"\n{'='*60}") + print(f"SQL Test Dry Run Summary") + print(f"{'='*60}") + print(f"Pages scanned: {len(pages)}") + print(f"Total SQL blocks: {total}") + print(f" Executable: {executable}") + print(f" Expected errors: {sum(1 for p in pages for b in p.blocks if b.block_type == BlockType.EXPECTED_ERROR)}") + print(f" Fragments: {fragments}") + print(f" Skipped: {skipped}") + print(f"{'='*60}\n") + + if verbose: + for page in pages: + print(f"\n--- {page.file_path} ({len(page.blocks)} blocks) ---") + for block in page.blocks: + status = block.block_type.value.upper() + preview = block.raw_content.split('\n')[0][:60] + print(f" [{status:15s}] line {block.line_number}: {preview}") + if block.skip_reason: + print(f" skip reason: {block.skip_reason}") + + +def print_results(pages: List[PageResult]) -> None: + """Print execution results to console.""" + total_tested = 0 + total_passed = 0 + total_failed = 0 + failures = [] + + for page in pages: + for result in page.results: + total_tested += 1 + if result.success: + total_passed += 1 + else: + total_failed += 1 + failures.append(result) + + total, executable, skipped, fragments = _count_blocks(pages) + + print(f"\n{'='*60}") + print(f"SQL Test Results") + print(f"{'='*60}") + print(f"Pages tested: {len(pages)}") + print(f"Total SQL blocks: {total}") + print(f" Tested: {total_tested}") + print(f" Passed: {total_passed}") + print(f" Failed: {total_failed}") + print(f" Fragments: {fragments}") + print(f" Skipped: {skipped}") + print(f"{'='*60}") + + if failures: + print(f"\nFailures:\n") + for result in failures: + block = result.block + print(f" FAIL: {block.file_path}:{block.line_number}") + # Show first statement as context + if block.cleaned_statements: + stmt_preview = block.cleaned_statements[0][:100] + print(f" Statement: {stmt_preview}") + print(f" Error: {result.error_message}") + print() + else: + print(f"\nAll tests passed.\n") + + +def write_github_comment(pages: List[PageResult], output_path: str = "sql-test-comment.md") -> None: + """Write a GitHub PR comment markdown file.""" + failures = [] + total_tested = 0 + total_passed = 0 + + for page in pages: + for result in page.results: + total_tested += 1 + if result.success: + total_passed += 1 + else: + failures.append(result) + + total, executable, skipped, fragments = _count_blocks(pages) + + lines = [] + if not failures: + lines.append("**SQL Test Check Passed**") + lines.append("") + lines.append(f"Tested {total_tested} SQL blocks across {len(pages)} pages. All passed.") + else: + lines.append("**SQL Test Check Failed**") + lines.append("") + lines.append(f"Found {len(failures)} failure(s) out of {total_tested} tested SQL blocks.") + lines.append("") + lines.append("| File | Line | Error |") + lines.append("|------|------|-------|") + for result in failures: + block = result.block + error_brief = result.error_message.split('\n')[0][:120] + lines.append(f"| `{block.file_path}` | {block.line_number} | {error_brief} |") + lines.append("") + lines.append("
") + lines.append("Failure details") + lines.append("") + for result in failures: + block = result.block + lines.append(f"### `{block.file_path}:{block.line_number}`") + lines.append("") + if block.cleaned_statements: + lines.append("```sql") + lines.append(block.cleaned_statements[0][:200]) + lines.append("```") + lines.append("") + lines.append(f"**Error:** {result.error_message}") + lines.append("") + lines.append("
") + + lines.append("") + lines.append(f"**Summary:** {total_tested} tested, {total_passed} passed, {len(failures)} failed, {fragments} fragments, {skipped} skipped") + + with open(output_path, 'w') as f: + f.write('\n'.join(lines)) diff --git a/.github/scripts/sql_test_runner.py b/.github/scripts/sql_test_runner.py new file mode 100644 index 00000000000..36929cd1a52 --- /dev/null +++ b/.github/scripts/sql_test_runner.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python3 +""" +sql_test_runner.py + +Extracts SQL code blocks from CockroachDB documentation markdown files +and optionally executes them against a CockroachDB cluster. + +Usage: + python sql_test_runner.py [file2] ... + python sql_test_runner.py --version v25.4 + python sql_test_runner.py --dry-run --version v25.4 +""" + +import argparse +import glob +import os +import sys + +# Ensure the scripts directory is on the path +sys.path.insert(0, os.path.dirname(__file__)) + +from sql_test.extractor import extract_blocks, extract_from_files +from sql_test.executor import execute_page, DEFAULT_CONNECTION_URL +from sql_test.reporter import print_dry_run, print_results, write_github_comment + + +def collect_files(file_args: list, version: str = None) -> list: + """Collect markdown files to test. + + Args: + file_args: Explicitly provided file paths. + version: If set, find all markdown files under src/current//. + + Returns: + List of file paths. + """ + files = [] + + if version: + # Find repo root by looking for src/current/ relative to this script + script_dir = os.path.dirname(os.path.abspath(__file__)) + repo_root = os.path.dirname(os.path.dirname(script_dir)) + version_dir = os.path.join(repo_root, "src", "current", version) + if not os.path.isdir(version_dir): + print(f"Error: version directory not found: {version_dir}", file=sys.stderr) + sys.exit(1) + pattern = os.path.join(version_dir, "**", "*.md") + files = sorted(glob.glob(pattern, recursive=True)) + + if file_args: + files.extend(file_args) + + return files + + +def main(): + parser = argparse.ArgumentParser( + description="Test SQL code blocks in CockroachDB documentation." + ) + parser.add_argument( + "files", nargs="*", help="Markdown files to test." + ) + parser.add_argument( + "--version", type=str, default=None, + help="Test all files in a version directory (e.g., v25.4)." + ) + parser.add_argument( + "--connection-url", type=str, default=DEFAULT_CONNECTION_URL, + help=f"CockroachDB connection URL (default: {DEFAULT_CONNECTION_URL})." + ) + parser.add_argument( + "--dry-run", action="store_true", + help="Extract and classify blocks only, no execution." + ) + parser.add_argument( + "--verbose", action="store_true", + help="Show all blocks including skipped and fragments." + ) + + args = parser.parse_args() + + # Collect files + files = collect_files(args.files, args.version) + if not files: + print("No files to test. Provide file paths or --version.", file=sys.stderr) + sys.exit(1) + + if args.verbose: + print(f"Scanning {len(files)} file(s)...") + + # Extract blocks from all files + pages = extract_from_files(files) + + if not pages: + print("No SQL blocks found in the provided files.") + sys.exit(0) + + if args.dry_run: + print_dry_run(pages, verbose=args.verbose) + sys.exit(0) + + # Execute blocks + has_failures = False + for page in pages: + if args.verbose: + executable_count = sum( + 1 for b in page.blocks + if b.block_type.value in ("executable", "expected_error") + ) + print(f"Testing {page.file_path} ({executable_count} executable blocks)...") + + execute_page(page, connection_url=args.connection_url) + + for result in page.results: + if not result.success: + has_failures = True + + # Report results + print_results(pages) + + # Write GitHub comment if in CI + if os.environ.get("GITHUB_ACTIONS"): + write_github_comment(pages) + + sys.exit(1 if has_failures else 0) + + +if __name__ == "__main__": + main() diff --git a/.github/scripts/test_sql_extractor.py b/.github/scripts/test_sql_extractor.py new file mode 100644 index 00000000000..97be00579f6 --- /dev/null +++ b/.github/scripts/test_sql_extractor.py @@ -0,0 +1,361 @@ +"""Unit tests for the SQL block extractor.""" + +import sys +import os +import unittest + +# Ensure the scripts directory is on the path +sys.path.insert(0, os.path.dirname(__file__)) + +from sql_test.extractor import extract_blocks, _clean_sql_lines, _has_page_level_skip +from sql_test.models import BlockType + + +class TestCleanSqlLines(unittest.TestCase): + """Tests for SQL line cleaning.""" + + def test_strips_prompt_prefix(self): + raw = "> SELECT 1;" + stmts = _clean_sql_lines(raw) + self.assertEqual(stmts, ["SELECT 1;"]) + + def test_strips_bare_prompt(self): + raw = ">\n> SELECT 1;" + stmts = _clean_sql_lines(raw) + self.assertEqual(stmts, ["SELECT 1;"]) + + def test_no_prefix(self): + raw = "SELECT 1;" + stmts = _clean_sql_lines(raw) + self.assertEqual(stmts, ["SELECT 1;"]) + + def test_multiline_statement(self): + raw = "> SELECT\n> id, name\n> FROM users;" + stmts = _clean_sql_lines(raw) + self.assertEqual(len(stmts), 1) + self.assertIn("SELECT", stmts[0]) + self.assertIn("FROM users;", stmts[0]) + + def test_multiple_statements(self): + raw = "> CREATE TABLE t (id INT);\n> INSERT INTO t VALUES (1);" + stmts = _clean_sql_lines(raw) + self.assertEqual(len(stmts), 2) + self.assertIn("CREATE TABLE", stmts[0]) + self.assertIn("INSERT INTO", stmts[1]) + + def test_empty_content(self): + self.assertEqual(_clean_sql_lines(""), []) + self.assertEqual(_clean_sql_lines(" \n "), []) + + +class TestHasPageLevelSkip(unittest.TestCase): + """Tests for page-level skip detection.""" + + def test_detects_skip(self): + content = "---\ntitle: Test\nsql_test: skip\n---\nBody" + self.assertTrue(_has_page_level_skip(content)) + + def test_no_skip(self): + content = "---\ntitle: Test\n---\nBody" + self.assertFalse(_has_page_level_skip(content)) + + def test_no_frontmatter(self): + content = "No frontmatter here\n~~~ sql\nSELECT 1;\n~~~" + self.assertFalse(_has_page_level_skip(content)) + + +class TestExtractBlocks(unittest.TestCase): + """Tests for block extraction and classification.""" + + def test_basic_executable_block(self): + md = """--- +title: Test +--- + +~~~ sql +> SELECT 1; +~~~ +""" + result = extract_blocks("test.md", content=md) + self.assertEqual(len(result.blocks), 1) + block = result.blocks[0] + self.assertEqual(block.block_type, BlockType.EXECUTABLE) + self.assertEqual(block.cleaned_statements, ["SELECT 1;"]) + self.assertEqual(block.line_number, 5) + + def test_block_with_output(self): + md = """~~~ sql +> SELECT 1; +~~~ + +~~~ + ?column? ++----------+ + 1 +(1 row) +~~~ +""" + result = extract_blocks("test.md", content=md) + self.assertEqual(len(result.blocks), 1) + block = result.blocks[0] + self.assertEqual(block.block_type, BlockType.EXECUTABLE) + self.assertIsNotNone(block.expected_output) + self.assertIn("?column?", block.expected_output) + + def test_expected_error_pq(self): + md = """~~~ sql +> INSERT INTO t VALUES (1); +~~~ + +~~~ +pq: duplicate key value violates unique constraint +~~~ +""" + result = extract_blocks("test.md", content=md) + self.assertEqual(len(result.blocks), 1) + self.assertEqual(result.blocks[0].block_type, BlockType.EXPECTED_ERROR) + + def test_expected_error_ERROR(self): + md = """~~~ sql +> DROP TABLE nonexistent; +~~~ + +~~~ +ERROR: relation "nonexistent" does not exist +~~~ +""" + result = extract_blocks("test.md", content=md) + self.assertEqual(len(result.blocks), 1) + self.assertEqual(result.blocks[0].block_type, BlockType.EXPECTED_ERROR) + + def test_fragment_with_ellipsis(self): + md = """~~~ sql +> SELECT ...; +~~~ +""" + result = extract_blocks("test.md", content=md) + self.assertEqual(len(result.blocks), 1) + self.assertEqual(result.blocks[0].block_type, BlockType.FRAGMENT) + + def test_fragment_with_placeholder(self): + md = """~~~ sql +ALTER ROLE SET copy_from_retries_enabled = true; +~~~ +""" + result = extract_blocks("test.md", content=md) + self.assertEqual(len(result.blocks), 1) + self.assertEqual(result.blocks[0].block_type, BlockType.FRAGMENT) + + def test_fragment_with_curly_placeholder(self): + md = """~~~ sql +ALTER ROLE {username} SET copy_from_retries_enabled = true; +~~~ +""" + result = extract_blocks("test.md", content=md) + self.assertEqual(len(result.blocks), 1) + self.assertEqual(result.blocks[0].block_type, BlockType.FRAGMENT) + + def test_fragment_with_remote_include(self): + md = """~~~ sql +{% remote_include https://example.com/snippet.sql %} +~~~ +""" + result = extract_blocks("test.md", content=md) + self.assertEqual(len(result.blocks), 1) + self.assertEqual(result.blocks[0].block_type, BlockType.FRAGMENT) + + def test_skip_annotation(self): + md = """ +~~~ sql +> SLEECT * FORM users; +~~~ +""" + result = extract_blocks("test.md", content=md) + self.assertEqual(len(result.blocks), 1) + block = result.blocks[0] + self.assertEqual(block.block_type, BlockType.SKIPPED) + self.assertEqual(block.skip_reason, "Demonstrates invalid syntax") + + def test_skip_annotation_no_reason(self): + md = """ +~~~ sql +> SELECT 1; +~~~ +""" + result = extract_blocks("test.md", content=md) + self.assertEqual(len(result.blocks), 1) + self.assertEqual(result.blocks[0].block_type, BlockType.SKIPPED) + + def test_page_level_skip(self): + md = """--- +title: Test +sql_test: skip +--- + +~~~ sql +> SELECT 1; +~~~ +""" + result = extract_blocks("test.md", content=md) + self.assertEqual(len(result.blocks), 1) + self.assertEqual(result.blocks[0].block_type, BlockType.SKIPPED) + + def test_multiple_blocks_preserve_order(self): + md = """~~~ sql +> CREATE TABLE t (id INT PRIMARY KEY); +~~~ + +~~~ sql +> INSERT INTO t VALUES (1); +~~~ + +~~~ sql +> SELECT * FROM t; +~~~ + +~~~ + id ++----+ + 1 +(1 row) +~~~ +""" + result = extract_blocks("test.md", content=md) + self.assertEqual(len(result.blocks), 3) + self.assertEqual(result.blocks[0].block_index, 0) + self.assertEqual(result.blocks[1].block_index, 1) + self.assertEqual(result.blocks[2].block_index, 2) + # Only the last block has expected output + self.assertIsNone(result.blocks[0].expected_output) + self.assertIsNone(result.blocks[1].expected_output) + self.assertIsNotNone(result.blocks[2].expected_output) + + def test_ignores_non_sql_blocks(self): + md = """~~~ shell +$ cockroach start --insecure +~~~ + +~~~ sql +> SELECT 1; +~~~ + +~~~ json +{"key": "value"} +~~~ +""" + result = extract_blocks("test.md", content=md) + # Should only extract the sql block, not shell or json + self.assertEqual(len(result.blocks), 1) + self.assertEqual(result.blocks[0].cleaned_statements, ["SELECT 1;"]) + + def test_no_sql_blocks(self): + md = """--- +title: No SQL +--- + +This page has no SQL blocks. + +~~~ shell +$ echo hello +~~~ +""" + result = extract_blocks("test.md", content=md) + self.assertEqual(len(result.blocks), 0) + + def test_sql_without_prompt_prefix(self): + md = """~~~ sql +SELECT 1; +~~~ +""" + result = extract_blocks("test.md", content=md) + self.assertEqual(len(result.blocks), 1) + self.assertEqual(result.blocks[0].cleaned_statements, ["SELECT 1;"]) + + def test_mixed_executable_and_fragment(self): + md = """~~~ sql +> SELECT * FROM users; +~~~ + +~~~ sql +> SELECT ... FROM ; +~~~ +""" + result = extract_blocks("test.md", content=md) + self.assertEqual(len(result.blocks), 2) + self.assertEqual(result.blocks[0].block_type, BlockType.EXECUTABLE) + self.assertEqual(result.blocks[1].block_type, BlockType.FRAGMENT) + + def test_block_line_numbers(self): + md = """Line 1 +Line 2 +Line 3 +~~~ sql +> SELECT 1; +~~~ +Line 7 +Line 8 +~~~ sql +> SELECT 2; +~~~ +""" + result = extract_blocks("test.md", content=md) + self.assertEqual(len(result.blocks), 2) + # ~~~ sql is on line 4 (1-indexed) + self.assertEqual(result.blocks[0].line_number, 4) + # ~~~ sql is on line 9 (1-indexed) + self.assertEqual(result.blocks[1].line_number, 9) + + +class TestExtractBlocksFromRealPatterns(unittest.TestCase): + """Tests using patterns found in actual CockroachDB docs.""" + + def test_movr_select_with_output(self): + """Pattern from select-clause.md.""" + md = """{% include_cached copy-clipboard.html %} +~~~ sql +> SELECT id, city, name FROM users LIMIT 10; +~~~ + +~~~ + id | city | name ++--------------------------------------+---------------+------------------+ + 7ae147ae-147a-4000-8000-000000000018 | los angeles | Alfred Garcia +(1 row) +~~~ +""" + result = extract_blocks("test.md", content=md) + self.assertEqual(len(result.blocks), 1) + block = result.blocks[0] + self.assertEqual(block.block_type, BlockType.EXECUTABLE) + self.assertEqual(block.cleaned_statements, ["SELECT id, city, name FROM users LIMIT 10;"]) + self.assertIn("Alfred Garcia", block.expected_output) + + def test_upsert_error_pattern(self): + """Pattern from upsert.md with pq: error output.""" + md = """~~~ sql +> UPSERT INTO unique_test VALUES (4, 1); +~~~ + +~~~ +pq: duplicate key value (b)=(1) violates unique constraint "unique_test_b_key" +~~~ +""" + result = extract_blocks("test.md", content=md) + self.assertEqual(len(result.blocks), 1) + self.assertEqual(result.blocks[0].block_type, BlockType.EXPECTED_ERROR) + + def test_multiline_insert(self): + """Multi-line SQL statement.""" + md = """~~~ sql +> INSERT INTO user_promo_codes (city, user_id, code, "timestamp", usage_count) + VALUES ('new york', '147ae147-ae14-4b00-8000-000000000004', 'promo_code', now(), 1); +~~~ +""" + result = extract_blocks("test.md", content=md) + self.assertEqual(len(result.blocks), 1) + self.assertEqual(len(result.blocks[0].cleaned_statements), 1) + self.assertIn("INSERT INTO", result.blocks[0].cleaned_statements[0]) + + +if __name__ == '__main__': + unittest.main() diff --git a/.github/workflows/sql-test.yml b/.github/workflows/sql-test.yml new file mode 100644 index 00000000000..9b9c4b5263e --- /dev/null +++ b/.github/workflows/sql-test.yml @@ -0,0 +1,136 @@ +name: SQL Test Check + +on: + pull_request: + types: [opened, synchronize, reopened] + paths: + - 'src/current/v25.4/**/*.md' + schedule: + # Run nightly at 6am UTC + - cron: '0 6 * * *' + workflow_dispatch: + +jobs: + sql-test: + name: Test SQL code blocks + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v3 + with: + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.x' + + - name: Install CockroachDB + run: | + curl https://binaries.cockroachdb.com/cockroach-v25.1.2.linux-amd64.tgz | tar -xz + sudo cp cockroach-v25.1.2.linux-amd64/cockroach /usr/local/bin/ + + - name: Start CockroachDB + run: | + cockroach start-single-node --insecure --background --store=type=mem,size=1GiB --listen-addr=localhost:26257 + sleep 5 + cockroach sql --insecure -e "SELECT 1;" + + - name: Get changed files + if: github.event_name == 'pull_request' + id: changed-files + uses: tj-actions/changed-files@cc08e170f4447237bcaf8acaacfa615b9cb86612 # v35 + with: + files: | + src/current/v25.4/**/*.md + separator: ' ' + + - name: Run SQL tests (PR - changed files only) + if: github.event_name == 'pull_request' && steps.changed-files.outputs.any_changed == 'true' + id: sql-test-pr + run: | + echo "Testing changed files..." + python .github/scripts/sql_test_runner.py ${{ steps.changed-files.outputs.all_changed_files }} + continue-on-error: true + + - name: Run SQL tests (scheduled/manual - all v25.4 files) + if: github.event_name != 'pull_request' + id: sql-test-full + run: | + echo "Testing all v25.4 files..." + python .github/scripts/sql_test_runner.py --version v25.4 + continue-on-error: true + + - name: Post PR comment with failures + if: github.event_name == 'pull_request' && steps.changed-files.outputs.any_changed == 'true' && steps.sql-test-pr.outcome == 'failure' + uses: actions/github-script@v6 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + const fs = require('fs'); + + let comment = ''; + try { + comment = fs.readFileSync('sql-test-comment.md', 'utf8'); + } catch (error) { + comment = '**SQL Test Check Failed**\n\nSQL test failures were detected, but the detailed report could not be generated.'; + } + + const { data: comments } = await github.rest.issues.listComments({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + }); + + const botComment = comments.find(comment => + comment.user.type === 'Bot' && + comment.body.includes('SQL Test Check') + ); + + if (botComment) { + await github.rest.issues.updateComment({ + owner: context.repo.owner, + repo: context.repo.repo, + comment_id: botComment.id, + body: comment + }); + } else { + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + body: comment + }); + } + + - name: Post success comment if previously failed + if: github.event_name == 'pull_request' && steps.changed-files.outputs.any_changed == 'true' && steps.sql-test-pr.outcome == 'success' + uses: actions/github-script@v6 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + const { data: comments } = await github.rest.issues.listComments({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + }); + + const botComment = comments.find(comment => + comment.user.type === 'Bot' && + comment.body.includes('SQL Test Check Failed') + ); + + if (botComment) { + await github.rest.issues.updateComment({ + owner: context.repo.owner, + repo: context.repo.repo, + comment_id: botComment.id, + body: '**SQL Test Check Passed**\n\nAll SQL test issues have been resolved.' + }); + } + + - name: Stop CockroachDB + if: always() + run: | + cockroach quit --insecure --host=localhost:26257 || true diff --git a/src/current/Makefile b/src/current/Makefile index f9ee8cddc49..6c3375f60f2 100644 --- a/src/current/Makefile +++ b/src/current/Makefile @@ -85,6 +85,20 @@ linkcheck: cockroachdb-build vale: vale $(subst $(\n), $( ), $(shell git status --porcelain | cut -c 4- | egrep "\.md")) +.PHONY: sql-test +sql-test: + cockroach start-single-node --insecure --background --store=type=mem,size=1GiB --listen-addr=localhost:26257 + sleep 5 + cockroach workload init movr 'postgresql://root@localhost:26257?sslmode=disable' || true + python3 ../../.github/scripts/sql_test_runner.py --version v25.4; \ + EXIT_CODE=$$?; \ + cockroach quit --insecure --host=localhost:26257 || true; \ + exit $$EXIT_CODE + +.PHONY: sql-test-dry-run +sql-test-dry-run: + python3 ../../.github/scripts/sql_test_runner.py --dry-run --verbose --version v25.4 + .PHONY: vendor vendor: gem install bundler From b1f09b940fdd308298b3ee207f658849e2eac84a Mon Sep 17 00:00:00 2001 From: ebembi-crdb Date: Wed, 20 May 2026 15:50:43 +0530 Subject: [PATCH 2/2] Add automated SQL response generation for docs (EDUENG-225) Extends the SQL test infrastructure to write captured output back into markdown files, keeping response blocks accurate automatically. Co-Authored-By: Claude Opus 4.6 --- .github/scripts/sql_response_runner.py | 244 +++++++++ .github/scripts/sql_test/extractor.py | 126 ++++- .github/scripts/sql_test/generator.py | 175 +++++++ .github/scripts/sql_test/models.py | 11 +- .github/scripts/test_sql_generator.py | 530 ++++++++++++++++++++ .github/workflows/sql-response-generate.yml | 95 ++++ src/current/Makefile | 18 + 7 files changed, 1195 insertions(+), 4 deletions(-) create mode 100644 .github/scripts/sql_response_runner.py create mode 100644 .github/scripts/sql_test/generator.py create mode 100644 .github/scripts/test_sql_generator.py create mode 100644 .github/workflows/sql-response-generate.yml diff --git a/.github/scripts/sql_response_runner.py b/.github/scripts/sql_response_runner.py new file mode 100644 index 00000000000..93fa9fe0828 --- /dev/null +++ b/.github/scripts/sql_response_runner.py @@ -0,0 +1,244 @@ +#!/usr/bin/env python3 +""" +sql_response_runner.py + +Generates SQL response output blocks in CockroachDB documentation markdown +files by executing SQL blocks and writing their output back into the docs. + +Usage: + python sql_response_runner.py --version v26.2 --update # write changes + python sql_response_runner.py --version v26.2 --diff # show diffs only + python sql_response_runner.py --version v26.2 --check # exit non-zero if stale + python sql_response_runner.py src/current/v26.2/show-tables.md --update +""" + +import argparse +import difflib +import os +import sys + +# Ensure the scripts directory is on the path +sys.path.insert(0, os.path.dirname(__file__)) + +from sql_test.extractor import extract_from_files +from sql_test.executor import execute_page, DEFAULT_CONNECTION_URL +from sql_test.generator import update_file_responses, format_output +from sql_test.models import ResponseMode + + +def collect_files(file_args: list, version: str = None) -> list: + """Collect markdown files to process. + + Args: + file_args: Explicitly provided file paths. + version: If set, find all markdown files under src/current//. + + Returns: + List of file paths. + """ + import glob as globmod + + files = [] + + if version: + script_dir = os.path.dirname(os.path.abspath(__file__)) + repo_root = os.path.dirname(os.path.dirname(script_dir)) + version_dir = os.path.join(repo_root, "src", "current", version) + if not os.path.isdir(version_dir): + print(f"Error: version directory not found: {version_dir}", file=sys.stderr) + sys.exit(1) + pattern = os.path.join(version_dir, "**", "*.md") + files = sorted(globmod.glob(pattern, recursive=True)) + + if file_args: + files.extend(file_args) + + return files + + +def _has_generate_blocks(page) -> bool: + """Check if a page has any blocks with response_mode == GENERATE.""" + return any( + b.response_mode == ResponseMode.GENERATE + for b in page.blocks + ) + + +def _build_results_map(page): + """Build a dict mapping block_index -> TestResult for a page.""" + return {r.block.block_index: r for r in page.results} + + +def _compute_diff(file_path: str, page, results_map: dict) -> str: + """Compute unified diff of what would change in a file.""" + from pathlib import Path + + original = Path(file_path).read_text(encoding='utf-8') + original_lines = original.split('\n') + + # Perform the update on a copy + file_result = update_file_responses(file_path, page, results_map, dry_run=True) + + if not file_result.modified: + return "" + + # Actually compute updated content by replaying the logic + lines = list(original_lines) + + generate_blocks = [] + for block in page.blocks: + if block.response_mode != ResponseMode.GENERATE: + continue + if block.block_type.value == "fragment": + continue + test_result = results_map.get(block.block_index) + if test_result is None or not test_result.success: + continue + generate_blocks.append((block, test_result)) + + generate_blocks.sort(key=lambda x: x[0].line_number, reverse=True) + + for block, test_result in generate_blocks: + formatted = format_output(test_result.actual_output) + + if block.output_block_range is not None: + out_open, out_close = block.output_block_range + existing_content = '\n'.join(lines[out_open + 1:out_close]) + if existing_content != formatted: + lines = lines[:out_open + 1] + formatted.split('\n') + lines[out_close:] + else: + sql_open_idx = block.line_number - 1 + sql_close_idx = sql_open_idx + 1 + while sql_close_idx < len(lines) and lines[sql_close_idx].strip() != '~~~': + sql_close_idx += 1 + insert_idx = sql_close_idx + 1 + new_block_lines = ['', '~~~'] + formatted.split('\n') + ['~~~'] + lines = lines[:insert_idx] + new_block_lines + lines[insert_idx:] + + updated = '\n'.join(lines) + if original == updated: + return "" + + diff = difflib.unified_diff( + original.splitlines(keepends=True), + updated.splitlines(keepends=True), + fromfile=f"a/{file_path}", + tofile=f"b/{file_path}", + ) + return ''.join(diff) + + +def main(): + parser = argparse.ArgumentParser( + description="Generate SQL response output blocks in CockroachDB documentation." + ) + parser.add_argument( + "files", nargs="*", help="Markdown files to process." + ) + parser.add_argument( + "--version", type=str, default=None, + help="Process all files in a version directory (e.g., v26.2)." + ) + parser.add_argument( + "--connection-url", type=str, default=DEFAULT_CONNECTION_URL, + help=f"CockroachDB connection URL (default: {DEFAULT_CONNECTION_URL})." + ) + + mode = parser.add_mutually_exclusive_group(required=True) + mode.add_argument( + "--update", action="store_true", + help="Write updated output blocks back to files." + ) + mode.add_argument( + "--diff", action="store_true", + help="Show unified diff of what would change (no writes)." + ) + mode.add_argument( + "--check", action="store_true", + help="Exit non-zero if any output is stale (for CI)." + ) + + parser.add_argument( + "--verbose", action="store_true", + help="Show detailed progress." + ) + + args = parser.parse_args() + + # Collect files + files = collect_files(args.files, args.version) + if not files: + print("No files to process. Provide file paths or --version.", file=sys.stderr) + sys.exit(1) + + # Extract blocks + pages = extract_from_files(files) + + # Filter to pages that have at least one GENERATE block + pages_with_generate = [p for p in pages if _has_generate_blocks(p)] + + if not pages_with_generate: + print("No blocks with sql-response:generate found.") + sys.exit(0) + + if args.verbose: + total_generate = sum( + 1 for p in pages_with_generate for b in p.blocks + if b.response_mode == ResponseMode.GENERATE + ) + print(f"Found {total_generate} generate block(s) across {len(pages_with_generate)} file(s).") + + # Execute pages + stale_count = 0 + total_updated = 0 + total_inserted = 0 + total_unchanged = 0 + + for page in pages_with_generate: + if args.verbose: + print(f"Executing {page.file_path}...") + + execute_page(page, connection_url=args.connection_url) + results_map = _build_results_map(page) + + if args.update: + file_result = update_file_responses( + page.file_path, page, results_map, dry_run=False + ) + total_updated += file_result.blocks_updated + total_inserted += file_result.blocks_inserted + total_unchanged += file_result.blocks_unchanged + + if args.verbose and file_result.modified: + print(f" Updated: {file_result.blocks_updated} replaced, {file_result.blocks_inserted} inserted") + + elif args.diff: + diff_output = _compute_diff(page.file_path, page, results_map) + if diff_output: + print(diff_output) + stale_count += 1 + + elif args.check: + file_result = update_file_responses( + page.file_path, page, results_map, dry_run=True + ) + if file_result.modified: + stale_count += 1 + print(f"STALE: {page.file_path} ({file_result.blocks_updated} to replace, {file_result.blocks_inserted} to insert)") + + # Summary + if args.update: + print(f"\nDone. {total_updated} replaced, {total_inserted} inserted, {total_unchanged} unchanged.") + elif args.diff: + if stale_count == 0: + print("All output blocks are up to date.") + elif args.check: + if stale_count > 0: + print(f"\n{stale_count} file(s) have stale output blocks.", file=sys.stderr) + sys.exit(1) + else: + print("All output blocks are up to date.") + + +if __name__ == "__main__": + main() diff --git a/.github/scripts/sql_test/extractor.py b/.github/scripts/sql_test/extractor.py index 2c9bbda4f01..b7c3cbdf9cb 100644 --- a/.github/scripts/sql_test/extractor.py +++ b/.github/scripts/sql_test/extractor.py @@ -4,7 +4,7 @@ from pathlib import Path from typing import List, Optional -from .models import BlockType, SqlBlock, PageResult +from .models import BlockType, ResponseMode, SqlBlock, PageResult # Tables that indicate MovR dataset usage @@ -26,6 +26,14 @@ r'' ) +# Response generation annotation patterns +RESPONSE_GENERATE_RE = re.compile( + r'' +) +RESPONSE_SKIP_RE = re.compile( + r'' +) + def _has_page_level_skip(content: str) -> bool: """Check if frontmatter contains sql_test: skip.""" @@ -36,6 +44,15 @@ def _has_page_level_skip(content: str) -> bool: return bool(re.search(r'^\s*sql_test:\s*skip\s*$', frontmatter, re.MULTILINE)) +def _has_page_level_response_generate(content: str) -> bool: + """Check if frontmatter contains sql_response: generate.""" + frontmatter_match = re.match(r'^---\s*\n(.*?)\n---', content, re.DOTALL) + if not frontmatter_match: + return False + frontmatter = frontmatter_match.group(1) + return bool(re.search(r'^\s*sql_response:\s*generate\s*$', frontmatter, re.MULTILINE)) + + def _clean_sql_lines(raw: str) -> List[str]: """Clean raw SQL block content into executable statements. @@ -136,8 +153,9 @@ def extract_blocks(file_path: str, content: Optional[str] = None) -> PageResult: page_result = PageResult(file_path=file_path) - # Check for page-level skip + # Check for page-level flags page_skip = _has_page_level_skip(content) + page_response_generate = _has_page_level_response_generate(content) lines = content.split('\n') i = 0 @@ -165,6 +183,24 @@ def extract_blocks(file_path: str, content: Optional[str] = None) -> PageResult: raw = '\n'.join(lines[sql_start:sql_end]) statements = _clean_sql_lines(raw) + # Check for response annotations on the same comment line + # or look back for a preceding response annotation + response_mode = _determine_response_mode( + lines, i, page_response_generate, is_skip_annotated=True + ) + + # Look ahead for output block range + output_block_range = None + k = sql_end + 1 + while k < len(lines) and lines[k].strip() == '': + k += 1 + if k < len(lines) and lines[k].strip() == '~~~': + out_open = k + out_close = out_open + 1 + while out_close < len(lines) and lines[out_close].strip() != '~~~': + out_close += 1 + output_block_range = (out_open, out_close) + block = SqlBlock( file_path=file_path, line_number=j + 1, # 1-indexed @@ -173,6 +209,8 @@ def extract_blocks(file_path: str, content: Optional[str] = None) -> PageResult: block_type=BlockType.SKIPPED, skip_reason=skip_reason, block_index=block_index, + response_mode=response_mode, + output_block_range=output_block_range, ) page_result.blocks.append(block) block_index += 1 @@ -182,6 +220,25 @@ def extract_blocks(file_path: str, content: Optional[str] = None) -> PageResult: i += 1 continue + # Check for response generate/skip annotation (not tied to sql-test:skip) + response_generate_match = RESPONSE_GENERATE_RE.search(line) + response_skip_match = RESPONSE_SKIP_RE.search(line) + + if response_generate_match or response_skip_match: + # Look for the next SQL block immediately following + j = i + 1 + while j < len(lines) and lines[j].strip() == '': + j += 1 + + if j < len(lines) and lines[j].strip() == '~~~ sql': + # The SQL block will be processed when we reach it. + # Store the annotation info to be picked up below. + # We don't advance i; we let the normal ~~~ sql handler pick it up. + pass + + i += 1 + continue + # Detect ~~~ sql block if line.strip() == '~~~ sql': sql_line_number = i + 1 # 1-indexed @@ -197,8 +254,9 @@ def extract_blocks(file_path: str, content: Optional[str] = None) -> PageResult: # Look ahead for expected output block (~~~ without a language tag) expected_output = None + output_block_range = None j = sql_end + 1 - # Skip blank lines and non-code-block lines between SQL and output + # Skip blank lines between SQL and output while j < len(lines) and lines[j].strip() == '': j += 1 @@ -209,6 +267,7 @@ def extract_blocks(file_path: str, content: Optional[str] = None) -> PageResult: while out_end < len(lines) and lines[out_end].strip() != '~~~': out_end += 1 expected_output = '\n'.join(lines[out_start:out_end]) + output_block_range = (j, out_end) # 0-indexed, inclusive of ~~~ delimiters # Determine skip reason skip_reason = None @@ -217,6 +276,11 @@ def extract_blocks(file_path: str, content: Optional[str] = None) -> PageResult: block_type = _classify_block(raw, statements, expected_output, skip_reason) + # Determine response mode by looking back for annotations + response_mode = _determine_response_mode( + lines, i, page_response_generate + ) + block = SqlBlock( file_path=file_path, line_number=sql_line_number, @@ -226,6 +290,8 @@ def extract_blocks(file_path: str, content: Optional[str] = None) -> PageResult: expected_output=expected_output, skip_reason=skip_reason, block_index=block_index, + response_mode=response_mode, + output_block_range=output_block_range, ) page_result.blocks.append(block) block_index += 1 @@ -239,6 +305,60 @@ def extract_blocks(file_path: str, content: Optional[str] = None) -> PageResult: return page_result +def _determine_response_mode( + lines: List[str], + sql_block_line_idx: int, + page_response_generate: bool, + is_skip_annotated: bool = False, +) -> ResponseMode: + """Determine the response mode for a SQL block. + + Looks backward from the SQL block's opening ~~~ sql line (or from + the sql-test:skip comment line) for a sql-response annotation. + Block-level annotations override page-level settings. + + Args: + lines: All lines from the file. + sql_block_line_idx: 0-indexed line of the ~~~ sql opener or skip comment. + page_response_generate: Whether the page has sql_response: generate in frontmatter. + is_skip_annotated: Whether this block has a sql-test:skip annotation. + + Returns: + The ResponseMode for this block. + """ + # Look backward through preceding blank lines and comments + j = sql_block_line_idx - 1 + while j >= 0 and lines[j].strip() == '': + j -= 1 + + if j >= 0: + # Check if the line immediately before (skipping blanks) is a response annotation + if RESPONSE_SKIP_RE.search(lines[j]): + return ResponseMode.SKIP + if RESPONSE_GENERATE_RE.search(lines[j]): + return ResponseMode.GENERATE + # Also check the case where sql-test:skip is on j, and response annotation is above that + if is_skip_annotated: + # sql_block_line_idx points to the skip comment, look further back + pass + else: + # Check one more line back (annotations may be above a sql-test:skip comment) + k = j - 1 + while k >= 0 and lines[k].strip() == '': + k -= 1 + if k >= 0: + if RESPONSE_SKIP_RE.search(lines[k]): + return ResponseMode.SKIP + if RESPONSE_GENERATE_RE.search(lines[k]): + return ResponseMode.GENERATE + + # Fall back to page-level setting + if page_response_generate: + return ResponseMode.GENERATE + + return ResponseMode.MANUAL + + def extract_from_files(file_paths: List[str]) -> List[PageResult]: """Extract SQL blocks from multiple files. diff --git a/.github/scripts/sql_test/generator.py b/.github/scripts/sql_test/generator.py new file mode 100644 index 00000000000..efde0cd73d0 --- /dev/null +++ b/.github/scripts/sql_test/generator.py @@ -0,0 +1,175 @@ +"""Replaces output blocks in markdown files with actual SQL execution results.""" + +import re +from dataclasses import dataclass, field +from pathlib import Path +from typing import Dict, List, Optional + +from .models import ResponseMode, SqlBlock, TestResult, PageResult + + +@dataclass +class BlockUpdateResult: + """Result of updating a single output block.""" + block_index: int + line_number: int + action: str # "replaced", "inserted", "unchanged", "skipped" + reason: str = "" + + +@dataclass +class FileUpdateResult: + """Summary of all updates applied to a single file.""" + file_path: str + blocks_updated: int = 0 + blocks_inserted: int = 0 + blocks_unchanged: int = 0 + blocks_skipped: int = 0 + block_details: List[BlockUpdateResult] = field(default_factory=list) + + @property + def modified(self) -> bool: + return self.blocks_updated > 0 or self.blocks_inserted > 0 + + +# Pattern to strip cockroach sql Time: lines +TIME_LINE_RE = re.compile(r'^Time:.*$', re.MULTILINE) + + +def format_output(raw_output: str) -> str: + """Format raw SQL execution output for insertion into a markdown file. + + Strips: + - Time: lines appended by cockroach sql (e.g., "Time: 1ms total (execution 1ms / network 0ms)") + - Leading/trailing blank lines + - Trailing whitespace on each line + + Preserves table formatting as-is. + """ + # Strip Time: lines + cleaned = TIME_LINE_RE.sub('', raw_output) + + # Strip trailing whitespace on each line + lines = [line.rstrip() for line in cleaned.split('\n')] + + # Strip leading/trailing blank lines + while lines and lines[0] == '': + lines.pop(0) + while lines and lines[-1] == '': + lines.pop() + + return '\n'.join(lines) + + +def update_file_responses( + file_path: str, + page_result: PageResult, + results: Dict[int, TestResult], + dry_run: bool = False, +) -> FileUpdateResult: + """Update output blocks in a markdown file with actual SQL execution results. + + Args: + file_path: Path to the markdown file. + page_result: PageResult containing extracted blocks. + results: Map from block_index to TestResult for executed blocks. + dry_run: If True, compute changes but don't write the file. + + Returns: + FileUpdateResult summarizing what changed. + """ + file_result = FileUpdateResult(file_path=file_path) + path = Path(file_path) + lines = path.read_text(encoding='utf-8').split('\n') + + # Collect blocks eligible for generation, sorted by block_index + generate_blocks = [] + for block in page_result.blocks: + if block.response_mode != ResponseMode.GENERATE: + file_result.blocks_skipped += 1 + file_result.block_details.append(BlockUpdateResult( + block_index=block.block_index, + line_number=block.line_number, + action="skipped", + reason="response_mode is not GENERATE", + )) + continue + + # Fragments are never eligible + if block.block_type.value == "fragment": + file_result.blocks_skipped += 1 + file_result.block_details.append(BlockUpdateResult( + block_index=block.block_index, + line_number=block.line_number, + action="skipped", + reason="block is a fragment", + )) + continue + + test_result = results.get(block.block_index) + if test_result is None or not test_result.success: + file_result.blocks_skipped += 1 + reason = "no test result" if test_result is None else f"execution failed: {test_result.error_message}" + file_result.block_details.append(BlockUpdateResult( + block_index=block.block_index, + line_number=block.line_number, + action="skipped", + reason=reason, + )) + continue + + generate_blocks.append((block, test_result)) + + # Process in reverse document order to avoid line-number shifting + generate_blocks.sort(key=lambda x: x[0].line_number, reverse=True) + + for block, test_result in generate_blocks: + formatted = format_output(test_result.actual_output) + + if block.output_block_range is not None: + # Replace existing output block + out_open, out_close = block.output_block_range + existing_content = '\n'.join(lines[out_open + 1:out_close]) + + if existing_content == formatted: + file_result.blocks_unchanged += 1 + file_result.block_details.append(BlockUpdateResult( + block_index=block.block_index, + line_number=block.line_number, + action="unchanged", + )) + else: + # Replace lines between the ~~~ delimiters (keep the delimiters) + new_lines = lines[:out_open + 1] + formatted.split('\n') + lines[out_close:] + lines = new_lines + file_result.blocks_updated += 1 + file_result.block_details.append(BlockUpdateResult( + block_index=block.block_index, + line_number=block.line_number, + action="replaced", + )) + else: + # Insert new output block after the SQL block's closing ~~~ + # Find the closing ~~~ of the SQL block + # line_number is 1-indexed and points to the ~~~ sql line + sql_open_idx = block.line_number - 1 # 0-indexed + sql_close_idx = sql_open_idx + 1 + while sql_close_idx < len(lines) and lines[sql_close_idx].strip() != '~~~': + sql_close_idx += 1 + + # Insert after the closing ~~~ + insert_idx = sql_close_idx + 1 + new_block_lines = ['', '~~~'] + formatted.split('\n') + ['~~~'] + lines = lines[:insert_idx] + new_block_lines + lines[insert_idx:] + file_result.blocks_inserted += 1 + file_result.block_details.append(BlockUpdateResult( + block_index=block.block_index, + line_number=block.line_number, + action="inserted", + )) + + # Write updated file + if file_result.modified and not dry_run: + path.write_text('\n'.join(lines), encoding='utf-8') + + return file_result diff --git a/.github/scripts/sql_test/models.py b/.github/scripts/sql_test/models.py index c4d7a17f50b..91872e39090 100644 --- a/.github/scripts/sql_test/models.py +++ b/.github/scripts/sql_test/models.py @@ -2,7 +2,7 @@ from dataclasses import dataclass, field from enum import Enum -from typing import List, Optional +from typing import List, Optional, Tuple class BlockType(Enum): @@ -13,6 +13,13 @@ class BlockType(Enum): SKIPPED = "skipped" +class ResponseMode(Enum): + """Whether a block's output should be auto-generated.""" + MANUAL = "manual" + GENERATE = "generate" + SKIP = "skip" + + @dataclass class SqlBlock: """A single SQL code block extracted from a markdown file.""" @@ -24,6 +31,8 @@ class SqlBlock: expected_output: Optional[str] = None skip_reason: Optional[str] = None block_index: int = 0 + response_mode: ResponseMode = ResponseMode.MANUAL + output_block_range: Optional[Tuple[int, int]] = None # (start_line, end_line) 0-indexed, inclusive of ~~~ delimiters @dataclass diff --git a/.github/scripts/test_sql_generator.py b/.github/scripts/test_sql_generator.py new file mode 100644 index 00000000000..483d2d07c41 --- /dev/null +++ b/.github/scripts/test_sql_generator.py @@ -0,0 +1,530 @@ +"""Unit tests for the SQL response generator.""" + +import os +import sys +import tempfile +import unittest + +# Ensure the scripts directory is on the path +sys.path.insert(0, os.path.dirname(__file__)) + +from sql_test.extractor import extract_blocks, _has_page_level_response_generate +from sql_test.generator import format_output, update_file_responses +from sql_test.models import BlockType, ResponseMode, SqlBlock, TestResult, PageResult + + +class TestFormatOutput(unittest.TestCase): + """Tests for output formatting.""" + + def test_strips_time_line(self): + raw = " id\n+----+\n 1\n(1 row)\n\nTime: 1ms total (execution 1ms / network 0ms)" + result = format_output(raw) + self.assertNotIn("Time:", result) + self.assertIn("(1 row)", result) + + def test_strips_multiple_time_lines(self): + raw = "result1\nTime: 2ms total\nresult2\nTime: 3ms total" + result = format_output(raw) + self.assertNotIn("Time:", result) + self.assertIn("result1", result) + self.assertIn("result2", result) + + def test_strips_leading_trailing_blanks(self): + raw = "\n\n id\n+----+\n 1\n\n\n" + result = format_output(raw) + self.assertTrue(result.startswith(" id")) + self.assertTrue(result.endswith("1")) + + def test_strips_trailing_whitespace_per_line(self): + raw = " id \n+----+ \n 1 " + result = format_output(raw) + for line in result.split('\n'): + self.assertEqual(line, line.rstrip()) + + def test_preserves_table_formatting(self): + raw = " column_name | data_type\n+-------------+-----------+\n id | INT8\n(1 row)" + result = format_output(raw) + self.assertIn("+-------------+-----------+", result) + + def test_empty_output(self): + result = format_output("") + self.assertEqual(result, "") + + def test_only_time_line(self): + raw = "Time: 1ms total (execution 1ms / network 0ms)" + result = format_output(raw) + self.assertEqual(result, "") + + def test_whitespace_only(self): + raw = " \n\n \n" + result = format_output(raw) + self.assertEqual(result, "") + + +class TestAnnotationDetection(unittest.TestCase): + """Tests for sql-response annotation detection.""" + + def test_block_level_generate(self): + md = """ +~~~ sql +> SELECT 1; +~~~ + +~~~ +old output +~~~ +""" + result = extract_blocks("test.md", content=md) + self.assertEqual(len(result.blocks), 1) + self.assertEqual(result.blocks[0].response_mode, ResponseMode.GENERATE) + + def test_block_level_skip(self): + md = """ +~~~ sql +> SELECT 1; +~~~ + +~~~ +curated output +~~~ +""" + result = extract_blocks("test.md", content=md) + self.assertEqual(len(result.blocks), 1) + self.assertEqual(result.blocks[0].response_mode, ResponseMode.SKIP) + + def test_block_level_skip_no_reason(self): + md = """ +~~~ sql +> SELECT 1; +~~~ +""" + result = extract_blocks("test.md", content=md) + self.assertEqual(len(result.blocks), 1) + self.assertEqual(result.blocks[0].response_mode, ResponseMode.SKIP) + + def test_page_level_generate(self): + md = """--- +title: Test +sql_response: generate +--- + +~~~ sql +> SELECT 1; +~~~ + +~~~ +old output +~~~ +""" + result = extract_blocks("test.md", content=md) + self.assertEqual(len(result.blocks), 1) + self.assertEqual(result.blocks[0].response_mode, ResponseMode.GENERATE) + + def test_page_level_generate_detected(self): + content = "---\ntitle: Test\nsql_response: generate\n---\nBody" + self.assertTrue(_has_page_level_response_generate(content)) + + def test_page_level_generate_not_detected(self): + content = "---\ntitle: Test\n---\nBody" + self.assertFalse(_has_page_level_response_generate(content)) + + def test_block_skip_overrides_page_generate(self): + md = """--- +title: Test +sql_response: generate +--- + + +~~~ sql +> SELECT 1; +~~~ + +~~~ +curated output +~~~ +""" + result = extract_blocks("test.md", content=md) + self.assertEqual(len(result.blocks), 1) + self.assertEqual(result.blocks[0].response_mode, ResponseMode.SKIP) + + def test_no_annotation_defaults_to_manual(self): + md = """~~~ sql +> SELECT 1; +~~~ +""" + result = extract_blocks("test.md", content=md) + self.assertEqual(len(result.blocks), 1) + self.assertEqual(result.blocks[0].response_mode, ResponseMode.MANUAL) + + def test_generate_with_blank_line_before_sql(self): + md = """ + +~~~ sql +> SELECT 1; +~~~ +""" + result = extract_blocks("test.md", content=md) + self.assertEqual(len(result.blocks), 1) + self.assertEqual(result.blocks[0].response_mode, ResponseMode.GENERATE) + + def test_sql_test_skip_and_response_generate(self): + """A block can be skipped for testing but still have response generation.""" + md = """ + +~~~ sql +> SELECT 1; +~~~ +""" + result = extract_blocks("test.md", content=md) + self.assertEqual(len(result.blocks), 1) + self.assertEqual(result.blocks[0].block_type, BlockType.SKIPPED) + self.assertEqual(result.blocks[0].response_mode, ResponseMode.GENERATE) + + +class TestOutputBlockRange(unittest.TestCase): + """Tests for output block range tracking.""" + + def test_tracks_existing_output_block(self): + md = """~~~ sql +> SELECT 1; +~~~ + +~~~ +old output +~~~ +""" + result = extract_blocks("test.md", content=md) + self.assertEqual(len(result.blocks), 1) + block = result.blocks[0] + self.assertIsNotNone(block.output_block_range) + # The output block ~~~ opens on line index 4, closes on line index 6 + out_open, out_close = block.output_block_range + lines = md.split('\n') + self.assertEqual(lines[out_open].strip(), '~~~') + self.assertEqual(lines[out_close].strip(), '~~~') + + def test_no_output_block_range_when_none(self): + md = """~~~ sql +> SELECT 1; +~~~ +""" + result = extract_blocks("test.md", content=md) + self.assertEqual(len(result.blocks), 1) + self.assertIsNone(result.blocks[0].output_block_range) + + +class TestOutputBlockReplacement(unittest.TestCase): + """Tests for replacing output blocks in files.""" + + def _write_temp(self, content): + """Write content to a temp file and return the path.""" + fd, path = tempfile.mkstemp(suffix='.md') + with os.fdopen(fd, 'w') as f: + f.write(content) + return path + + def test_replace_existing_output(self): + md = """ +~~~ sql +> SELECT 1; +~~~ + +~~~ +old output +~~~ +""" + path = self._write_temp(md) + try: + page = extract_blocks(path) + block = page.blocks[0] + + test_result = TestResult( + block=block, + success=True, + actual_output=" ?column?\n+----------+\n 1\n(1 row)\n\nTime: 1ms total", + ) + results = {block.block_index: test_result} + + file_result = update_file_responses(path, page, results, dry_run=False) + self.assertTrue(file_result.modified) + self.assertEqual(file_result.blocks_updated, 1) + + updated = open(path).read() + self.assertIn("?column?", updated) + self.assertNotIn("old output", updated) + self.assertNotIn("Time:", updated) + finally: + os.unlink(path) + + def test_insert_new_output(self): + md = """ +~~~ sql +> SELECT 1; +~~~ +""" + path = self._write_temp(md) + try: + page = extract_blocks(path) + block = page.blocks[0] + + test_result = TestResult( + block=block, + success=True, + actual_output=" ?column?\n+----------+\n 1\n(1 row)", + ) + results = {block.block_index: test_result} + + file_result = update_file_responses(path, page, results, dry_run=False) + self.assertTrue(file_result.modified) + self.assertEqual(file_result.blocks_inserted, 1) + + updated = open(path).read() + self.assertIn("?column?", updated) + self.assertIn("~~~\n ?column?", updated) + finally: + os.unlink(path) + + def test_unchanged_output(self): + md = """ +~~~ sql +> SELECT 1; +~~~ + +~~~ + ?column? ++----------+ + 1 +(1 row) +~~~ +""" + path = self._write_temp(md) + try: + page = extract_blocks(path) + block = page.blocks[0] + + test_result = TestResult( + block=block, + success=True, + actual_output=" ?column?\n+----------+\n 1\n(1 row)", + ) + results = {block.block_index: test_result} + + file_result = update_file_responses(path, page, results, dry_run=False) + self.assertFalse(file_result.modified) + self.assertEqual(file_result.blocks_unchanged, 1) + finally: + os.unlink(path) + + def test_dry_run_does_not_write(self): + md = """ +~~~ sql +> SELECT 1; +~~~ + +~~~ +old output +~~~ +""" + path = self._write_temp(md) + try: + page = extract_blocks(path) + block = page.blocks[0] + + test_result = TestResult( + block=block, + success=True, + actual_output="new output", + ) + results = {block.block_index: test_result} + + file_result = update_file_responses(path, page, results, dry_run=True) + self.assertTrue(file_result.modified) + + # File should be unchanged + unchanged = open(path).read() + self.assertIn("old output", unchanged) + self.assertNotIn("new output", unchanged) + finally: + os.unlink(path) + + def test_skips_manual_blocks(self): + md = """~~~ sql +> SELECT 1; +~~~ + +~~~ +old output +~~~ +""" + path = self._write_temp(md) + try: + page = extract_blocks(path) + block = page.blocks[0] + self.assertEqual(block.response_mode, ResponseMode.MANUAL) + + test_result = TestResult( + block=block, + success=True, + actual_output="new output", + ) + results = {block.block_index: test_result} + + file_result = update_file_responses(path, page, results, dry_run=False) + self.assertFalse(file_result.modified) + self.assertEqual(file_result.blocks_skipped, 1) + finally: + os.unlink(path) + + def test_skips_fragments(self): + md = """ +~~~ sql +> SELECT ... FROM
; +~~~ +""" + path = self._write_temp(md) + try: + page = extract_blocks(path) + block = page.blocks[0] + self.assertEqual(block.block_type, BlockType.FRAGMENT) + + file_result = update_file_responses(path, page, {}, dry_run=False) + self.assertFalse(file_result.modified) + self.assertEqual(file_result.blocks_skipped, 1) + finally: + os.unlink(path) + + def test_skips_failed_execution(self): + md = """ +~~~ sql +> SELECT 1; +~~~ + +~~~ +old output +~~~ +""" + path = self._write_temp(md) + try: + page = extract_blocks(path) + block = page.blocks[0] + + test_result = TestResult( + block=block, + success=False, + error_message="connection refused", + ) + results = {block.block_index: test_result} + + file_result = update_file_responses(path, page, results, dry_run=False) + self.assertFalse(file_result.modified) + self.assertEqual(file_result.blocks_skipped, 1) + + unchanged = open(path).read() + self.assertIn("old output", unchanged) + finally: + os.unlink(path) + + +class TestReverseOrderProcessing(unittest.TestCase): + """Tests that reverse-order processing preserves line numbers.""" + + def _write_temp(self, content): + fd, path = tempfile.mkstemp(suffix='.md') + with os.fdopen(fd, 'w') as f: + f.write(content) + return path + + def test_multiple_blocks_replaced_correctly(self): + md = """ +~~~ sql +> SELECT 1; +~~~ + +~~~ +old1 +~~~ + + +~~~ sql +> SELECT 2; +~~~ + +~~~ +old2 +~~~ +""" + path = self._write_temp(md) + try: + page = extract_blocks(path) + self.assertEqual(len(page.blocks), 2) + + results = {} + for block in page.blocks: + if "SELECT 1" in block.raw_content: + results[block.block_index] = TestResult( + block=block, success=True, actual_output="new1", + ) + else: + results[block.block_index] = TestResult( + block=block, success=True, actual_output="new2", + ) + + file_result = update_file_responses(path, page, results, dry_run=False) + self.assertEqual(file_result.blocks_updated, 2) + + updated = open(path).read() + self.assertNotIn("old1", updated) + self.assertNotIn("old2", updated) + self.assertIn("new1", updated) + self.assertIn("new2", updated) + + # Verify ordering: new1 should appear before new2 + self.assertLess(updated.index("new1"), updated.index("new2")) + finally: + os.unlink(path) + + def test_mixed_insert_and_replace(self): + md = """ +~~~ sql +> SELECT 1; +~~~ + + +~~~ sql +> SELECT 2; +~~~ + +~~~ +old2 +~~~ +""" + path = self._write_temp(md) + try: + page = extract_blocks(path) + self.assertEqual(len(page.blocks), 2) + + results = {} + for block in page.blocks: + if "SELECT 1" in block.raw_content: + results[block.block_index] = TestResult( + block=block, success=True, actual_output="new1", + ) + else: + results[block.block_index] = TestResult( + block=block, success=True, actual_output="new2", + ) + + file_result = update_file_responses(path, page, results, dry_run=False) + self.assertEqual(file_result.blocks_inserted, 1) + self.assertEqual(file_result.blocks_updated, 1) + + updated = open(path).read() + self.assertIn("new1", updated) + self.assertIn("new2", updated) + self.assertNotIn("old2", updated) + finally: + os.unlink(path) + + +if __name__ == '__main__': + unittest.main() diff --git a/.github/workflows/sql-response-generate.yml b/.github/workflows/sql-response-generate.yml new file mode 100644 index 00000000000..48773ffb18f --- /dev/null +++ b/.github/workflows/sql-response-generate.yml @@ -0,0 +1,95 @@ +name: SQL Response Generate + +on: + schedule: + - cron: '0 8 * * 1' # Weekly Monday 8am UTC + workflow_dispatch: + inputs: + version: + description: 'Version to generate for (e.g., v26.2)' + required: true + default: 'v26.2' + +jobs: + generate-responses: + name: Generate SQL response blocks + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v3 + with: + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.x' + + - name: Determine version + id: version + run: | + if [ "${{ github.event_name }}" = "workflow_dispatch" ]; then + echo "version=${{ github.event.inputs.version }}" >> "$GITHUB_OUTPUT" + else + echo "version=v26.2" >> "$GITHUB_OUTPUT" + fi + + - name: Install CockroachDB + run: | + curl https://binaries.cockroachdb.com/cockroach-v26.2.0.linux-amd64.tgz | tar -xz + sudo cp cockroach-v26.2.0.linux-amd64/cockroach /usr/local/bin/ + + - name: Start CockroachDB + run: | + cockroach start-single-node --insecure --background --store=type=mem,size=1GiB --listen-addr=localhost:26257 + sleep 5 + cockroach sql --insecure -e "SELECT 1;" + + - name: Run SQL response generator + run: | + python .github/scripts/sql_response_runner.py \ + --version ${{ steps.version.outputs.version }} \ + --update --verbose + + - name: Check for changes + id: changes + run: | + if git diff --quiet; then + echo "changed=false" >> "$GITHUB_OUTPUT" + else + echo "changed=true" >> "$GITHUB_OUTPUT" + fi + + - name: Create PR with updated responses + if: steps.changes.outputs.changed == 'true' + run: | + DATE=$(date +%Y%m%d) + BRANCH="auto/sql-responses-${DATE}" + git checkout -b "$BRANCH" + git add -A + git commit -m "Auto-update SQL response output blocks + + Generated by sql_response_runner.py for ${{ steps.version.outputs.version }}." + git push origin "$BRANCH" + gh pr create \ + --title "Auto-update SQL response blocks (${{ steps.version.outputs.version }})" \ + --body "$(cat <<'EOF' + ## Summary + - Automatically regenerated SQL response output blocks for ${{ steps.version.outputs.version }} + - Generated by the weekly \`sql-response-generate\` workflow + + ## Review checklist + - [ ] Verify output changes look correct + - [ ] Check that no non-deterministic output was accidentally included + + 🤖 Generated automatically by [sql-response-generate](../actions/workflows/sql-response-generate.yml) + EOF + )" + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Stop CockroachDB + if: always() + run: | + cockroach quit --insecure --host=localhost:26257 || true diff --git a/src/current/Makefile b/src/current/Makefile index 6c3375f60f2..20c0ad07750 100644 --- a/src/current/Makefile +++ b/src/current/Makefile @@ -99,6 +99,24 @@ sql-test: sql-test-dry-run: python3 ../../.github/scripts/sql_test_runner.py --dry-run --verbose --version v25.4 +.PHONY: sql-response-generate +sql-response-generate: + cockroach start-single-node --insecure --background --store=type=mem,size=1GiB --listen-addr=localhost:26257 + sleep 5 + python3 ../../.github/scripts/sql_response_runner.py --version v26.2 --update --verbose; \ + EXIT_CODE=$$?; \ + cockroach quit --insecure --host=localhost:26257 || true; \ + exit $$EXIT_CODE + +.PHONY: sql-response-diff +sql-response-diff: + cockroach start-single-node --insecure --background --store=type=mem,size=1GiB --listen-addr=localhost:26257 + sleep 5 + python3 ../../.github/scripts/sql_response_runner.py --version v26.2 --diff --verbose; \ + EXIT_CODE=$$?; \ + cockroach quit --insecure --host=localhost:26257 || true; \ + exit $$EXIT_CODE + .PHONY: vendor vendor: gem install bundler