Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions src/sqlite_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
//
// Stores all non-embedding columns in a local SQLite database (bundled libsqlite3).
// Scalar columns map to INTEGER/TEXT/REAL; list columns are serialised as JSON TEXT.
// Lookups use `WHERE row_idx IN (?, ...)` against the INTEGER PRIMARY KEY B-tree.
// Lookups use `WHERE <key_col> IN (?, ...)` against the INTEGER PRIMARY KEY B-tree.
//
// Schema: row_idx INTEGER PRIMARY KEY, <col> TEXT/INTEGER/REAL, ...
// Schema: <key_col> INTEGER PRIMARY KEY, <col> TEXT/INTEGER/REAL, ...
//
// The key column name is caller-provided (e.g. "_key") and must match the first
// field in the schema passed to `open_or_build`.
//
// Persistence: the database is written once to the given path and reused on
// subsequent runs. The first build reads all parquet files and inserts rows
Expand Down Expand Up @@ -42,6 +45,7 @@ use crate::lookup::PointLookupProvider;
pub struct SqliteLookupProvider {
schema: SchemaRef,
table_name: String,
key_col: String,
pool: Arc<Mutex<Vec<Connection>>>,
sem: Arc<Semaphore>,
}
Expand Down Expand Up @@ -117,6 +121,8 @@ impl SqliteLookupProvider {
schema: SchemaRef,
parquet_col_indices: &[usize],
) -> DFResult<Self> {
// The first field in the schema is the key column (INTEGER PRIMARY KEY).
let key_col = schema.field(0).name().clone();
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: all existing tests happen to use row_idx as the key column name, so the only coverage here is "dynamic name produces the same SQL as the old hardcoded string." Consider adding a test in sqlite_provider_test.rs that builds a provider with a schema whose first field is named something other than row_idx (e.g. _key) and confirms that fetch_by_keys returns the correct rows. That would give you a direct regression guard for the actual bug scenario.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — added test_custom_key_column_name that uses _key as the key column name.

if pool_size == 0 {
return Err(DataFusionError::Execution(
"pool_size must be at least 1".into(),
Expand Down Expand Up @@ -167,6 +173,7 @@ impl SqliteLookupProvider {
Ok(Self {
schema,
table_name: table_name.to_string(),
key_col,
pool: Arc::new(Mutex::new(conns)),
sem: Arc::new(Semaphore::new(pool_size)),
})
Expand Down Expand Up @@ -202,6 +209,7 @@ impl PointLookupProvider for SqliteLookupProvider {
let keys_vec = keys.to_vec();
let pool = self.pool.clone();
let table_name = self.table_name.clone();
let key_col = self.key_col.clone();
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-blocking suggestion: fetch_by_keys takes a _key_col: &str parameter (from the PointLookupProvider trait) that is silently ignored here in favour of self.key_col. This was already the pre-existing behaviour in HashKeyProvider, so this PR is consistent, but it means a caller who passes the wrong column name gets no feedback.\n\nConsider a debug_assert_eq!(_key_col, self.key_col) guard, or removing the parameter from the trait entirely in a follow-up.


// Acquire a semaphore permit to bound concurrency to the pool size,
// then run the synchronous SQLite query on a blocking thread.
Expand All @@ -227,6 +235,7 @@ impl PointLookupProvider for SqliteLookupProvider {
&keys_vec,
&out_schema,
&table_name,
&key_col,
);
drop(guard); // explicit but not required — Drop handles it
res
Expand All @@ -243,6 +252,7 @@ fn execute_query_sync(
keys: &[u64],
out_schema: &SchemaRef,
table_name: &str,
key_col: &str,
) -> DFResult<Vec<RecordBatch>> {
let placeholders = keys.iter().map(|_| "?").collect::<Vec<_>>().join(", ");
// Select only the columns in out_schema (already projection-applied by the
Expand All @@ -253,8 +263,9 @@ fn execute_query_sync(
.map(|f| quote_ident(f.name()))
.collect::<Vec<_>>()
.join(", ");
let qk = quote_ident(key_col);
let sql = format!(
"SELECT {col_list} FROM {tn} WHERE row_idx IN ({placeholders}) ORDER BY row_idx",
"SELECT {col_list} FROM {tn} WHERE {qk} IN ({placeholders}) ORDER BY {qk}",
tn = quote_ident(table_name)
);

Expand Down Expand Up @@ -586,14 +597,16 @@ fn build_table(
schema: &SchemaRef,
parquet_col_indices: &[usize],
) -> DFResult<()> {
// The first field is the key column (INTEGER PRIMARY KEY).
let key_col_name = schema.field(0).name();
let col_defs = schema
.fields()
.iter()
.map(|f| {
let sql_type = arrow_type_to_sql(f.data_type());
if f.name() == "row_idx" {
"row_idx INTEGER PRIMARY KEY".to_string()
if f.name() == key_col_name {
format!("{} INTEGER PRIMARY KEY", quote_ident(f.name()))
} else {
let sql_type = arrow_type_to_sql(f.data_type());
format!("{} {}", quote_ident(f.name()), sql_type)
}
})
Expand Down
75 changes: 75 additions & 0 deletions tests/sqlite_provider_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,78 @@ async fn test_table_name_with_spaces() {
let batches = provider.fetch_by_keys(&[0], "row_idx", None).await.unwrap();
assert_eq!(batches.iter().map(|b| b.num_rows()).sum::<usize>(), 1);
}

/// Verify that a non-default key column name (e.g. "_key") works correctly.
/// This is the scenario used by runtimedb where Parquet files have a `_key` column.
#[tokio::test]
async fn test_custom_key_column_name() {
let dir = tempdir().unwrap();

let parquet_schema = Arc::new(Schema::new(vec![Field::new("name", DataType::Utf8, true)]));

// Provider schema uses "_key" instead of the default "row_idx".
let provider_schema = Arc::new(Schema::new(vec![
Field::new("_key", DataType::UInt64, false),
Field::new("name", DataType::Utf8, true),
]));

let batch = RecordBatch::try_new(
parquet_schema.clone(),
vec![Arc::new(StringArray::from(vec![
Some("alice"),
Some("bob"),
Some("carol"),
]))],
)
.unwrap();

let parquet_path = dir.path().join("test.parquet");
let file = std::fs::File::create(&parquet_path).unwrap();
let mut writer = ArrowWriter::try_new(file, parquet_schema, None).unwrap();
writer.write(&batch).unwrap();
writer.close().unwrap();

let db_path = dir.path().join("test_key.db");
let provider = SqliteLookupProvider::open_or_build(
db_path.to_str().unwrap(),
"vectors",
2,
&[parquet_path.to_str().unwrap().to_string()],
provider_schema,
&[0],
)
.unwrap();

// fetch_by_keys should work with the custom key column
let batches = provider.fetch_by_keys(&[0, 2], "_key", None).await.unwrap();
let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(total_rows, 2);

let names: Vec<String> = batches
.iter()
.flat_map(|b| {
b.column_by_name("name")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap()
.iter()
.map(|v| v.unwrap().to_string())
.collect::<Vec<_>>()
})
.collect();
assert_eq!(names, vec!["alice", "carol"]);

// projection to only the key column should also work
let batches = provider
.fetch_by_keys(&[1], "_key", Some(&[0]))
.await
.unwrap();
assert_eq!(batches[0].schema().field(0).name(), "_key");
let key_col = batches[0]
.column(0)
.as_any()
.downcast_ref::<UInt64Array>()
.unwrap();
assert_eq!(key_col.value(0), 1);
}
Loading